readysite / website / internal / assist / executor_test.go
8.7 KB
executor_test.go
package assist_test

import (
	"encoding/json"
	"testing"

	"github.com/readysite/readysite/pkg/assistant"
	"github.com/readysite/readysite/website/internal/assist"
	"github.com/readysite/readysite/website/internal/content"
	"github.com/readysite/readysite/website/models"
)

func TestCreatePage(t *testing.T) {
	exec := assist.NewExecutor(nil)

	result, err := exec.Execute(assistant.ToolCall{
		ID:        "call_1",
		Name:      "create_page",
		Arguments: `{"title":"Test Page","html":"<h1>Hello</h1>","id":"test-page"}`,
	})
	if err != nil {
		t.Fatalf("unexpected error: %v", err)
	}

	var resp map[string]any
	if err := json.Unmarshal([]byte(result), &resp); err != nil {
		t.Fatalf("failed to parse result: %v", err)
	}

	if resp["id"] != "test-page" {
		t.Errorf("expected id 'test-page', got %v", resp["id"])
	}

	// Verify page was created
	page, err := models.Pages.Get("test-page")
	if err != nil {
		t.Fatalf("failed to get page: %v", err)
	}

	if page.Title() != "Test Page" {
		t.Errorf("expected title 'Test Page', got %q", page.Title())
	}

	if page.HTML() != "<h1>Hello</h1>" {
		t.Errorf("expected HTML '<h1>Hello</h1>', got %q", page.HTML())
	}

	// Cleanup
	models.Pages.Delete(page)
}

func TestUpdatePage(t *testing.T) {
	// Create a page first
	page := &models.Page{}
	page.ID = "update-test"
	models.Pages.Insert(page)
	content.SavePageContent(page, "Original", "", "<p>Original</p>", "")
	defer models.Pages.Delete(page)

	exec := assist.NewExecutor(nil)

	result, err := exec.Execute(assistant.ToolCall{
		ID:        "call_1",
		Name:      "update_page",
		Arguments: `{"id":"update-test","title":"Updated","html":"<p>Updated</p>"}`,
	})
	if err != nil {
		t.Fatalf("unexpected error: %v", err)
	}

	var resp map[string]any
	if err := json.Unmarshal([]byte(result), &resp); err != nil {
		t.Fatalf("failed to parse result: %v", err)
	}

	// Verify page was updated
	page, _ = models.Pages.Get("update-test")
	if page.Title() != "Updated" {
		t.Errorf("expected title 'Updated', got %q", page.Title())
	}
}

func TestDeletePage(t *testing.T) {
	// Create a page first
	page := &models.Page{}
	page.ID = "delete-test"
	models.Pages.Insert(page)
	content.SavePageContent(page, "To Delete", "", "<p>Delete me</p>", "")

	exec := assist.NewExecutor(nil)

	_, err := exec.Execute(assistant.ToolCall{
		ID:        "call_1",
		Name:      "delete_page",
		Arguments: `{"id":"delete-test"}`,
	})
	if err != nil {
		t.Fatalf("unexpected error: %v", err)
	}

	// Verify page was deleted
	_, err = models.Pages.Get("delete-test")
	if err == nil {
		t.Error("expected page to be deleted")
	}
}

func TestCreateCollection(t *testing.T) {
	exec := assist.NewExecutor(nil)

	result, err := exec.Execute(assistant.ToolCall{
		ID:        "call_1",
		Name:      "create_collection",
		Arguments: `{"name":"Blog Posts","id":"blog-posts","description":"Blog post entries"}`,
	})
	if err != nil {
		t.Fatalf("unexpected error: %v", err)
	}

	var resp map[string]any
	if err := json.Unmarshal([]byte(result), &resp); err != nil {
		t.Fatalf("failed to parse result: %v", err)
	}

	if resp["id"] != "blog-posts" {
		t.Errorf("expected id 'blog-posts', got %v", resp["id"])
	}

	// Cleanup
	col, _ := models.Collections.Get("blog-posts")
	models.Collections.Delete(col)
}

func TestCreateDocument(t *testing.T) {
	// Create collection first
	col := &models.Collection{Name: "Test Collection"}
	col.ID = "test-col"
	models.Collections.Insert(col)
	defer models.Collections.Delete(col)

	exec := assist.NewExecutor(nil)

	result, err := exec.Execute(assistant.ToolCall{
		ID:        "call_1",
		Name:      "create_document",
		Arguments: `{"collection_id":"test-col","data":"{\"title\":\"Test Doc\",\"content\":\"Hello\"}"}`,
	})
	if err != nil {
		t.Fatalf("unexpected error: %v", err)
	}

	var resp map[string]any
	if err := json.Unmarshal([]byte(result), &resp); err != nil {
		t.Fatalf("failed to parse result: %v", err)
	}

	id, ok := resp["id"].(string)
	if !ok || id == "" {
		t.Error("expected document ID")
	}

	// Cleanup
	doc, _ := models.Documents.Get(id)
	models.Documents.Delete(doc)
}

func TestAllTools(t *testing.T) {
	allTools := assist.All()

	if len(allTools) < 38 {
		t.Errorf("expected at least 38 tools, got %d", len(allTools))
	}

	// Verify all expected tools exist
	expected := []string{
		"create_page", "update_page", "delete_page", "delete_pages", "get_page", "list_pages",
		"create_collection", "update_collection", "delete_collection", "delete_collections", "get_collection", "list_collections",
		"create_document", "create_documents", "update_document", "delete_document", "get_document", "query_documents",
		"create_partial", "update_partial", "delete_partial", "get_partial", "list_partials",
		"update_file", "get_file", "list_files", "read_file",
		"create_note", "list_notes", "get_note", "update_note", "delete_note",
		"create_user", "update_user", "delete_user", "list_users",
		"validate_template", "navigate_user",
	}

	for _, name := range expected {
		tool := assist.GetTool(name)
		if tool == nil {
			t.Errorf("tool %q not found", name)
		}
	}
}

func TestMutationTracking(t *testing.T) {
	// Create a conversation and message
	conv := &models.Conversation{Title: "Test"}
	models.Conversations.Insert(conv)
	defer models.Conversations.Delete(conv)

	msg := &models.Message{ConversationID: conv.ID, Role: "assistant", Content: "Creating page"}
	models.Messages.Insert(msg)
	defer models.Messages.Delete(msg)

	// Execute with mutation tracking
	exec := assist.NewExecutor(msg)

	_, err := exec.Execute(assistant.ToolCall{
		ID:        "call_1",
		Name:      "create_page",
		Arguments: `{"title":"Mutation Test","html":"<p>Test</p>","id":"mutation-test"}`,
	})
	if err != nil {
		t.Fatalf("unexpected error: %v", err)
	}

	// Verify mutations were recorded (page + page_content)
	mutations, err := msg.Mutations()
	if err != nil {
		t.Fatalf("failed to get mutations: %v", err)
	}

	if len(mutations) < 2 {
		t.Errorf("expected at least 2 mutations (page + page_content), got %d", len(mutations))
	}

	// Find page mutation
	var pageMutation *models.Mutation
	for _, m := range mutations {
		if m.EntityType == "page" {
			pageMutation = m
			break
		}
	}

	if pageMutation == nil {
		t.Error("expected a page mutation")
	} else if pageMutation.Action != "create" {
		t.Errorf("expected action 'create', got %q", pageMutation.Action)
	}

	// Cleanup
	page, _ := models.Pages.Get("mutation-test")
	models.Pages.Delete(page)
	for _, m := range mutations {
		models.Mutations.Delete(m)
	}
}

func TestSQLInjectionPrevention(t *testing.T) {
	// Create collection for testing
	col := &models.Collection{Name: "SQL Test"}
	col.ID = "sql-test-col"
	models.Collections.Insert(col)
	defer models.Collections.Delete(col)

	// Create a document
	doc := &models.Document{
		CollectionID: col.ID,
		Data:         `{"title":"Test"}`,
	}
	models.Documents.Insert(doc)
	defer models.Documents.Delete(doc)

	exec := assist.NewExecutor(nil)

	// Test cases that should be blocked
	blockedFilters := []string{
		"1=1; DROP TABLE Documents",
		"1=1 UNION SELECT * FROM Users",
		"1=1--comment",
		"title = 'test'; DELETE FROM Documents",
		"title LIKE '%' OR 1=1/*",
	}

	for _, filter := range blockedFilters {
		_, err := exec.Execute(assistant.ToolCall{
			ID:        "call_1",
			Name:      "query_documents",
			Arguments: `{"collection_id":"sql-test-col","filter":"` + filter + `"}`,
		})
		if err == nil {
			t.Errorf("expected error for SQL injection attempt: %s", filter)
		}
	}

	// Test cases that should be allowed
	allowedFilters := []string{
		"Data LIKE '%test%'",
		"CreatedAt > '2024-01-01'",
		"ID = 'abc123'",
	}

	for _, filter := range allowedFilters {
		_, err := exec.Execute(assistant.ToolCall{
			ID:        "call_1",
			Name:      "query_documents",
			Arguments: `{"collection_id":"sql-test-col","filter":"` + filter + `"}`,
		})
		if err != nil {
			t.Errorf("unexpected error for valid filter %q: %v", filter, err)
		}
	}
}

func TestPageVersioning(t *testing.T) {
	// Create a page
	page := &models.Page{}
	page.ID = "version-test"
	models.Pages.Insert(page)
	defer models.Pages.Delete(page)

	// Save initial content
	content.SavePageContent(page, "Version 1", "Desc 1", "<p>V1</p>", "")

	// Verify initial content
	if page.Title() != "Version 1" {
		t.Errorf("expected title 'Version 1', got %q", page.Title())
	}

	// Save new content version
	content.SavePageContent(page, "Version 2", "Desc 2", "<p>V2</p>", "")

	// Verify new content
	if page.Title() != "Version 2" {
		t.Errorf("expected title 'Version 2', got %q", page.Title())
	}

	// Verify history has 2 versions
	contents, err := page.Contents()
	if err != nil {
		t.Fatalf("failed to get contents: %v", err)
	}
	if len(contents) != 2 {
		t.Errorf("expected 2 content versions, got %d", len(contents))
	}

	// Cleanup page contents
	for _, c := range contents {
		models.PageContents.Delete(c)
	}
}
← Back