eval.go
package rules
import (
"fmt"
"net/http"
"regexp"
"strings"
"github.com/readysite/readysite/website/models"
)
// RuleType represents the type of rule.
type RuleType int
const (
// RuleTypeLocked means only admins can access (null/unset rule).
RuleTypeLocked RuleType = iota
// RuleTypePublic means anyone can access (empty string rule).
RuleTypePublic
// RuleTypeExpression means the expression must evaluate to true.
RuleTypeExpression
)
// Context provides request context for rule evaluation.
type Context struct {
User *models.User // Current authenticated user (nil if guest)
Method string // HTTP method
Query map[string]string
Headers map[string]string
Body map[string]any
}
// NewContext creates a Context from an HTTP request.
func NewContext(r *http.Request, user *models.User, body map[string]any) *Context {
ctx := &Context{
User: user,
Method: r.Method,
Query: make(map[string]string),
Headers: make(map[string]string),
Body: body,
}
// Copy query params
for k, v := range r.URL.Query() {
if len(v) > 0 {
ctx.Query[k] = v[0]
}
}
// Copy headers (lowercase, hyphens to underscores)
for k, v := range r.Header {
if len(v) > 0 {
key := strings.ToLower(strings.ReplaceAll(k, "-", "_"))
ctx.Headers[key] = v[0]
}
}
return ctx
}
// ParseRuleType determines the type of rule from the rule string.
func ParseRuleType(rule string) RuleType {
if rule == "" {
return RuleTypePublic
}
// Check for special "locked" value or treat non-empty as expression
if rule == "null" || rule == "locked" {
return RuleTypeLocked
}
return RuleTypeExpression
}
// Evaluate evaluates a rule expression against context and record data.
// Returns true if the rule passes, false if it fails.
// For locked rules, returns false. For public rules, returns true.
// For expressions, parses and evaluates the expression.
func Evaluate(rule string, ctx *Context, record map[string]any) (bool, error) {
ruleType := ParseRuleType(rule)
switch ruleType {
case RuleTypeLocked:
// Only admins can access - check if user is admin
if ctx.User != nil && ctx.User.Role == "admin" {
return true, nil
}
return false, nil
case RuleTypePublic:
// Anyone can access
return true, nil
case RuleTypeExpression:
// Parse and evaluate expression
expr, err := Parse(rule)
if err != nil {
return false, fmt.Errorf("failed to parse rule: %w", err)
}
return expr.Evaluate(ctx, record)
}
return false, nil
}
// ToSQLFilter converts a rule expression to a SQL WHERE clause.
// Returns the WHERE clause (without "WHERE") and parameters.
// Returns empty string if rule is public (no filtering needed).
// Returns error if rule is locked (should be handled separately).
func ToSQLFilter(rule string, ctx *Context) (string, []any, error) {
ruleType := ParseRuleType(rule)
switch ruleType {
case RuleTypeLocked:
// Locked rules should be checked before querying
return "", nil, fmt.Errorf("locked rule cannot be converted to SQL filter")
case RuleTypePublic:
// No filtering needed
return "", nil, nil
case RuleTypeExpression:
expr, err := Parse(rule)
if err != nil {
return "", nil, fmt.Errorf("failed to parse rule: %w", err)
}
return expr.ToSQL(ctx)
}
return "", nil, nil
}
// --- Expression AST ---
// Expr represents a parsed expression node.
type Expr interface {
// Evaluate evaluates the expression against context and record data.
Evaluate(ctx *Context, record map[string]any) (bool, error)
// ToSQL converts the expression to a SQL WHERE clause fragment.
ToSQL(ctx *Context) (string, []any, error)
}
// BinaryExpr represents a binary expression (left op right).
type BinaryExpr struct {
Left Expr
Op string // &&, ||
Right Expr
}
func (e *BinaryExpr) Evaluate(ctx *Context, record map[string]any) (bool, error) {
left, err := e.Left.Evaluate(ctx, record)
if err != nil {
return false, err
}
switch e.Op {
case "&&":
if !left {
return false, nil // Short-circuit
}
return e.Right.Evaluate(ctx, record)
case "||":
if left {
return true, nil // Short-circuit
}
return e.Right.Evaluate(ctx, record)
}
return false, fmt.Errorf("unknown binary operator: %s", e.Op)
}
func (e *BinaryExpr) ToSQL(ctx *Context) (string, []any, error) {
leftSQL, leftParams, err := e.Left.ToSQL(ctx)
if err != nil {
return "", nil, err
}
rightSQL, rightParams, err := e.Right.ToSQL(ctx)
if err != nil {
return "", nil, err
}
sqlOp := "AND"
if e.Op == "||" {
sqlOp = "OR"
}
sql := fmt.Sprintf("(%s %s %s)", leftSQL, sqlOp, rightSQL)
params := append(leftParams, rightParams...)
return sql, params, nil
}
// CompareExpr represents a comparison expression (left op right).
type CompareExpr struct {
Left Value
Op string // =, !=, >, >=, <, <=, ~, !~
Right Value
}
func (e *CompareExpr) Evaluate(ctx *Context, record map[string]any) (bool, error) {
// Handle ?= and ?!= operators for collection existence checks
if e.Op == "?=" || e.Op == "?!=" {
return e.evaluateCollectionCheck(ctx, record)
}
// Check for :each modifier on left side (field:each ~ pattern)
if fv, ok := e.Left.(*FieldValue); ok && fv.Modifier == "each" {
return e.evaluateEach(ctx, record, fv)
}
left, err := e.Left.Resolve(ctx, record)
if err != nil {
return false, err
}
right, err := e.Right.Resolve(ctx, record)
if err != nil {
return false, err
}
return compare(left, e.Op, right)
}
// evaluateEach handles the :each modifier by applying the comparison to every array item.
// Returns true only if ALL items match the condition.
func (e *CompareExpr) evaluateEach(ctx *Context, record map[string]any, fv *FieldValue) (bool, error) {
if record == nil {
return false, nil
}
// Get the array value without the modifier
value := record[fv.Name]
if value == nil {
return true, nil // Empty array, all items match (vacuous truth)
}
// Get the comparison value
right, err := e.Right.Resolve(ctx, record)
if err != nil {
return false, err
}
// Try to iterate as array
switch arr := value.(type) {
case []any:
if len(arr) == 0 {
return true, nil
}
for _, item := range arr {
match, err := compare(item, e.Op, right)
if err != nil {
return false, err
}
if !match {
return false, nil
}
}
return true, nil
case []string:
if len(arr) == 0 {
return true, nil
}
for _, item := range arr {
match, err := compare(item, e.Op, right)
if err != nil {
return false, err
}
if !match {
return false, nil
}
}
return true, nil
case []int:
if len(arr) == 0 {
return true, nil
}
for _, item := range arr {
match, err := compare(item, e.Op, right)
if err != nil {
return false, err
}
if !match {
return false, nil
}
}
return true, nil
case []float64:
if len(arr) == 0 {
return true, nil
}
for _, item := range arr {
match, err := compare(item, e.Op, right)
if err != nil {
return false, err
}
if !match {
return false, nil
}
}
return true, nil
default:
// Single value - compare directly
return compare(value, e.Op, right)
}
}
// evaluateCollectionCheck handles ?= and ?!= operators that check for existence in another collection.
func (e *CompareExpr) evaluateCollectionCheck(ctx *Context, record map[string]any) (bool, error) {
// One side should be a CollectionValue, the other is the value to match
var colVal *CollectionValue
var otherVal Value
var flipped bool // true if collection is on the right side
if cv, ok := e.Left.(*CollectionValue); ok {
colVal = cv
otherVal = e.Right
} else if cv, ok := e.Right.(*CollectionValue); ok {
colVal = cv
otherVal = e.Left
flipped = true
} else {
return false, fmt.Errorf("?= and ?!= operators require a @collection reference")
}
// Resolve the other value
other, err := otherVal.Resolve(ctx, record)
if err != nil {
return false, err
}
// Get the field to check from the collection
if len(colVal.FieldPath) == 0 {
return false, fmt.Errorf("@collection reference missing field path")
}
field := colVal.FieldPath[0]
// Determine the comparison operator for the subquery
subOp := "="
if flipped {
// If the collection is on the right, we're checking if the left value
// exists in the collection's field values
subOp = "="
}
// Execute the existence check
query := &CollectionQuery{
Collection: colVal.Collection,
Field: field,
Op: subOp,
Value: other,
}
exists, err := CheckCollectionExists(query)
if err != nil {
return false, err
}
if e.Op == "?=" {
return exists, nil
}
// ?!=
return !exists, nil
}
func (e *CompareExpr) ToSQL(ctx *Context) (string, []any, error) {
// Handle ?= and ?!= operators for collection existence checks
if e.Op == "?=" || e.Op == "?!=" {
return e.toSQLCollectionCheck(ctx)
}
leftSQL, leftParams, err := e.Left.ToSQLValue(ctx)
if err != nil {
return "", nil, err
}
rightSQL, rightParams, err := e.Right.ToSQLValue(ctx)
if err != nil {
return "", nil, err
}
sqlOp := e.Op
switch e.Op {
case "~":
sqlOp = "LIKE"
case "!~":
sqlOp = "NOT LIKE"
}
sql := fmt.Sprintf("%s %s %s", leftSQL, sqlOp, rightSQL)
params := append(leftParams, rightParams...)
return sql, params, nil
}
// toSQLCollectionCheck generates SQL for ?= and ?!= operators using EXISTS subquery.
func (e *CompareExpr) toSQLCollectionCheck(ctx *Context) (string, []any, error) {
// One side should be a CollectionValue, the other is the value to match
var colVal *CollectionValue
var otherVal Value
if cv, ok := e.Left.(*CollectionValue); ok {
colVal = cv
otherVal = e.Right
} else if cv, ok := e.Right.(*CollectionValue); ok {
colVal = cv
otherVal = e.Left
} else {
return "", nil, fmt.Errorf("?= and ?!= operators require a @collection reference")
}
// Resolve the other value for SQL
otherSQL, otherParams, err := otherVal.ToSQLValue(ctx)
if err != nil {
return "", nil, err
}
// Get the field to check from the collection
if len(colVal.FieldPath) == 0 {
return "", nil, fmt.Errorf("@collection reference missing field path")
}
field := colVal.FieldPath[0]
// Build EXISTS subquery
// EXISTS (SELECT 1 FROM Documents WHERE CollectionID = ? AND json_extract(Data, '$.field') = ?)
subquery := fmt.Sprintf(
"EXISTS (SELECT 1 FROM Documents WHERE CollectionID = ? AND json_extract(Data, '$.%s') = %s)",
field,
otherSQL,
)
params := []any{colVal.Collection}
params = append(params, otherParams...)
if e.Op == "?!=" {
subquery = "NOT " + subquery
}
return subquery, params, nil
}
// --- Values ---
// Value represents a value in an expression (literal, field reference, or context variable).
type Value interface {
// Resolve returns the actual value given context and record.
Resolve(ctx *Context, record map[string]any) (any, error)
// ToSQLValue returns the SQL representation (column reference or placeholder).
ToSQLValue(ctx *Context) (string, []any, error)
}
// LiteralValue represents a literal value (string, number, bool, null).
type LiteralValue struct {
Value any
}
func (v *LiteralValue) Resolve(ctx *Context, record map[string]any) (any, error) {
return v.Value, nil
}
func (v *LiteralValue) ToSQLValue(ctx *Context) (string, []any, error) {
return "?", []any{v.Value}, nil
}
// FieldValue represents a reference to a record field.
// Supports modifiers: :isset, :length, :lower, :each
type FieldValue struct {
Name string
Modifier string // Optional: "isset", "length", "lower", "each"
}
func (v *FieldValue) Resolve(ctx *Context, record map[string]any) (any, error) {
if record == nil {
return nil, nil
}
value := record[v.Name]
// Apply modifier
switch v.Modifier {
case "":
return value, nil
case "isset":
// Check if field exists in record (not just non-nil)
_, exists := record[v.Name]
return exists, nil
case "length":
return getLength(value), nil
case "lower":
if s, ok := value.(string); ok {
return strings.ToLower(s), nil
}
return value, nil
case "each":
// :each is handled specially in comparison - return the array itself
return value, nil
default:
return nil, fmt.Errorf("unknown modifier: %s", v.Modifier)
}
}
func (v *FieldValue) ToSQLValue(ctx *Context) (string, []any, error) {
base := fmt.Sprintf("json_extract(Data, '$.%s')", v.Name)
// Apply modifier
switch v.Modifier {
case "":
return base, nil, nil
case "isset":
// Check if JSON key exists
return fmt.Sprintf("(json_extract(Data, '$.%s') IS NOT NULL)", v.Name), nil, nil
case "length":
// Get length - works for strings and arrays
return fmt.Sprintf("COALESCE(json_array_length(json_extract(Data, '$.%s')), LENGTH(json_extract(Data, '$.%s')))", v.Name, v.Name), nil, nil
case "lower":
return fmt.Sprintf("LOWER(%s)", base), nil, nil
case "each":
// :each requires special handling in comparison - cannot be converted to simple SQL
return "", nil, fmt.Errorf(":each modifier requires special handling and cannot be used in SQL filter")
default:
return "", nil, fmt.Errorf("unknown modifier: %s", v.Modifier)
}
}
// getLength returns the length of a value (string length or array length).
func getLength(v any) int64 {
if v == nil {
return 0
}
switch val := v.(type) {
case string:
return int64(len(val))
case []any:
return int64(len(val))
case []string:
return int64(len(val))
case []int:
return int64(len(val))
case []float64:
return int64(len(val))
default:
// Try to convert to string and get length
return int64(len(fmt.Sprintf("%v", v)))
}
}
// ContextValue represents a reference to request context (@request.*).
// Supports modifiers: :isset, :length, :lower
type ContextValue struct {
Path []string // e.g., ["auth", "id"] for @request.auth.id
Modifier string // Optional: "isset", "length", "lower"
}
func (v *ContextValue) Resolve(ctx *Context, record map[string]any) (any, error) {
if len(v.Path) == 0 {
return nil, nil
}
var value any
var exists bool
switch v.Path[0] {
case "auth":
val, err := v.resolveAuth(ctx)
if err != nil {
return nil, err
}
value = val
exists = val != nil && val != ""
case "method":
value = ctx.Method
exists = true
case "query":
if len(v.Path) < 2 {
return nil, nil
}
val, ok := ctx.Query[v.Path[1]]
value = val
exists = ok
case "headers":
if len(v.Path) < 2 {
return nil, nil
}
val, ok := ctx.Headers[v.Path[1]]
value = val
exists = ok
case "body":
if len(v.Path) < 2 {
return nil, nil
}
if ctx.Body != nil {
val, ok := ctx.Body[v.Path[1]]
value = val
exists = ok
}
default:
return nil, fmt.Errorf("unknown context variable: @request.%s", strings.Join(v.Path, "."))
}
// Apply modifier
switch v.Modifier {
case "":
return value, nil
case "isset":
return exists, nil
case "length":
return getLength(value), nil
case "lower":
if s, ok := value.(string); ok {
return strings.ToLower(s), nil
}
return value, nil
default:
return nil, fmt.Errorf("unknown modifier: %s", v.Modifier)
}
}
func (v *ContextValue) resolveAuth(ctx *Context) (any, error) {
if ctx.User == nil {
// Guest - return empty strings for auth fields
if len(v.Path) >= 2 {
return "", nil
}
return nil, nil
}
if len(v.Path) < 2 {
return nil, nil
}
switch v.Path[1] {
case "id":
return ctx.User.ID, nil
case "email":
return ctx.User.Email, nil
case "name":
return ctx.User.Name, nil
case "role":
return ctx.User.Role, nil
case "verified":
return ctx.User.Verified, nil
// Never expose sensitive fields
case "passwordHash", "password", "tokenKey":
return nil, fmt.Errorf("access to sensitive auth fields is not allowed")
}
return nil, fmt.Errorf("unknown auth field: %s", v.Path[1])
}
func (v *ContextValue) ToSQLValue(ctx *Context) (string, []any, error) {
// Context values are resolved at query time, so we return a placeholder
val, err := v.Resolve(ctx, nil)
if err != nil {
return "", nil, err
}
// For :isset modifier, we need to handle the boolean value
if v.Modifier == "isset" {
// Return 1 or 0 for SQL comparison
if b, ok := val.(bool); ok {
if b {
return "?", []any{1}, nil
}
return "?", []any{0}, nil
}
}
return "?", []any{val}, nil
}
// --- Comparison helper ---
func compare(left any, op string, right any) (bool, error) {
// Handle nil/null
if left == nil && right == nil {
switch op {
case "=":
return true, nil
case "!=":
return false, nil
}
return false, nil
}
if left == nil || right == nil {
switch op {
case "=":
return false, nil
case "!=":
return true, nil
}
return false, nil
}
// Convert to comparable types
leftStr := fmt.Sprintf("%v", left)
rightStr := fmt.Sprintf("%v", right)
switch op {
case "=":
return leftStr == rightStr, nil
case "!=":
return leftStr != rightStr, nil
case "~":
// LIKE pattern matching - convert SQL LIKE to regex
pattern := likeToRegex(rightStr)
matched, _ := regexp.MatchString(pattern, leftStr)
return matched, nil
case "!~":
pattern := likeToRegex(rightStr)
matched, _ := regexp.MatchString(pattern, leftStr)
return !matched, nil
}
// Numeric comparisons
leftNum, leftOk := toFloat(left)
rightNum, rightOk := toFloat(right)
if !leftOk || !rightOk {
// Fall back to string comparison for non-numeric
switch op {
case ">":
return leftStr > rightStr, nil
case ">=":
return leftStr >= rightStr, nil
case "<":
return leftStr < rightStr, nil
case "<=":
return leftStr <= rightStr, nil
}
}
switch op {
case ">":
return leftNum > rightNum, nil
case ">=":
return leftNum >= rightNum, nil
case "<":
return leftNum < rightNum, nil
case "<=":
return leftNum <= rightNum, nil
}
return false, fmt.Errorf("unknown comparison operator: %s", op)
}
func toFloat(v any) (float64, bool) {
switch n := v.(type) {
case float64:
return n, true
case float32:
return float64(n), true
case int:
return float64(n), true
case int64:
return float64(n), true
case int32:
return float64(n), true
case string:
var f float64
if _, err := fmt.Sscanf(n, "%f", &f); err == nil {
return f, true
}
}
return 0, false
}
func likeToRegex(pattern string) string {
// Escape regex special chars except % and _
result := regexp.QuoteMeta(pattern)
// Convert SQL LIKE wildcards to regex
result = strings.ReplaceAll(result, "%", ".*")
result = strings.ReplaceAll(result, "_", ".")
return "^" + result + "$"
}
// CollectionValue represents a reference to another collection (@collection.*).
// Used for cross-collection rule checks like "user is member of this project".
type CollectionValue struct {
Collection string // Target collection name (e.g., "memberships")
Alias string // Optional alias for self-joins (e.g., "manager" in @collection.users:manager.id)
FieldPath []string // Field path within the collection (e.g., ["userId"])
}
func (v *CollectionValue) Resolve(ctx *Context, record map[string]any) (any, error) {
// CollectionValue is used with ?= and ?!= operators for existence checks
// It cannot be resolved to a single value - it represents a query
// The actual resolution happens in the compare function with ?= / ?!= operators
return nil, fmt.Errorf("@collection values must be used with ?= or ?!= operators")
}
func (v *CollectionValue) ToSQLValue(ctx *Context) (string, []any, error) {
// For SQL generation, @collection references become subqueries
// This is handled specially in CompareExpr.ToSQL for ?= and ?!= operators
return "", nil, fmt.Errorf("@collection values must be used with ?= or ?!= operators")
}
// CollectionQuery represents a query against another collection for rule evaluation.
type CollectionQuery struct {
Collection string
Field string
Op string
Value any
}
// CheckCollectionExists checks if any record exists in a collection matching the criteria.
// Used by ?= operator in rule expressions.
func CheckCollectionExists(query *CollectionQuery) (bool, error) {
// Query the documents collection
docs, err := models.Documents.Search(
"CollectionID = ? LIMIT 1",
query.Collection,
)
if err != nil {
return false, fmt.Errorf("failed to query collection %s: %w", query.Collection, err)
}
// If no documents at all, return false
if len(docs) == 0 {
// Check if any documents match the field condition
return false, nil
}
// Build the full query with field filter
// We need to find if ANY document matches: json_extract(Data, '$.field') op value
sqlField := fmt.Sprintf("json_extract(Data, '$.%s')", query.Field)
sqlOp := query.Op
switch query.Op {
case "~":
sqlOp = "LIKE"
case "!~":
sqlOp = "NOT LIKE"
}
whereClause := fmt.Sprintf("CollectionID = ? AND %s %s ?", sqlField, sqlOp)
docs, err = models.Documents.Search(whereClause+" LIMIT 1", query.Collection, query.Value)
if err != nil {
return false, fmt.Errorf("failed to query collection %s: %w", query.Collection, err)
}
return len(docs) > 0, nil
}