readysite / website / internal / content / query / filter.go
17.1 KB
filter.go
// Package query provides filter parsing, expansion, and view execution.
package query

import (
	"fmt"
	"regexp"
	"strings"
)

// Comparison operators for API filters
var apiFilterOperators = map[string]string{
	"=":  "=",
	"!=": "!=",
	">":  ">",
	">=": ">=",
	"<":  "<",
	"<=": "<=",
	"~":  "LIKE", // Contains/like
	"!~": "NOT LIKE",
}

// validAPIFieldName matches alphanumeric field names with underscores
var validAPIFieldName = regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]*$`)

// Token types for API filter lexer
type apiFilterTokenType int

const (
	apiFilterTokenField apiFilterTokenType = iota
	apiFilterTokenOperator
	apiFilterTokenValue
	apiFilterTokenAnd
	apiFilterTokenOr
	apiFilterTokenLParen
	apiFilterTokenRParen
	apiFilterTokenEOF
)

type apiFilterToken struct {
	typ   apiFilterTokenType
	value string
}

// APIFilterResult contains the parsed WHERE clause and parameters.
type APIFilterResult struct {
	Where  string
	Params []any
}

// ParseFilter parses a filter expression and returns a parameterized WHERE clause.
// Example: "status = 'published' && views > 100"
// Returns: WHERE clause with ? placeholders and slice of parameter values.
//
// Supported operators: = != > >= < <= ~ (LIKE) !~ (NOT LIKE)
// Supported connectors: && (AND) || (OR)
// Grouping: ( )
//
// Values must be quoted with single quotes for strings: name = 'John'
// Numbers are unquoted: age > 25
// Booleans: active = true, active = false
//
// For LIKE operator, % wildcards are NOT automatically added.
// Use: title ~ '%hello%' for contains, title ~ 'hello%' for starts with.
//
// Field names must exist in the provided allowedFields list.
// If allowedFields is nil, all field names are allowed (dangerous!).
func ParseFilter(expr string, allowedFields []string) (*APIFilterResult, error) {
	if expr == "" {
		return &APIFilterResult{Where: "", Params: nil}, nil
	}

	// Build allowed fields map
	fieldSet := make(map[string]bool)
	if allowedFields != nil {
		for _, f := range allowedFields {
			fieldSet[f] = true
		}
	}

	// Tokenize
	tokens, err := tokenizeAPIFilter(expr)
	if err != nil {
		return nil, err
	}

	// Parse tokens into WHERE clause
	return parseAPIFilterTokens(tokens, fieldSet, allowedFields == nil)
}

// tokenizeAPIFilter breaks the expression into tokens.
func tokenizeAPIFilter(expr string) ([]apiFilterToken, error) {
	var tokens []apiFilterToken
	expr = strings.TrimSpace(expr)

	i := 0
	for i < len(expr) {
		// Skip whitespace
		for i < len(expr) && (expr[i] == ' ' || expr[i] == '\t') {
			i++
		}
		if i >= len(expr) {
			break
		}

		ch := expr[i]

		// Parentheses
		if ch == '(' {
			tokens = append(tokens, apiFilterToken{apiFilterTokenLParen, "("})
			i++
			continue
		}
		if ch == ')' {
			tokens = append(tokens, apiFilterToken{apiFilterTokenRParen, ")"})
			i++
			continue
		}

		// Logical operators
		if i+1 < len(expr) && expr[i:i+2] == "&&" {
			tokens = append(tokens, apiFilterToken{apiFilterTokenAnd, "&&"})
			i += 2
			continue
		}
		if i+1 < len(expr) && expr[i:i+2] == "||" {
			tokens = append(tokens, apiFilterToken{apiFilterTokenOr, "||"})
			i += 2
			continue
		}

		// Comparison operators (check longer ones first)
		foundOp := false
		for _, op := range []string{"!=", ">=", "<=", "!~", "=", ">", "<", "~"} {
			if i+len(op) <= len(expr) && expr[i:i+len(op)] == op {
				tokens = append(tokens, apiFilterToken{apiFilterTokenOperator, op})
				i += len(op)
				foundOp = true
				break
			}
		}
		if foundOp {
			continue
		}

		// Quoted string value
		if ch == '\'' {
			end := i + 1
			for end < len(expr) && expr[end] != '\'' {
				if expr[end] == '\\' && end+1 < len(expr) {
					end += 2 // Skip escaped char
				} else {
					end++
				}
			}
			if end >= len(expr) {
				return nil, fmt.Errorf("unterminated string at position %d", i)
			}
			// Include quotes in value for later processing
			tokens = append(tokens, apiFilterToken{apiFilterTokenValue, expr[i : end+1]})
			i = end + 1
			continue
		}

		// Unquoted value (number, boolean, field name)
		end := i
		for end < len(expr) && !strings.ContainsAny(string(expr[end]), " \t()&|=!<>~'") {
			end++
		}
		if end > i {
			word := expr[i:end]
			// Determine if it's a field or value based on context
			// Fields come before operators, values come after
			if len(tokens) > 0 && tokens[len(tokens)-1].typ == apiFilterTokenOperator {
				tokens = append(tokens, apiFilterToken{apiFilterTokenValue, word})
			} else {
				tokens = append(tokens, apiFilterToken{apiFilterTokenField, word})
			}
			i = end
			continue
		}

		return nil, fmt.Errorf("unexpected character '%c' at position %d", ch, i)
	}

	tokens = append(tokens, apiFilterToken{apiFilterTokenEOF, ""})
	return tokens, nil
}

// parseAPIFilterTokens converts tokens to a WHERE clause with parameters.
func parseAPIFilterTokens(tokens []apiFilterToken, allowedFields map[string]bool, allowAll bool) (*APIFilterResult, error) {
	var where strings.Builder
	var params []any

	i := 0
	expectField := true
	parenDepth := 0

	for i < len(tokens) && tokens[i].typ != apiFilterTokenEOF {
		t := tokens[i]

		switch t.typ {
		case apiFilterTokenLParen:
			where.WriteString("(")
			parenDepth++
			expectField = true
			i++

		case apiFilterTokenRParen:
			if parenDepth == 0 {
				return nil, fmt.Errorf("unmatched closing parenthesis")
			}
			where.WriteString(")")
			parenDepth--
			expectField = false
			i++

		case apiFilterTokenAnd:
			where.WriteString(" AND ")
			expectField = true
			i++

		case apiFilterTokenOr:
			where.WriteString(" OR ")
			expectField = true
			i++

		case apiFilterTokenField:
			if !expectField {
				return nil, fmt.Errorf("unexpected field '%s'", t.value)
			}
			// Validate field name
			if !validAPIFieldName.MatchString(t.value) {
				return nil, fmt.Errorf("invalid field name '%s'", t.value)
			}
			if !allowAll && !allowedFields[t.value] {
				return nil, fmt.Errorf("field '%s' is not allowed", t.value)
			}

			// Expect: field operator value
			if i+2 >= len(tokens) || tokens[i+1].typ != apiFilterTokenOperator || tokens[i+2].typ != apiFilterTokenValue {
				return nil, fmt.Errorf("expected 'field operator value' pattern")
			}

			field := t.value
			op := tokens[i+1].value
			val := tokens[i+2].value

			sqlOp, ok := apiFilterOperators[op]
			if !ok {
				return nil, fmt.Errorf("unknown operator '%s'", op)
			}

			// Parse value
			parsedVal, err := parseAPIFilterValue(val)
			if err != nil {
				return nil, err
			}

			// For JSON fields in SQLite, use json_extract
			// Document.Data is JSON, so we query: json_extract(Data, '$.fieldname')
			where.WriteString(fmt.Sprintf("json_extract(Data, '$.%s') %s ?", field, sqlOp))
			params = append(params, parsedVal)

			expectField = false
			i += 3

		default:
			return nil, fmt.Errorf("unexpected token: %v", t)
		}
	}

	if parenDepth != 0 {
		return nil, fmt.Errorf("unmatched opening parenthesis")
	}

	return &APIFilterResult{
		Where:  where.String(),
		Params: params,
	}, nil
}

// parseAPIFilterValue parses a value token into a Go value.
func parseAPIFilterValue(val string) (any, error) {
	// Quoted string
	if strings.HasPrefix(val, "'") && strings.HasSuffix(val, "'") {
		// Remove quotes and unescape
		s := val[1 : len(val)-1]
		s = strings.ReplaceAll(s, "\\'", "'")
		return s, nil
	}

	// Boolean
	if val == "true" {
		return true, nil
	}
	if val == "false" {
		return false, nil
	}

	// Null
	if val == "null" {
		return nil, nil
	}

	// Number - try to parse as float64
	var f float64
	if _, err := fmt.Sscanf(val, "%f", &f); err == nil {
		// Check if it's an integer
		if f == float64(int64(f)) {
			return int64(f), nil
		}
		return f, nil
	}

	return nil, fmt.Errorf("invalid value '%s'", val)
}

// ParseSort parses a sort expression and returns an ORDER BY clause.
// Example: "-created,title" -> "json_extract(Data, '$.created') DESC, json_extract(Data, '$.title') ASC"
// Prefix with - for descending order.
func ParseSort(expr string, allowedFields []string) (string, error) {
	if expr == "" {
		return "", nil
	}

	// Build allowed fields map
	fieldSet := make(map[string]bool)
	if allowedFields != nil {
		for _, f := range allowedFields {
			fieldSet[f] = true
		}
	}

	parts := strings.Split(expr, ",")
	var orderParts []string

	for _, part := range parts {
		part = strings.TrimSpace(part)
		if part == "" {
			continue
		}

		desc := false
		if strings.HasPrefix(part, "-") {
			desc = true
			part = part[1:]
		}

		// Validate field name
		if !validAPIFieldName.MatchString(part) {
			return "", fmt.Errorf("invalid sort field '%s'", part)
		}

		// Check against allowed fields
		if allowedFields != nil && !fieldSet[part] {
			// Allow system fields: id, created, updated
			if part != "id" && part != "created" && part != "updated" {
				return "", fmt.Errorf("sort field '%s' is not allowed", part)
			}
		}

		// Map to actual columns
		var column string
		switch part {
		case "id":
			column = "ID"
		case "created":
			column = "CreatedAt"
		case "updated":
			column = "UpdatedAt"
		default:
			column = fmt.Sprintf("json_extract(Data, '$.%s')", part)
		}

		direction := "ASC"
		if desc {
			direction = "DESC"
		}

		orderParts = append(orderParts, fmt.Sprintf("%s %s", column, direction))
	}

	if len(orderParts) == 0 {
		return "", nil
	}

	return strings.Join(orderParts, ", "), nil
}

// FilterFields filters a map to only include the specified fields.
// If fields is empty, returns the original map unchanged.
func FilterFields(data map[string]any, fields []string) map[string]any {
	if len(fields) == 0 {
		return data
	}

	// Build field set
	fieldSet := make(map[string]bool)
	for _, f := range fields {
		fieldSet[strings.TrimSpace(f)] = true
	}

	// System fields always included
	fieldSet["id"] = true
	fieldSet["collectionId"] = true
	fieldSet["collectionName"] = true
	fieldSet["created"] = true
	fieldSet["updated"] = true

	result := make(map[string]any)
	for k, v := range data {
		if fieldSet[k] {
			result[k] = v
		}
	}

	return result
}

// ParseFields parses a comma-separated fields parameter.
func ParseFields(expr string) []string {
	if expr == "" {
		return nil
	}

	parts := strings.Split(expr, ",")
	var fields []string
	for _, p := range parts {
		p = strings.TrimSpace(p)
		if p != "" {
			fields = append(fields, p)
		}
	}
	return fields
}

// ParseExpand parses a comma-separated expand parameter.
// Supports nested expansion with dots: "author,comments.author"
// Returns a list of expand paths.
func ParseExpand(expr string) [][]string {
	if expr == "" {
		return nil
	}

	parts := strings.Split(expr, ",")
	var paths [][]string

	for _, p := range parts {
		p = strings.TrimSpace(p)
		if p == "" {
			continue
		}

		// Split by dot for nested paths
		segments := strings.Split(p, ".")
		var path []string
		for _, seg := range segments {
			seg = strings.TrimSpace(seg)
			if seg != "" && validAPIFieldName.MatchString(seg) {
				path = append(path, seg)
			}
		}
		if len(path) > 0 {
			paths = append(paths, path)
		}
	}

	return paths
}

// MatchesFilter evaluates an APIFilterResult filter against a record in memory.
// Returns true if the record matches the filter, false otherwise.
// This is used for real-time event filtering where we can't use SQL.
func MatchesFilter(record map[string]any, filter *APIFilterResult) bool {
	if filter == nil || filter.Where == "" {
		return true
	}

	// Re-parse the original expression to evaluate against the record
	// We need to walk through conditions and evaluate them
	// For now, use a simpler approach: extract conditions and check them

	// Note: This is a simplified implementation that handles common cases.
	// For complex expressions with nested parentheses, the SQL evaluation
	// in the database is more reliable.

	return matchesFilterSimple(record, filter)
}

// matchesFilterSimple evaluates a parsed filter against a record.
// It evaluates each condition in the filter parameters against the record fields.
func matchesFilterSimple(record map[string]any, filter *APIFilterResult) bool {
	// The filter.Where contains placeholders like "json_extract(Data, '$.field') = ?"
	// and filter.Params contains the values.
	//
	// We need to extract field names and operators from the WHERE clause
	// and evaluate them against the record.

	where := filter.Where
	params := filter.Params

	// Simple state machine to handle AND/OR
	// This handles expressions like: "A AND B" or "A OR B" but not nested parens well
	// For real-time filtering, this covers most use cases.

	// Split by AND/OR while tracking which connector
	type condition struct {
		fieldExpr string
		op        string
		paramIdx  int
	}

	// Extract conditions from WHERE clause
	// Pattern: json_extract(Data, '$.fieldname') OP ?
	condPattern := regexp.MustCompile(`json_extract\(Data, '\$\.([^']+)'\)\s*(=|!=|>|>=|<|<=|LIKE|NOT LIKE)\s*\?`)
	matches := condPattern.FindAllStringSubmatchIndex(where, -1)

	if len(matches) == 0 {
		// No recognizable conditions, assume match
		return true
	}

	// Evaluate each condition
	paramIdx := 0
	allResults := make([]bool, 0, len(matches))

	for _, match := range matches {
		if paramIdx >= len(params) {
			break
		}

		// Extract field name and operator
		fieldStart, fieldEnd := match[2], match[3]
		opStart, opEnd := match[4], match[5]

		fieldName := where[fieldStart:fieldEnd]
		op := where[opStart:opEnd]

		// Get value from record
		recordVal := record[fieldName]
		paramVal := params[paramIdx]
		paramIdx++

		// Evaluate condition
		result := evaluateFilterCondition(recordVal, op, paramVal)
		allResults = append(allResults, result)
	}

	// Determine how to combine results based on presence of AND/OR
	// Simple heuristic: if ANY OR is present, use OR logic; otherwise AND
	if strings.Contains(where, " OR ") {
		// At least one must be true
		for _, r := range allResults {
			if r {
				return true
			}
		}
		return false
	}

	// Default: all must be true (AND)
	for _, r := range allResults {
		if !r {
			return false
		}
	}
	return true
}

// evaluateFilterCondition evaluates a single condition.
func evaluateFilterCondition(recordVal any, op string, paramVal any) bool {
	// Handle nil record value
	if recordVal == nil {
		if op == "=" && paramVal == nil {
			return true
		}
		if op == "!=" && paramVal != nil {
			return true
		}
		return false
	}

	// Convert values to comparable types
	switch op {
	case "=":
		return filterCompareEqual(recordVal, paramVal)
	case "!=":
		return !filterCompareEqual(recordVal, paramVal)
	case ">":
		return filterCompareNumeric(recordVal, paramVal) > 0
	case ">=":
		return filterCompareNumeric(recordVal, paramVal) >= 0
	case "<":
		return filterCompareNumeric(recordVal, paramVal) < 0
	case "<=":
		return filterCompareNumeric(recordVal, paramVal) <= 0
	case "LIKE":
		return filterMatchLike(recordVal, paramVal)
	case "NOT LIKE":
		return !filterMatchLike(recordVal, paramVal)
	}

	return false
}

// filterCompareEqual checks if two values are equal.
func filterCompareEqual(a, b any) bool {
	// String comparison
	aStr, aIsStr := filterToString(a)
	bStr, bIsStr := filterToString(b)
	if aIsStr && bIsStr {
		return aStr == bStr
	}

	// Numeric comparison
	aNum, aIsNum := filterToFloat64(a)
	bNum, bIsNum := filterToFloat64(b)
	if aIsNum && bIsNum {
		return aNum == bNum
	}

	// Boolean comparison
	aBool, aIsBool := a.(bool)
	bBool, bIsBool := b.(bool)
	if aIsBool && bIsBool {
		return aBool == bBool
	}

	// Default: use fmt.Sprint for comparison
	return fmt.Sprint(a) == fmt.Sprint(b)
}

// filterCompareNumeric compares two values numerically.
// Returns -1 if a < b, 0 if a == b, 1 if a > b.
func filterCompareNumeric(a, b any) int {
	aNum, aOk := filterToFloat64(a)
	bNum, bOk := filterToFloat64(b)

	if !aOk || !bOk {
		// Fall back to string comparison
		aStr := fmt.Sprint(a)
		bStr := fmt.Sprint(b)
		if aStr < bStr {
			return -1
		}
		if aStr > bStr {
			return 1
		}
		return 0
	}

	if aNum < bNum {
		return -1
	}
	if aNum > bNum {
		return 1
	}
	return 0
}

// filterMatchLike performs SQL LIKE-style pattern matching.
func filterMatchLike(val, pattern any) bool {
	valStr, _ := filterToString(val)
	patternStr, _ := filterToString(pattern)

	// Convert SQL LIKE pattern to regex
	// % matches any sequence, _ matches single character
	regexPattern := "^"
	for _, ch := range patternStr {
		switch ch {
		case '%':
			regexPattern += ".*"
		case '_':
			regexPattern += "."
		case '.', '+', '*', '?', '^', '$', '(', ')', '[', ']', '{', '}', '|', '\\':
			regexPattern += "\\" + string(ch)
		default:
			regexPattern += string(ch)
		}
	}
	regexPattern += "$"

	re, err := regexp.Compile("(?i)" + regexPattern) // Case-insensitive
	if err != nil {
		return false
	}

	return re.MatchString(valStr)
}

// filterToString converts a value to string.
func filterToString(v any) (string, bool) {
	switch val := v.(type) {
	case string:
		return val, true
	case []byte:
		return string(val), true
	default:
		return fmt.Sprint(v), true
	}
}

// filterToFloat64 converts a value to float64.
func filterToFloat64(v any) (float64, bool) {
	switch val := v.(type) {
	case float64:
		return val, true
	case float32:
		return float64(val), true
	case int:
		return float64(val), true
	case int64:
		return float64(val), true
	case int32:
		return float64(val), true
	case uint:
		return float64(val), true
	case uint64:
		return float64(val), true
	case uint32:
		return float64(val), true
	case string:
		var f float64
		if _, err := fmt.Sscanf(val, "%f", &f); err == nil {
			return f, true
		}
	}
	return 0, false
}
← Back