filter.go
// Package query provides filter parsing, expansion, and view execution.
package query
import (
"fmt"
"regexp"
"strings"
)
// Comparison operators for API filters
var apiFilterOperators = map[string]string{
"=": "=",
"!=": "!=",
">": ">",
">=": ">=",
"<": "<",
"<=": "<=",
"~": "LIKE", // Contains/like
"!~": "NOT LIKE",
}
// validAPIFieldName matches alphanumeric field names with underscores
var validAPIFieldName = regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]*$`)
// Token types for API filter lexer
type apiFilterTokenType int
const (
apiFilterTokenField apiFilterTokenType = iota
apiFilterTokenOperator
apiFilterTokenValue
apiFilterTokenAnd
apiFilterTokenOr
apiFilterTokenLParen
apiFilterTokenRParen
apiFilterTokenEOF
)
type apiFilterToken struct {
typ apiFilterTokenType
value string
}
// APIFilterResult contains the parsed WHERE clause and parameters.
type APIFilterResult struct {
Where string
Params []any
}
// ParseFilter parses a filter expression and returns a parameterized WHERE clause.
// Example: "status = 'published' && views > 100"
// Returns: WHERE clause with ? placeholders and slice of parameter values.
//
// Supported operators: = != > >= < <= ~ (LIKE) !~ (NOT LIKE)
// Supported connectors: && (AND) || (OR)
// Grouping: ( )
//
// Values must be quoted with single quotes for strings: name = 'John'
// Numbers are unquoted: age > 25
// Booleans: active = true, active = false
//
// For LIKE operator, % wildcards are NOT automatically added.
// Use: title ~ '%hello%' for contains, title ~ 'hello%' for starts with.
//
// Field names must exist in the provided allowedFields list.
// If allowedFields is nil, all field names are allowed (dangerous!).
func ParseFilter(expr string, allowedFields []string) (*APIFilterResult, error) {
if expr == "" {
return &APIFilterResult{Where: "", Params: nil}, nil
}
// Build allowed fields map
fieldSet := make(map[string]bool)
if allowedFields != nil {
for _, f := range allowedFields {
fieldSet[f] = true
}
}
// Tokenize
tokens, err := tokenizeAPIFilter(expr)
if err != nil {
return nil, err
}
// Parse tokens into WHERE clause
return parseAPIFilterTokens(tokens, fieldSet, allowedFields == nil)
}
// tokenizeAPIFilter breaks the expression into tokens.
func tokenizeAPIFilter(expr string) ([]apiFilterToken, error) {
var tokens []apiFilterToken
expr = strings.TrimSpace(expr)
i := 0
for i < len(expr) {
// Skip whitespace
for i < len(expr) && (expr[i] == ' ' || expr[i] == '\t') {
i++
}
if i >= len(expr) {
break
}
ch := expr[i]
// Parentheses
if ch == '(' {
tokens = append(tokens, apiFilterToken{apiFilterTokenLParen, "("})
i++
continue
}
if ch == ')' {
tokens = append(tokens, apiFilterToken{apiFilterTokenRParen, ")"})
i++
continue
}
// Logical operators
if i+1 < len(expr) && expr[i:i+2] == "&&" {
tokens = append(tokens, apiFilterToken{apiFilterTokenAnd, "&&"})
i += 2
continue
}
if i+1 < len(expr) && expr[i:i+2] == "||" {
tokens = append(tokens, apiFilterToken{apiFilterTokenOr, "||"})
i += 2
continue
}
// Comparison operators (check longer ones first)
foundOp := false
for _, op := range []string{"!=", ">=", "<=", "!~", "=", ">", "<", "~"} {
if i+len(op) <= len(expr) && expr[i:i+len(op)] == op {
tokens = append(tokens, apiFilterToken{apiFilterTokenOperator, op})
i += len(op)
foundOp = true
break
}
}
if foundOp {
continue
}
// Quoted string value
if ch == '\'' {
end := i + 1
for end < len(expr) && expr[end] != '\'' {
if expr[end] == '\\' && end+1 < len(expr) {
end += 2 // Skip escaped char
} else {
end++
}
}
if end >= len(expr) {
return nil, fmt.Errorf("unterminated string at position %d", i)
}
// Include quotes in value for later processing
tokens = append(tokens, apiFilterToken{apiFilterTokenValue, expr[i : end+1]})
i = end + 1
continue
}
// Unquoted value (number, boolean, field name)
end := i
for end < len(expr) && !strings.ContainsAny(string(expr[end]), " \t()&|=!<>~'") {
end++
}
if end > i {
word := expr[i:end]
// Determine if it's a field or value based on context
// Fields come before operators, values come after
if len(tokens) > 0 && tokens[len(tokens)-1].typ == apiFilterTokenOperator {
tokens = append(tokens, apiFilterToken{apiFilterTokenValue, word})
} else {
tokens = append(tokens, apiFilterToken{apiFilterTokenField, word})
}
i = end
continue
}
return nil, fmt.Errorf("unexpected character '%c' at position %d", ch, i)
}
tokens = append(tokens, apiFilterToken{apiFilterTokenEOF, ""})
return tokens, nil
}
// parseAPIFilterTokens converts tokens to a WHERE clause with parameters.
func parseAPIFilterTokens(tokens []apiFilterToken, allowedFields map[string]bool, allowAll bool) (*APIFilterResult, error) {
var where strings.Builder
var params []any
i := 0
expectField := true
parenDepth := 0
for i < len(tokens) && tokens[i].typ != apiFilterTokenEOF {
t := tokens[i]
switch t.typ {
case apiFilterTokenLParen:
where.WriteString("(")
parenDepth++
expectField = true
i++
case apiFilterTokenRParen:
if parenDepth == 0 {
return nil, fmt.Errorf("unmatched closing parenthesis")
}
where.WriteString(")")
parenDepth--
expectField = false
i++
case apiFilterTokenAnd:
where.WriteString(" AND ")
expectField = true
i++
case apiFilterTokenOr:
where.WriteString(" OR ")
expectField = true
i++
case apiFilterTokenField:
if !expectField {
return nil, fmt.Errorf("unexpected field '%s'", t.value)
}
// Validate field name
if !validAPIFieldName.MatchString(t.value) {
return nil, fmt.Errorf("invalid field name '%s'", t.value)
}
if !allowAll && !allowedFields[t.value] {
return nil, fmt.Errorf("field '%s' is not allowed", t.value)
}
// Expect: field operator value
if i+2 >= len(tokens) || tokens[i+1].typ != apiFilterTokenOperator || tokens[i+2].typ != apiFilterTokenValue {
return nil, fmt.Errorf("expected 'field operator value' pattern")
}
field := t.value
op := tokens[i+1].value
val := tokens[i+2].value
sqlOp, ok := apiFilterOperators[op]
if !ok {
return nil, fmt.Errorf("unknown operator '%s'", op)
}
// Parse value
parsedVal, err := parseAPIFilterValue(val)
if err != nil {
return nil, err
}
// For JSON fields in SQLite, use json_extract
// Document.Data is JSON, so we query: json_extract(Data, '$.fieldname')
where.WriteString(fmt.Sprintf("json_extract(Data, '$.%s') %s ?", field, sqlOp))
params = append(params, parsedVal)
expectField = false
i += 3
default:
return nil, fmt.Errorf("unexpected token: %v", t)
}
}
if parenDepth != 0 {
return nil, fmt.Errorf("unmatched opening parenthesis")
}
return &APIFilterResult{
Where: where.String(),
Params: params,
}, nil
}
// parseAPIFilterValue parses a value token into a Go value.
func parseAPIFilterValue(val string) (any, error) {
// Quoted string
if strings.HasPrefix(val, "'") && strings.HasSuffix(val, "'") {
// Remove quotes and unescape
s := val[1 : len(val)-1]
s = strings.ReplaceAll(s, "\\'", "'")
return s, nil
}
// Boolean
if val == "true" {
return true, nil
}
if val == "false" {
return false, nil
}
// Null
if val == "null" {
return nil, nil
}
// Number - try to parse as float64
var f float64
if _, err := fmt.Sscanf(val, "%f", &f); err == nil {
// Check if it's an integer
if f == float64(int64(f)) {
return int64(f), nil
}
return f, nil
}
return nil, fmt.Errorf("invalid value '%s'", val)
}
// ParseSort parses a sort expression and returns an ORDER BY clause.
// Example: "-created,title" -> "json_extract(Data, '$.created') DESC, json_extract(Data, '$.title') ASC"
// Prefix with - for descending order.
func ParseSort(expr string, allowedFields []string) (string, error) {
if expr == "" {
return "", nil
}
// Build allowed fields map
fieldSet := make(map[string]bool)
if allowedFields != nil {
for _, f := range allowedFields {
fieldSet[f] = true
}
}
parts := strings.Split(expr, ",")
var orderParts []string
for _, part := range parts {
part = strings.TrimSpace(part)
if part == "" {
continue
}
desc := false
if strings.HasPrefix(part, "-") {
desc = true
part = part[1:]
}
// Validate field name
if !validAPIFieldName.MatchString(part) {
return "", fmt.Errorf("invalid sort field '%s'", part)
}
// Check against allowed fields
if allowedFields != nil && !fieldSet[part] {
// Allow system fields: id, created, updated
if part != "id" && part != "created" && part != "updated" {
return "", fmt.Errorf("sort field '%s' is not allowed", part)
}
}
// Map to actual columns
var column string
switch part {
case "id":
column = "ID"
case "created":
column = "CreatedAt"
case "updated":
column = "UpdatedAt"
default:
column = fmt.Sprintf("json_extract(Data, '$.%s')", part)
}
direction := "ASC"
if desc {
direction = "DESC"
}
orderParts = append(orderParts, fmt.Sprintf("%s %s", column, direction))
}
if len(orderParts) == 0 {
return "", nil
}
return strings.Join(orderParts, ", "), nil
}
// FilterFields filters a map to only include the specified fields.
// If fields is empty, returns the original map unchanged.
func FilterFields(data map[string]any, fields []string) map[string]any {
if len(fields) == 0 {
return data
}
// Build field set
fieldSet := make(map[string]bool)
for _, f := range fields {
fieldSet[strings.TrimSpace(f)] = true
}
// System fields always included
fieldSet["id"] = true
fieldSet["collectionId"] = true
fieldSet["collectionName"] = true
fieldSet["created"] = true
fieldSet["updated"] = true
result := make(map[string]any)
for k, v := range data {
if fieldSet[k] {
result[k] = v
}
}
return result
}
// ParseFields parses a comma-separated fields parameter.
func ParseFields(expr string) []string {
if expr == "" {
return nil
}
parts := strings.Split(expr, ",")
var fields []string
for _, p := range parts {
p = strings.TrimSpace(p)
if p != "" {
fields = append(fields, p)
}
}
return fields
}
// ParseExpand parses a comma-separated expand parameter.
// Supports nested expansion with dots: "author,comments.author"
// Returns a list of expand paths.
func ParseExpand(expr string) [][]string {
if expr == "" {
return nil
}
parts := strings.Split(expr, ",")
var paths [][]string
for _, p := range parts {
p = strings.TrimSpace(p)
if p == "" {
continue
}
// Split by dot for nested paths
segments := strings.Split(p, ".")
var path []string
for _, seg := range segments {
seg = strings.TrimSpace(seg)
if seg != "" && validAPIFieldName.MatchString(seg) {
path = append(path, seg)
}
}
if len(path) > 0 {
paths = append(paths, path)
}
}
return paths
}
// MatchesFilter evaluates an APIFilterResult filter against a record in memory.
// Returns true if the record matches the filter, false otherwise.
// This is used for real-time event filtering where we can't use SQL.
func MatchesFilter(record map[string]any, filter *APIFilterResult) bool {
if filter == nil || filter.Where == "" {
return true
}
// Re-parse the original expression to evaluate against the record
// We need to walk through conditions and evaluate them
// For now, use a simpler approach: extract conditions and check them
// Note: This is a simplified implementation that handles common cases.
// For complex expressions with nested parentheses, the SQL evaluation
// in the database is more reliable.
return matchesFilterSimple(record, filter)
}
// matchesFilterSimple evaluates a parsed filter against a record.
// It evaluates each condition in the filter parameters against the record fields.
func matchesFilterSimple(record map[string]any, filter *APIFilterResult) bool {
// The filter.Where contains placeholders like "json_extract(Data, '$.field') = ?"
// and filter.Params contains the values.
//
// We need to extract field names and operators from the WHERE clause
// and evaluate them against the record.
where := filter.Where
params := filter.Params
// Simple state machine to handle AND/OR
// This handles expressions like: "A AND B" or "A OR B" but not nested parens well
// For real-time filtering, this covers most use cases.
// Split by AND/OR while tracking which connector
type condition struct {
fieldExpr string
op string
paramIdx int
}
// Extract conditions from WHERE clause
// Pattern: json_extract(Data, '$.fieldname') OP ?
condPattern := regexp.MustCompile(`json_extract\(Data, '\$\.([^']+)'\)\s*(=|!=|>|>=|<|<=|LIKE|NOT LIKE)\s*\?`)
matches := condPattern.FindAllStringSubmatchIndex(where, -1)
if len(matches) == 0 {
// No recognizable conditions, assume match
return true
}
// Evaluate each condition
paramIdx := 0
allResults := make([]bool, 0, len(matches))
for _, match := range matches {
if paramIdx >= len(params) {
break
}
// Extract field name and operator
fieldStart, fieldEnd := match[2], match[3]
opStart, opEnd := match[4], match[5]
fieldName := where[fieldStart:fieldEnd]
op := where[opStart:opEnd]
// Get value from record
recordVal := record[fieldName]
paramVal := params[paramIdx]
paramIdx++
// Evaluate condition
result := evaluateFilterCondition(recordVal, op, paramVal)
allResults = append(allResults, result)
}
// Determine how to combine results based on presence of AND/OR
// Simple heuristic: if ANY OR is present, use OR logic; otherwise AND
if strings.Contains(where, " OR ") {
// At least one must be true
for _, r := range allResults {
if r {
return true
}
}
return false
}
// Default: all must be true (AND)
for _, r := range allResults {
if !r {
return false
}
}
return true
}
// evaluateFilterCondition evaluates a single condition.
func evaluateFilterCondition(recordVal any, op string, paramVal any) bool {
// Handle nil record value
if recordVal == nil {
if op == "=" && paramVal == nil {
return true
}
if op == "!=" && paramVal != nil {
return true
}
return false
}
// Convert values to comparable types
switch op {
case "=":
return filterCompareEqual(recordVal, paramVal)
case "!=":
return !filterCompareEqual(recordVal, paramVal)
case ">":
return filterCompareNumeric(recordVal, paramVal) > 0
case ">=":
return filterCompareNumeric(recordVal, paramVal) >= 0
case "<":
return filterCompareNumeric(recordVal, paramVal) < 0
case "<=":
return filterCompareNumeric(recordVal, paramVal) <= 0
case "LIKE":
return filterMatchLike(recordVal, paramVal)
case "NOT LIKE":
return !filterMatchLike(recordVal, paramVal)
}
return false
}
// filterCompareEqual checks if two values are equal.
func filterCompareEqual(a, b any) bool {
// String comparison
aStr, aIsStr := filterToString(a)
bStr, bIsStr := filterToString(b)
if aIsStr && bIsStr {
return aStr == bStr
}
// Numeric comparison
aNum, aIsNum := filterToFloat64(a)
bNum, bIsNum := filterToFloat64(b)
if aIsNum && bIsNum {
return aNum == bNum
}
// Boolean comparison
aBool, aIsBool := a.(bool)
bBool, bIsBool := b.(bool)
if aIsBool && bIsBool {
return aBool == bBool
}
// Default: use fmt.Sprint for comparison
return fmt.Sprint(a) == fmt.Sprint(b)
}
// filterCompareNumeric compares two values numerically.
// Returns -1 if a < b, 0 if a == b, 1 if a > b.
func filterCompareNumeric(a, b any) int {
aNum, aOk := filterToFloat64(a)
bNum, bOk := filterToFloat64(b)
if !aOk || !bOk {
// Fall back to string comparison
aStr := fmt.Sprint(a)
bStr := fmt.Sprint(b)
if aStr < bStr {
return -1
}
if aStr > bStr {
return 1
}
return 0
}
if aNum < bNum {
return -1
}
if aNum > bNum {
return 1
}
return 0
}
// filterMatchLike performs SQL LIKE-style pattern matching.
func filterMatchLike(val, pattern any) bool {
valStr, _ := filterToString(val)
patternStr, _ := filterToString(pattern)
// Convert SQL LIKE pattern to regex
// % matches any sequence, _ matches single character
regexPattern := "^"
for _, ch := range patternStr {
switch ch {
case '%':
regexPattern += ".*"
case '_':
regexPattern += "."
case '.', '+', '*', '?', '^', '$', '(', ')', '[', ']', '{', '}', '|', '\\':
regexPattern += "\\" + string(ch)
default:
regexPattern += string(ch)
}
}
regexPattern += "$"
re, err := regexp.Compile("(?i)" + regexPattern) // Case-insensitive
if err != nil {
return false
}
return re.MatchString(valStr)
}
// filterToString converts a value to string.
func filterToString(v any) (string, bool) {
switch val := v.(type) {
case string:
return val, true
case []byte:
return string(val), true
default:
return fmt.Sprint(v), true
}
}
// filterToFloat64 converts a value to float64.
func filterToFloat64(v any) (float64, bool) {
switch val := v.(type) {
case float64:
return val, true
case float32:
return float64(val), true
case int:
return float64(val), true
case int64:
return float64(val), true
case int32:
return float64(val), true
case uint:
return float64(val), true
case uint64:
return float64(val), true
case uint32:
return float64(val), true
case string:
var f float64
if _, err := fmt.Sscanf(val, "%f", &f); err == nil {
return f, true
}
}
return 0, false
}