validate.go
package assist
import (
"fmt"
"strings"
)
// allowedSQLColumns are the document columns that can be used in filters.
var allowedSQLColumns = map[string]bool{
"id": true,
"collectionid": true,
"data": true,
"createdat": true,
"updatedat": true,
}
// allowedSQLOperators are the valid SQL operators for filters.
var allowedSQLOperators = []string{
"=", "!=", "<>", "<", ">", "<=", ">=",
"like", "not like", "in", "not in", "is null", "is not null",
"and", "or", "between",
}
// validateSQLFilter validates a SQL filter expression using a whitelist approach.
// It ensures only safe, read-only operations are allowed.
func validateSQLFilter(filter string) error {
if filter == "" {
return nil
}
lower := strings.ToLower(filter)
// Block dangerous SQL patterns (must be whole words or special chars)
dangerous := []string{
" drop ", " delete ", " insert ", " update ", " alter ",
" create ", " truncate ", " exec ", " execute ",
" union ", " into ", " outfile ", " dumpfile ", " load_file ",
"--", ";", "/*", "*/", "@@", "char(", "concat(",
"benchmark(", "sleep(", "waitfor", "pg_sleep",
}
// Add spaces around for word matching
padded := " " + lower + " "
for _, d := range dangerous {
if strings.Contains(padded, d) {
return fmt.Errorf("filter contains forbidden SQL keyword or pattern: %s", strings.TrimSpace(d))
}
}
// Validate that the filter only uses allowed column names
// Extract potential identifiers (words not in quotes)
tokens := tokenizeSQLFilter(filter)
for _, token := range tokens {
token = strings.ToLower(token)
// Skip string literals and numbers
if isQuotedString(token) || isNumber(token) {
continue
}
// Skip operators and SQL keywords
if isAllowedOperator(token) {
continue
}
// Skip parentheses and common SQL literals
if token == "(" || token == ")" || token == "," ||
token == "true" || token == "false" || token == "null" {
continue
}
// Check if it's an allowed column name
if !allowedSQLColumns[token] {
// Allow JSON path expressions for the Data column only
// Valid: data, data->'field', data->>'field', Data->'field'
lowerToken := strings.ToLower(token)
if lowerToken == "data" || strings.HasPrefix(lowerToken, "data->") || strings.HasPrefix(lowerToken, "data->>") {
continue
}
return fmt.Errorf("filter contains invalid column name: %s", token)
}
}
return nil
}
// validateSQLOrderBy validates an ORDER BY clause.
func validateSQLOrderBy(orderBy string) error {
if orderBy == "" {
return nil
}
lower := strings.ToLower(orderBy)
// Block dangerous patterns
dangerous := []string{"--", ";", "/*", "*/", "(", ")"}
for _, d := range dangerous {
if strings.Contains(lower, d) {
return fmt.Errorf("order_by contains forbidden pattern: %s", d)
}
}
// Parse the ORDER BY clause
parts := strings.Split(orderBy, ",")
for _, part := range parts {
part = strings.TrimSpace(part)
tokens := strings.Fields(part)
if len(tokens) == 0 {
continue
}
// First token should be column name
col := strings.ToLower(tokens[0])
if !allowedSQLColumns[col] && !strings.HasPrefix(col, "data") {
return fmt.Errorf("order_by contains invalid column: %s", col)
}
// Second token (if exists) should be ASC or DESC
if len(tokens) > 1 {
dir := strings.ToLower(tokens[1])
if dir != "asc" && dir != "desc" {
return fmt.Errorf("order_by contains invalid direction: %s", dir)
}
}
// No more than 2 tokens per column
if len(tokens) > 2 {
return fmt.Errorf("order_by contains unexpected tokens")
}
}
return nil
}
// tokenizeSQLFilter splits a SQL filter into tokens.
func tokenizeSQLFilter(s string) []string {
var tokens []string
var current strings.Builder
inQuote := false
quoteChar := byte(0)
for i := 0; i < len(s); i++ {
c := s[i]
// Handle quotes
if (c == '\'' || c == '"') && (i == 0 || s[i-1] != '\\') {
if !inQuote {
inQuote = true
quoteChar = c
} else if c == quoteChar {
inQuote = false
current.WriteByte(c)
tokens = append(tokens, current.String())
current.Reset()
continue
}
}
if inQuote {
current.WriteByte(c)
continue
}
// Handle operators and whitespace
if c == ' ' || c == '\t' || c == '\n' || c == '\r' {
if current.Len() > 0 {
tokens = append(tokens, current.String())
current.Reset()
}
continue
}
// Handle special characters
if c == '(' || c == ')' || c == ',' {
if current.Len() > 0 {
tokens = append(tokens, current.String())
current.Reset()
}
tokens = append(tokens, string(c))
continue
}
// Handle comparison operators
if c == '=' || c == '<' || c == '>' || c == '!' {
if current.Len() > 0 {
tokens = append(tokens, current.String())
current.Reset()
}
// Look ahead for compound operators
if i+1 < len(s) && (s[i+1] == '=' || s[i+1] == '>') {
tokens = append(tokens, string([]byte{c, s[i+1]}))
i++
} else {
tokens = append(tokens, string(c))
}
continue
}
current.WriteByte(c)
}
if current.Len() > 0 {
tokens = append(tokens, current.String())
}
return tokens
}
// isQuotedString checks if a token is a quoted string literal.
func isQuotedString(s string) bool {
return (strings.HasPrefix(s, "'") && strings.HasSuffix(s, "'")) ||
(strings.HasPrefix(s, "\"") && strings.HasSuffix(s, "\""))
}
// isNumber checks if a token is a numeric literal.
func isNumber(s string) bool {
if s == "" {
return false
}
for _, c := range s {
if c != '.' && c != '-' && (c < '0' || c > '9') {
return false
}
}
return true
}
// isAllowedOperator checks if a token is an allowed SQL operator.
func isAllowedOperator(s string) bool {
lower := strings.ToLower(s)
for _, op := range allowedSQLOperators {
if lower == op {
return true
}
}
return false
}
// containsDangerousSQL is a simpler check for backward compatibility.
func containsDangerousSQL(s string) bool {
return validateSQLFilter(s) != nil
}