readysite / pkg / assistant / providers / anthropic / anthropic.go
9.2 KB
anthropic.go
// Package anthropic provides an Anthropic implementation of assistant.Backend.
package anthropic

import (
	"bytes"
	"context"
	"encoding/json"
	"fmt"
	"io"
	"net/http"

	"github.com/readysite/readysite/pkg/assistant"
)

const (
	defaultBaseURL   = "https://api.anthropic.com/v1"
	defaultAPIVersion = "2023-06-01"
)

// backend implements assistant.Backend for Anthropic.
type backend struct {
	apiKey     string
	baseURL    string
	apiVersion string
	client     *http.Client
}

// Option configures the Anthropic backend.
type Option func(*backend)

// WithBaseURL sets a custom base URL.
func WithBaseURL(url string) Option {
	return func(b *backend) {
		b.baseURL = url
	}
}

// WithAPIVersion sets a custom API version.
func WithAPIVersion(version string) Option {
	return func(b *backend) {
		b.apiVersion = version
	}
}

// WithHTTPClient sets a custom HTTP client.
func WithHTTPClient(client *http.Client) Option {
	return func(b *backend) {
		b.client = client
	}
}

// New creates an Anthropic assistant.
func New(apiKey string, opts ...Option) (*assistant.Assistant, error) {
	if apiKey == "" {
		return nil, assistant.ErrNoAPIKey
	}

	b := &backend{
		apiKey:     apiKey,
		baseURL:    defaultBaseURL,
		apiVersion: defaultAPIVersion,
		client:     http.DefaultClient,
	}

	for _, opt := range opts {
		opt(b)
	}

	return &assistant.Assistant{Backend: b}, nil
}

// Chat sends a chat completion request.
func (b *backend) Chat(ctx context.Context, req assistant.ChatRequest) (*assistant.ChatResponse, error) {
	body := b.buildRequest(req, false)

	respBody, err := b.doRequest(ctx, body)
	if err != nil {
		return nil, err
	}
	defer respBody.Close()

	var resp messageResponse
	if err := json.NewDecoder(respBody).Decode(&resp); err != nil {
		return nil, fmt.Errorf("failed to decode response: %w", err)
	}

	return b.convertResponse(&resp), nil
}

// Stream sends a streaming chat completion request.
func (b *backend) Stream(ctx context.Context, req assistant.ChatRequest) (*assistant.StreamReader, error) {
	body := b.buildRequest(req, true)

	respBody, err := b.doRequest(ctx, body)
	if err != nil {
		return nil, err
	}

	return assistant.NewStreamReader(respBody, parseStreamEvent), nil
}

func (b *backend) buildRequest(req assistant.ChatRequest, stream bool) map[string]any {
	messages := make([]map[string]any, 0, len(req.Messages))

	// Convert messages
	for _, m := range req.Messages {
		// Skip system messages (handled separately)
		if m.Role == assistant.RoleSystem {
			continue
		}

		role := m.Role
		if role == assistant.RoleTool {
			// Anthropic uses "user" role with tool_result content
			role = "user"
		}

		msg := map[string]any{
			"role": role,
		}

		// Handle different message types
		if m.Role == assistant.RoleTool {
			// Tool result
			msg["content"] = []map[string]any{{
				"type":        "tool_result",
				"tool_use_id": m.ToolCallID,
				"content":     m.Content,
			}}
		} else if len(m.ToolCalls) > 0 {
			// Assistant message with tool use
			content := make([]map[string]any, 0)
			if m.Content != "" {
				content = append(content, map[string]any{
					"type": "text",
					"text": m.Content,
				})
			}
			for _, tc := range m.ToolCalls {
				input := make(map[string]any)
				if tc.Arguments != "" {
					json.Unmarshal([]byte(tc.Arguments), &input)
				}
				// Ensure input is never nil - Anthropic requires a valid dictionary
				if input == nil {
					input = make(map[string]any)
				}
				content = append(content, map[string]any{
					"type":  "tool_use",
					"id":    tc.ID,
					"name":  tc.Name,
					"input": input,
				})
			}
			msg["content"] = content
		} else {
			msg["content"] = m.Content
		}

		messages = append(messages, msg)
	}

	body := map[string]any{
		"model":    req.Model,
		"messages": messages,
		"stream":   stream,
	}

	// Add system message
	if req.System != "" {
		body["system"] = req.System
	}

	if req.MaxTokens > 0 {
		body["max_tokens"] = req.MaxTokens
	} else {
		body["max_tokens"] = 4096 // Anthropic requires max_tokens
	}

	if req.Temperature != nil {
		body["temperature"] = *req.Temperature
	}

	if len(req.Tools) > 0 {
		tools := make([]map[string]any, len(req.Tools))
		for i, t := range req.Tools {
			tool := map[string]any{
				"name":        t.Name,
				"description": t.Description,
			}
			if t.Parameters != nil {
				tool["input_schema"] = t.Parameters
			} else {
				tool["input_schema"] = map[string]any{"type": "object"}
			}
			tools[i] = tool
		}
		body["tools"] = tools
	}

	return body
}

func (b *backend) doRequest(ctx context.Context, body map[string]any) (io.ReadCloser, error) {
	jsonBody, err := json.Marshal(body)
	if err != nil {
		return nil, fmt.Errorf("failed to marshal request: %w", err)
	}

	httpReq, err := http.NewRequestWithContext(ctx, "POST", b.baseURL+"/messages", bytes.NewReader(jsonBody))
	if err != nil {
		return nil, fmt.Errorf("failed to create request: %w", err)
	}

	httpReq.Header.Set("Content-Type", "application/json")
	httpReq.Header.Set("x-api-key", b.apiKey)
	httpReq.Header.Set("anthropic-version", b.apiVersion)

	resp, err := b.client.Do(httpReq)
	if err != nil {
		return nil, fmt.Errorf("request failed: %w", err)
	}

	if resp.StatusCode != http.StatusOK {
		defer resp.Body.Close()
		body, _ := io.ReadAll(resp.Body)

		var errResp errorResponse
		if json.Unmarshal(body, &errResp) == nil && errResp.Error.Message != "" {
			return nil, &assistant.APIError{
				StatusCode: resp.StatusCode,
				Type:       errResp.Error.Type,
				Message:    errResp.Error.Message,
			}
		}

		return nil, &assistant.APIError{
			StatusCode: resp.StatusCode,
			Message:    string(body),
		}
	}

	return resp.Body, nil
}

func (b *backend) convertResponse(resp *messageResponse) *assistant.ChatResponse {
	var content string
	var toolCalls []assistant.ToolCall

	for _, block := range resp.Content {
		switch block.Type {
		case "text":
			content = block.Text
		case "tool_use":
			args, _ := json.Marshal(block.Input)
			toolCalls = append(toolCalls, assistant.ToolCall{
				ID:        block.ID,
				Name:      block.Name,
				Arguments: string(args),
			})
		}
	}

	finishReason := "stop"
	if resp.StopReason == "tool_use" {
		finishReason = "tool_calls"
	}

	return &assistant.ChatResponse{
		Content:      content,
		ToolCalls:    toolCalls,
		FinishReason: finishReason,
		Usage: assistant.Usage{
			PromptTokens:     resp.Usage.InputTokens,
			CompletionTokens: resp.Usage.OutputTokens,
			TotalTokens:      resp.Usage.InputTokens + resp.Usage.OutputTokens,
		},
	}
}

// --- Anthropic API types ---

type messageResponse struct {
	Content []struct {
		Type  string         `json:"type"`
		Text  string         `json:"text,omitempty"`
		ID    string         `json:"id,omitempty"`
		Name  string         `json:"name,omitempty"`
		Input map[string]any `json:"input,omitempty"`
	} `json:"content"`
	StopReason string `json:"stop_reason"`
	Usage      struct {
		InputTokens  int `json:"input_tokens"`
		OutputTokens int `json:"output_tokens"`
	} `json:"usage"`
}

type streamEvent struct {
	Type         string `json:"type"`
	Index        int    `json:"index,omitempty"`
	ContentBlock *struct {
		Type  string         `json:"type"`
		ID    string         `json:"id,omitempty"`
		Name  string         `json:"name,omitempty"`
		Text  string         `json:"text,omitempty"`
		Input map[string]any `json:"input,omitempty"`
	} `json:"content_block,omitempty"`
	Delta *struct {
		Type        string `json:"type"`
		Text        string `json:"text,omitempty"`
		PartialJSON string `json:"partial_json,omitempty"`
	} `json:"delta,omitempty"`
	Usage *struct {
		InputTokens  int `json:"input_tokens"`
		OutputTokens int `json:"output_tokens"`
	} `json:"usage,omitempty"`
}

type errorResponse struct {
	Error struct {
		Type    string `json:"type"`
		Message string `json:"message"`
	} `json:"error"`
}

func parseStreamEvent(data string) (*assistant.StreamEvent, error) {
	var event streamEvent
	if err := json.Unmarshal([]byte(data), &event); err != nil {
		return nil, err
	}

	switch event.Type {
	case "content_block_start":
		if event.ContentBlock != nil {
			if event.ContentBlock.Type == "tool_use" {
				return &assistant.StreamEvent{
					Type: assistant.EventToolCallStart,
					ToolCall: &assistant.ToolCall{
						ID:   event.ContentBlock.ID,
						Name: event.ContentBlock.Name,
					},
					ToolIndex: event.Index,
				}, nil
			}
		}

	case "content_block_delta":
		if event.Delta != nil {
			if event.Delta.Type == "text_delta" {
				return &assistant.StreamEvent{
					Type:    assistant.EventContentDelta,
					Content: event.Delta.Text,
				}, nil
			}
			if event.Delta.Type == "input_json_delta" {
				return &assistant.StreamEvent{
					Type: assistant.EventToolCallDelta,
					ToolCall: &assistant.ToolCall{
						Arguments: event.Delta.PartialJSON,
					},
					ToolIndex: event.Index,
				}, nil
			}
		}

	case "message_stop":
		return &assistant.StreamEvent{Type: assistant.EventDone}, nil

	case "message_delta":
		if event.Usage != nil {
			return &assistant.StreamEvent{
				Type: assistant.EventDone,
				Usage: &assistant.Usage{
					PromptTokens:     event.Usage.InputTokens,
					CompletionTokens: event.Usage.OutputTokens,
					TotalTokens:      event.Usage.InputTokens + event.Usage.OutputTokens,
				},
			}, nil
		}
	}

	return &assistant.StreamEvent{Type: assistant.EventContentDelta}, nil
}
← Back