readysite / website / internal / assist / validate.go
5.9 KB
validate.go
package assist

import (
	"fmt"
	"strings"
)

// allowedSQLColumns are the document columns that can be used in filters.
var allowedSQLColumns = map[string]bool{
	"id":           true,
	"collectionid": true,
	"data":         true,
	"createdat":    true,
	"updatedat":    true,
}

// allowedSQLOperators are the valid SQL operators for filters.
var allowedSQLOperators = []string{
	"=", "!=", "<>", "<", ">", "<=", ">=",
	"like", "not like", "in", "not in", "is null", "is not null",
	"and", "or", "between",
}

// validateSQLFilter validates a SQL filter expression using a whitelist approach.
// It ensures only safe, read-only operations are allowed.
func validateSQLFilter(filter string) error {
	if filter == "" {
		return nil
	}

	lower := strings.ToLower(filter)

	// Block dangerous SQL patterns (must be whole words or special chars)
	dangerous := []string{
		" drop ", " delete ", " insert ", " update ", " alter ",
		" create ", " truncate ", " exec ", " execute ",
		" union ", " into ", " outfile ", " dumpfile ", " load_file ",
		"--", ";", "/*", "*/", "@@", "char(", "concat(",
		"benchmark(", "sleep(", "waitfor", "pg_sleep",
	}
	// Add spaces around for word matching
	padded := " " + lower + " "
	for _, d := range dangerous {
		if strings.Contains(padded, d) {
			return fmt.Errorf("filter contains forbidden SQL keyword or pattern: %s", strings.TrimSpace(d))
		}
	}

	// Validate that the filter only uses allowed column names
	// Extract potential identifiers (words not in quotes)
	tokens := tokenizeSQLFilter(filter)
	for _, token := range tokens {
		token = strings.ToLower(token)

		// Skip string literals and numbers
		if isQuotedString(token) || isNumber(token) {
			continue
		}

		// Skip operators and SQL keywords
		if isAllowedOperator(token) {
			continue
		}

		// Skip parentheses and common SQL literals
		if token == "(" || token == ")" || token == "," ||
			token == "true" || token == "false" || token == "null" {
			continue
		}

		// Check if it's an allowed column name
		if !allowedSQLColumns[token] {
			// Allow JSON path expressions for the Data column only
			// Valid: data, data->'field', data->>'field', Data->'field'
			lowerToken := strings.ToLower(token)
			if lowerToken == "data" || strings.HasPrefix(lowerToken, "data->") || strings.HasPrefix(lowerToken, "data->>") {
				continue
			}
			return fmt.Errorf("filter contains invalid column name: %s", token)
		}
	}

	return nil
}

// validateSQLOrderBy validates an ORDER BY clause.
func validateSQLOrderBy(orderBy string) error {
	if orderBy == "" {
		return nil
	}

	lower := strings.ToLower(orderBy)

	// Block dangerous patterns
	dangerous := []string{"--", ";", "/*", "*/", "(", ")"}
	for _, d := range dangerous {
		if strings.Contains(lower, d) {
			return fmt.Errorf("order_by contains forbidden pattern: %s", d)
		}
	}

	// Parse the ORDER BY clause
	parts := strings.Split(orderBy, ",")
	for _, part := range parts {
		part = strings.TrimSpace(part)
		tokens := strings.Fields(part)
		if len(tokens) == 0 {
			continue
		}

		// First token should be column name
		col := strings.ToLower(tokens[0])
		if !allowedSQLColumns[col] && !strings.HasPrefix(col, "data") {
			return fmt.Errorf("order_by contains invalid column: %s", col)
		}

		// Second token (if exists) should be ASC or DESC
		if len(tokens) > 1 {
			dir := strings.ToLower(tokens[1])
			if dir != "asc" && dir != "desc" {
				return fmt.Errorf("order_by contains invalid direction: %s", dir)
			}
		}

		// No more than 2 tokens per column
		if len(tokens) > 2 {
			return fmt.Errorf("order_by contains unexpected tokens")
		}
	}

	return nil
}

// tokenizeSQLFilter splits a SQL filter into tokens.
func tokenizeSQLFilter(s string) []string {
	var tokens []string
	var current strings.Builder
	inQuote := false
	quoteChar := byte(0)

	for i := 0; i < len(s); i++ {
		c := s[i]

		// Handle quotes
		if (c == '\'' || c == '"') && (i == 0 || s[i-1] != '\\') {
			if !inQuote {
				inQuote = true
				quoteChar = c
			} else if c == quoteChar {
				inQuote = false
				current.WriteByte(c)
				tokens = append(tokens, current.String())
				current.Reset()
				continue
			}
		}

		if inQuote {
			current.WriteByte(c)
			continue
		}

		// Handle operators and whitespace
		if c == ' ' || c == '\t' || c == '\n' || c == '\r' {
			if current.Len() > 0 {
				tokens = append(tokens, current.String())
				current.Reset()
			}
			continue
		}

		// Handle special characters
		if c == '(' || c == ')' || c == ',' {
			if current.Len() > 0 {
				tokens = append(tokens, current.String())
				current.Reset()
			}
			tokens = append(tokens, string(c))
			continue
		}

		// Handle comparison operators
		if c == '=' || c == '<' || c == '>' || c == '!' {
			if current.Len() > 0 {
				tokens = append(tokens, current.String())
				current.Reset()
			}
			// Look ahead for compound operators
			if i+1 < len(s) && (s[i+1] == '=' || s[i+1] == '>') {
				tokens = append(tokens, string([]byte{c, s[i+1]}))
				i++
			} else {
				tokens = append(tokens, string(c))
			}
			continue
		}

		current.WriteByte(c)
	}

	if current.Len() > 0 {
		tokens = append(tokens, current.String())
	}

	return tokens
}

// isQuotedString checks if a token is a quoted string literal.
func isQuotedString(s string) bool {
	return (strings.HasPrefix(s, "'") && strings.HasSuffix(s, "'")) ||
		(strings.HasPrefix(s, "\"") && strings.HasSuffix(s, "\""))
}

// isNumber checks if a token is a numeric literal.
func isNumber(s string) bool {
	if s == "" {
		return false
	}
	for _, c := range s {
		if c != '.' && c != '-' && (c < '0' || c > '9') {
			return false
		}
	}
	return true
}

// isAllowedOperator checks if a token is an allowed SQL operator.
func isAllowedOperator(s string) bool {
	lower := strings.ToLower(s)
	for _, op := range allowedSQLOperators {
		if lower == op {
			return true
		}
	}
	return false
}

// containsDangerousSQL is a simpler check for backward compatibility.
func containsDangerousSQL(s string) bool {
	return validateSQLFilter(s) != nil
}
← Back