mock.go
// Package mock provides a mock AI provider for testing.
package mock
import (
"context"
"io"
"strings"
"time"
"github.com/readysite/readysite/pkg/assistant"
)
// Config configures the mock provider behavior.
type Config struct {
// Response is the static response to return from Chat.
Response string
// ToolCalls are the tool calls to return.
ToolCalls []assistant.ToolCall
// StreamDelay is the delay between streaming events.
StreamDelay time.Duration
// Error is an error to return from all calls.
Error error
}
// backend implements assistant.Backend for testing.
type backend struct {
config Config
}
// New creates a mock assistant for testing.
func New(config Config) *assistant.Assistant {
return &assistant.Assistant{
Backend: &backend{config: config},
}
}
// Chat returns the configured response.
func (b *backend) Chat(ctx context.Context, req assistant.ChatRequest) (*assistant.ChatResponse, error) {
if b.config.Error != nil {
return nil, b.config.Error
}
return &assistant.ChatResponse{
Content: b.config.Response,
ToolCalls: b.config.ToolCalls,
FinishReason: finishReason(b.config),
Usage: assistant.Usage{
PromptTokens: countTokens(req),
CompletionTokens: len(strings.Fields(b.config.Response)),
TotalTokens: countTokens(req) + len(strings.Fields(b.config.Response)),
},
}, nil
}
// Stream returns a stream that emits the configured response.
func (b *backend) Stream(ctx context.Context, req assistant.ChatRequest) (*assistant.StreamReader, error) {
if b.config.Error != nil {
return nil, b.config.Error
}
// Create a pipe for streaming
pr, pw := io.Pipe()
go func() {
defer pw.Close()
delay := b.config.StreamDelay
if delay == 0 {
delay = 10 * time.Millisecond
}
// Stream content word by word
if b.config.Response != "" {
words := strings.Fields(b.config.Response)
for i, word := range words {
select {
case <-ctx.Done():
return
default:
}
if i > 0 {
pw.Write([]byte("data: {\"type\":\"content\",\"text\":\" \"}\n\n"))
time.Sleep(delay)
}
pw.Write([]byte("data: {\"type\":\"content\",\"text\":\"" + word + "\"}\n\n"))
time.Sleep(delay)
}
}
// Stream tool calls
for _, tc := range b.config.ToolCalls {
pw.Write([]byte("data: {\"type\":\"tool_call\",\"id\":\"" + tc.ID + "\",\"name\":\"" + tc.Name + "\",\"arguments\":\"" + escapeJSON(tc.Arguments) + "\"}\n\n"))
time.Sleep(delay)
}
pw.Write([]byte("data: [DONE]\n\n"))
}()
return assistant.NewStreamReader(pr, parseMockEvent), nil
}
func finishReason(config Config) string {
if len(config.ToolCalls) > 0 {
return "tool_calls"
}
return "stop"
}
func countTokens(req assistant.ChatRequest) int {
count := 0
for _, m := range req.Messages {
count += len(strings.Fields(m.Content))
}
return count
}
func escapeJSON(s string) string {
s = strings.ReplaceAll(s, "\\", "\\\\")
s = strings.ReplaceAll(s, "\"", "\\\"")
s = strings.ReplaceAll(s, "\n", "\\n")
return s
}
// parseMockEvent parses a mock SSE event.
func parseMockEvent(data string) (*assistant.StreamEvent, error) {
// Simple mock event parsing
if strings.Contains(data, `"type":"content"`) {
// Extract text between "text":" and the closing "
start := strings.Index(data, `"text":"`) + 8
end := strings.LastIndex(data, `"`)
if start > 8 && end > start {
text := data[start:end]
text = strings.ReplaceAll(text, "\\n", "\n")
text = strings.ReplaceAll(text, "\\\"", "\"")
text = strings.ReplaceAll(text, "\\\\", "\\")
return &assistant.StreamEvent{
Type: assistant.EventContentDelta,
Content: text,
}, nil
}
}
if strings.Contains(data, `"type":"tool_call"`) {
// Extract tool call fields
id := extractField(data, "id")
name := extractField(data, "name")
args := extractField(data, "arguments")
args = strings.ReplaceAll(args, "\\n", "\n")
args = strings.ReplaceAll(args, "\\\"", "\"")
args = strings.ReplaceAll(args, "\\\\", "\\")
return &assistant.StreamEvent{
Type: assistant.EventToolCallStart,
ToolCall: &assistant.ToolCall{
ID: id,
Name: name,
Arguments: args,
},
}, nil
}
return &assistant.StreamEvent{Type: assistant.EventContentDelta}, nil
}
func extractField(data, field string) string {
key := `"` + field + `":"`
start := strings.Index(data, key)
if start < 0 {
return ""
}
start += len(key)
end := start
escaped := false
for end < len(data) {
if escaped {
escaped = false
end++
continue
}
if data[end] == '\\' {
escaped = true
end++
continue
}
if data[end] == '"' {
break
}
end++
}
return data[start:end]
}