readysite / website / internal / content / api / middleware.go
3.3 KB
middleware.go
package api

import (
	"net/http"
	"strings"

	"github.com/readysite/readysite/website/internal/access"
	"github.com/readysite/readysite/website/internal/helpers"
	"github.com/readysite/readysite/website/models"
)

// SetCORSHeaders sets CORS headers for API responses.
// SECURITY: Uses configurable allowed origins instead of wildcard.
// Set cors_origins setting to comma-separated list of allowed origins.
// Empty value means same-origin only (most secure).
// Use "*" only for truly public APIs (not recommended for auth endpoints).
func SetCORSHeaders(w http.ResponseWriter, r *http.Request) {
	origin := r.Header.Get("Origin")
	allowedOrigins := helpers.GetSetting(models.SettingCORSOrigins)

	// Determine if origin is allowed
	var allowOrigin string
	if allowedOrigins == "*" {
		// Explicit wildcard configured - allow all (not recommended)
		allowOrigin = "*"
	} else if allowedOrigins != "" {
		// Check if origin is in allowed list
		for _, allowed := range strings.Split(allowedOrigins, ",") {
			allowed = strings.TrimSpace(allowed)
			if allowed == origin {
				allowOrigin = origin
				break
			}
		}
	}
	// If no match and not wildcard, don't set Access-Control-Allow-Origin
	// This effectively blocks cross-origin requests (same-origin only)

	if allowOrigin != "" {
		w.Header().Set("Access-Control-Allow-Origin", allowOrigin)
		// Vary header is important when origin is not wildcard
		if allowOrigin != "*" {
			w.Header().Set("Vary", "Origin")
		}
	}

	w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PATCH, DELETE, OPTIONS")
	w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, If-Match")
	w.Header().Set("Access-Control-Max-Age", "86400")
}

// CheckRateLimit checks if the request is within rate limits using the default API limiter.
// Uses stricter limits for write operations (POST/PATCH/DELETE).
// Returns true if allowed, false if rate limited.
func CheckRateLimit(r *http.Request) bool {
	key := getClientIP(r)

	// Use stricter rate limit for write operations
	if r.Method == "POST" || r.Method == "PATCH" || r.Method == "DELETE" {
		return access.APIWriteLimiter.Allow(key)
	}
	return access.APILimiter.Allow(key)
}

// CheckCollectionRateLimit checks rate limits using the collection's config.
// Admin users bypass rate limits. OPTIONS requests don't count.
// Returns true if allowed, false if rate limited (headers are set on the response).
func CheckCollectionRateLimit(w http.ResponseWriter, r *http.Request, collection *models.Collection, user *models.User) bool {
	// Admin users bypass rate limits
	if user != nil && user.Role == "admin" {
		return true
	}

	// OPTIONS requests (CORS preflight) don't count toward limit
	if r.Method == "OPTIONS" {
		return true
	}

	// Get client IP
	clientIP := getClientIP(r)

	// Parse collection's rate limit config
	config := access.ParseRateLimitConfig(collection.RateLimit)

	// Check rate limit
	isAuthenticated := user != nil
	result := access.ResourceLimiter.Check("collection", collection.ID, clientIP, config, isAuthenticated)

	// Set rate limit headers
	access.SetRateLimitHeaders(w, result)

	return result.Allowed
}

// getClientIP extracts the client IP from a request.
func getClientIP(r *http.Request) string {
	if ip := r.Header.Get("X-Forwarded-For"); ip != "" {
		return ip
	}
	return r.RemoteAddr
}
← Back