parser.go
// Package rules provides PocketBase-style rule expressions for collection access control.
package rules
import (
"fmt"
"strings"
)
// Token types for the lexer.
type tokenType int
const (
tokenEOF tokenType = iota
tokenLParen
tokenRParen
tokenAnd
tokenOr
tokenOp // =, !=, >, >=, <, <=, ~, !~, ?=, ?!=
tokenField // field name
tokenContext // @request.auth.id
tokenCollection // @collection.posts.userId
tokenString // 'string'
tokenNumber // 123, 123.45
tokenBool // true, false
tokenNull // null
)
type token struct {
typ tokenType
value string
}
// Parser parses rule expressions into an AST.
type Parser struct {
tokens []token
pos int
}
// Parse parses a rule expression string into an Expr AST.
func Parse(rule string) (Expr, error) {
tokens, err := tokenize(rule)
if err != nil {
return nil, err
}
p := &Parser{tokens: tokens, pos: 0}
return p.parseOr()
}
// tokenize breaks the rule string into tokens.
func tokenize(rule string) ([]token, error) {
var tokens []token
rule = strings.TrimSpace(rule)
i := 0
for i < len(rule) {
// Skip whitespace
for i < len(rule) && (rule[i] == ' ' || rule[i] == '\t' || rule[i] == '\n') {
i++
}
if i >= len(rule) {
break
}
ch := rule[i]
// Parentheses
if ch == '(' {
tokens = append(tokens, token{tokenLParen, "("})
i++
continue
}
if ch == ')' {
tokens = append(tokens, token{tokenRParen, ")"})
i++
continue
}
// Logical operators
if i+1 < len(rule) && rule[i:i+2] == "&&" {
tokens = append(tokens, token{tokenAnd, "&&"})
i += 2
continue
}
if i+1 < len(rule) && rule[i:i+2] == "||" {
tokens = append(tokens, token{tokenOr, "||"})
i += 2
continue
}
// Comparison operators (check longer ones first)
// Include ?= and ?!= for collection existence checks
foundOp := false
for _, op := range []string{"?!=", "?=", "!=", ">=", "<=", "!~", "=", ">", "<", "~"} {
if i+len(op) <= len(rule) && rule[i:i+len(op)] == op {
tokens = append(tokens, token{tokenOp, op})
i += len(op)
foundOp = true
break
}
}
if foundOp {
continue
}
// Quoted string
if ch == '\'' || ch == '"' {
quote := ch
end := i + 1
for end < len(rule) && rule[end] != quote {
if rule[end] == '\\' && end+1 < len(rule) {
end += 2 // Skip escaped char
} else {
end++
}
}
if end >= len(rule) {
return nil, fmt.Errorf("unterminated string at position %d", i)
}
// Extract string without quotes
str := rule[i+1 : end]
str = strings.ReplaceAll(str, "\\'", "'")
str = strings.ReplaceAll(str, "\\\"", "\"")
tokens = append(tokens, token{tokenString, str})
i = end + 1
continue
}
// Context variable (@request.*) or collection reference (@collection.*)
if ch == '@' {
end := i + 1
// Also allow colon for aliases like @collection.users:manager.id
for end < len(rule) && (isAlphaNum(rule[end]) || rule[end] == '.' || rule[end] == '_' || rule[end] == ':') {
end++
}
value := rule[i:end]
if strings.HasPrefix(value, "@collection.") {
tokens = append(tokens, token{tokenCollection, value})
} else {
tokens = append(tokens, token{tokenContext, value})
}
i = end
continue
}
// Number
if isDigit(ch) || (ch == '-' && i+1 < len(rule) && isDigit(rule[i+1])) {
end := i
if ch == '-' {
end++
}
for end < len(rule) && (isDigit(rule[end]) || rule[end] == '.') {
end++
}
tokens = append(tokens, token{tokenNumber, rule[i:end]})
i = end
continue
}
// Identifier (field name, bool, null)
// Also allows modifiers like field:isset, field:length, field:lower, field:each
if isAlpha(ch) || ch == '_' {
end := i
for end < len(rule) && (isAlphaNum(rule[end]) || rule[end] == '_' || rule[end] == ':') {
end++
}
word := rule[i:end]
// Check for bare keywords (without modifiers)
switch word {
case "true", "false":
tokens = append(tokens, token{tokenBool, word})
case "null", "nil":
tokens = append(tokens, token{tokenNull, word})
default:
tokens = append(tokens, token{tokenField, word})
}
i = end
continue
}
return nil, fmt.Errorf("unexpected character '%c' at position %d", ch, i)
}
tokens = append(tokens, token{tokenEOF, ""})
return tokens, nil
}
func isAlpha(ch byte) bool {
return (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z')
}
func isDigit(ch byte) bool {
return ch >= '0' && ch <= '9'
}
func isAlphaNum(ch byte) bool {
return isAlpha(ch) || isDigit(ch)
}
// --- Parser methods ---
func (p *Parser) current() token {
if p.pos >= len(p.tokens) {
return token{tokenEOF, ""}
}
return p.tokens[p.pos]
}
func (p *Parser) advance() token {
t := p.current()
p.pos++
return t
}
func (p *Parser) parseOr() (Expr, error) {
left, err := p.parseAnd()
if err != nil {
return nil, err
}
for p.current().typ == tokenOr {
p.advance() // consume ||
right, err := p.parseAnd()
if err != nil {
return nil, err
}
left = &BinaryExpr{Left: left, Op: "||", Right: right}
}
return left, nil
}
func (p *Parser) parseAnd() (Expr, error) {
left, err := p.parseComparison()
if err != nil {
return nil, err
}
for p.current().typ == tokenAnd {
p.advance() // consume &&
right, err := p.parseComparison()
if err != nil {
return nil, err
}
left = &BinaryExpr{Left: left, Op: "&&", Right: right}
}
return left, nil
}
func (p *Parser) parseComparison() (Expr, error) {
// Handle parentheses
if p.current().typ == tokenLParen {
p.advance() // consume (
expr, err := p.parseOr()
if err != nil {
return nil, err
}
if p.current().typ != tokenRParen {
return nil, fmt.Errorf("expected ')' but got %v", p.current())
}
p.advance() // consume )
return expr, nil
}
// Parse left value
left, err := p.parseValue()
if err != nil {
return nil, err
}
// Parse operator
if p.current().typ != tokenOp {
return nil, fmt.Errorf("expected comparison operator but got %v", p.current())
}
op := p.advance().value
// Parse right value
right, err := p.parseValue()
if err != nil {
return nil, err
}
return &CompareExpr{Left: left, Op: op, Right: right}, nil
}
func (p *Parser) parseValue() (Value, error) {
t := p.current()
switch t.typ {
case tokenString:
p.advance()
return &LiteralValue{Value: t.value}, nil
case tokenNumber:
p.advance()
var f float64
fmt.Sscanf(t.value, "%f", &f)
// Return int if it's a whole number
if f == float64(int64(f)) {
return &LiteralValue{Value: int64(f)}, nil
}
return &LiteralValue{Value: f}, nil
case tokenBool:
p.advance()
return &LiteralValue{Value: t.value == "true"}, nil
case tokenNull:
p.advance()
return &LiteralValue{Value: nil}, nil
case tokenField:
p.advance()
// Check for modifiers like field:isset, field:length, field:lower, field:each
name := t.value
var modifier string
if colonIdx := strings.Index(name, ":"); colonIdx > 0 {
modifier = name[colonIdx+1:]
name = name[:colonIdx]
}
return &FieldValue{Name: name, Modifier: modifier}, nil
case tokenContext:
p.advance()
// Parse @request.auth.id or @request.body.field:modifier into path ["auth", "id"]
path := strings.TrimPrefix(t.value, "@request.")
parts := strings.Split(path, ".")
// Check if last part has a modifier
var modifier string
if len(parts) > 0 {
lastPart := parts[len(parts)-1]
if colonIdx := strings.Index(lastPart, ":"); colonIdx > 0 {
modifier = lastPart[colonIdx+1:]
parts[len(parts)-1] = lastPart[:colonIdx]
}
}
return &ContextValue{Path: parts, Modifier: modifier}, nil
case tokenCollection:
p.advance()
// Parse @collection.posts.userId or @collection.users:alias.field
path := strings.TrimPrefix(t.value, "@collection.")
parts := strings.Split(path, ".")
if len(parts) < 2 {
return nil, fmt.Errorf("@collection reference must have at least collection and field: %s", t.value)
}
// First part may have an alias: collectionName:alias
collectionPart := parts[0]
var collectionName, alias string
if colonIdx := strings.Index(collectionPart, ":"); colonIdx > 0 {
collectionName = collectionPart[:colonIdx]
alias = collectionPart[colonIdx+1:]
} else {
collectionName = collectionPart
alias = ""
}
fieldPath := parts[1:]
return &CollectionValue{
Collection: collectionName,
Alias: alias,
FieldPath: fieldPath,
}, nil
default:
return nil, fmt.Errorf("expected value but got %v", t)
}
}