limiter.go
// Rate limiting middleware.
package access
import (
"net/http"
"sync"
"time"
)
// bucket represents a token bucket for rate limiting.
type bucket struct {
tokens float64
lastRefill time.Time
}
// Limiter implements a token bucket rate limiter.
type Limiter struct {
mu sync.Mutex
buckets map[string]*bucket
rate float64 // tokens per second
capacity float64 // max tokens
}
// NewLimiter creates a rate limiter with the given rate (requests per minute) and burst.
func NewLimiter(requestsPerMinute int, burst int) *Limiter {
return &Limiter{
buckets: make(map[string]*bucket),
rate: float64(requestsPerMinute) / 60.0,
capacity: float64(burst),
}
}
// Allow checks if a request is allowed for the given key (usually IP or user ID).
func (l *Limiter) Allow(key string) bool {
l.mu.Lock()
defer l.mu.Unlock()
now := time.Now()
b, ok := l.buckets[key]
if !ok {
b = &bucket{tokens: l.capacity, lastRefill: now}
l.buckets[key] = b
}
// Refill tokens
elapsed := now.Sub(b.lastRefill).Seconds()
b.tokens = min(l.capacity, b.tokens+elapsed*l.rate)
b.lastRefill = now
// Check if we have a token
if b.tokens >= 1 {
b.tokens--
return true
}
return false
}
// Middleware returns an HTTP middleware that rate limits by IP.
func (l *Limiter) Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
key := r.RemoteAddr
if ip := r.Header.Get("X-Forwarded-For"); ip != "" {
key = ip
}
if !l.Allow(key) {
http.Error(w, "Too Many Requests", http.StatusTooManyRequests)
return
}
next.ServeHTTP(w, r)
})
}
// Cleanup removes old buckets to prevent memory leaks.
// Call periodically (e.g., every 5 minutes).
func (l *Limiter) Cleanup(maxAge time.Duration) {
l.mu.Lock()
defer l.mu.Unlock()
now := time.Now()
for key, b := range l.buckets {
if now.Sub(b.lastRefill) > maxAge {
delete(l.buckets, key)
}
}
}
// --- Pre-configured limiters for common use cases ---
var (
// AuthLimiter limits auth endpoints to 5 requests per minute per IP.
AuthLimiter = NewLimiter(5, 5)
// UploadLimiter limits file uploads to 10 per minute per IP.
UploadLimiter = NewLimiter(10, 10)
// ChatLimiter limits message creation to 30 per minute per IP.
ChatLimiter = NewLimiter(30, 30)
// APILimiter limits public API read requests to 100 per minute per IP.
APILimiter = NewLimiter(100, 100)
// APIWriteLimiter limits public API write requests (POST/PATCH/DELETE) to 30 per minute per IP.
// This is more restrictive than reads to prevent abuse.
APIWriteLimiter = NewLimiter(30, 30)
)
// RateLimitAuth wraps a handler with auth rate limiting.
func RateLimitAuth(handler http.Handler) http.Handler {
return AuthLimiter.Middleware(handler)
}
// RateLimitUpload wraps a handler with upload rate limiting.
func RateLimitUpload(handler http.Handler) http.Handler {
return UploadLimiter.Middleware(handler)
}
// RateLimitChat wraps a handler with chat rate limiting.
func RateLimitChat(handler http.Handler) http.Handler {
return ChatLimiter.Middleware(handler)
}