readysite / website / internal / access / ratelimit.go
6.3 KB
ratelimit.go
// Configurable rate limiting.
package access

import (
	"encoding/json"
	"net/http"
	"strconv"
	"sync"
	"time"
)

// RateLimitConfig represents a configurable rate limit.
type RateLimitConfig struct {
	Requests int    `json:"requests"` // Requests allowed per window
	Window   string `json:"window"`   // Time window (e.g., "1m", "1h")
	Burst    int    `json:"burst"`    // Burst capacity (max tokens)
}

// ParseRateLimitConfig parses a JSON rate limit config string.
// Returns nil if the string is empty or invalid.
func ParseRateLimitConfig(configJSON string) *RateLimitConfig {
	if configJSON == "" {
		return nil
	}

	var config RateLimitConfig
	if err := json.Unmarshal([]byte(configJSON), &config); err != nil {
		return nil
	}

	// Validate and set defaults
	if config.Requests <= 0 {
		return nil
	}
	if config.Window == "" {
		config.Window = "1m"
	}
	if config.Burst <= 0 {
		config.Burst = config.Requests
	}

	return &config
}

// WindowDuration parses the window string to a time.Duration.
func (c *RateLimitConfig) WindowDuration() time.Duration {
	if c == nil || c.Window == "" {
		return time.Minute
	}

	// Parse duration string (e.g., "1m", "30s", "1h")
	d, err := time.ParseDuration(c.Window)
	if err != nil {
		// Try parsing as just a number (assume minutes)
		if mins, err := strconv.Atoi(c.Window); err == nil {
			return time.Duration(mins) * time.Minute
		}
		return time.Minute
	}
	return d
}

// Rate returns the token rate (tokens per second).
func (c *RateLimitConfig) Rate() float64 {
	if c == nil {
		return 1.0
	}
	window := c.WindowDuration()
	return float64(c.Requests) / window.Seconds()
}

// ConfigurableLimiter manages per-resource rate limiters.
type ConfigurableLimiter struct {
	mu            sync.RWMutex
	limiters      map[string]*resourceLimiter
	defaultConfig *RateLimitConfig
	authMultiply  float64 // Multiplier for authenticated users
	cleanupTicker *time.Ticker
	stopCleanup   chan struct{}
}

// resourceLimiter holds a limiter for a specific resource+client combination.
type resourceLimiter struct {
	buckets  map[string]*bucket // keyed by client IP
	config   *RateLimitConfig
	mu       sync.Mutex
	lastUsed time.Time
}

// NewConfigurableLimiter creates a new configurable rate limiter.
func NewConfigurableLimiter(defaultRequests int, defaultWindow string, authMultiplier float64) *ConfigurableLimiter {
	cl := &ConfigurableLimiter{
		limiters: make(map[string]*resourceLimiter),
		defaultConfig: &RateLimitConfig{
			Requests: defaultRequests,
			Window:   defaultWindow,
			Burst:    defaultRequests,
		},
		authMultiply:  authMultiplier,
		stopCleanup:   make(chan struct{}),
		cleanupTicker: time.NewTicker(5 * time.Minute),
	}

	// Start cleanup goroutine
	go cl.cleanupLoop()

	return cl
}

// cleanupLoop periodically removes unused limiters.
func (cl *ConfigurableLimiter) cleanupLoop() {
	for {
		select {
		case <-cl.cleanupTicker.C:
			cl.cleanup(10 * time.Minute)
		case <-cl.stopCleanup:
			cl.cleanupTicker.Stop()
			return
		}
	}
}

// Stop stops the cleanup goroutine.
func (cl *ConfigurableLimiter) Stop() {
	close(cl.stopCleanup)
}

// cleanup removes limiters that haven't been used recently.
func (cl *ConfigurableLimiter) cleanup(maxAge time.Duration) {
	cl.mu.Lock()
	defer cl.mu.Unlock()

	now := time.Now()
	for key, rl := range cl.limiters {
		if now.Sub(rl.lastUsed) > maxAge {
			delete(cl.limiters, key)
		}
	}
}

// getLimiter gets or creates a limiter for a resource.
func (cl *ConfigurableLimiter) getLimiter(resourceKey string, config *RateLimitConfig) *resourceLimiter {
	cl.mu.RLock()
	rl, exists := cl.limiters[resourceKey]
	cl.mu.RUnlock()

	if exists {
		rl.lastUsed = time.Now()
		return rl
	}

	// Create new limiter
	cl.mu.Lock()
	defer cl.mu.Unlock()

	// Double-check after acquiring write lock
	if rl, exists = cl.limiters[resourceKey]; exists {
		rl.lastUsed = time.Now()
		return rl
	}

	if config == nil {
		config = cl.defaultConfig
	}

	rl = &resourceLimiter{
		buckets:  make(map[string]*bucket),
		config:   config,
		lastUsed: time.Now(),
	}
	cl.limiters[resourceKey] = rl
	return rl
}

// CheckResult contains the result of a rate limit check.
type CheckResult struct {
	Allowed   bool
	Limit     int
	Remaining int
	ResetAt   time.Time
}

// Check checks if a request is allowed and returns rate limit info.
func (cl *ConfigurableLimiter) Check(resourceType, resourceID, clientIP string, config *RateLimitConfig, isAuthenticated bool) CheckResult {
	resourceKey := resourceType + ":" + resourceID
	rl := cl.getLimiter(resourceKey, config)

	effectiveConfig := rl.config
	if effectiveConfig == nil {
		effectiveConfig = cl.defaultConfig
	}

	rl.mu.Lock()
	defer rl.mu.Unlock()

	now := time.Now()
	b, ok := rl.buckets[clientIP]
	if !ok {
		b = &bucket{tokens: float64(effectiveConfig.Burst), lastRefill: now}
		rl.buckets[clientIP] = b
	}

	// Calculate effective rate (higher for authenticated users)
	rate := effectiveConfig.Rate()
	capacity := float64(effectiveConfig.Burst)
	if isAuthenticated && cl.authMultiply > 1.0 {
		rate *= cl.authMultiply
		capacity *= cl.authMultiply
	}

	// Refill tokens
	elapsed := now.Sub(b.lastRefill).Seconds()
	b.tokens = min(capacity, b.tokens+elapsed*rate)
	b.lastRefill = now

	// Calculate reset time (when bucket will be full)
	tokensNeeded := capacity - b.tokens
	resetDuration := time.Duration(tokensNeeded/rate) * time.Second
	resetAt := now.Add(resetDuration)

	// Check if we have a token
	result := CheckResult{
		Limit:     int(capacity),
		Remaining: int(b.tokens),
		ResetAt:   resetAt,
	}

	if b.tokens >= 1 {
		b.tokens--
		result.Allowed = true
		result.Remaining = int(b.tokens)
	}

	return result
}

// SetRateLimitHeaders sets standard rate limit headers on the response.
func SetRateLimitHeaders(w http.ResponseWriter, result CheckResult) {
	w.Header().Set("X-RateLimit-Limit", strconv.Itoa(result.Limit))
	w.Header().Set("X-RateLimit-Remaining", strconv.Itoa(result.Remaining))
	w.Header().Set("X-RateLimit-Reset", strconv.FormatInt(result.ResetAt.Unix(), 10))

	if !result.Allowed {
		retryAfter := int(time.Until(result.ResetAt).Seconds())
		if retryAfter < 1 {
			retryAfter = 1
		}
		w.Header().Set("Retry-After", strconv.Itoa(retryAfter))
	}
}

// Global configurable limiter for collections and endpoints
var (
	// ResourceLimiter is the configurable rate limiter for collections and endpoints.
	// Default: 100 requests per minute, 2x for authenticated users.
	ResourceLimiter = NewConfigurableLimiter(100, "1m", 2.0)
)
← Back