views.go
package query
import (
"database/sql"
"encoding/json"
"fmt"
"regexp"
"strings"
"time"
"github.com/readysite/readysite/website/models"
)
// ViewRecord represents a record from a view collection query.
type ViewRecord struct {
ID string `json:"id"`
Data map[string]any `json:"data"`
CreatedAt time.Time `json:"created"`
UpdatedAt time.Time `json:"updated"`
}
// MaxQueryLength is the maximum allowed length for view queries.
const MaxQueryLength = 4096
// ValidateQuery validates a view query to ensure it's safe.
// Only SELECT statements are allowed - no INSERT, UPDATE, DELETE, DROP, etc.
// Uses defense-in-depth: multiple validation layers to prevent SQL injection.
func ValidateQuery(query string) error {
query = strings.TrimSpace(query)
if query == "" {
return fmt.Errorf("query cannot be empty")
}
// Limit query length to prevent DoS
if len(query) > MaxQueryLength {
return fmt.Errorf("query exceeds maximum length of %d characters", MaxQueryLength)
}
// Normalize to uppercase for checking
upper := strings.ToUpper(query)
// Must start with SELECT (case-insensitive)
if !strings.HasPrefix(upper, "SELECT") {
return fmt.Errorf("view query must start with SELECT")
}
// Block stacked queries (semicolons indicate multiple statements)
if strings.Contains(query, ";") {
return fmt.Errorf("view query cannot contain semicolons (multiple statements not allowed)")
}
// Block SQL comments which can be used to hide malicious code
if strings.Contains(query, "--") || strings.Contains(query, "/*") || strings.Contains(query, "*/") {
return fmt.Errorf("view query cannot contain SQL comments")
}
// Disallowed keywords (case-insensitive word boundary check)
disallowed := []string{
"INSERT", "UPDATE", "DELETE", "DROP", "CREATE", "ALTER",
"TRUNCATE", "GRANT", "REVOKE", "EXEC", "EXECUTE",
"INTO", // SELECT INTO
"ATTACH", // ATTACH DATABASE
"DETACH", // DETACH DATABASE
"PRAGMA", // SQLite PRAGMA commands
"VACUUM", // VACUUM command
"REINDEX",
"ANALYZE",
"REPLACE", // INSERT OR REPLACE
}
// Check for disallowed keywords (case-insensitive as whole words)
for _, kw := range disallowed {
pattern := regexp.MustCompile(`(?i)\b` + kw + `\b`)
if pattern.MatchString(query) {
return fmt.Errorf("view query cannot contain %s", kw)
}
}
// Disallow system table access (SQLite internal tables)
systemTables := []string{
"sqlite_master", "sqlite_sequence", "sqlite_stat",
"sqlite_temp_master", "sqlite_schema",
"users", "acl_rules", "settings", // App system tables
}
for _, table := range systemTables {
pattern := regexp.MustCompile(`(?i)\b` + regexp.QuoteMeta(table) + `\b`)
if pattern.MatchString(query) {
return fmt.Errorf("view query cannot access system table: %s", table)
}
}
// Limit query complexity: max nesting depth via parentheses count
openParens := strings.Count(query, "(")
closeParens := strings.Count(query, ")")
if openParens != closeParens {
return fmt.Errorf("view query has unbalanced parentheses")
}
if openParens > 10 {
return fmt.Errorf("view query is too complex (max 10 levels of nesting)")
}
// Block UNION which can be used for injection attacks
if regexp.MustCompile(`(?i)\bUNION\b`).MatchString(query) {
return fmt.Errorf("view query cannot contain UNION")
}
return nil
}
// ExecuteView executes a view collection query and returns records.
// Supports filter, sort, limit, and offset for pagination.
func ExecuteView(collection *models.Collection, filter, sort string, limit, offset int) ([]ViewRecord, int, error) {
if !collection.IsView() {
return nil, 0, fmt.Errorf("collection %s is not a view", collection.ID)
}
if err := ValidateQuery(collection.Query); err != nil {
return nil, 0, fmt.Errorf("invalid view query: %w", err)
}
// Build the outer query wrapping the view
query := fmt.Sprintf("SELECT * FROM (%s) AS view_data", collection.Query)
var whereClauses []string
var params []any
// Apply filter if provided
if filter != "" {
whereClauses = append(whereClauses, filter)
}
if len(whereClauses) > 0 {
query += " WHERE " + strings.Join(whereClauses, " AND ")
}
// Get total count before pagination
countQuery := fmt.Sprintf("SELECT COUNT(*) FROM (%s) AS count_data", collection.Query)
if len(whereClauses) > 0 {
countQuery += " WHERE " + strings.Join(whereClauses, " AND ")
}
var totalCount int
if err := models.DB.QueryRow(countQuery, params...).Scan(&totalCount); err != nil {
return nil, 0, fmt.Errorf("failed to count view records: %w", err)
}
// Apply sort
if sort != "" {
query += " ORDER BY " + sort
}
// Apply pagination
if limit > 0 {
query += fmt.Sprintf(" LIMIT %d", limit)
if offset > 0 {
query += fmt.Sprintf(" OFFSET %d", offset)
}
}
// Execute query
rows, err := models.DB.Query(query, params...)
if err != nil {
return nil, 0, fmt.Errorf("failed to execute view query: %w", err)
}
defer rows.Close()
// Get column names
columns, err := rows.Columns()
if err != nil {
return nil, 0, fmt.Errorf("failed to get columns: %w", err)
}
// Read rows
var records []ViewRecord
for rows.Next() {
// Create scan destinations
values := make([]any, len(columns))
valuePtrs := make([]any, len(columns))
for i := range values {
valuePtrs[i] = &values[i]
}
if err := rows.Scan(valuePtrs...); err != nil {
return nil, 0, fmt.Errorf("failed to scan row: %w", err)
}
// Build record
record := ViewRecord{
Data: make(map[string]any),
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
for i, col := range columns {
val := values[i]
// Convert bytes to string
if b, ok := val.([]byte); ok {
val = string(b)
}
// Handle special columns
switch strings.ToLower(col) {
case "id":
if s, ok := val.(string); ok {
record.ID = s
} else {
record.ID = fmt.Sprintf("%v", val)
}
case "created", "createdat", "created_at":
if t, ok := parseTime(val); ok {
record.CreatedAt = t
}
case "updated", "updatedat", "updated_at":
if t, ok := parseTime(val); ok {
record.UpdatedAt = t
}
default:
record.Data[col] = val
}
}
// Generate ID if not present
if record.ID == "" {
record.ID = fmt.Sprintf("view-%d", len(records)+1)
}
records = append(records, record)
}
return records, totalCount, nil
}
// ExecuteViewSingle executes a view query and returns a single record by ID.
func ExecuteViewSingle(collection *models.Collection, id string) (*ViewRecord, error) {
if !collection.IsView() {
return nil, fmt.Errorf("collection %s is not a view", collection.ID)
}
if err := ValidateQuery(collection.Query); err != nil {
return nil, fmt.Errorf("invalid view query: %w", err)
}
// Build query to find by ID
query := fmt.Sprintf("SELECT * FROM (%s) AS view_data WHERE id = ?", collection.Query)
rows, err := models.DB.Query(query, id)
if err != nil {
return nil, fmt.Errorf("failed to execute view query: %w", err)
}
defer rows.Close()
if !rows.Next() {
return nil, sql.ErrNoRows
}
// Get column names
columns, err := rows.Columns()
if err != nil {
return nil, fmt.Errorf("failed to get columns: %w", err)
}
// Create scan destinations
values := make([]any, len(columns))
valuePtrs := make([]any, len(columns))
for i := range values {
valuePtrs[i] = &values[i]
}
if err := rows.Scan(valuePtrs...); err != nil {
return nil, fmt.Errorf("failed to scan row: %w", err)
}
// Build record
record := &ViewRecord{
ID: id,
Data: make(map[string]any),
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
for i, col := range columns {
val := values[i]
if b, ok := val.([]byte); ok {
val = string(b)
}
switch strings.ToLower(col) {
case "id":
// Already set
case "created", "createdat", "created_at":
if t, ok := parseTime(val); ok {
record.CreatedAt = t
}
case "updated", "updatedat", "updated_at":
if t, ok := parseTime(val); ok {
record.UpdatedAt = t
}
default:
record.Data[col] = val
}
}
return record, nil
}
// DeriveSchema derives a schema from the view query results.
func DeriveSchema(collection *models.Collection) ([]map[string]any, error) {
if !collection.IsView() {
return nil, fmt.Errorf("collection %s is not a view", collection.ID)
}
if err := ValidateQuery(collection.Query); err != nil {
return nil, fmt.Errorf("invalid view query: %w", err)
}
// Execute with LIMIT 1 to get column info
query := fmt.Sprintf("SELECT * FROM (%s) AS view_data LIMIT 1", collection.Query)
rows, err := models.DB.Query(query)
if err != nil {
return nil, fmt.Errorf("failed to execute view query: %w", err)
}
defer rows.Close()
columns, err := rows.Columns()
if err != nil {
return nil, fmt.Errorf("failed to get columns: %w", err)
}
columnTypes, err := rows.ColumnTypes()
if err != nil {
return nil, fmt.Errorf("failed to get column types: %w", err)
}
var schema []map[string]any
for i, col := range columns {
// Skip system columns
lower := strings.ToLower(col)
if lower == "id" || lower == "created" || lower == "updated" ||
lower == "createdat" || lower == "updatedat" ||
lower == "created_at" || lower == "updated_at" {
continue
}
field := map[string]any{
"name": col,
"type": sqlTypeToSchemaType(columnTypes[i].DatabaseTypeName()),
"required": false,
}
schema = append(schema, field)
}
return schema, nil
}
// ToJSON converts a ViewRecord to a JSON-compatible map.
func (r *ViewRecord) ToJSON(collectionID, collectionName string) map[string]any {
result := map[string]any{
"id": r.ID,
"collectionId": collectionID,
"collectionName": collectionName,
"created": r.CreatedAt.Format(time.RFC3339),
"updated": r.UpdatedAt.Format(time.RFC3339),
}
// Flatten data fields
for k, v := range r.Data {
result[k] = v
}
return result
}
// SetSchemaFromDerived sets the collection schema from derived schema.
func SetSchemaFromDerived(collection *models.Collection, schema []map[string]any) error {
bytes, err := json.Marshal(schema)
if err != nil {
return err
}
collection.Schema = string(bytes)
return nil
}
func parseTime(v any) (time.Time, bool) {
switch t := v.(type) {
case time.Time:
return t, true
case string:
parsed, err := time.Parse(time.RFC3339, t)
if err == nil {
return parsed, true
}
// Try other formats
for _, format := range []string{
"2006-01-02 15:04:05",
"2006-01-02",
} {
if parsed, err := time.Parse(format, t); err == nil {
return parsed, true
}
}
}
return time.Time{}, false
}
func sqlTypeToSchemaType(sqlType string) string {
upper := strings.ToUpper(sqlType)
switch {
case strings.Contains(upper, "INT"):
return "number"
case strings.Contains(upper, "REAL") || strings.Contains(upper, "FLOAT") || strings.Contains(upper, "DOUBLE"):
return "number"
case strings.Contains(upper, "BOOL"):
return "bool"
case strings.Contains(upper, "DATE") || strings.Contains(upper, "TIME"):
return "date"
case strings.Contains(upper, "JSON"):
return "json"
default:
return "text"
}
}