ratelimit.go
// Configurable rate limiting.
package access
import (
"encoding/json"
"net/http"
"strconv"
"sync"
"time"
)
// RateLimitConfig represents a configurable rate limit.
type RateLimitConfig struct {
Requests int `json:"requests"` // Requests allowed per window
Window string `json:"window"` // Time window (e.g., "1m", "1h")
Burst int `json:"burst"` // Burst capacity (max tokens)
}
// ParseRateLimitConfig parses a JSON rate limit config string.
// Returns nil if the string is empty or invalid.
func ParseRateLimitConfig(configJSON string) *RateLimitConfig {
if configJSON == "" {
return nil
}
var config RateLimitConfig
if err := json.Unmarshal([]byte(configJSON), &config); err != nil {
return nil
}
// Validate and set defaults
if config.Requests <= 0 {
return nil
}
if config.Window == "" {
config.Window = "1m"
}
if config.Burst <= 0 {
config.Burst = config.Requests
}
return &config
}
// WindowDuration parses the window string to a time.Duration.
func (c *RateLimitConfig) WindowDuration() time.Duration {
if c == nil || c.Window == "" {
return time.Minute
}
// Parse duration string (e.g., "1m", "30s", "1h")
d, err := time.ParseDuration(c.Window)
if err != nil {
// Try parsing as just a number (assume minutes)
if mins, err := strconv.Atoi(c.Window); err == nil {
return time.Duration(mins) * time.Minute
}
return time.Minute
}
return d
}
// Rate returns the token rate (tokens per second).
func (c *RateLimitConfig) Rate() float64 {
if c == nil {
return 1.0
}
window := c.WindowDuration()
return float64(c.Requests) / window.Seconds()
}
// ConfigurableLimiter manages per-resource rate limiters.
type ConfigurableLimiter struct {
mu sync.RWMutex
limiters map[string]*resourceLimiter
defaultConfig *RateLimitConfig
authMultiply float64 // Multiplier for authenticated users
cleanupTicker *time.Ticker
stopCleanup chan struct{}
}
// resourceLimiter holds a limiter for a specific resource+client combination.
type resourceLimiter struct {
buckets map[string]*bucket // keyed by client IP
config *RateLimitConfig
mu sync.Mutex
lastUsed time.Time
}
// NewConfigurableLimiter creates a new configurable rate limiter.
func NewConfigurableLimiter(defaultRequests int, defaultWindow string, authMultiplier float64) *ConfigurableLimiter {
cl := &ConfigurableLimiter{
limiters: make(map[string]*resourceLimiter),
defaultConfig: &RateLimitConfig{
Requests: defaultRequests,
Window: defaultWindow,
Burst: defaultRequests,
},
authMultiply: authMultiplier,
stopCleanup: make(chan struct{}),
cleanupTicker: time.NewTicker(5 * time.Minute),
}
// Start cleanup goroutine
go cl.cleanupLoop()
return cl
}
// cleanupLoop periodically removes unused limiters.
func (cl *ConfigurableLimiter) cleanupLoop() {
for {
select {
case <-cl.cleanupTicker.C:
cl.cleanup(10 * time.Minute)
case <-cl.stopCleanup:
cl.cleanupTicker.Stop()
return
}
}
}
// Stop stops the cleanup goroutine.
func (cl *ConfigurableLimiter) Stop() {
close(cl.stopCleanup)
}
// cleanup removes limiters that haven't been used recently.
func (cl *ConfigurableLimiter) cleanup(maxAge time.Duration) {
cl.mu.Lock()
defer cl.mu.Unlock()
now := time.Now()
for key, rl := range cl.limiters {
if now.Sub(rl.lastUsed) > maxAge {
delete(cl.limiters, key)
}
}
}
// getLimiter gets or creates a limiter for a resource.
func (cl *ConfigurableLimiter) getLimiter(resourceKey string, config *RateLimitConfig) *resourceLimiter {
cl.mu.RLock()
rl, exists := cl.limiters[resourceKey]
cl.mu.RUnlock()
if exists {
rl.lastUsed = time.Now()
return rl
}
// Create new limiter
cl.mu.Lock()
defer cl.mu.Unlock()
// Double-check after acquiring write lock
if rl, exists = cl.limiters[resourceKey]; exists {
rl.lastUsed = time.Now()
return rl
}
if config == nil {
config = cl.defaultConfig
}
rl = &resourceLimiter{
buckets: make(map[string]*bucket),
config: config,
lastUsed: time.Now(),
}
cl.limiters[resourceKey] = rl
return rl
}
// CheckResult contains the result of a rate limit check.
type CheckResult struct {
Allowed bool
Limit int
Remaining int
ResetAt time.Time
}
// Check checks if a request is allowed and returns rate limit info.
func (cl *ConfigurableLimiter) Check(resourceType, resourceID, clientIP string, config *RateLimitConfig, isAuthenticated bool) CheckResult {
resourceKey := resourceType + ":" + resourceID
rl := cl.getLimiter(resourceKey, config)
effectiveConfig := rl.config
if effectiveConfig == nil {
effectiveConfig = cl.defaultConfig
}
rl.mu.Lock()
defer rl.mu.Unlock()
now := time.Now()
b, ok := rl.buckets[clientIP]
if !ok {
b = &bucket{tokens: float64(effectiveConfig.Burst), lastRefill: now}
rl.buckets[clientIP] = b
}
// Calculate effective rate (higher for authenticated users)
rate := effectiveConfig.Rate()
capacity := float64(effectiveConfig.Burst)
if isAuthenticated && cl.authMultiply > 1.0 {
rate *= cl.authMultiply
capacity *= cl.authMultiply
}
// Refill tokens
elapsed := now.Sub(b.lastRefill).Seconds()
b.tokens = min(capacity, b.tokens+elapsed*rate)
b.lastRefill = now
// Calculate reset time (when bucket will be full)
tokensNeeded := capacity - b.tokens
resetDuration := time.Duration(tokensNeeded/rate) * time.Second
resetAt := now.Add(resetDuration)
// Check if we have a token
result := CheckResult{
Limit: int(capacity),
Remaining: int(b.tokens),
ResetAt: resetAt,
}
if b.tokens >= 1 {
b.tokens--
result.Allowed = true
result.Remaining = int(b.tokens)
}
return result
}
// SetRateLimitHeaders sets standard rate limit headers on the response.
func SetRateLimitHeaders(w http.ResponseWriter, result CheckResult) {
w.Header().Set("X-RateLimit-Limit", strconv.Itoa(result.Limit))
w.Header().Set("X-RateLimit-Remaining", strconv.Itoa(result.Remaining))
w.Header().Set("X-RateLimit-Reset", strconv.FormatInt(result.ResetAt.Unix(), 10))
if !result.Allowed {
retryAfter := int(time.Until(result.ResetAt).Seconds())
if retryAfter < 1 {
retryAfter = 1
}
w.Header().Set("Retry-After", strconv.Itoa(retryAfter))
}
}
// Global configurable limiter for collections and endpoints
var (
// ResourceLimiter is the configurable rate limiter for collections and endpoints.
// Default: 100 requests per minute, 2x for authenticated users.
ResourceLimiter = NewConfigurableLimiter(100, "1m", 2.0)
)