-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmiddleware_ratelimit.go
More file actions
235 lines (220 loc) · 7.5 KB
/
Copy pathmiddleware_ratelimit.go
File metadata and controls
235 lines (220 loc) · 7.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
package theauth
import (
"bytes"
"encoding/json"
"io"
"net"
"net/http"
"net/netip"
"strings"
"sync"
"sync/atomic"
"time"
"golang.org/x/time/rate"
)
// keyedLimiter is an in-memory per-key sliding-window limiter. Each unique
// key (e.g. an IP or email) gets its own *rate.Limiter that allows N events
// per minute with the same burst budget.
//
// A background goroutine evicts limiters not used in the last evictAfter
// duration to keep memory bounded under attack. The whole struct is
// goroutine-safe.
//
// Perf re-audit 2026-06-21 (item 2): mu is now a sync.RWMutex so
// concurrent Allow calls on already-existing keys take a shared read lock
// for the map lookup and only upgrade to a write lock for new-key
// insertion. lastUsed is stored as atomic.Int64 (unix nanos) so the Allow
// hot path can update it without holding the write lock.
type keyedLimiter struct {
mu sync.RWMutex
limits map[string]*limiterEntry
perMinute int
evictAfter time.Duration
stop chan struct{}
stopOnce sync.Once
tickerEvery time.Duration
}
type limiterEntry struct {
lim *rate.Limiter
lastUsed atomic.Int64 // unix nanos; updated without holding mu
}
// newKeyedLimiter starts the GC goroutine. Callers should defer .Stop() in
// tests; in production these live for the process lifetime.
func newKeyedLimiter(perMinute int) *keyedLimiter {
return newKeyedLimiterWith(perMinute, 10*time.Minute, time.Minute)
}
// newKeyedLimiterWith is the testable variant, caller specifies GC timing.
func newKeyedLimiterWith(perMinute int, evictAfter, tickerEvery time.Duration) *keyedLimiter {
k := &keyedLimiter{
limits: make(map[string]*limiterEntry),
perMinute: perMinute,
evictAfter: evictAfter,
stop: make(chan struct{}),
tickerEvery: tickerEvery,
}
go k.gcLoop()
return k
}
func (k *keyedLimiter) Allow(key string) bool {
if key == "" {
// Empty key = no limiter applied. Caller decided to skip this dimension.
return true
}
// Fast path: entry already exists, take a shared read lock.
k.mu.RLock()
entry, ok := k.limits[key]
k.mu.RUnlock()
if ok {
// Update lastUsed atomically without holding the write lock.
entry.lastUsed.Store(time.Now().UnixNano())
return entry.lim.Allow()
}
// Slow path: first request for this key; take the write lock.
k.mu.Lock()
// Re-check under write lock (another goroutine may have inserted it).
entry, ok = k.limits[key]
if !ok {
// rate.Every(perMinute per minute) = 1 token every (60/perMinute) seconds.
// Burst of perMinute lets a fresh client burn the full budget instantly,
// after which it refills smoothly, matches what users intuit as "N/min".
r := rate.Every(time.Minute / time.Duration(k.perMinute))
entry = &limiterEntry{lim: rate.NewLimiter(r, k.perMinute)}
entry.lastUsed.Store(time.Now().UnixNano())
k.limits[key] = entry
}
k.mu.Unlock()
return entry.lim.Allow()
}
func (k *keyedLimiter) gcLoop() {
t := time.NewTicker(k.tickerEvery)
defer t.Stop()
for {
select {
case <-k.stop:
return
case <-t.C:
cutoffNanos := time.Now().Add(-k.evictAfter).UnixNano()
k.mu.Lock()
for key, e := range k.limits {
if e.lastUsed.Load() < cutoffNanos {
delete(k.limits, key)
}
}
k.mu.Unlock()
}
}
}
// Stop terminates the GC goroutine. Safe to call multiple times.
func (k *keyedLimiter) Stop() {
k.stopOnce.Do(func() { close(k.stop) })
}
// extractClientIPTrusting returns the best-effort client IP for the
// request. The X-Forwarded-For header is consulted ONLY when the incoming
// r.RemoteAddr belongs to one of the operator-configured trusted prefixes;
// on a public-internet deployment with no proxy in front, that allowlist
// is empty and XFF is ignored, so an attacker cannot trivially bypass
// per-IP rate limits by forging the header (security audit H4,
// 2026-06-20).
//
// When the request arrives from a trusted proxy the first segment of XFF
// (the original client) is returned. Otherwise the function returns the
// connection-level RemoteAddr without a port.
func extractClientIPTrusting(r *http.Request, trusted []netip.Prefix) string {
remoteHost := remoteAddrHost(r)
if len(trusted) > 0 && remoteIsTrusted(remoteHost, trusted) {
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
parts := strings.Split(xff, ",")
ip := strings.TrimSpace(parts[0])
if ip != "" {
return ip
}
}
}
return remoteHost
}
// remoteAddrHost strips the port from r.RemoteAddr; if the address has no
// port it is returned as-is.
func remoteAddrHost(r *http.Request) string {
host, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
return r.RemoteAddr
}
return host
}
// remoteIsTrusted reports whether the supplied IP literal belongs to any
// of the configured trusted prefixes. Malformed IPs are never trusted.
func remoteIsTrusted(ipLiteral string, trusted []netip.Prefix) bool {
addr, err := netip.ParseAddr(ipLiteral)
if err != nil {
return false
}
for _, p := range trusted {
if p.Contains(addr) {
return true
}
}
return false
}
// RateLimitByIP returns a middleware that limits requests per source IP to
// perMinute per minute. Use on credential endpoints (signin, signup, forgot,
// reset). The limiter lives on the returned handler. Multiple calls produce
// independent buckets, so wire it once per route group at startup.
//
// X-Forwarded-For is honored only when r.RemoteAddr is inside one of the
// Config.TrustedProxies prefixes. Deployments behind a reverse proxy MUST
// opt in by listing the proxy network in TrustedProxies; the default is
// the empty allowlist (no XFF trust), which is the safe behavior on a
// direct public-internet bind (security audit H4, 2026-06-20).
func (a *TheAuth) RateLimitByIP(perMinute int) func(http.Handler) http.Handler {
k := newKeyedLimiter(perMinute)
trusted := a.trustedProxies
blocked := a.hooks.Counter(MetricRateLimitBlockedTotal, Labels{AttrRule: "ip"})
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ip := extractClientIPTrusting(r, trusted)
if !k.Allow(ip) {
blocked.Inc()
w.Header().Set("Retry-After", "60")
http.Error(w, "rate_limited", http.StatusTooManyRequests)
return
}
next.ServeHTTP(w, r)
})
}
}
// RateLimitByEmail returns a middleware that limits requests per email body
// field. Reads the JSON body up to 16 KiB, extracts "email", restores the body
// so downstream handlers can re-read it. Requests without a parseable email
// are passed through unlimited (handler will reject them on its own).
func (a *TheAuth) RateLimitByEmail(perMinute int) func(http.Handler) http.Handler {
k := newKeyedLimiter(perMinute)
blocked := a.hooks.Counter(MetricRateLimitBlockedTotal, Labels{AttrRule: "email"})
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
buf, err := io.ReadAll(io.LimitReader(r.Body, 1<<14))
if err != nil {
next.ServeHTTP(w, r)
return
}
_ = r.Body.Close()
// Restore body for the downstream handler.
r.Body = io.NopCloser(bytes.NewReader(buf))
var body struct {
Email string `json:"email"`
}
// Best-effort decode; if it fails, we don't have a key, pass through.
if err := json.Unmarshal(buf, &body); err != nil || body.Email == "" {
next.ServeHTTP(w, r)
return
}
key := strings.ToLower(strings.TrimSpace(body.Email))
if !k.Allow(key) {
blocked.Inc()
w.Header().Set("Retry-After", "60")
http.Error(w, "rate_limited", http.StatusTooManyRequests)
return
}
next.ServeHTTP(w, r)
})
}
}