readysite / pkg / database / database.go
3.2 KB
database.go
package database

import (
	"database/sql"
	"fmt"
	"strings"

	"github.com/google/uuid"
)

// Database wraps a sql.DB connection with its engine for lifecycle management.
// It embeds *sql.DB so all standard database methods are available directly.
type Database struct {
	*sql.DB
	Sync Syncer
}

// Syncer is an optional interface for databases that support syncing with a remote.
type Syncer interface {
	Sync() error
}

// GenerateID creates a new UUID string
func (db *Database) GenerateID() string {
	return uuid.NewString()
}

// TableExists checks if a table exists in the database.
func (db *Database) TableExists(name string) bool {
	var tableName string
	err := db.QueryRow(
		"SELECT name FROM sqlite_master WHERE type='table' AND name=?",
		name,
	).Scan(&tableName)
	return err == nil
}

// Column describes a database column for table creation and migration.
type Column struct {
	Name    string
	Type    string // SQL type: TEXT, INTEGER, REAL, DATETIME, BLOB
	Primary bool
	Default string // Default value as SQL literal: '', 0, CURRENT_TIMESTAMP
}

// CreateTable creates a table with the given columns.
func (db *Database) CreateTable(name string, columns []Column) error {
	var cols []string
	for _, col := range columns {
		def := fmt.Sprintf("%s %s", col.Name, col.Type)
		if col.Primary {
			def += " PRIMARY KEY"
		}
		cols = append(cols, def)
	}

	query := fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (%s)", name, strings.Join(cols, ", "))
	_, err := db.Exec(query)
	return err
}

// GetColumns returns the names of all columns in a table.
func (db *Database) GetColumns(table string) []string {
	rows, err := db.Query(fmt.Sprintf("PRAGMA table_info(%s)", table))
	if err != nil {
		return nil
	}
	defer rows.Close()

	var columns []string
	for rows.Next() {
		var cid int
		var name, colType string
		var notNull, pk int
		var dfltValue any
		if err := rows.Scan(&cid, &name, &colType, &notNull, &dfltValue, &pk); err != nil {
			continue
		}
		columns = append(columns, name)
	}
	return columns
}

// AddColumn adds a column to an existing table with a default value.
// SQLite requires constant defaults for ALTER TABLE ADD COLUMN,
// so non-constant defaults like CURRENT_TIMESTAMP are replaced.
func (db *Database) AddColumn(table string, col Column) error {
	defaultVal := col.Default
	// SQLite ALTER TABLE ADD COLUMN requires constant defaults
	if defaultVal == "CURRENT_TIMESTAMP" {
		defaultVal = "'1970-01-01 00:00:00'"
	}
	query := fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s %s DEFAULT %s",
		table, col.Name, col.Type, defaultVal)
	_, err := db.Exec(query)
	return err
}

// Transaction executes a function within a database transaction.
// If the function returns an error, the transaction is rolled back.
// If the function completes successfully, the transaction is committed.
func (db *Database) Transaction(fn func(tx *sql.Tx) error) error {
	tx, err := db.Begin()
	if err != nil {
		return fmt.Errorf("begin transaction: %w", err)
	}

	if err := fn(tx); err != nil {
		if rbErr := tx.Rollback(); rbErr != nil {
			return fmt.Errorf("rollback failed: %v (original error: %w)", rbErr, err)
		}
		return err
	}

	if err := tx.Commit(); err != nil {
		return fmt.Errorf("commit transaction: %w", err)
	}
	return nil
}
← Back