readysite / pkg / assistant / providers / openai / openai.go
7.9 KB
openai.go
// Package openai provides an OpenAI implementation of assistant.Backend.
package openai

import (
	"bytes"
	"context"
	"encoding/json"
	"fmt"
	"io"
	"net/http"

	"github.com/readysite/readysite/pkg/assistant"
)

const defaultBaseURL = "https://api.openai.com/v1"

// backend implements assistant.Backend for OpenAI.
type backend struct {
	apiKey  string
	baseURL string
	client  *http.Client
}

// Option configures the OpenAI backend.
type Option func(*backend)

// WithBaseURL sets a custom base URL (for Azure OpenAI or proxies).
func WithBaseURL(url string) Option {
	return func(b *backend) {
		b.baseURL = url
	}
}

// WithHTTPClient sets a custom HTTP client.
func WithHTTPClient(client *http.Client) Option {
	return func(b *backend) {
		b.client = client
	}
}

// New creates an OpenAI assistant.
func New(apiKey string, opts ...Option) (*assistant.Assistant, error) {
	if apiKey == "" {
		return nil, assistant.ErrNoAPIKey
	}

	b := &backend{
		apiKey:  apiKey,
		baseURL: defaultBaseURL,
		client:  http.DefaultClient,
	}

	for _, opt := range opts {
		opt(b)
	}

	return &assistant.Assistant{Backend: b}, nil
}

// Chat sends a chat completion request.
func (b *backend) Chat(ctx context.Context, req assistant.ChatRequest) (*assistant.ChatResponse, error) {
	body := b.buildRequest(req, false)

	respBody, err := b.doRequest(ctx, body)
	if err != nil {
		return nil, err
	}
	defer respBody.Close()

	var resp chatResponse
	if err := json.NewDecoder(respBody).Decode(&resp); err != nil {
		return nil, fmt.Errorf("failed to decode response: %w", err)
	}

	if len(resp.Choices) == 0 {
		return nil, assistant.ErrEmptyResponse
	}

	choice := resp.Choices[0]
	return &assistant.ChatResponse{
		Content:      choice.Message.Content,
		ToolCalls:    convertToolCalls(choice.Message.ToolCalls),
		FinishReason: choice.FinishReason,
		Usage: assistant.Usage{
			PromptTokens:     resp.Usage.PromptTokens,
			CompletionTokens: resp.Usage.CompletionTokens,
			TotalTokens:      resp.Usage.TotalTokens,
		},
	}, nil
}

// Stream sends a streaming chat completion request.
func (b *backend) Stream(ctx context.Context, req assistant.ChatRequest) (*assistant.StreamReader, error) {
	body := b.buildRequest(req, true)

	respBody, err := b.doRequest(ctx, body)
	if err != nil {
		return nil, err
	}

	return assistant.NewStreamReader(respBody, parseStreamEvent), nil
}

func (b *backend) buildRequest(req assistant.ChatRequest, stream bool) map[string]any {
	messages := make([]map[string]any, 0, len(req.Messages))

	// Add system message if present
	if req.System != "" {
		messages = append(messages, map[string]any{
			"role":    "system",
			"content": req.System,
		})
	}

	// Convert messages
	for _, m := range req.Messages {
		msg := map[string]any{
			"role":    m.Role,
			"content": m.Content,
		}

		if len(m.ToolCalls) > 0 {
			toolCalls := make([]map[string]any, len(m.ToolCalls))
			for i, tc := range m.ToolCalls {
				toolCalls[i] = map[string]any{
					"id":   tc.ID,
					"type": "function",
					"function": map[string]any{
						"name":      tc.Name,
						"arguments": tc.Arguments,
					},
				}
			}
			msg["tool_calls"] = toolCalls
		}

		if m.ToolCallID != "" {
			msg["tool_call_id"] = m.ToolCallID
		}

		messages = append(messages, msg)
	}

	body := map[string]any{
		"model":    req.Model,
		"messages": messages,
		"stream":   stream,
	}

	if req.MaxTokens > 0 {
		body["max_completion_tokens"] = req.MaxTokens
	}

	if req.Temperature != nil {
		body["temperature"] = *req.Temperature
	}

	if len(req.Tools) > 0 {
		tools := make([]map[string]any, len(req.Tools))
		for i, t := range req.Tools {
			tool := map[string]any{
				"type": "function",
				"function": map[string]any{
					"name":        t.Name,
					"description": t.Description,
				},
			}
			if t.Parameters != nil {
				tool["function"].(map[string]any)["parameters"] = t.Parameters
			}
			tools[i] = tool
		}
		body["tools"] = tools
	}

	return body
}

func (b *backend) doRequest(ctx context.Context, body map[string]any) (io.ReadCloser, error) {
	jsonBody, err := json.Marshal(body)
	if err != nil {
		return nil, fmt.Errorf("failed to marshal request: %w", err)
	}

	httpReq, err := http.NewRequestWithContext(ctx, "POST", b.baseURL+"/chat/completions", bytes.NewReader(jsonBody))
	if err != nil {
		return nil, fmt.Errorf("failed to create request: %w", err)
	}

	httpReq.Header.Set("Content-Type", "application/json")
	httpReq.Header.Set("Authorization", "Bearer "+b.apiKey)

	resp, err := b.client.Do(httpReq)
	if err != nil {
		return nil, fmt.Errorf("request failed: %w", err)
	}

	if resp.StatusCode != http.StatusOK {
		defer resp.Body.Close()
		body, _ := io.ReadAll(resp.Body)

		var errResp errorResponse
		if json.Unmarshal(body, &errResp) == nil && errResp.Error.Message != "" {
			return nil, &assistant.APIError{
				StatusCode: resp.StatusCode,
				Type:       errResp.Error.Type,
				Message:    errResp.Error.Message,
			}
		}

		return nil, &assistant.APIError{
			StatusCode: resp.StatusCode,
			Message:    string(body),
		}
	}

	return resp.Body, nil
}

// --- OpenAI API types ---

type chatResponse struct {
	Choices []struct {
		Message struct {
			Content   string `json:"content"`
			ToolCalls []struct {
				ID       string `json:"id"`
				Type     string `json:"type"`
				Function struct {
					Name      string `json:"name"`
					Arguments string `json:"arguments"`
				} `json:"function"`
			} `json:"tool_calls"`
		} `json:"message"`
		FinishReason string `json:"finish_reason"`
	} `json:"choices"`
	Usage struct {
		PromptTokens     int `json:"prompt_tokens"`
		CompletionTokens int `json:"completion_tokens"`
		TotalTokens      int `json:"total_tokens"`
	} `json:"usage"`
}

type streamChunk struct {
	Choices []struct {
		Delta struct {
			Content   string `json:"content"`
			ToolCalls []struct {
				Index    int    `json:"index"`
				ID       string `json:"id"`
				Type     string `json:"type"`
				Function struct {
					Name      string `json:"name"`
					Arguments string `json:"arguments"`
				} `json:"function"`
			} `json:"tool_calls"`
		} `json:"delta"`
		FinishReason string `json:"finish_reason"`
	} `json:"choices"`
}

type errorResponse struct {
	Error struct {
		Message string `json:"message"`
		Type    string `json:"type"`
		Code    string `json:"code"`
	} `json:"error"`
}

func convertToolCalls(calls []struct {
	ID       string `json:"id"`
	Type     string `json:"type"`
	Function struct {
		Name      string `json:"name"`
		Arguments string `json:"arguments"`
	} `json:"function"`
}) []assistant.ToolCall {
	result := make([]assistant.ToolCall, len(calls))
	for i, c := range calls {
		result[i] = assistant.ToolCall{
			ID:        c.ID,
			Name:      c.Function.Name,
			Arguments: c.Function.Arguments,
		}
	}
	return result
}

func parseStreamEvent(data string) (*assistant.StreamEvent, error) {
	var chunk streamChunk
	if err := json.Unmarshal([]byte(data), &chunk); err != nil {
		return nil, err
	}

	if len(chunk.Choices) == 0 {
		return &assistant.StreamEvent{Type: assistant.EventContentDelta}, nil
	}

	choice := chunk.Choices[0]

	// Content delta
	if choice.Delta.Content != "" {
		return &assistant.StreamEvent{
			Type:    assistant.EventContentDelta,
			Content: choice.Delta.Content,
		}, nil
	}

	// Tool call
	if len(choice.Delta.ToolCalls) > 0 {
		tc := choice.Delta.ToolCalls[0]
		if tc.ID != "" {
			// New tool call
			return &assistant.StreamEvent{
				Type: assistant.EventToolCallStart,
				ToolCall: &assistant.ToolCall{
					ID:        tc.ID,
					Name:      tc.Function.Name,
					Arguments: tc.Function.Arguments,
				},
				ToolIndex: tc.Index,
			}, nil
		}
		// Tool call delta (arguments continuation)
		return &assistant.StreamEvent{
			Type: assistant.EventToolCallDelta,
			ToolCall: &assistant.ToolCall{
				Arguments: tc.Function.Arguments,
			},
			ToolIndex: tc.Index,
		}, nil
	}

	// Done
	if choice.FinishReason != "" {
		return &assistant.StreamEvent{Type: assistant.EventDone}, nil
	}

	return &assistant.StreamEvent{Type: assistant.EventContentDelta}, nil
}
← Back