readysite / pkg / assistant / providers / mock / mock.go
4.6 KB
mock.go
// Package mock provides a mock AI provider for testing.
package mock

import (
	"context"
	"io"
	"strings"
	"time"

	"github.com/readysite/readysite/pkg/assistant"
)

// Config configures the mock provider behavior.
type Config struct {
	// Response is the static response to return from Chat.
	Response string

	// ToolCalls are the tool calls to return.
	ToolCalls []assistant.ToolCall

	// StreamDelay is the delay between streaming events.
	StreamDelay time.Duration

	// Error is an error to return from all calls.
	Error error
}

// backend implements assistant.Backend for testing.
type backend struct {
	config Config
}

// New creates a mock assistant for testing.
func New(config Config) *assistant.Assistant {
	return &assistant.Assistant{
		Backend: &backend{config: config},
	}
}

// Chat returns the configured response.
func (b *backend) Chat(ctx context.Context, req assistant.ChatRequest) (*assistant.ChatResponse, error) {
	if b.config.Error != nil {
		return nil, b.config.Error
	}

	return &assistant.ChatResponse{
		Content:      b.config.Response,
		ToolCalls:    b.config.ToolCalls,
		FinishReason: finishReason(b.config),
		Usage: assistant.Usage{
			PromptTokens:     countTokens(req),
			CompletionTokens: len(strings.Fields(b.config.Response)),
			TotalTokens:      countTokens(req) + len(strings.Fields(b.config.Response)),
		},
	}, nil
}

// Stream returns a stream that emits the configured response.
func (b *backend) Stream(ctx context.Context, req assistant.ChatRequest) (*assistant.StreamReader, error) {
	if b.config.Error != nil {
		return nil, b.config.Error
	}

	// Create a pipe for streaming
	pr, pw := io.Pipe()

	go func() {
		defer pw.Close()

		delay := b.config.StreamDelay
		if delay == 0 {
			delay = 10 * time.Millisecond
		}

		// Stream content word by word
		if b.config.Response != "" {
			words := strings.Fields(b.config.Response)
			for i, word := range words {
				select {
				case <-ctx.Done():
					return
				default:
				}

				if i > 0 {
					pw.Write([]byte("data: {\"type\":\"content\",\"text\":\" \"}\n\n"))
					time.Sleep(delay)
				}
				pw.Write([]byte("data: {\"type\":\"content\",\"text\":\"" + word + "\"}\n\n"))
				time.Sleep(delay)
			}
		}

		// Stream tool calls
		for _, tc := range b.config.ToolCalls {
			pw.Write([]byte("data: {\"type\":\"tool_call\",\"id\":\"" + tc.ID + "\",\"name\":\"" + tc.Name + "\",\"arguments\":\"" + escapeJSON(tc.Arguments) + "\"}\n\n"))
			time.Sleep(delay)
		}

		pw.Write([]byte("data: [DONE]\n\n"))
	}()

	return assistant.NewStreamReader(pr, parseMockEvent), nil
}

func finishReason(config Config) string {
	if len(config.ToolCalls) > 0 {
		return "tool_calls"
	}
	return "stop"
}

func countTokens(req assistant.ChatRequest) int {
	count := 0
	for _, m := range req.Messages {
		count += len(strings.Fields(m.Content))
	}
	return count
}

func escapeJSON(s string) string {
	s = strings.ReplaceAll(s, "\\", "\\\\")
	s = strings.ReplaceAll(s, "\"", "\\\"")
	s = strings.ReplaceAll(s, "\n", "\\n")
	return s
}

// parseMockEvent parses a mock SSE event.
func parseMockEvent(data string) (*assistant.StreamEvent, error) {
	// Simple mock event parsing
	if strings.Contains(data, `"type":"content"`) {
		// Extract text between "text":" and the closing "
		start := strings.Index(data, `"text":"`) + 8
		end := strings.LastIndex(data, `"`)
		if start > 8 && end > start {
			text := data[start:end]
			text = strings.ReplaceAll(text, "\\n", "\n")
			text = strings.ReplaceAll(text, "\\\"", "\"")
			text = strings.ReplaceAll(text, "\\\\", "\\")
			return &assistant.StreamEvent{
				Type:    assistant.EventContentDelta,
				Content: text,
			}, nil
		}
	}

	if strings.Contains(data, `"type":"tool_call"`) {
		// Extract tool call fields
		id := extractField(data, "id")
		name := extractField(data, "name")
		args := extractField(data, "arguments")
		args = strings.ReplaceAll(args, "\\n", "\n")
		args = strings.ReplaceAll(args, "\\\"", "\"")
		args = strings.ReplaceAll(args, "\\\\", "\\")

		return &assistant.StreamEvent{
			Type: assistant.EventToolCallStart,
			ToolCall: &assistant.ToolCall{
				ID:        id,
				Name:      name,
				Arguments: args,
			},
		}, nil
	}

	return &assistant.StreamEvent{Type: assistant.EventContentDelta}, nil
}

func extractField(data, field string) string {
	key := `"` + field + `":"`
	start := strings.Index(data, key)
	if start < 0 {
		return ""
	}
	start += len(key)
	end := start
	escaped := false
	for end < len(data) {
		if escaped {
			escaped = false
			end++
			continue
		}
		if data[end] == '\\' {
			escaped = true
			end++
			continue
		}
		if data[end] == '"' {
			break
		}
		end++
	}
	return data[start:end]
}
← Back