collection.go
package database
import (
"errors"
"fmt"
"log"
"reflect"
"strings"
"time"
)
var (
// ErrNotFound is returned when a record doesn't exist
ErrNotFound = errors.New("record not found")
// ErrDuplicate is returned when a unique constraint is violated
ErrDuplicate = errors.New("duplicate record")
)
// Collection is a generic repository for type-safe CRUD operations.
// It embeds Mirror to inherit reflection methods for field access.
type Collection[E any] struct {
db *Database
table string
*Mirror
}
// ManageOption configures a Collection during creation.
type ManageOption[E any] func(*Collection[E])
// WithIndex creates a non-unique index on the specified columns.
func WithIndex[E any](columns ...string) ManageOption[E] {
return func(c *Collection[E]) {
indexName := fmt.Sprintf("idx_%s_%s", strings.ToLower(c.table), strings.ToLower(strings.Join(columns, "_")))
query := fmt.Sprintf("CREATE INDEX IF NOT EXISTS %s ON %s(%s)", indexName, c.table, strings.Join(columns, ", "))
if _, err := c.db.Exec(query); err != nil {
log.Printf("Failed to create index %s: %v", indexName, err)
}
}
}
// WithUniqueIndex creates a unique index on the specified column.
func WithUniqueIndex[E any](column string) ManageOption[E] {
return func(c *Collection[E]) {
indexName := fmt.Sprintf("idx_%s_%s_unique", strings.ToLower(c.table), strings.ToLower(column))
query := fmt.Sprintf("CREATE UNIQUE INDEX IF NOT EXISTS %s ON %s(%s)", indexName, c.table, column)
if _, err := c.db.Exec(query); err != nil {
log.Printf("Failed to create unique index %s: %v", indexName, err)
}
}
}
// Manage creates a new Collection for the given entity type.
// The table name is derived from the struct name.
// Tables are automatically created if they don't exist, and missing columns are added.
func Manage[E any](db *Database, model *E, opts ...ManageOption[E]) *Collection[E] {
mirror := Reflect[E]()
c := &Collection[E]{
db: db,
table: mirror.Name(),
Mirror: mirror,
}
// Auto-create table and migrate schema
if !c.db.TableExists(c.table) {
c.db.CreateTable(c.table, c.Columns())
} else {
// Table exists, check for missing columns
existing := make(map[string]bool)
for _, name := range c.db.GetColumns(c.table) {
existing[name] = true
}
for _, col := range c.Columns() {
if !existing[col.Name] {
if err := c.db.AddColumn(c.table, col); err != nil {
log.Printf("database: failed to add column %s.%s: %v", c.table, col.Name, err)
}
}
}
}
// Apply options (including index creation)
for _, opt := range opts {
opt(c)
}
return c
}
// Get retrieves an entity by ID
func (c *Collection[E]) Get(id string) (*E, error) {
return c.First("WHERE ID = ?", id)
}
// First returns the first entity matching the query
func (c *Collection[E]) First(where string, args ...any) (*E, error) {
results, err := c.Search(where+" LIMIT 1", args...)
if err != nil {
return nil, err
}
if len(results) == 0 {
return nil, ErrNotFound
}
return results[0], nil
}
// Search returns all entities matching the query
func (c *Collection[E]) Search(where string, args ...any) ([]*E, error) {
cols := c.Columns()
names := make([]string, len(cols))
for i, col := range cols {
names[i] = col.Name
}
query := fmt.Sprintf("SELECT %s FROM %s %s", strings.Join(names, ", "), c.table, where)
rows, err := c.db.Query(query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
var results []*E
for rows.Next() {
entity := new(E)
if err := c.Scan(rows, entity); err != nil {
return nil, err
}
results = append(results, entity)
}
return results, rows.Err()
}
// All returns all entities in the collection
func (c *Collection[E]) All() ([]*E, error) {
return c.Search("")
}
// Insert creates a new record and returns the generated ID
func (c *Collection[E]) Insert(entity *E) (string, error) {
v := reflect.ValueOf(entity).Elem()
// Set ID if empty
idField := c.Field(v, "ID")
if idField.IsValid() && idField.String() == "" {
idField.SetString(c.db.GenerateID())
}
// Set timestamps
now := time.Now()
if createdAt := c.Field(v, "CreatedAt"); createdAt.IsValid() && createdAt.Interface().(time.Time).IsZero() {
createdAt.Set(reflect.ValueOf(now))
}
if updatedAt := c.Field(v, "UpdatedAt"); updatedAt.IsValid() {
updatedAt.Set(reflect.ValueOf(now))
}
// Build INSERT query
cols := c.Columns()
names := make([]string, len(cols))
placeholders := make([]string, len(cols))
for i, col := range cols {
names[i] = col.Name
placeholders[i] = "?"
}
values := c.Values(v)
query := fmt.Sprintf(
"INSERT INTO %s (%s) VALUES (%s)",
c.table,
strings.Join(names, ", "),
strings.Join(placeholders, ", "),
)
_, err := c.db.Exec(query, values...)
if err != nil {
if strings.Contains(err.Error(), "UNIQUE constraint") {
return "", ErrDuplicate
}
return "", err
}
return idField.String(), nil
}
// Update saves changes to an existing record
func (c *Collection[E]) Update(entity *E) error {
v := reflect.ValueOf(entity).Elem()
// Update timestamp
if updatedAt := c.Field(v, "UpdatedAt"); updatedAt.IsValid() {
updatedAt.Set(reflect.ValueOf(time.Now()))
}
// Build UPDATE query (skip ID and CreatedAt)
var sets []string
var values []any
for _, col := range c.Columns() {
if col.Name == "ID" || col.Name == "CreatedAt" {
continue
}
sets = append(sets, col.Name+" = ?")
values = append(values, c.Field(v, col.Name).Interface())
}
// Add ID for WHERE clause
values = append(values, c.Field(v, "ID").Interface())
query := fmt.Sprintf("UPDATE %s SET %s WHERE ID = ?", c.table, strings.Join(sets, ", "))
result, err := c.db.Exec(query, values...)
if err != nil {
return err
}
rows, _ := result.RowsAffected()
if rows == 0 {
return ErrNotFound
}
return nil
}
// Delete removes a record
func (c *Collection[E]) Delete(entity *E) error {
v := reflect.ValueOf(entity).Elem()
id := c.Field(v, "ID").Interface()
result, err := c.db.Exec(fmt.Sprintf("DELETE FROM %s WHERE ID = ?", c.table), id)
if err != nil {
return err
}
rows, _ := result.RowsAffected()
if rows == 0 {
return ErrNotFound
}
return nil
}
// Count returns the number of records matching the query.
// Returns 0 and logs if the query fails.
func (c *Collection[E]) Count(where string, args ...any) int {
query := fmt.Sprintf("SELECT COUNT(*) FROM %s %s", c.table, where)
var count int
if err := c.db.QueryRow(query, args...).Scan(&count); err != nil {
log.Printf("Count query failed on %s: %v", c.table, err)
return 0
}
return count
}
// DB returns the underlying database connection
func (c *Collection[E]) DB() *Database {
return c.db
}
// Table returns the table name
func (c *Collection[E]) Table() string {
return c.table
}