readysite / website / internal / content / query / views.go
10.9 KB
views.go
package query

import (
	"database/sql"
	"encoding/json"
	"fmt"
	"regexp"
	"strings"
	"time"

	"github.com/readysite/readysite/website/models"
)

// ViewRecord represents a record from a view collection query.
type ViewRecord struct {
	ID        string         `json:"id"`
	Data      map[string]any `json:"data"`
	CreatedAt time.Time      `json:"created"`
	UpdatedAt time.Time      `json:"updated"`
}

// MaxQueryLength is the maximum allowed length for view queries.
const MaxQueryLength = 4096

// ValidateQuery validates a view query to ensure it's safe.
// Only SELECT statements are allowed - no INSERT, UPDATE, DELETE, DROP, etc.
// Uses defense-in-depth: multiple validation layers to prevent SQL injection.
func ValidateQuery(query string) error {
	query = strings.TrimSpace(query)
	if query == "" {
		return fmt.Errorf("query cannot be empty")
	}

	// Limit query length to prevent DoS
	if len(query) > MaxQueryLength {
		return fmt.Errorf("query exceeds maximum length of %d characters", MaxQueryLength)
	}

	// Normalize to uppercase for checking
	upper := strings.ToUpper(query)

	// Must start with SELECT (case-insensitive)
	if !strings.HasPrefix(upper, "SELECT") {
		return fmt.Errorf("view query must start with SELECT")
	}

	// Block stacked queries (semicolons indicate multiple statements)
	if strings.Contains(query, ";") {
		return fmt.Errorf("view query cannot contain semicolons (multiple statements not allowed)")
	}

	// Block SQL comments which can be used to hide malicious code
	if strings.Contains(query, "--") || strings.Contains(query, "/*") || strings.Contains(query, "*/") {
		return fmt.Errorf("view query cannot contain SQL comments")
	}

	// Disallowed keywords (case-insensitive word boundary check)
	disallowed := []string{
		"INSERT", "UPDATE", "DELETE", "DROP", "CREATE", "ALTER",
		"TRUNCATE", "GRANT", "REVOKE", "EXEC", "EXECUTE",
		"INTO",   // SELECT INTO
		"ATTACH", // ATTACH DATABASE
		"DETACH", // DETACH DATABASE
		"PRAGMA", // SQLite PRAGMA commands
		"VACUUM", // VACUUM command
		"REINDEX",
		"ANALYZE",
		"REPLACE", // INSERT OR REPLACE
	}

	// Check for disallowed keywords (case-insensitive as whole words)
	for _, kw := range disallowed {
		pattern := regexp.MustCompile(`(?i)\b` + kw + `\b`)
		if pattern.MatchString(query) {
			return fmt.Errorf("view query cannot contain %s", kw)
		}
	}

	// Disallow system table access (SQLite internal tables)
	systemTables := []string{
		"sqlite_master", "sqlite_sequence", "sqlite_stat",
		"sqlite_temp_master", "sqlite_schema",
		"users", "acl_rules", "settings", // App system tables
	}
	for _, table := range systemTables {
		pattern := regexp.MustCompile(`(?i)\b` + regexp.QuoteMeta(table) + `\b`)
		if pattern.MatchString(query) {
			return fmt.Errorf("view query cannot access system table: %s", table)
		}
	}

	// Limit query complexity: max nesting depth via parentheses count
	openParens := strings.Count(query, "(")
	closeParens := strings.Count(query, ")")
	if openParens != closeParens {
		return fmt.Errorf("view query has unbalanced parentheses")
	}
	if openParens > 10 {
		return fmt.Errorf("view query is too complex (max 10 levels of nesting)")
	}

	// Block UNION which can be used for injection attacks
	if regexp.MustCompile(`(?i)\bUNION\b`).MatchString(query) {
		return fmt.Errorf("view query cannot contain UNION")
	}

	return nil
}

// ExecuteView executes a view collection query and returns records.
// Supports filter, sort, limit, and offset for pagination.
func ExecuteView(collection *models.Collection, filter, sort string, limit, offset int) ([]ViewRecord, int, error) {
	if !collection.IsView() {
		return nil, 0, fmt.Errorf("collection %s is not a view", collection.ID)
	}

	if err := ValidateQuery(collection.Query); err != nil {
		return nil, 0, fmt.Errorf("invalid view query: %w", err)
	}

	// Build the outer query wrapping the view
	query := fmt.Sprintf("SELECT * FROM (%s) AS view_data", collection.Query)

	var whereClauses []string
	var params []any

	// Apply filter if provided
	if filter != "" {
		whereClauses = append(whereClauses, filter)
	}

	if len(whereClauses) > 0 {
		query += " WHERE " + strings.Join(whereClauses, " AND ")
	}

	// Get total count before pagination
	countQuery := fmt.Sprintf("SELECT COUNT(*) FROM (%s) AS count_data", collection.Query)
	if len(whereClauses) > 0 {
		countQuery += " WHERE " + strings.Join(whereClauses, " AND ")
	}

	var totalCount int
	if err := models.DB.QueryRow(countQuery, params...).Scan(&totalCount); err != nil {
		return nil, 0, fmt.Errorf("failed to count view records: %w", err)
	}

	// Apply sort
	if sort != "" {
		query += " ORDER BY " + sort
	}

	// Apply pagination
	if limit > 0 {
		query += fmt.Sprintf(" LIMIT %d", limit)
		if offset > 0 {
			query += fmt.Sprintf(" OFFSET %d", offset)
		}
	}

	// Execute query
	rows, err := models.DB.Query(query, params...)
	if err != nil {
		return nil, 0, fmt.Errorf("failed to execute view query: %w", err)
	}
	defer rows.Close()

	// Get column names
	columns, err := rows.Columns()
	if err != nil {
		return nil, 0, fmt.Errorf("failed to get columns: %w", err)
	}

	// Read rows
	var records []ViewRecord
	for rows.Next() {
		// Create scan destinations
		values := make([]any, len(columns))
		valuePtrs := make([]any, len(columns))
		for i := range values {
			valuePtrs[i] = &values[i]
		}

		if err := rows.Scan(valuePtrs...); err != nil {
			return nil, 0, fmt.Errorf("failed to scan row: %w", err)
		}

		// Build record
		record := ViewRecord{
			Data:      make(map[string]any),
			CreatedAt: time.Now(),
			UpdatedAt: time.Now(),
		}

		for i, col := range columns {
			val := values[i]
			// Convert bytes to string
			if b, ok := val.([]byte); ok {
				val = string(b)
			}

			// Handle special columns
			switch strings.ToLower(col) {
			case "id":
				if s, ok := val.(string); ok {
					record.ID = s
				} else {
					record.ID = fmt.Sprintf("%v", val)
				}
			case "created", "createdat", "created_at":
				if t, ok := parseTime(val); ok {
					record.CreatedAt = t
				}
			case "updated", "updatedat", "updated_at":
				if t, ok := parseTime(val); ok {
					record.UpdatedAt = t
				}
			default:
				record.Data[col] = val
			}
		}

		// Generate ID if not present
		if record.ID == "" {
			record.ID = fmt.Sprintf("view-%d", len(records)+1)
		}

		records = append(records, record)
	}

	return records, totalCount, nil
}

// ExecuteViewSingle executes a view query and returns a single record by ID.
func ExecuteViewSingle(collection *models.Collection, id string) (*ViewRecord, error) {
	if !collection.IsView() {
		return nil, fmt.Errorf("collection %s is not a view", collection.ID)
	}

	if err := ValidateQuery(collection.Query); err != nil {
		return nil, fmt.Errorf("invalid view query: %w", err)
	}

	// Build query to find by ID
	query := fmt.Sprintf("SELECT * FROM (%s) AS view_data WHERE id = ?", collection.Query)

	rows, err := models.DB.Query(query, id)
	if err != nil {
		return nil, fmt.Errorf("failed to execute view query: %w", err)
	}
	defer rows.Close()

	if !rows.Next() {
		return nil, sql.ErrNoRows
	}

	// Get column names
	columns, err := rows.Columns()
	if err != nil {
		return nil, fmt.Errorf("failed to get columns: %w", err)
	}

	// Create scan destinations
	values := make([]any, len(columns))
	valuePtrs := make([]any, len(columns))
	for i := range values {
		valuePtrs[i] = &values[i]
	}

	if err := rows.Scan(valuePtrs...); err != nil {
		return nil, fmt.Errorf("failed to scan row: %w", err)
	}

	// Build record
	record := &ViewRecord{
		ID:        id,
		Data:      make(map[string]any),
		CreatedAt: time.Now(),
		UpdatedAt: time.Now(),
	}

	for i, col := range columns {
		val := values[i]
		if b, ok := val.([]byte); ok {
			val = string(b)
		}

		switch strings.ToLower(col) {
		case "id":
			// Already set
		case "created", "createdat", "created_at":
			if t, ok := parseTime(val); ok {
				record.CreatedAt = t
			}
		case "updated", "updatedat", "updated_at":
			if t, ok := parseTime(val); ok {
				record.UpdatedAt = t
			}
		default:
			record.Data[col] = val
		}
	}

	return record, nil
}

// DeriveSchema derives a schema from the view query results.
func DeriveSchema(collection *models.Collection) ([]map[string]any, error) {
	if !collection.IsView() {
		return nil, fmt.Errorf("collection %s is not a view", collection.ID)
	}

	if err := ValidateQuery(collection.Query); err != nil {
		return nil, fmt.Errorf("invalid view query: %w", err)
	}

	// Execute with LIMIT 1 to get column info
	query := fmt.Sprintf("SELECT * FROM (%s) AS view_data LIMIT 1", collection.Query)
	rows, err := models.DB.Query(query)
	if err != nil {
		return nil, fmt.Errorf("failed to execute view query: %w", err)
	}
	defer rows.Close()

	columns, err := rows.Columns()
	if err != nil {
		return nil, fmt.Errorf("failed to get columns: %w", err)
	}

	columnTypes, err := rows.ColumnTypes()
	if err != nil {
		return nil, fmt.Errorf("failed to get column types: %w", err)
	}

	var schema []map[string]any
	for i, col := range columns {
		// Skip system columns
		lower := strings.ToLower(col)
		if lower == "id" || lower == "created" || lower == "updated" ||
			lower == "createdat" || lower == "updatedat" ||
			lower == "created_at" || lower == "updated_at" {
			continue
		}

		field := map[string]any{
			"name":     col,
			"type":     sqlTypeToSchemaType(columnTypes[i].DatabaseTypeName()),
			"required": false,
		}
		schema = append(schema, field)
	}

	return schema, nil
}

// ToJSON converts a ViewRecord to a JSON-compatible map.
func (r *ViewRecord) ToJSON(collectionID, collectionName string) map[string]any {
	result := map[string]any{
		"id":             r.ID,
		"collectionId":   collectionID,
		"collectionName": collectionName,
		"created":        r.CreatedAt.Format(time.RFC3339),
		"updated":        r.UpdatedAt.Format(time.RFC3339),
	}

	// Flatten data fields
	for k, v := range r.Data {
		result[k] = v
	}

	return result
}

// SetSchemaFromDerived sets the collection schema from derived schema.
func SetSchemaFromDerived(collection *models.Collection, schema []map[string]any) error {
	bytes, err := json.Marshal(schema)
	if err != nil {
		return err
	}
	collection.Schema = string(bytes)
	return nil
}

func parseTime(v any) (time.Time, bool) {
	switch t := v.(type) {
	case time.Time:
		return t, true
	case string:
		parsed, err := time.Parse(time.RFC3339, t)
		if err == nil {
			return parsed, true
		}
		// Try other formats
		for _, format := range []string{
			"2006-01-02 15:04:05",
			"2006-01-02",
		} {
			if parsed, err := time.Parse(format, t); err == nil {
				return parsed, true
			}
		}
	}
	return time.Time{}, false
}

func sqlTypeToSchemaType(sqlType string) string {
	upper := strings.ToUpper(sqlType)
	switch {
	case strings.Contains(upper, "INT"):
		return "number"
	case strings.Contains(upper, "REAL") || strings.Contains(upper, "FLOAT") || strings.Contains(upper, "DOUBLE"):
		return "number"
	case strings.Contains(upper, "BOOL"):
		return "bool"
	case strings.Contains(upper, "DATE") || strings.Contains(upper, "TIME"):
		return "date"
	case strings.Contains(upper, "JSON"):
		return "json"
	default:
		return "text"
	}
}
← Back