Advanced Middleware
Middleware is a function that executes before and after an HTTP request reaches a handler. It separates cross-cutting concerns such as authentication, logging, CORS, and rate limiting from handler logic, reducing code duplication and improving maintainability.
Middleware Pattern
net/http Style (Function Wrapping)
package main
import (
"fmt"
"log"
"net/http"
"time"
)
// Middleware type definition
type Middleware func(http.Handler) http.Handler
// Chain applies multiple middlewares in order
func Chain(h http.Handler, middlewares ...Middleware) http.Handler {
// Apply in reverse order to get the correct execution order
for i := len(middlewares) - 1; i >= 0; i-- {
h = middlewares[i](h)
}
return h
}
// LoggingMiddleware request/response logging
func LoggingMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
// Wrapper to capture response status code
rw := &responseWriter{ResponseWriter: w, status: http.StatusOK}
next.ServeHTTP(rw, r)
log.Printf("%s %s %d %v", r.Method, r.URL.Path, rw.status, time.Since(start))
})
}
// Wrapper to capture response status code
type responseWriter struct {
http.ResponseWriter
status int
}
func (rw *responseWriter) WriteHeader(status int) {
rw.status = status
rw.ResponseWriter.WriteHeader(status)
}
// RecoveryMiddleware panic recovery
func RecoveryMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer func() {
if err := recover(); err != nil {
log.Printf("panic: %v", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
}
}()
next.ServeHTTP(w, r)
})
}
func main() {
mux := http.NewServeMux()
mux.HandleFunc("GET /hello", func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "Hello, World!")
})
// Middleware chain: Recovery → Logging → Handler execution order
handler := Chain(mux, RecoveryMiddleware, LoggingMiddleware)
http.ListenAndServe(":8080", handler)
}
JWT Authentication Middleware
package main
import (
"context"
"net/http"
"strings"
"time"
"github.com/golang-jwt/jwt/v5"
)
var jwtSecret = []byte("your-secret-key-change-in-production")
// Claims JWT claims structure
type Claims struct {
UserID int64 `json:"user_id"`
Email string `json:"email"`
Role string `json:"role"`
jwt.RegisteredClaims
}
// contextKey context key type (to prevent collisions)
type contextKey string
const claimsKey contextKey = "claims"
// GenerateToken issue JWT token
func GenerateToken(userID int64, email, role string) (string, error) {
claims := Claims{
UserID: userID,
Email: email,
Role: role,
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(time.Now().Add(24 * time.Hour)),
IssuedAt: jwt.NewNumericDate(time.Now()),
Issuer: "my-app",
},
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
return token.SignedString(jwtSecret)
}
// JWTMiddleware JWT authentication middleware (net/http style)
func JWTMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Parse Authorization: Bearer <token> header
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
http.Error(w, `{"error":{"code":"UNAUTHORIZED","message":"Authentication token is required"}}`,
http.StatusUnauthorized)
return
}
parts := strings.SplitN(authHeader, " ", 2)
if len(parts) != 2 || parts[0] != "Bearer" {
http.Error(w, `{"error":{"code":"INVALID_TOKEN","message":"Not a Bearer token format"}}`,
http.StatusUnauthorized)
return
}
// Parse and validate token
claims := &Claims{}
token, err := jwt.ParseWithClaims(parts[1], claims, func(t *jwt.Token) (any, error) {
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, jwt.ErrSignatureInvalid
}
return jwtSecret, nil
})
if err != nil || !token.Valid {
http.Error(w, `{"error":{"code":"INVALID_TOKEN","message":"Invalid token"}}`,
http.StatusUnauthorized)
return
}
// Store claims in context
ctx := context.WithValue(r.Context(), claimsKey, claims)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
// GetClaims extract claims from context
func GetClaims(r *http.Request) *Claims {
claims, _ := r.Context().Value(claimsKey).(*Claims)
return claims
}
// AdminOnly admin-only middleware (use after JWTMiddleware)
func AdminOnly(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
claims := GetClaims(r)
if claims == nil || claims.Role != "admin" {
http.Error(w, `{"error":{"code":"FORBIDDEN","message":"Admin privileges required"}}`,
http.StatusForbidden)
return
}
next.ServeHTTP(w, r)
})
}
CORS Middleware
package main
import "net/http"
// CORSConfig CORS configuration
type CORSConfig struct {
AllowOrigins []string
AllowMethods []string
AllowHeaders []string
ExposeHeaders []string
MaxAge int // Preflight cache duration in seconds
}
// CORSMiddleware CORS middleware
func CORSMiddleware(config CORSConfig) func(http.Handler) http.Handler {
allowOriginSet := make(map[string]bool)
for _, o := range config.AllowOrigins {
allowOriginSet[o] = true
}
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
origin := r.Header.Get("Origin")
// Check if origin is allowed
if allowOriginSet["*"] || allowOriginSet[origin] {
w.Header().Set("Access-Control-Allow-Origin", origin)
}
w.Header().Set("Access-Control-Allow-Methods",
joinStrings(config.AllowMethods, ", "))
w.Header().Set("Access-Control-Allow-Headers",
joinStrings(config.AllowHeaders, ", "))
if len(config.ExposeHeaders) > 0 {
w.Header().Set("Access-Control-Expose-Headers",
joinStrings(config.ExposeHeaders, ", "))
}
// Handle preflight requests
if r.Method == http.MethodOptions {
w.WriteHeader(http.StatusNoContent)
return
}
next.ServeHTTP(w, r)
})
}
}
func joinStrings(ss []string, sep string) string {
result := ""
for i, s := range ss {
if i > 0 {
result += sep
}
result += s
}
return result
}
// Usage example
func newCORSMiddleware() func(http.Handler) http.Handler {
return CORSMiddleware(CORSConfig{
AllowOrigins: []string{"https://example.com", "http://localhost:3000"},
AllowMethods: []string{"GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"},
AllowHeaders: []string{"Content-Type", "Authorization", "X-Request-ID"},
ExposeHeaders: []string{"X-Request-ID"},
MaxAge: 86400,
})
}
Request Logging Middleware (slog + Request ID)
package main
import (
"context"
"log/slog"
"net/http"
"time"
"github.com/google/uuid"
)
type requestIDKey struct{}
// RequestIDMiddleware generate and propagate request ID
func RequestIDMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Use client-provided ID or generate a new one
requestID := r.Header.Get("X-Request-ID")
if requestID == "" {
requestID = uuid.New().String()
}
w.Header().Set("X-Request-ID", requestID)
ctx := context.WithValue(r.Context(), requestIDKey{}, requestID)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
// StructuredLoggingMiddleware slog-based structured logging
func StructuredLoggingMiddleware(logger *slog.Logger) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
rw := &responseWriter{ResponseWriter: w, status: http.StatusOK}
next.ServeHTTP(rw, r)
requestID, _ := r.Context().Value(requestIDKey{}).(string)
logger.Info("request",
slog.String("request_id", requestID),
slog.String("method", r.Method),
slog.String("path", r.URL.Path),
slog.String("remote_addr", r.RemoteAddr),
slog.Int("status", rw.status),
slog.Duration("duration", time.Since(start)),
)
})
}
}
Rate Limiting Middleware
package main
import (
"net/http"
"sync"
"time"
"golang.org/x/time/rate"
)
// IPRateLimiter per-IP token bucket rate limiter
type IPRateLimiter struct {
limiters map[string]*rate.Limiter
mu sync.RWMutex
r rate.Limit // requests allowed per second
b int // burst size
}
func NewIPRateLimiter(r rate.Limit, b int) *IPRateLimiter {
return &IPRateLimiter{
limiters: make(map[string]*rate.Limiter),
r: r,
b: b,
}
}
func (rl *IPRateLimiter) getLimiter(ip string) *rate.Limiter {
rl.mu.RLock()
limiter, ok := rl.limiters[ip]
rl.mu.RUnlock()
if !ok {
rl.mu.Lock()
limiter = rate.NewLimiter(rl.r, rl.b)
rl.limiters[ip] = limiter
rl.mu.Unlock()
}
return limiter
}
// RateLimitMiddleware IP-based rate limiting middleware
func RateLimitMiddleware(rl *IPRateLimiter) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ip := r.RemoteAddr // use X-Real-IP or X-Forwarded-For in production
if !rl.getLimiter(ip).Allow() {
w.Header().Set("Retry-After", "1")
http.Error(w,
`{"error":{"code":"RATE_LIMITED","message":"Too many requests. Please try again later."}}`,
http.StatusTooManyRequests)
return
}
next.ServeHTTP(w, r)
})
}
}
Timeout Middleware
package main
import (
"context"
"net/http"
"time"
)
// TimeoutMiddleware request processing timeout
func TimeoutMiddleware(timeout time.Duration) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx, cancel := context.WithTimeout(r.Context(), timeout)
defer cancel()
// Channel for timeout detection
done := make(chan struct{})
tw := &timeoutResponseWriter{ResponseWriter: w}
go func() {
next.ServeHTTP(tw, r.WithContext(ctx))
close(done)
}()
select {
case <-done:
// Completed normally
case <-ctx.Done():
tw.mu.Lock()
if !tw.wroteHeader {
http.Error(w, `{"error":{"code":"TIMEOUT","message":"Request processing timed out"}}`,
http.StatusGatewayTimeout)
}
tw.mu.Unlock()
}
})
}
}
type timeoutResponseWriter struct {
http.ResponseWriter
mu sync.Mutex
wroteHeader bool
}
func (tw *timeoutResponseWriter) WriteHeader(status int) {
tw.mu.Lock()
defer tw.mu.Unlock()
tw.wroteHeader = true
tw.ResponseWriter.WriteHeader(status)
}
Practical Example: Production-Level Gin Server
package main
import (
"log/slog"
"net/http"
"os"
"time"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"golang.org/x/time/rate"
)
func setupRouter() *gin.Engine {
r := gin.New() // use gin.New() instead of gin.Default() for direct middleware control
logger := slog.New(slog.NewJSONHandler(os.Stdout, nil))
// 1. Recovery (placed outermost)
r.Use(gin.CustomRecoveryWithWriter(os.Stderr, func(c *gin.Context, err any) {
logger.Error("panic recovered", slog.Any("error", err))
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{
"error": gin.H{"code": "INTERNAL_ERROR", "message": "An internal server error occurred"},
})
}))
// 2. Inject request ID
r.Use(func(c *gin.Context) {
id := c.GetHeader("X-Request-ID")
if id == "" {
id = uuid.New().String()
}
c.Set("request_id", id)
c.Header("X-Request-ID", id)
c.Next()
})
// 3. Structured logging
r.Use(func(c *gin.Context) {
start := time.Now()
c.Next()
logger.Info("request",
slog.String("request_id", c.GetString("request_id")),
slog.String("method", c.Request.Method),
slog.String("path", c.Request.URL.Path),
slog.Int("status", c.Writer.Status()),
slog.Duration("latency", time.Since(start)),
)
})
// 4. CORS
r.Use(func(c *gin.Context) {
c.Header("Access-Control-Allow-Origin", "*")
c.Header("Access-Control-Allow-Methods", "GET,POST,PUT,PATCH,DELETE,OPTIONS")
c.Header("Access-Control-Allow-Headers", "Content-Type,Authorization,X-Request-ID")
if c.Request.Method == http.MethodOptions {
c.AbortWithStatus(http.StatusNoContent)
return
}
c.Next()
})
// 5. Rate Limiting (10 requests/sec, burst 20)
limiter := NewIPRateLimiter(rate.Limit(10), 20)
r.Use(func(c *gin.Context) {
if !limiter.getLimiter(c.ClientIP()).Allow() {
c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{
"error": gin.H{"code": "RATE_LIMITED", "message": "Please try again later"},
})
return
}
c.Next()
})
// Public routes
r.GET("/health", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"status": "ok", "time": time.Now()})
})
// Route group requiring authentication
api := r.Group("/api/v1")
api.Use(func(c *gin.Context) {
// JWT validation (simplified)
token := c.GetHeader("Authorization")
if token == "" {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
"error": gin.H{"code": "UNAUTHORIZED", "message": "Authentication required"},
})
return
}
c.Next()
})
{
api.GET("/users/me", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"user": gin.H{"id": 1, "name": "John Doe"}})
})
}
return r
}
func main() {
r := setupRouter()
r.Run(":8080")
}
Middleware Execution Order
Request →
[Recovery start] →
[RequestID start] →
[Logging start] →
[CORS start] →
[RateLimit start] →
[JWT start] →
Handler execution
[JWT end] →
[RateLimit end] →
[CORS end] →
[Logging end] (record status code, response time) →
[RequestID end] →
[Recovery end] →
← Response
Pro Tips
1. Keep middleware thin: Each middleware should have a single responsibility
2. Use c.Abort(): In Gin, use c.Abort() to prevent execution of subsequent middlewares
3. Pass data via context: Share data between middlewares using c.Set/c.Get or context.WithValue
4. Middleware order matters: Recovery should always be outermost, Logging next, so that all errors and response times are recorded accurately