readysite / website / internal / content / rules / eval.go
20.5 KB
eval.go
package rules

import (
	"fmt"
	"net/http"
	"regexp"
	"strings"

	"github.com/readysite/readysite/website/models"
)

// RuleType represents the type of rule.
type RuleType int

const (
	// RuleTypeLocked means only admins can access (null/unset rule).
	RuleTypeLocked RuleType = iota
	// RuleTypePublic means anyone can access (empty string rule).
	RuleTypePublic
	// RuleTypeExpression means the expression must evaluate to true.
	RuleTypeExpression
)

// Context provides request context for rule evaluation.
type Context struct {
	User    *models.User // Current authenticated user (nil if guest)
	Method  string       // HTTP method
	Query   map[string]string
	Headers map[string]string
	Body    map[string]any
}

// NewContext creates a Context from an HTTP request.
func NewContext(r *http.Request, user *models.User, body map[string]any) *Context {
	ctx := &Context{
		User:    user,
		Method:  r.Method,
		Query:   make(map[string]string),
		Headers: make(map[string]string),
		Body:    body,
	}

	// Copy query params
	for k, v := range r.URL.Query() {
		if len(v) > 0 {
			ctx.Query[k] = v[0]
		}
	}

	// Copy headers (lowercase, hyphens to underscores)
	for k, v := range r.Header {
		if len(v) > 0 {
			key := strings.ToLower(strings.ReplaceAll(k, "-", "_"))
			ctx.Headers[key] = v[0]
		}
	}

	return ctx
}


// ParseRuleType determines the type of rule from the rule string.
func ParseRuleType(rule string) RuleType {
	if rule == "" {
		return RuleTypePublic
	}
	// Check for special "locked" value or treat non-empty as expression
	if rule == "null" || rule == "locked" {
		return RuleTypeLocked
	}
	return RuleTypeExpression
}

// Evaluate evaluates a rule expression against context and record data.
// Returns true if the rule passes, false if it fails.
// For locked rules, returns false. For public rules, returns true.
// For expressions, parses and evaluates the expression.
func Evaluate(rule string, ctx *Context, record map[string]any) (bool, error) {
	ruleType := ParseRuleType(rule)

	switch ruleType {
	case RuleTypeLocked:
		// Only admins can access - check if user is admin
		if ctx.User != nil && ctx.User.Role == "admin" {
			return true, nil
		}
		return false, nil

	case RuleTypePublic:
		// Anyone can access
		return true, nil

	case RuleTypeExpression:
		// Parse and evaluate expression
		expr, err := Parse(rule)
		if err != nil {
			return false, fmt.Errorf("failed to parse rule: %w", err)
		}
		return expr.Evaluate(ctx, record)
	}

	return false, nil
}

// ToSQLFilter converts a rule expression to a SQL WHERE clause.
// Returns the WHERE clause (without "WHERE") and parameters.
// Returns empty string if rule is public (no filtering needed).
// Returns error if rule is locked (should be handled separately).
func ToSQLFilter(rule string, ctx *Context) (string, []any, error) {
	ruleType := ParseRuleType(rule)

	switch ruleType {
	case RuleTypeLocked:
		// Locked rules should be checked before querying
		return "", nil, fmt.Errorf("locked rule cannot be converted to SQL filter")

	case RuleTypePublic:
		// No filtering needed
		return "", nil, nil

	case RuleTypeExpression:
		expr, err := Parse(rule)
		if err != nil {
			return "", nil, fmt.Errorf("failed to parse rule: %w", err)
		}
		return expr.ToSQL(ctx)
	}

	return "", nil, nil
}

// --- Expression AST ---

// Expr represents a parsed expression node.
type Expr interface {
	// Evaluate evaluates the expression against context and record data.
	Evaluate(ctx *Context, record map[string]any) (bool, error)
	// ToSQL converts the expression to a SQL WHERE clause fragment.
	ToSQL(ctx *Context) (string, []any, error)
}

// BinaryExpr represents a binary expression (left op right).
type BinaryExpr struct {
	Left  Expr
	Op    string // &&, ||
	Right Expr
}

func (e *BinaryExpr) Evaluate(ctx *Context, record map[string]any) (bool, error) {
	left, err := e.Left.Evaluate(ctx, record)
	if err != nil {
		return false, err
	}

	switch e.Op {
	case "&&":
		if !left {
			return false, nil // Short-circuit
		}
		return e.Right.Evaluate(ctx, record)
	case "||":
		if left {
			return true, nil // Short-circuit
		}
		return e.Right.Evaluate(ctx, record)
	}

	return false, fmt.Errorf("unknown binary operator: %s", e.Op)
}

func (e *BinaryExpr) ToSQL(ctx *Context) (string, []any, error) {
	leftSQL, leftParams, err := e.Left.ToSQL(ctx)
	if err != nil {
		return "", nil, err
	}

	rightSQL, rightParams, err := e.Right.ToSQL(ctx)
	if err != nil {
		return "", nil, err
	}

	sqlOp := "AND"
	if e.Op == "||" {
		sqlOp = "OR"
	}

	sql := fmt.Sprintf("(%s %s %s)", leftSQL, sqlOp, rightSQL)
	params := append(leftParams, rightParams...)
	return sql, params, nil
}

// CompareExpr represents a comparison expression (left op right).
type CompareExpr struct {
	Left  Value
	Op    string // =, !=, >, >=, <, <=, ~, !~
	Right Value
}

func (e *CompareExpr) Evaluate(ctx *Context, record map[string]any) (bool, error) {
	// Handle ?= and ?!= operators for collection existence checks
	if e.Op == "?=" || e.Op == "?!=" {
		return e.evaluateCollectionCheck(ctx, record)
	}

	// Check for :each modifier on left side (field:each ~ pattern)
	if fv, ok := e.Left.(*FieldValue); ok && fv.Modifier == "each" {
		return e.evaluateEach(ctx, record, fv)
	}

	left, err := e.Left.Resolve(ctx, record)
	if err != nil {
		return false, err
	}

	right, err := e.Right.Resolve(ctx, record)
	if err != nil {
		return false, err
	}

	return compare(left, e.Op, right)
}

// evaluateEach handles the :each modifier by applying the comparison to every array item.
// Returns true only if ALL items match the condition.
func (e *CompareExpr) evaluateEach(ctx *Context, record map[string]any, fv *FieldValue) (bool, error) {
	if record == nil {
		return false, nil
	}

	// Get the array value without the modifier
	value := record[fv.Name]
	if value == nil {
		return true, nil // Empty array, all items match (vacuous truth)
	}

	// Get the comparison value
	right, err := e.Right.Resolve(ctx, record)
	if err != nil {
		return false, err
	}

	// Try to iterate as array
	switch arr := value.(type) {
	case []any:
		if len(arr) == 0 {
			return true, nil
		}
		for _, item := range arr {
			match, err := compare(item, e.Op, right)
			if err != nil {
				return false, err
			}
			if !match {
				return false, nil
			}
		}
		return true, nil
	case []string:
		if len(arr) == 0 {
			return true, nil
		}
		for _, item := range arr {
			match, err := compare(item, e.Op, right)
			if err != nil {
				return false, err
			}
			if !match {
				return false, nil
			}
		}
		return true, nil
	case []int:
		if len(arr) == 0 {
			return true, nil
		}
		for _, item := range arr {
			match, err := compare(item, e.Op, right)
			if err != nil {
				return false, err
			}
			if !match {
				return false, nil
			}
		}
		return true, nil
	case []float64:
		if len(arr) == 0 {
			return true, nil
		}
		for _, item := range arr {
			match, err := compare(item, e.Op, right)
			if err != nil {
				return false, err
			}
			if !match {
				return false, nil
			}
		}
		return true, nil
	default:
		// Single value - compare directly
		return compare(value, e.Op, right)
	}
}

// evaluateCollectionCheck handles ?= and ?!= operators that check for existence in another collection.
func (e *CompareExpr) evaluateCollectionCheck(ctx *Context, record map[string]any) (bool, error) {
	// One side should be a CollectionValue, the other is the value to match
	var colVal *CollectionValue
	var otherVal Value
	var flipped bool // true if collection is on the right side

	if cv, ok := e.Left.(*CollectionValue); ok {
		colVal = cv
		otherVal = e.Right
	} else if cv, ok := e.Right.(*CollectionValue); ok {
		colVal = cv
		otherVal = e.Left
		flipped = true
	} else {
		return false, fmt.Errorf("?= and ?!= operators require a @collection reference")
	}

	// Resolve the other value
	other, err := otherVal.Resolve(ctx, record)
	if err != nil {
		return false, err
	}

	// Get the field to check from the collection
	if len(colVal.FieldPath) == 0 {
		return false, fmt.Errorf("@collection reference missing field path")
	}
	field := colVal.FieldPath[0]

	// Determine the comparison operator for the subquery
	subOp := "="
	if flipped {
		// If the collection is on the right, we're checking if the left value
		// exists in the collection's field values
		subOp = "="
	}

	// Execute the existence check
	query := &CollectionQuery{
		Collection: colVal.Collection,
		Field:      field,
		Op:         subOp,
		Value:      other,
	}

	exists, err := CheckCollectionExists(query)
	if err != nil {
		return false, err
	}

	if e.Op == "?=" {
		return exists, nil
	}
	// ?!=
	return !exists, nil
}

func (e *CompareExpr) ToSQL(ctx *Context) (string, []any, error) {
	// Handle ?= and ?!= operators for collection existence checks
	if e.Op == "?=" || e.Op == "?!=" {
		return e.toSQLCollectionCheck(ctx)
	}

	leftSQL, leftParams, err := e.Left.ToSQLValue(ctx)
	if err != nil {
		return "", nil, err
	}

	rightSQL, rightParams, err := e.Right.ToSQLValue(ctx)
	if err != nil {
		return "", nil, err
	}

	sqlOp := e.Op
	switch e.Op {
	case "~":
		sqlOp = "LIKE"
	case "!~":
		sqlOp = "NOT LIKE"
	}

	sql := fmt.Sprintf("%s %s %s", leftSQL, sqlOp, rightSQL)
	params := append(leftParams, rightParams...)
	return sql, params, nil
}

// toSQLCollectionCheck generates SQL for ?= and ?!= operators using EXISTS subquery.
func (e *CompareExpr) toSQLCollectionCheck(ctx *Context) (string, []any, error) {
	// One side should be a CollectionValue, the other is the value to match
	var colVal *CollectionValue
	var otherVal Value

	if cv, ok := e.Left.(*CollectionValue); ok {
		colVal = cv
		otherVal = e.Right
	} else if cv, ok := e.Right.(*CollectionValue); ok {
		colVal = cv
		otherVal = e.Left
	} else {
		return "", nil, fmt.Errorf("?= and ?!= operators require a @collection reference")
	}

	// Resolve the other value for SQL
	otherSQL, otherParams, err := otherVal.ToSQLValue(ctx)
	if err != nil {
		return "", nil, err
	}

	// Get the field to check from the collection
	if len(colVal.FieldPath) == 0 {
		return "", nil, fmt.Errorf("@collection reference missing field path")
	}
	field := colVal.FieldPath[0]

	// Build EXISTS subquery
	// EXISTS (SELECT 1 FROM Documents WHERE CollectionID = ? AND json_extract(Data, '$.field') = ?)
	subquery := fmt.Sprintf(
		"EXISTS (SELECT 1 FROM Documents WHERE CollectionID = ? AND json_extract(Data, '$.%s') = %s)",
		field,
		otherSQL,
	)

	params := []any{colVal.Collection}
	params = append(params, otherParams...)

	if e.Op == "?!=" {
		subquery = "NOT " + subquery
	}

	return subquery, params, nil
}

// --- Values ---

// Value represents a value in an expression (literal, field reference, or context variable).
type Value interface {
	// Resolve returns the actual value given context and record.
	Resolve(ctx *Context, record map[string]any) (any, error)
	// ToSQLValue returns the SQL representation (column reference or placeholder).
	ToSQLValue(ctx *Context) (string, []any, error)
}

// LiteralValue represents a literal value (string, number, bool, null).
type LiteralValue struct {
	Value any
}

func (v *LiteralValue) Resolve(ctx *Context, record map[string]any) (any, error) {
	return v.Value, nil
}

func (v *LiteralValue) ToSQLValue(ctx *Context) (string, []any, error) {
	return "?", []any{v.Value}, nil
}

// FieldValue represents a reference to a record field.
// Supports modifiers: :isset, :length, :lower, :each
type FieldValue struct {
	Name     string
	Modifier string // Optional: "isset", "length", "lower", "each"
}

func (v *FieldValue) Resolve(ctx *Context, record map[string]any) (any, error) {
	if record == nil {
		return nil, nil
	}

	value := record[v.Name]

	// Apply modifier
	switch v.Modifier {
	case "":
		return value, nil
	case "isset":
		// Check if field exists in record (not just non-nil)
		_, exists := record[v.Name]
		return exists, nil
	case "length":
		return getLength(value), nil
	case "lower":
		if s, ok := value.(string); ok {
			return strings.ToLower(s), nil
		}
		return value, nil
	case "each":
		// :each is handled specially in comparison - return the array itself
		return value, nil
	default:
		return nil, fmt.Errorf("unknown modifier: %s", v.Modifier)
	}
}

func (v *FieldValue) ToSQLValue(ctx *Context) (string, []any, error) {
	base := fmt.Sprintf("json_extract(Data, '$.%s')", v.Name)

	// Apply modifier
	switch v.Modifier {
	case "":
		return base, nil, nil
	case "isset":
		// Check if JSON key exists
		return fmt.Sprintf("(json_extract(Data, '$.%s') IS NOT NULL)", v.Name), nil, nil
	case "length":
		// Get length - works for strings and arrays
		return fmt.Sprintf("COALESCE(json_array_length(json_extract(Data, '$.%s')), LENGTH(json_extract(Data, '$.%s')))", v.Name, v.Name), nil, nil
	case "lower":
		return fmt.Sprintf("LOWER(%s)", base), nil, nil
	case "each":
		// :each requires special handling in comparison - cannot be converted to simple SQL
		return "", nil, fmt.Errorf(":each modifier requires special handling and cannot be used in SQL filter")
	default:
		return "", nil, fmt.Errorf("unknown modifier: %s", v.Modifier)
	}
}

// getLength returns the length of a value (string length or array length).
func getLength(v any) int64 {
	if v == nil {
		return 0
	}
	switch val := v.(type) {
	case string:
		return int64(len(val))
	case []any:
		return int64(len(val))
	case []string:
		return int64(len(val))
	case []int:
		return int64(len(val))
	case []float64:
		return int64(len(val))
	default:
		// Try to convert to string and get length
		return int64(len(fmt.Sprintf("%v", v)))
	}
}

// ContextValue represents a reference to request context (@request.*).
// Supports modifiers: :isset, :length, :lower
type ContextValue struct {
	Path     []string // e.g., ["auth", "id"] for @request.auth.id
	Modifier string   // Optional: "isset", "length", "lower"
}

func (v *ContextValue) Resolve(ctx *Context, record map[string]any) (any, error) {
	if len(v.Path) == 0 {
		return nil, nil
	}

	var value any
	var exists bool

	switch v.Path[0] {
	case "auth":
		val, err := v.resolveAuth(ctx)
		if err != nil {
			return nil, err
		}
		value = val
		exists = val != nil && val != ""
	case "method":
		value = ctx.Method
		exists = true
	case "query":
		if len(v.Path) < 2 {
			return nil, nil
		}
		val, ok := ctx.Query[v.Path[1]]
		value = val
		exists = ok
	case "headers":
		if len(v.Path) < 2 {
			return nil, nil
		}
		val, ok := ctx.Headers[v.Path[1]]
		value = val
		exists = ok
	case "body":
		if len(v.Path) < 2 {
			return nil, nil
		}
		if ctx.Body != nil {
			val, ok := ctx.Body[v.Path[1]]
			value = val
			exists = ok
		}
	default:
		return nil, fmt.Errorf("unknown context variable: @request.%s", strings.Join(v.Path, "."))
	}

	// Apply modifier
	switch v.Modifier {
	case "":
		return value, nil
	case "isset":
		return exists, nil
	case "length":
		return getLength(value), nil
	case "lower":
		if s, ok := value.(string); ok {
			return strings.ToLower(s), nil
		}
		return value, nil
	default:
		return nil, fmt.Errorf("unknown modifier: %s", v.Modifier)
	}
}

func (v *ContextValue) resolveAuth(ctx *Context) (any, error) {
	if ctx.User == nil {
		// Guest - return empty strings for auth fields
		if len(v.Path) >= 2 {
			return "", nil
		}
		return nil, nil
	}

	if len(v.Path) < 2 {
		return nil, nil
	}

	switch v.Path[1] {
	case "id":
		return ctx.User.ID, nil
	case "email":
		return ctx.User.Email, nil
	case "name":
		return ctx.User.Name, nil
	case "role":
		return ctx.User.Role, nil
	case "verified":
		return ctx.User.Verified, nil
	// Never expose sensitive fields
	case "passwordHash", "password", "tokenKey":
		return nil, fmt.Errorf("access to sensitive auth fields is not allowed")
	}

	return nil, fmt.Errorf("unknown auth field: %s", v.Path[1])
}

func (v *ContextValue) ToSQLValue(ctx *Context) (string, []any, error) {
	// Context values are resolved at query time, so we return a placeholder
	val, err := v.Resolve(ctx, nil)
	if err != nil {
		return "", nil, err
	}

	// For :isset modifier, we need to handle the boolean value
	if v.Modifier == "isset" {
		// Return 1 or 0 for SQL comparison
		if b, ok := val.(bool); ok {
			if b {
				return "?", []any{1}, nil
			}
			return "?", []any{0}, nil
		}
	}

	return "?", []any{val}, nil
}

// --- Comparison helper ---

func compare(left any, op string, right any) (bool, error) {
	// Handle nil/null
	if left == nil && right == nil {
		switch op {
		case "=":
			return true, nil
		case "!=":
			return false, nil
		}
		return false, nil
	}
	if left == nil || right == nil {
		switch op {
		case "=":
			return false, nil
		case "!=":
			return true, nil
		}
		return false, nil
	}

	// Convert to comparable types
	leftStr := fmt.Sprintf("%v", left)
	rightStr := fmt.Sprintf("%v", right)

	switch op {
	case "=":
		return leftStr == rightStr, nil
	case "!=":
		return leftStr != rightStr, nil
	case "~":
		// LIKE pattern matching - convert SQL LIKE to regex
		pattern := likeToRegex(rightStr)
		matched, _ := regexp.MatchString(pattern, leftStr)
		return matched, nil
	case "!~":
		pattern := likeToRegex(rightStr)
		matched, _ := regexp.MatchString(pattern, leftStr)
		return !matched, nil
	}

	// Numeric comparisons
	leftNum, leftOk := toFloat(left)
	rightNum, rightOk := toFloat(right)

	if !leftOk || !rightOk {
		// Fall back to string comparison for non-numeric
		switch op {
		case ">":
			return leftStr > rightStr, nil
		case ">=":
			return leftStr >= rightStr, nil
		case "<":
			return leftStr < rightStr, nil
		case "<=":
			return leftStr <= rightStr, nil
		}
	}

	switch op {
	case ">":
		return leftNum > rightNum, nil
	case ">=":
		return leftNum >= rightNum, nil
	case "<":
		return leftNum < rightNum, nil
	case "<=":
		return leftNum <= rightNum, nil
	}

	return false, fmt.Errorf("unknown comparison operator: %s", op)
}

func toFloat(v any) (float64, bool) {
	switch n := v.(type) {
	case float64:
		return n, true
	case float32:
		return float64(n), true
	case int:
		return float64(n), true
	case int64:
		return float64(n), true
	case int32:
		return float64(n), true
	case string:
		var f float64
		if _, err := fmt.Sscanf(n, "%f", &f); err == nil {
			return f, true
		}
	}
	return 0, false
}

func likeToRegex(pattern string) string {
	// Escape regex special chars except % and _
	result := regexp.QuoteMeta(pattern)
	// Convert SQL LIKE wildcards to regex
	result = strings.ReplaceAll(result, "%", ".*")
	result = strings.ReplaceAll(result, "_", ".")
	return "^" + result + "$"
}

// CollectionValue represents a reference to another collection (@collection.*).
// Used for cross-collection rule checks like "user is member of this project".
type CollectionValue struct {
	Collection string   // Target collection name (e.g., "memberships")
	Alias      string   // Optional alias for self-joins (e.g., "manager" in @collection.users:manager.id)
	FieldPath  []string // Field path within the collection (e.g., ["userId"])
}

func (v *CollectionValue) Resolve(ctx *Context, record map[string]any) (any, error) {
	// CollectionValue is used with ?= and ?!= operators for existence checks
	// It cannot be resolved to a single value - it represents a query
	// The actual resolution happens in the compare function with ?= / ?!= operators
	return nil, fmt.Errorf("@collection values must be used with ?= or ?!= operators")
}

func (v *CollectionValue) ToSQLValue(ctx *Context) (string, []any, error) {
	// For SQL generation, @collection references become subqueries
	// This is handled specially in CompareExpr.ToSQL for ?= and ?!= operators
	return "", nil, fmt.Errorf("@collection values must be used with ?= or ?!= operators")
}

// CollectionQuery represents a query against another collection for rule evaluation.
type CollectionQuery struct {
	Collection string
	Field      string
	Op         string
	Value      any
}

// CheckCollectionExists checks if any record exists in a collection matching the criteria.
// Used by ?= operator in rule expressions.
func CheckCollectionExists(query *CollectionQuery) (bool, error) {
	// Query the documents collection
	docs, err := models.Documents.Search(
		"CollectionID = ? LIMIT 1",
		query.Collection,
	)
	if err != nil {
		return false, fmt.Errorf("failed to query collection %s: %w", query.Collection, err)
	}

	// If no documents at all, return false
	if len(docs) == 0 {
		// Check if any documents match the field condition
		return false, nil
	}

	// Build the full query with field filter
	// We need to find if ANY document matches: json_extract(Data, '$.field') op value
	sqlField := fmt.Sprintf("json_extract(Data, '$.%s')", query.Field)
	sqlOp := query.Op
	switch query.Op {
	case "~":
		sqlOp = "LIKE"
	case "!~":
		sqlOp = "NOT LIKE"
	}

	whereClause := fmt.Sprintf("CollectionID = ? AND %s %s ?", sqlField, sqlOp)
	docs, err = models.Documents.Search(whereClause+" LIMIT 1", query.Collection, query.Value)
	if err != nil {
		return false, fmt.Errorf("failed to query collection %s: %w", query.Collection, err)
	}

	return len(docs) > 0, nil
}
← Back