readysite / website / internal / access / auth.go
3.4 KB
auth.go
package access

import (
	"crypto/rand"
	"encoding/base64"
	"log"
	"net/http"
	"os"
	"time"

	"github.com/golang-jwt/jwt/v5"
	"github.com/readysite/readysite/website/models"
	"golang.org/x/crypto/bcrypt"
)

// jwtSecret is the signing key for JWT tokens.
var jwtSecret []byte

func init() {
	secret := os.Getenv("AUTH_SECRET")
	if secret == "" || secret == "change-me-in-production" {
		if os.Getenv("ENV") == "production" {
			// Generate random secret for production if not set
			// Sessions won't persist across restarts
			b := make([]byte, 32)
			rand.Read(b)
			secret = base64.StdEncoding.EncodeToString(b)
			log.Println("WARNING: AUTH_SECRET not set - using random secret. Sessions will not persist across restarts.")
			log.Println("         Set AUTH_SECRET env var for persistent sessions: openssl rand -base64 32")
		} else {
			// Use default for development only
			secret = "dev-only-not-for-production"
		}
	}
	jwtSecret = []byte(secret)
}

// JWTExpiration is the duration before JWT tokens expire.
const JWTExpiration = 30 * 24 * time.Hour

// isSecure returns true if the request is over HTTPS.
func isSecure(r *http.Request) bool {
	// Direct TLS connection
	if r.TLS != nil {
		return true
	}
	// Behind a reverse proxy
	if r.Header.Get("X-Forwarded-Proto") == "https" {
		return true
	}
	return false
}

// SetSessionCookie creates a signed JWT and sets it as a session cookie.
func SetSessionCookie(w http.ResponseWriter, r *http.Request, userID string) error {
	token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
		"user_id": userID,
		"exp":     time.Now().Add(JWTExpiration).Unix(),
	})
	tokenString, err := token.SignedString(jwtSecret)
	if err != nil {
		return err
	}
	http.SetCookie(w, &http.Cookie{
		Name:     "session",
		Value:    tokenString,
		Path:     "/",
		HttpOnly: true,
		Secure:   isSecure(r),
		SameSite: http.SameSiteLaxMode,
		MaxAge:   int(JWTExpiration.Seconds()),
	})
	return nil
}

// ClearSessionCookie removes the session cookie.
func ClearSessionCookie(w http.ResponseWriter, r *http.Request) {
	http.SetCookie(w, &http.Cookie{
		Name:     "session",
		Value:    "",
		Path:     "/",
		HttpOnly: true,
		Secure:   isSecure(r),
		SameSite: http.SameSiteLaxMode,
		MaxAge:   -1,
	})
}

// GetUserFromJWT extracts the user from the JWT cookie.
func GetUserFromJWT(r *http.Request) *models.User {
	cookie, err := r.Cookie("session")
	if err != nil {
		return nil
	}

	token, err := jwt.Parse(cookie.Value, func(token *jwt.Token) (any, error) {
		return jwtSecret, nil
	})
	if err != nil || !token.Valid {
		return nil
	}

	claims, ok := token.Claims.(jwt.MapClaims)
	if !ok {
		return nil
	}

	userID, ok := claims["user_id"].(string)
	if !ok {
		return nil
	}

	user, err := models.Users.Get(userID)
	if err != nil {
		return nil
	}

	return user
}

// HashPassword creates a bcrypt hash of a password.
func HashPassword(password string) (string, error) {
	hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
	if err != nil {
		return "", err
	}
	return string(hash), nil
}

// CheckPassword compares a password with its bcrypt hash.
func CheckPassword(password, hash string) bool {
	err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(password))
	return err == nil
}

// GenerateTokenKey creates a new random token key for token invalidation.
func GenerateTokenKey() string {
	bytes := make([]byte, 16)
	if _, err := rand.Read(bytes); err != nil {
		return ""
	}
	return base64.URLEncoding.EncodeToString(bytes)
}
← Back