readysite / pkg / assistant / assistant_test.go
4.0 KB
assistant_test.go
package assistant_test

import (
	"context"
	"testing"

	"github.com/readysite/readysite/pkg/assistant"
	"github.com/readysite/readysite/pkg/assistant/providers/mock"
)

func TestMockChat(t *testing.T) {
	ai := mock.New(mock.Config{Response: "Hello, World!"})

	resp, err := ai.Chat(context.Background(), assistant.ChatRequest{
		Messages: []assistant.Message{
			assistant.NewUserMessage("Hi"),
		},
	})
	if err != nil {
		t.Fatalf("unexpected error: %v", err)
	}

	if resp.Content != "Hello, World!" {
		t.Errorf("expected 'Hello, World!', got %q", resp.Content)
	}

	if resp.FinishReason != "stop" {
		t.Errorf("expected finish reason 'stop', got %q", resp.FinishReason)
	}
}

func TestMockChatWithToolCalls(t *testing.T) {
	ai := mock.New(mock.Config{
		Response: "Let me help you with that.",
		ToolCalls: []assistant.ToolCall{
			{ID: "call_1", Name: "create_page", Arguments: `{"title":"Test"}`},
		},
	})

	resp, err := ai.Chat(context.Background(), assistant.ChatRequest{
		Messages: []assistant.Message{
			assistant.NewUserMessage("Create a page"),
		},
		Tools: []assistant.Tool{
			assistant.NewTool("create_page", "Create a new page").
				String("title", "Page title", true).
				Build(),
		},
	})
	if err != nil {
		t.Fatalf("unexpected error: %v", err)
	}

	if len(resp.ToolCalls) != 1 {
		t.Fatalf("expected 1 tool call, got %d", len(resp.ToolCalls))
	}

	if resp.ToolCalls[0].Name != "create_page" {
		t.Errorf("expected tool name 'create_page', got %q", resp.ToolCalls[0].Name)
	}

	if resp.FinishReason != "tool_calls" {
		t.Errorf("expected finish reason 'tool_calls', got %q", resp.FinishReason)
	}
}

func TestMockStream(t *testing.T) {
	ai := mock.New(mock.Config{Response: "Hello World"})

	stream, err := ai.Stream(context.Background(), assistant.ChatRequest{
		Messages: []assistant.Message{
			assistant.NewUserMessage("Hi"),
		},
	})
	if err != nil {
		t.Fatalf("unexpected error: %v", err)
	}
	defer stream.Close()

	// Collect all events
	resp, err := stream.Collect()
	if err != nil {
		t.Fatalf("unexpected error: %v", err)
	}

	// Content should be accumulated
	if resp.Content != "Hello World" {
		t.Errorf("expected 'Hello World', got %q", resp.Content)
	}
}

func TestToolBuilder(t *testing.T) {
	tool := assistant.NewTool("test_tool", "A test tool").
		String("name", "The name", true).
		Int("count", "The count", false).
		Bool("enabled", "Whether enabled", true).
		Enum("status", "Status value", []string{"active", "inactive"}, true).
		Array("tags", "List of tags", "string", false).
		Build()

	if tool.Name != "test_tool" {
		t.Errorf("expected name 'test_tool', got %q", tool.Name)
	}

	if tool.Parameters == nil {
		t.Fatal("expected parameters, got nil")
	}

	if len(tool.Parameters.Properties) != 5 {
		t.Errorf("expected 5 properties, got %d", len(tool.Parameters.Properties))
	}

	if len(tool.Parameters.Required) != 3 {
		t.Errorf("expected 3 required, got %d", len(tool.Parameters.Required))
	}
}

func TestToolCallParseArguments(t *testing.T) {
	tc := assistant.ToolCall{
		ID:        "call_1",
		Name:      "create_page",
		Arguments: `{"title":"My Page","published":true}`,
	}

	var args struct {
		Title     string `json:"title"`
		Published bool   `json:"published"`
	}

	if err := tc.ParseArguments(&args); err != nil {
		t.Fatalf("unexpected error: %v", err)
	}

	if args.Title != "My Page" {
		t.Errorf("expected title 'My Page', got %q", args.Title)
	}

	if !args.Published {
		t.Error("expected published to be true")
	}
}

func TestMessageConstructors(t *testing.T) {
	tests := []struct {
		name     string
		msg      assistant.Message
		wantRole string
	}{
		{"system", assistant.NewSystemMessage("system content"), assistant.RoleSystem},
		{"user", assistant.NewUserMessage("user content"), assistant.RoleUser},
		{"assistant", assistant.NewAssistantMessage("assistant content"), assistant.RoleAssistant},
	}

	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			if tt.msg.Role != tt.wantRole {
				t.Errorf("expected role %q, got %q", tt.wantRole, tt.msg.Role)
			}
		})
	}
}
← Back