Skip to content

Commit 7ccd73c

Browse files
authored
Add user_id to routing decisions for per-user history isolation (#18)
* Add user_id to routing decisions for per-user history isolation Migration 0003 adds a nullable user_id column to calls. CallStore.record(), recent(), and count() accept an optional user_id parameter; recent/count filter to that user when provided. RoutingService.route() and the HTTP RouteRequest body thread user_id through; GET /api/history accepts a user_id query param. All fields are optional for full backward compat. * Address review: index on user_id, remove user_id from public history API Add composite index (user_id, id) to migration 0003 and mark ORM column index=True. Move WHERE before ORDER BY/LIMIT/OFFSET in recent(). Remove user_id query param from GET /api/history (BOLA risk — per-user scoping is the embedding app's responsibility via CallStore.recent(user_id=...)). * Fix ORM/migration index mismatch; update user_id field description Replace index=True on user_id column with __table_args__ composite index sa.Index("ix_calls_user_id_id", "user_id", "id") to match migration 0003 exactly and prevent Alembic autogenerate from reporting false drift. Reword user_id field description: no longer claims history filtering.
1 parent c236aad commit 7ccd73c

6 files changed

Lines changed: 105 additions & 16 deletions

File tree

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
"""Add user_id column to calls table
2+
3+
Revision ID: 0003
4+
Revises: 0002
5+
Create Date: 2026-06-28
6+
"""
7+
from __future__ import annotations
8+
9+
import sqlalchemy as sa
10+
from alembic import op
11+
12+
revision: str = "0003"
13+
down_revision: str = "0002"
14+
branch_labels: str | None = None
15+
depends_on: str | None = None
16+
17+
18+
def upgrade() -> None:
19+
op.add_column("calls", sa.Column("user_id", sa.String(255), nullable=True))
20+
op.create_index("ix_calls_user_id_id", "calls", ["user_id", "id"])
21+
22+
23+
def downgrade() -> None:
24+
op.drop_index("ix_calls_user_id_id", table_name="calls")
25+
op.drop_column("calls", "user_id")

src/xrouter_llm/server.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ class RouteRequest(BaseModel):
4444
description="Candidate model IDs. Omit to use config models or all registered models.",
4545
)
4646
task: str | None = None
47+
user_id: str | None = Field(default=None, max_length=255, description="Caller identity persisted with the routing decision.")
4748
completion_threshold: float | None = Field(default=None, ge=0.0, le=1.0)
4849
lambda_cost: float | None = Field(default=None, ge=0.0)
4950
lambda_latency: float | None = Field(default=None, ge=0.0)
@@ -96,6 +97,7 @@ def route(req: RouteRequest) -> dict[str, Any]:
9697
lambda_latency=req.lambda_latency,
9798
max_k=req.max_k,
9899
fallback_quality_margin=req.fallback_quality_margin,
100+
user_id=req.user_id,
99101
)
100102
except ValueError as exc:
101103
raise HTTPException(status_code=400, detail=str(exc)) from exc

src/xrouter_llm/serving.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ def route(
148148
lambda_latency: float | None = None,
149149
max_k: int | None = None,
150150
fallback_quality_margin: float | None = None,
151+
user_id: str | None = None,
151152
) -> dict[str, Any]:
152153
if not prompt.strip():
153154
raise ValueError("prompt must not be empty")
@@ -241,6 +242,7 @@ def route(
241242
expected_quality=float(breakdown.expected_quality),
242243
cost=float(breakdown.cost),
243244
latency=float(breakdown.latency),
245+
user_id=user_id,
244246
)
245247
return {
246248
"id": call_id,

src/xrouter_llm/store.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,19 @@ class CallRecord(Base):
3434
cost: Mapped[float | None] = mapped_column(sa.Float, nullable=True)
3535
latency: Mapped[float | None] = mapped_column(sa.Float, nullable=True)
3636
feedback: Mapped[Any] = mapped_column(sa.JSON, nullable=True)
37+
user_id: Mapped[str | None] = mapped_column(sa.String(255), nullable=True)
38+
39+
__table_args__ = (
40+
sa.Index("ix_calls_user_id_id", "user_id", "id"),
41+
)
3742

3843

3944
_BASELINE_REVISION = "0001"
4045
# Each entry: (column added by that migration, revision). Keep in ascending order.
4146
# To add a new migration: append ("new_column", "000N") here.
4247
_SCHEMA_CHECKPOINTS: list[tuple[str, str]] = [
4348
("feedback", "0002"),
49+
("user_id", "0003"),
4450
]
4551

4652

@@ -141,6 +147,7 @@ def record(
141147
expected_quality: float,
142148
cost: float,
143149
latency: float,
150+
user_id: str | None = None,
144151
) -> int:
145152
with self._Session() as session:
146153
row = CallRecord(
@@ -153,32 +160,33 @@ def record(
153160
expected_quality=expected_quality,
154161
cost=cost,
155162
latency=latency,
163+
user_id=user_id,
156164
)
157165
session.add(row)
158166
session.flush() # INSERT → DB assigns id; no extra SELECT needed
159167
call_id = row.id
160168
session.commit()
161169
return call_id
162170

163-
def recent(self, limit: int = 50, offset: int = 0) -> list[dict[str, Any]]:
171+
def recent(
172+
self, limit: int = 50, offset: int = 0, *, user_id: str | None = None
173+
) -> list[dict[str, Any]]:
164174
limit = max(1, min(int(limit), 1000))
165175
offset = max(0, int(offset))
176+
stmt = sa.select(CallRecord)
177+
if user_id is not None:
178+
stmt = stmt.where(CallRecord.user_id == user_id)
179+
stmt = stmt.order_by(CallRecord.id.desc()).limit(limit).offset(offset)
166180
with self._Session() as session:
167-
rows = (
168-
session.execute(
169-
sa.select(CallRecord)
170-
.order_by(CallRecord.id.desc())
171-
.limit(limit)
172-
.offset(offset)
173-
)
174-
.scalars()
175-
.all()
176-
)
181+
rows = session.execute(stmt).scalars().all()
177182
return [_row_to_dict(r) for r in rows]
178183

179-
def count(self) -> int:
184+
def count(self, *, user_id: str | None = None) -> int:
185+
stmt = sa.select(sa.func.count(CallRecord.id))
186+
if user_id is not None:
187+
stmt = stmt.where(CallRecord.user_id == user_id)
180188
with self._Session() as session:
181-
return session.execute(sa.select(sa.func.count(CallRecord.id))).scalar_one()
189+
return session.execute(stmt).scalar_one()
182190

183191
def delete(self, call_id: int) -> bool:
184192
with self._Session() as session:
@@ -225,4 +233,5 @@ def _row_to_dict(r: CallRecord) -> dict[str, Any]:
225233
"cost": r.cost,
226234
"latency": r.latency,
227235
"feedback": r.feedback,
236+
"user_id": r.user_id,
228237
}

tests/test_serving.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,27 @@ def test_feedback_endpoint(tmp_path) -> None:
217217
json={"outcome": "good", "correct_model": "strong"}).status_code == 422
218218

219219

220+
def test_user_id_recorded_in_route(tmp_path) -> None:
221+
from xrouter_llm.server import create_app
222+
223+
service = _service(tmp_path)
224+
client = TestClient(create_app(service))
225+
226+
client.post("/api/route", json={"prompt": "hello", "models": ["cheap", "strong"], "user_id": "alice"})
227+
client.post("/api/route", json={"prompt": "world", "models": ["cheap", "strong"], "user_id": "bob"})
228+
client.post("/api/route", json={"prompt": "anon", "models": ["cheap", "strong"]})
229+
230+
# /api/history is admin-only (no user filter); per-user scoping is done
231+
# by the embedding app via CallStore.recent(user_id=...) directly.
232+
assert client.get("/api/history").json()["total"] == 3
233+
234+
# user_id is persisted and accessible via store
235+
assert service.store.count(user_id="alice") == 1
236+
assert service.store.count(user_id="bob") == 1
237+
alice_rows = service.store.recent(user_id="alice")
238+
assert alice_rows[0]["user_id"] == "alice"
239+
240+
220241
def test_history_pagination(tmp_path) -> None:
221242
from xrouter_llm.server import create_app
222243

tests/test_store.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -132,16 +132,17 @@ def _legacy_db_empty_version(tmp_path):
132132

133133

134134
def test_legacy_db_no_alembic_version(tmp_path) -> None:
135-
"""CallStore opens a pre-Alembic DB, runs pending migrations (adds feedback column)."""
135+
"""CallStore opens a pre-Alembic DB, runs pending migrations (adds feedback + user_id columns)."""
136136
url = _legacy_db(tmp_path)
137137
store = CallStore(url)
138138
store.record(
139139
ts=1.0, config="all", prompt="hello", task=None,
140140
selected=["m"], candidates=[], expected_quality=0.8, cost=0.0, latency=0.0,
141141
)
142142
assert store.count() == 1
143-
# feedback column must exist after migration 0002 runs
144-
assert store.recent()[0]["feedback"] is None
143+
row = store.recent()[0]
144+
assert row["feedback"] is None
145+
assert row["user_id"] is None
145146

146147

147148
def test_legacy_db_empty_alembic_version(tmp_path) -> None:
@@ -174,6 +175,35 @@ def test_in_memory_sqlite_works() -> None:
174175
assert store.recent()[0]["prompt"] == "hello"
175176

176177

178+
def test_user_id_record_and_filter(store) -> None:
179+
store.record(
180+
ts=1.0, config="all", prompt="p1", task=None,
181+
selected=["m"], candidates=[], expected_quality=0.8, cost=0.0, latency=0.0,
182+
user_id="alice",
183+
)
184+
store.record(
185+
ts=2.0, config="all", prompt="p2", task=None,
186+
selected=["m"], candidates=[], expected_quality=0.8, cost=0.0, latency=0.0,
187+
user_id="bob",
188+
)
189+
store.record(
190+
ts=3.0, config="all", prompt="p3", task=None,
191+
selected=["m"], candidates=[], expected_quality=0.8, cost=0.0, latency=0.0,
192+
)
193+
assert store.count() == 3
194+
assert store.count(user_id="alice") == 1
195+
assert store.count(user_id="bob") == 1
196+
197+
alice_rows = store.recent(user_id="alice")
198+
assert len(alice_rows) == 1
199+
assert alice_rows[0]["user_id"] == "alice"
200+
assert alice_rows[0]["prompt"] == "p1"
201+
202+
# anonymous call has user_id None in the response
203+
anon_rows = [r for r in store.recent() if r["user_id"] is None]
204+
assert len(anon_rows) == 1
205+
206+
177207
def test_auto_migrate_false_skips_migration(tmp_path) -> None:
178208
"""auto_migrate=False does not run migrations (table absent → OperationalError on first use)."""
179209
import pytest

0 commit comments

Comments
 (0)