readysite / pkg / database / collection.go
6.6 KB
collection.go
package database

import (
	"errors"
	"fmt"
	"log"
	"reflect"
	"strings"
	"time"
)

var (
	// ErrNotFound is returned when a record doesn't exist
	ErrNotFound = errors.New("record not found")

	// ErrDuplicate is returned when a unique constraint is violated
	ErrDuplicate = errors.New("duplicate record")
)

// Collection is a generic repository for type-safe CRUD operations.
// It embeds Mirror to inherit reflection methods for field access.
type Collection[E any] struct {
	db    *Database
	table string
	*Mirror
}

// ManageOption configures a Collection during creation.
type ManageOption[E any] func(*Collection[E])

// WithIndex creates a non-unique index on the specified columns.
func WithIndex[E any](columns ...string) ManageOption[E] {
	return func(c *Collection[E]) {
		indexName := fmt.Sprintf("idx_%s_%s", strings.ToLower(c.table), strings.ToLower(strings.Join(columns, "_")))
		query := fmt.Sprintf("CREATE INDEX IF NOT EXISTS %s ON %s(%s)", indexName, c.table, strings.Join(columns, ", "))
		if _, err := c.db.Exec(query); err != nil {
			log.Printf("Failed to create index %s: %v", indexName, err)
		}
	}
}

// WithUniqueIndex creates a unique index on the specified column.
func WithUniqueIndex[E any](column string) ManageOption[E] {
	return func(c *Collection[E]) {
		indexName := fmt.Sprintf("idx_%s_%s_unique", strings.ToLower(c.table), strings.ToLower(column))
		query := fmt.Sprintf("CREATE UNIQUE INDEX IF NOT EXISTS %s ON %s(%s)", indexName, c.table, column)
		if _, err := c.db.Exec(query); err != nil {
			log.Printf("Failed to create unique index %s: %v", indexName, err)
		}
	}
}

// Manage creates a new Collection for the given entity type.
// The table name is derived from the struct name.
// Tables are automatically created if they don't exist, and missing columns are added.
func Manage[E any](db *Database, model *E, opts ...ManageOption[E]) *Collection[E] {
	mirror := Reflect[E]()

	c := &Collection[E]{
		db:     db,
		table:  mirror.Name(),
		Mirror: mirror,
	}

	// Auto-create table and migrate schema
	if !c.db.TableExists(c.table) {
		c.db.CreateTable(c.table, c.Columns())
	} else {
		// Table exists, check for missing columns
		existing := make(map[string]bool)
		for _, name := range c.db.GetColumns(c.table) {
			existing[name] = true
		}

		for _, col := range c.Columns() {
			if !existing[col.Name] {
				if err := c.db.AddColumn(c.table, col); err != nil {
					log.Printf("database: failed to add column %s.%s: %v", c.table, col.Name, err)
				}
			}
		}
	}

	// Apply options (including index creation)
	for _, opt := range opts {
		opt(c)
	}

	return c
}

// Get retrieves an entity by ID
func (c *Collection[E]) Get(id string) (*E, error) {
	return c.First("WHERE ID = ?", id)
}

// First returns the first entity matching the query
func (c *Collection[E]) First(where string, args ...any) (*E, error) {
	results, err := c.Search(where+" LIMIT 1", args...)
	if err != nil {
		return nil, err
	}
	if len(results) == 0 {
		return nil, ErrNotFound
	}
	return results[0], nil
}

// Search returns all entities matching the query
func (c *Collection[E]) Search(where string, args ...any) ([]*E, error) {
	cols := c.Columns()
	names := make([]string, len(cols))
	for i, col := range cols {
		names[i] = col.Name
	}

	query := fmt.Sprintf("SELECT %s FROM %s %s", strings.Join(names, ", "), c.table, where)

	rows, err := c.db.Query(query, args...)
	if err != nil {
		return nil, err
	}
	defer rows.Close()

	var results []*E
	for rows.Next() {
		entity := new(E)
		if err := c.Scan(rows, entity); err != nil {
			return nil, err
		}
		results = append(results, entity)
	}

	return results, rows.Err()
}

// All returns all entities in the collection
func (c *Collection[E]) All() ([]*E, error) {
	return c.Search("")
}

// Insert creates a new record and returns the generated ID
func (c *Collection[E]) Insert(entity *E) (string, error) {
	v := reflect.ValueOf(entity).Elem()

	// Set ID if empty
	idField := c.Field(v, "ID")
	if idField.IsValid() && idField.String() == "" {
		idField.SetString(c.db.GenerateID())
	}

	// Set timestamps
	now := time.Now()
	if createdAt := c.Field(v, "CreatedAt"); createdAt.IsValid() && createdAt.Interface().(time.Time).IsZero() {
		createdAt.Set(reflect.ValueOf(now))
	}
	if updatedAt := c.Field(v, "UpdatedAt"); updatedAt.IsValid() {
		updatedAt.Set(reflect.ValueOf(now))
	}

	// Build INSERT query
	cols := c.Columns()
	names := make([]string, len(cols))
	placeholders := make([]string, len(cols))
	for i, col := range cols {
		names[i] = col.Name
		placeholders[i] = "?"
	}
	values := c.Values(v)

	query := fmt.Sprintf(
		"INSERT INTO %s (%s) VALUES (%s)",
		c.table,
		strings.Join(names, ", "),
		strings.Join(placeholders, ", "),
	)

	_, err := c.db.Exec(query, values...)
	if err != nil {
		if strings.Contains(err.Error(), "UNIQUE constraint") {
			return "", ErrDuplicate
		}
		return "", err
	}

	return idField.String(), nil
}

// Update saves changes to an existing record
func (c *Collection[E]) Update(entity *E) error {
	v := reflect.ValueOf(entity).Elem()

	// Update timestamp
	if updatedAt := c.Field(v, "UpdatedAt"); updatedAt.IsValid() {
		updatedAt.Set(reflect.ValueOf(time.Now()))
	}

	// Build UPDATE query (skip ID and CreatedAt)
	var sets []string
	var values []any
	for _, col := range c.Columns() {
		if col.Name == "ID" || col.Name == "CreatedAt" {
			continue
		}
		sets = append(sets, col.Name+" = ?")
		values = append(values, c.Field(v, col.Name).Interface())
	}

	// Add ID for WHERE clause
	values = append(values, c.Field(v, "ID").Interface())

	query := fmt.Sprintf("UPDATE %s SET %s WHERE ID = ?", c.table, strings.Join(sets, ", "))

	result, err := c.db.Exec(query, values...)
	if err != nil {
		return err
	}

	rows, _ := result.RowsAffected()
	if rows == 0 {
		return ErrNotFound
	}

	return nil
}

// Delete removes a record
func (c *Collection[E]) Delete(entity *E) error {
	v := reflect.ValueOf(entity).Elem()
	id := c.Field(v, "ID").Interface()

	result, err := c.db.Exec(fmt.Sprintf("DELETE FROM %s WHERE ID = ?", c.table), id)
	if err != nil {
		return err
	}

	rows, _ := result.RowsAffected()
	if rows == 0 {
		return ErrNotFound
	}

	return nil
}

// Count returns the number of records matching the query.
// Returns 0 and logs if the query fails.
func (c *Collection[E]) Count(where string, args ...any) int {
	query := fmt.Sprintf("SELECT COUNT(*) FROM %s %s", c.table, where)

	var count int
	if err := c.db.QueryRow(query, args...).Scan(&count); err != nil {
		log.Printf("Count query failed on %s: %v", c.table, err)
		return 0
	}
	return count
}

// DB returns the underlying database connection
func (c *Collection[E]) DB() *Database {
	return c.db
}

// Table returns the table name
func (c *Collection[E]) Table() string {
	return c.table
}
← Back