readysite / hosting / internal / access / limiter.go
1.8 KB
limiter.go
package access

import (
	"net/http"
	"sync"
	"time"
)

// bucket represents a token bucket for rate limiting.
type bucket struct {
	tokens     float64
	lastRefill time.Time
}

// Limiter implements a token bucket rate limiter.
type Limiter struct {
	mu       sync.Mutex
	buckets  map[string]*bucket
	rate     float64
	capacity float64
}

// NewLimiter creates a rate limiter with the given rate (requests per minute) and burst.
func NewLimiter(requestsPerMinute int, burst int) *Limiter {
	return &Limiter{
		buckets:  make(map[string]*bucket),
		rate:     float64(requestsPerMinute) / 60.0,
		capacity: float64(burst),
	}
}

// Allow checks if a request is allowed for the given key.
func (l *Limiter) Allow(key string) bool {
	l.mu.Lock()
	defer l.mu.Unlock()

	now := time.Now()

	// Periodically evict stale entries to prevent memory leak
	if len(l.buckets) > 10000 {
		for k, b := range l.buckets {
			if now.Sub(b.lastRefill) > 10*time.Minute {
				delete(l.buckets, k)
			}
		}
	}

	b, ok := l.buckets[key]
	if !ok {
		b = &bucket{tokens: l.capacity, lastRefill: now}
		l.buckets[key] = b
	}

	elapsed := now.Sub(b.lastRefill).Seconds()
	b.tokens = min(l.capacity, b.tokens+elapsed*l.rate)
	b.lastRefill = now

	if b.tokens >= 1 {
		b.tokens--
		return true
	}
	return false
}

// Middleware returns an HTTP middleware that rate limits by IP.
func (l *Limiter) Middleware(next http.Handler) http.Handler {
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		if !l.Allow(r.RemoteAddr) {
			http.Error(w, "Too Many Requests", http.StatusTooManyRequests)
			return
		}

		next.ServeHTTP(w, r)
	})
}

var (
	// AuthLimiter limits magic link requests to 3 per minute per IP.
	AuthLimiter = NewLimiter(3, 3)

	// APILimiter limits API requests to 60 per minute per IP.
	APILimiter = NewLimiter(60, 60)
)
← Back