openai.go
// Package openai provides an OpenAI implementation of assistant.Backend.
package openai
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"github.com/readysite/readysite/pkg/assistant"
)
const defaultBaseURL = "https://api.openai.com/v1"
// backend implements assistant.Backend for OpenAI.
type backend struct {
apiKey string
baseURL string
client *http.Client
}
// Option configures the OpenAI backend.
type Option func(*backend)
// WithBaseURL sets a custom base URL (for Azure OpenAI or proxies).
func WithBaseURL(url string) Option {
return func(b *backend) {
b.baseURL = url
}
}
// WithHTTPClient sets a custom HTTP client.
func WithHTTPClient(client *http.Client) Option {
return func(b *backend) {
b.client = client
}
}
// New creates an OpenAI assistant.
func New(apiKey string, opts ...Option) (*assistant.Assistant, error) {
if apiKey == "" {
return nil, assistant.ErrNoAPIKey
}
b := &backend{
apiKey: apiKey,
baseURL: defaultBaseURL,
client: http.DefaultClient,
}
for _, opt := range opts {
opt(b)
}
return &assistant.Assistant{Backend: b}, nil
}
// Chat sends a chat completion request.
func (b *backend) Chat(ctx context.Context, req assistant.ChatRequest) (*assistant.ChatResponse, error) {
body := b.buildRequest(req, false)
respBody, err := b.doRequest(ctx, body)
if err != nil {
return nil, err
}
defer respBody.Close()
var resp chatResponse
if err := json.NewDecoder(respBody).Decode(&resp); err != nil {
return nil, fmt.Errorf("failed to decode response: %w", err)
}
if len(resp.Choices) == 0 {
return nil, assistant.ErrEmptyResponse
}
choice := resp.Choices[0]
return &assistant.ChatResponse{
Content: choice.Message.Content,
ToolCalls: convertToolCalls(choice.Message.ToolCalls),
FinishReason: choice.FinishReason,
Usage: assistant.Usage{
PromptTokens: resp.Usage.PromptTokens,
CompletionTokens: resp.Usage.CompletionTokens,
TotalTokens: resp.Usage.TotalTokens,
},
}, nil
}
// Stream sends a streaming chat completion request.
func (b *backend) Stream(ctx context.Context, req assistant.ChatRequest) (*assistant.StreamReader, error) {
body := b.buildRequest(req, true)
respBody, err := b.doRequest(ctx, body)
if err != nil {
return nil, err
}
return assistant.NewStreamReader(respBody, parseStreamEvent), nil
}
func (b *backend) buildRequest(req assistant.ChatRequest, stream bool) map[string]any {
messages := make([]map[string]any, 0, len(req.Messages))
// Add system message if present
if req.System != "" {
messages = append(messages, map[string]any{
"role": "system",
"content": req.System,
})
}
// Convert messages
for _, m := range req.Messages {
msg := map[string]any{
"role": m.Role,
"content": m.Content,
}
if len(m.ToolCalls) > 0 {
toolCalls := make([]map[string]any, len(m.ToolCalls))
for i, tc := range m.ToolCalls {
toolCalls[i] = map[string]any{
"id": tc.ID,
"type": "function",
"function": map[string]any{
"name": tc.Name,
"arguments": tc.Arguments,
},
}
}
msg["tool_calls"] = toolCalls
}
if m.ToolCallID != "" {
msg["tool_call_id"] = m.ToolCallID
}
messages = append(messages, msg)
}
body := map[string]any{
"model": req.Model,
"messages": messages,
"stream": stream,
}
if req.MaxTokens > 0 {
body["max_completion_tokens"] = req.MaxTokens
}
if req.Temperature != nil {
body["temperature"] = *req.Temperature
}
if len(req.Tools) > 0 {
tools := make([]map[string]any, len(req.Tools))
for i, t := range req.Tools {
tool := map[string]any{
"type": "function",
"function": map[string]any{
"name": t.Name,
"description": t.Description,
},
}
if t.Parameters != nil {
tool["function"].(map[string]any)["parameters"] = t.Parameters
}
tools[i] = tool
}
body["tools"] = tools
}
return body
}
func (b *backend) doRequest(ctx context.Context, body map[string]any) (io.ReadCloser, error) {
jsonBody, err := json.Marshal(body)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
httpReq, err := http.NewRequestWithContext(ctx, "POST", b.baseURL+"/chat/completions", bytes.NewReader(jsonBody))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("Authorization", "Bearer "+b.apiKey)
resp, err := b.client.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
if resp.StatusCode != http.StatusOK {
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
var errResp errorResponse
if json.Unmarshal(body, &errResp) == nil && errResp.Error.Message != "" {
return nil, &assistant.APIError{
StatusCode: resp.StatusCode,
Type: errResp.Error.Type,
Message: errResp.Error.Message,
}
}
return nil, &assistant.APIError{
StatusCode: resp.StatusCode,
Message: string(body),
}
}
return resp.Body, nil
}
// --- OpenAI API types ---
type chatResponse struct {
Choices []struct {
Message struct {
Content string `json:"content"`
ToolCalls []struct {
ID string `json:"id"`
Type string `json:"type"`
Function struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
} `json:"function"`
} `json:"tool_calls"`
} `json:"message"`
FinishReason string `json:"finish_reason"`
} `json:"choices"`
Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
} `json:"usage"`
}
type streamChunk struct {
Choices []struct {
Delta struct {
Content string `json:"content"`
ToolCalls []struct {
Index int `json:"index"`
ID string `json:"id"`
Type string `json:"type"`
Function struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
} `json:"function"`
} `json:"tool_calls"`
} `json:"delta"`
FinishReason string `json:"finish_reason"`
} `json:"choices"`
}
type errorResponse struct {
Error struct {
Message string `json:"message"`
Type string `json:"type"`
Code string `json:"code"`
} `json:"error"`
}
func convertToolCalls(calls []struct {
ID string `json:"id"`
Type string `json:"type"`
Function struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
} `json:"function"`
}) []assistant.ToolCall {
result := make([]assistant.ToolCall, len(calls))
for i, c := range calls {
result[i] = assistant.ToolCall{
ID: c.ID,
Name: c.Function.Name,
Arguments: c.Function.Arguments,
}
}
return result
}
func parseStreamEvent(data string) (*assistant.StreamEvent, error) {
var chunk streamChunk
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
return nil, err
}
if len(chunk.Choices) == 0 {
return &assistant.StreamEvent{Type: assistant.EventContentDelta}, nil
}
choice := chunk.Choices[0]
// Content delta
if choice.Delta.Content != "" {
return &assistant.StreamEvent{
Type: assistant.EventContentDelta,
Content: choice.Delta.Content,
}, nil
}
// Tool call
if len(choice.Delta.ToolCalls) > 0 {
tc := choice.Delta.ToolCalls[0]
if tc.ID != "" {
// New tool call
return &assistant.StreamEvent{
Type: assistant.EventToolCallStart,
ToolCall: &assistant.ToolCall{
ID: tc.ID,
Name: tc.Function.Name,
Arguments: tc.Function.Arguments,
},
ToolIndex: tc.Index,
}, nil
}
// Tool call delta (arguments continuation)
return &assistant.StreamEvent{
Type: assistant.EventToolCallDelta,
ToolCall: &assistant.ToolCall{
Arguments: tc.Function.Arguments,
},
ToolIndex: tc.Index,
}, nil
}
// Done
if choice.FinishReason != "" {
return &assistant.StreamEvent{Type: assistant.EventDone}, nil
}
return &assistant.StreamEvent{Type: assistant.EventContentDelta}, nil
}