readysite / website / internal / content / rules / parser.go
8.3 KB
parser.go
// Package rules provides PocketBase-style rule expressions for collection access control.
package rules

import (
	"fmt"
	"strings"
)

// Token types for the lexer.
type tokenType int

const (
	tokenEOF tokenType = iota
	tokenLParen
	tokenRParen
	tokenAnd
	tokenOr
	tokenOp         // =, !=, >, >=, <, <=, ~, !~, ?=, ?!=
	tokenField      // field name
	tokenContext    // @request.auth.id
	tokenCollection // @collection.posts.userId
	tokenString     // 'string'
	tokenNumber     // 123, 123.45
	tokenBool       // true, false
	tokenNull       // null
)

type token struct {
	typ   tokenType
	value string
}

// Parser parses rule expressions into an AST.
type Parser struct {
	tokens []token
	pos    int
}

// Parse parses a rule expression string into an Expr AST.
func Parse(rule string) (Expr, error) {
	tokens, err := tokenize(rule)
	if err != nil {
		return nil, err
	}

	p := &Parser{tokens: tokens, pos: 0}
	return p.parseOr()
}

// tokenize breaks the rule string into tokens.
func tokenize(rule string) ([]token, error) {
	var tokens []token
	rule = strings.TrimSpace(rule)
	i := 0

	for i < len(rule) {
		// Skip whitespace
		for i < len(rule) && (rule[i] == ' ' || rule[i] == '\t' || rule[i] == '\n') {
			i++
		}
		if i >= len(rule) {
			break
		}

		ch := rule[i]

		// Parentheses
		if ch == '(' {
			tokens = append(tokens, token{tokenLParen, "("})
			i++
			continue
		}
		if ch == ')' {
			tokens = append(tokens, token{tokenRParen, ")"})
			i++
			continue
		}

		// Logical operators
		if i+1 < len(rule) && rule[i:i+2] == "&&" {
			tokens = append(tokens, token{tokenAnd, "&&"})
			i += 2
			continue
		}
		if i+1 < len(rule) && rule[i:i+2] == "||" {
			tokens = append(tokens, token{tokenOr, "||"})
			i += 2
			continue
		}

		// Comparison operators (check longer ones first)
		// Include ?= and ?!= for collection existence checks
		foundOp := false
		for _, op := range []string{"?!=", "?=", "!=", ">=", "<=", "!~", "=", ">", "<", "~"} {
			if i+len(op) <= len(rule) && rule[i:i+len(op)] == op {
				tokens = append(tokens, token{tokenOp, op})
				i += len(op)
				foundOp = true
				break
			}
		}
		if foundOp {
			continue
		}

		// Quoted string
		if ch == '\'' || ch == '"' {
			quote := ch
			end := i + 1
			for end < len(rule) && rule[end] != quote {
				if rule[end] == '\\' && end+1 < len(rule) {
					end += 2 // Skip escaped char
				} else {
					end++
				}
			}
			if end >= len(rule) {
				return nil, fmt.Errorf("unterminated string at position %d", i)
			}
			// Extract string without quotes
			str := rule[i+1 : end]
			str = strings.ReplaceAll(str, "\\'", "'")
			str = strings.ReplaceAll(str, "\\\"", "\"")
			tokens = append(tokens, token{tokenString, str})
			i = end + 1
			continue
		}

		// Context variable (@request.*) or collection reference (@collection.*)
		if ch == '@' {
			end := i + 1
			// Also allow colon for aliases like @collection.users:manager.id
			for end < len(rule) && (isAlphaNum(rule[end]) || rule[end] == '.' || rule[end] == '_' || rule[end] == ':') {
				end++
			}
			value := rule[i:end]
			if strings.HasPrefix(value, "@collection.") {
				tokens = append(tokens, token{tokenCollection, value})
			} else {
				tokens = append(tokens, token{tokenContext, value})
			}
			i = end
			continue
		}

		// Number
		if isDigit(ch) || (ch == '-' && i+1 < len(rule) && isDigit(rule[i+1])) {
			end := i
			if ch == '-' {
				end++
			}
			for end < len(rule) && (isDigit(rule[end]) || rule[end] == '.') {
				end++
			}
			tokens = append(tokens, token{tokenNumber, rule[i:end]})
			i = end
			continue
		}

		// Identifier (field name, bool, null)
		// Also allows modifiers like field:isset, field:length, field:lower, field:each
		if isAlpha(ch) || ch == '_' {
			end := i
			for end < len(rule) && (isAlphaNum(rule[end]) || rule[end] == '_' || rule[end] == ':') {
				end++
			}
			word := rule[i:end]
			// Check for bare keywords (without modifiers)
			switch word {
			case "true", "false":
				tokens = append(tokens, token{tokenBool, word})
			case "null", "nil":
				tokens = append(tokens, token{tokenNull, word})
			default:
				tokens = append(tokens, token{tokenField, word})
			}
			i = end
			continue
		}

		return nil, fmt.Errorf("unexpected character '%c' at position %d", ch, i)
	}

	tokens = append(tokens, token{tokenEOF, ""})
	return tokens, nil
}

func isAlpha(ch byte) bool {
	return (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z')
}

func isDigit(ch byte) bool {
	return ch >= '0' && ch <= '9'
}

func isAlphaNum(ch byte) bool {
	return isAlpha(ch) || isDigit(ch)
}

// --- Parser methods ---

func (p *Parser) current() token {
	if p.pos >= len(p.tokens) {
		return token{tokenEOF, ""}
	}
	return p.tokens[p.pos]
}

func (p *Parser) advance() token {
	t := p.current()
	p.pos++
	return t
}

func (p *Parser) parseOr() (Expr, error) {
	left, err := p.parseAnd()
	if err != nil {
		return nil, err
	}

	for p.current().typ == tokenOr {
		p.advance() // consume ||
		right, err := p.parseAnd()
		if err != nil {
			return nil, err
		}
		left = &BinaryExpr{Left: left, Op: "||", Right: right}
	}

	return left, nil
}

func (p *Parser) parseAnd() (Expr, error) {
	left, err := p.parseComparison()
	if err != nil {
		return nil, err
	}

	for p.current().typ == tokenAnd {
		p.advance() // consume &&
		right, err := p.parseComparison()
		if err != nil {
			return nil, err
		}
		left = &BinaryExpr{Left: left, Op: "&&", Right: right}
	}

	return left, nil
}

func (p *Parser) parseComparison() (Expr, error) {
	// Handle parentheses
	if p.current().typ == tokenLParen {
		p.advance() // consume (
		expr, err := p.parseOr()
		if err != nil {
			return nil, err
		}
		if p.current().typ != tokenRParen {
			return nil, fmt.Errorf("expected ')' but got %v", p.current())
		}
		p.advance() // consume )
		return expr, nil
	}

	// Parse left value
	left, err := p.parseValue()
	if err != nil {
		return nil, err
	}

	// Parse operator
	if p.current().typ != tokenOp {
		return nil, fmt.Errorf("expected comparison operator but got %v", p.current())
	}
	op := p.advance().value

	// Parse right value
	right, err := p.parseValue()
	if err != nil {
		return nil, err
	}

	return &CompareExpr{Left: left, Op: op, Right: right}, nil
}

func (p *Parser) parseValue() (Value, error) {
	t := p.current()

	switch t.typ {
	case tokenString:
		p.advance()
		return &LiteralValue{Value: t.value}, nil

	case tokenNumber:
		p.advance()
		var f float64
		fmt.Sscanf(t.value, "%f", &f)
		// Return int if it's a whole number
		if f == float64(int64(f)) {
			return &LiteralValue{Value: int64(f)}, nil
		}
		return &LiteralValue{Value: f}, nil

	case tokenBool:
		p.advance()
		return &LiteralValue{Value: t.value == "true"}, nil

	case tokenNull:
		p.advance()
		return &LiteralValue{Value: nil}, nil

	case tokenField:
		p.advance()
		// Check for modifiers like field:isset, field:length, field:lower, field:each
		name := t.value
		var modifier string
		if colonIdx := strings.Index(name, ":"); colonIdx > 0 {
			modifier = name[colonIdx+1:]
			name = name[:colonIdx]
		}
		return &FieldValue{Name: name, Modifier: modifier}, nil

	case tokenContext:
		p.advance()
		// Parse @request.auth.id or @request.body.field:modifier into path ["auth", "id"]
		path := strings.TrimPrefix(t.value, "@request.")
		parts := strings.Split(path, ".")
		// Check if last part has a modifier
		var modifier string
		if len(parts) > 0 {
			lastPart := parts[len(parts)-1]
			if colonIdx := strings.Index(lastPart, ":"); colonIdx > 0 {
				modifier = lastPart[colonIdx+1:]
				parts[len(parts)-1] = lastPart[:colonIdx]
			}
		}
		return &ContextValue{Path: parts, Modifier: modifier}, nil

	case tokenCollection:
		p.advance()
		// Parse @collection.posts.userId or @collection.users:alias.field
		path := strings.TrimPrefix(t.value, "@collection.")
		parts := strings.Split(path, ".")
		if len(parts) < 2 {
			return nil, fmt.Errorf("@collection reference must have at least collection and field: %s", t.value)
		}
		// First part may have an alias: collectionName:alias
		collectionPart := parts[0]
		var collectionName, alias string
		if colonIdx := strings.Index(collectionPart, ":"); colonIdx > 0 {
			collectionName = collectionPart[:colonIdx]
			alias = collectionPart[colonIdx+1:]
		} else {
			collectionName = collectionPart
			alias = ""
		}
		fieldPath := parts[1:]
		return &CollectionValue{
			Collection: collectionName,
			Alias:      alias,
			FieldPath:  fieldPath,
		}, nil

	default:
		return nil, fmt.Errorf("expected value but got %v", t)
	}
}
← Back