readysite / website / internal / access / limiter.go
3.0 KB
limiter.go
// Rate limiting middleware.
package access

import (
	"net/http"
	"sync"
	"time"
)

// bucket represents a token bucket for rate limiting.
type bucket struct {
	tokens     float64
	lastRefill time.Time
}

// Limiter implements a token bucket rate limiter.
type Limiter struct {
	mu       sync.Mutex
	buckets  map[string]*bucket
	rate     float64 // tokens per second
	capacity float64 // max tokens
}

// NewLimiter creates a rate limiter with the given rate (requests per minute) and burst.
func NewLimiter(requestsPerMinute int, burst int) *Limiter {
	return &Limiter{
		buckets:  make(map[string]*bucket),
		rate:     float64(requestsPerMinute) / 60.0,
		capacity: float64(burst),
	}
}

// Allow checks if a request is allowed for the given key (usually IP or user ID).
func (l *Limiter) Allow(key string) bool {
	l.mu.Lock()
	defer l.mu.Unlock()

	now := time.Now()
	b, ok := l.buckets[key]
	if !ok {
		b = &bucket{tokens: l.capacity, lastRefill: now}
		l.buckets[key] = b
	}

	// Refill tokens
	elapsed := now.Sub(b.lastRefill).Seconds()
	b.tokens = min(l.capacity, b.tokens+elapsed*l.rate)
	b.lastRefill = now

	// Check if we have a token
	if b.tokens >= 1 {
		b.tokens--
		return true
	}
	return false
}

// Middleware returns an HTTP middleware that rate limits by IP.
func (l *Limiter) Middleware(next http.Handler) http.Handler {
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		key := r.RemoteAddr
		if ip := r.Header.Get("X-Forwarded-For"); ip != "" {
			key = ip
		}

		if !l.Allow(key) {
			http.Error(w, "Too Many Requests", http.StatusTooManyRequests)
			return
		}

		next.ServeHTTP(w, r)
	})
}

// Cleanup removes old buckets to prevent memory leaks.
// Call periodically (e.g., every 5 minutes).
func (l *Limiter) Cleanup(maxAge time.Duration) {
	l.mu.Lock()
	defer l.mu.Unlock()

	now := time.Now()
	for key, b := range l.buckets {
		if now.Sub(b.lastRefill) > maxAge {
			delete(l.buckets, key)
		}
	}
}

// --- Pre-configured limiters for common use cases ---

var (
	// AuthLimiter limits auth endpoints to 5 requests per minute per IP.
	AuthLimiter = NewLimiter(5, 5)

	// UploadLimiter limits file uploads to 10 per minute per IP.
	UploadLimiter = NewLimiter(10, 10)

	// ChatLimiter limits message creation to 30 per minute per IP.
	ChatLimiter = NewLimiter(30, 30)

	// APILimiter limits public API read requests to 100 per minute per IP.
	APILimiter = NewLimiter(100, 100)

	// APIWriteLimiter limits public API write requests (POST/PATCH/DELETE) to 30 per minute per IP.
	// This is more restrictive than reads to prevent abuse.
	APIWriteLimiter = NewLimiter(30, 30)
)

// RateLimitAuth wraps a handler with auth rate limiting.
func RateLimitAuth(handler http.Handler) http.Handler {
	return AuthLimiter.Middleware(handler)
}

// RateLimitUpload wraps a handler with upload rate limiting.
func RateLimitUpload(handler http.Handler) http.Handler {
	return UploadLimiter.Middleware(handler)
}

// RateLimitChat wraps a handler with chat rate limiting.
func RateLimitChat(handler http.Handler) http.Handler {
	return ChatLimiter.Middleware(handler)
}
← Back