middleware.go
103 lines1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
package application
import (
"bufio"
"encoding/json"
"fmt"
"log"
"net"
"net/http"
"os"
"strings"
"time"
)
// requestLog represents a logged request.
type requestLog struct {
Timestamp string `json:"timestamp"`
Method string `json:"method"`
Path string `json:"path"`
Status int `json:"status"`
Duration string `json:"duration"`
DurationMs int64 `json:"duration_ms"`
IP string `json:"ip"`
UserAgent string `json:"user_agent,omitempty"`
}
// responseWriter wraps http.ResponseWriter to capture status code.
// It also implements http.Flusher to support SSE streaming.
type responseWriter struct {
http.ResponseWriter
status int
}
// WriteHeader captures the status code and delegates to the underlying ResponseWriter.
func (rw *responseWriter) WriteHeader(code int) {
rw.status = code
rw.ResponseWriter.WriteHeader(code)
}
// Flush implements http.Flusher for SSE streaming support.
func (rw *responseWriter) Flush() {
if flusher, ok := rw.ResponseWriter.(http.Flusher); ok {
flusher.Flush()
}
}
// Hijack implements http.Hijacker for WebSocket upgrade support.
func (rw *responseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
if hijacker, ok := rw.ResponseWriter.(http.Hijacker); ok {
return hijacker.Hijack()
}
return nil, nil, fmt.Errorf("upstream ResponseWriter does not implement http.Hijacker")
}
// LoggingMiddleware returns middleware that logs all requests.
func LoggingMiddleware(next http.Handler) http.Handler {
isProduction := os.Getenv("ENV") == "production"
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
// Wrap response writer to capture status
wrapped := &responseWriter{ResponseWriter: w, status: http.StatusOK}
// Process request
next.ServeHTTP(wrapped, r)
// Calculate duration
duration := time.Since(start)
// Get client IP
ip := r.RemoteAddr
if forwarded := r.Header.Get("X-Forwarded-For"); forwarded != "" {
// X-Forwarded-For: client, proxy1, proxy2 — use first (client) IP
if i := strings.IndexByte(forwarded, ','); i > 0 {
ip = strings.TrimSpace(forwarded[:i])
} else {
ip = forwarded
}
}
// Log request
if isProduction {
// Structured JSON logging for production
logEntry := requestLog{
Timestamp: time.Now().UTC().Format(time.RFC3339),
Method: r.Method,
Path: r.URL.Path,
Status: wrapped.status,
Duration: duration.String(),
DurationMs: duration.Milliseconds(),
IP: ip,
UserAgent: r.UserAgent(),
}
if data, err := json.Marshal(logEntry); err == nil {
log.Println(string(data))
}
} else {
// Human-readable logging for development
log.Printf("%s %s %d %s", r.Method, r.URL.Path, wrapped.status, duration)
}
})
}