Skip to content

Commit 2e8d158

Browse files
authored
Add support for fastapi's lazy routers (#379)
Support FastAPI 0.137.1+ included-router trees, including routes added after inclusion, for both versioned and unversioned Cadwyn routing
1 parent e1ff8f1 commit 2e8d158

8 files changed

Lines changed: 561 additions & 53 deletions

File tree

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,12 @@ Please follow [the Keep a Changelog standard](https://keepachangelog.com/en/1.0.
55

66
## [Unreleased]
77

8+
## [7.1.0]
9+
10+
### Added
11+
12+
- Support FastAPI 0.137.1+ included-router trees, including routes added to included routers after inclusion, for Cadwyn's versioned and unversioned routing
13+
814
## [7.0.0]
915

1016
- Breaking change if you relied on lifespan running multiple times: Fixed duplicated lifecycle callbacks in Cadwyn routers, see [this issue](https://github.com/zmievsa/cadwyn/issues/372) for more details

cadwyn/applications.py

Lines changed: 43 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from fastapi.openapi.utils import get_openapi
1919
from fastapi.params import Depends
2020
from fastapi.responses import HTMLResponse
21+
from fastapi.routing import _EffectiveRouteContext, _IncludedRouter, _iter_routes_with_context
2122
from fastapi.templating import Jinja2Templates
2223
from fastapi.utils import generate_unique_id
2324
from starlette.middleware import Middleware
@@ -38,7 +39,7 @@
3839
VersionPickingMiddleware,
3940
_generate_api_version_dependency,
4041
)
41-
from cadwyn.route_generation import generate_versioned_routers
42+
from cadwyn.route_generation import copy_route, generate_versioned_routers
4243
from cadwyn.routing import _RootCadwynAPIRouter
4344
from cadwyn.structure import VersionBundle
4445

@@ -51,6 +52,22 @@
5152
logger = getLogger(__name__)
5253

5354

55+
def _get_effective_include_in_schema(route: BaseRoute, effective_route_context: _EffectiveRouteContext | None) -> bool:
56+
if effective_route_context is not None:
57+
starlette_route = effective_route_context.starlette_route
58+
if starlette_route is not None:
59+
return bool(getattr(starlette_route, "include_in_schema", False))
60+
return bool(effective_route_context.include_in_schema)
61+
return bool(getattr(route, "include_in_schema", False))
62+
63+
64+
def _materialize_routes(routes: Sequence[BaseRoute]) -> list[BaseRoute]:
65+
return [
66+
copy_route(route, effective_route_context) if effective_route_context is not None else route
67+
for route, effective_route_context in _iter_routes_with_context(routes)
68+
]
69+
70+
5471
@dataclasses.dataclass(**DATACLASS_SLOTS)
5572
class FakeDependencyOverridesProvider:
5673
dependency_overrides: dict[Callable[..., Any], Callable[..., Any]]
@@ -255,7 +272,13 @@ def __init__(
255272
unversioned_router = APIRouter(**self._kwargs_to_router)
256273
self._add_utility_endpoints(unversioned_router)
257274
self._add_default_versioned_routers()
275+
route_count_before_include = len(self.router.routes)
276+
unversioned_route_count_before_include = len(self.router.unversioned_routes)
258277
self.include_router(unversioned_router)
278+
utility_routes = _materialize_routes(self.router.routes[route_count_before_include:])
279+
self.router.routes[route_count_before_include:] = utility_routes
280+
self.router.unversioned_routes[unversioned_route_count_before_include:] = utility_routes
281+
self.router._mark_routes_changed()
259282
self.add_middleware(
260283
versioning_middleware_class,
261284
api_version_parameter_name=api_version_parameter_name,
@@ -411,15 +434,19 @@ async def openapi_jsons(self, req: Request) -> JSONResponse:
411434
)
412435

413436
def _there_are_public_unversioned_routes(self):
414-
return any(isinstance(route, Route) and route.include_in_schema for route in self.router.unversioned_routes)
437+
return any(
438+
isinstance(route, Route) and _get_effective_include_in_schema(route, effective_route_context)
439+
for route, effective_route_context in _iter_routes_with_context(self.router.unversioned_routes)
440+
)
415441

416442
def _filter_openapi_tags(self, routes: list) -> Union[list[dict[str, Any]], None]:
417443
if not self.openapi_tags:
418444
return self.openapi_tags
419445
used_tags: set[str | Enum] = set()
420-
for route in routes:
421-
if isinstance(route, routing.APIRoute) and route.include_in_schema:
422-
used_tags.update(route.tags)
446+
for route, effective_route_context in _iter_routes_with_context(routes):
447+
route_data = cast("Any", effective_route_context or route)
448+
if isinstance(route, routing.APIRoute) and route_data.include_in_schema:
449+
used_tags.update(route_data.tags)
423450
return [tag for tag in self.openapi_tags if tag.get("name") in used_tags]
424451

425452
async def swagger_dashboard(self, req: Request) -> Response:
@@ -505,8 +532,8 @@ def _add_versioned_routers(
505532
)
506533
added_routes.append(versioned_router.routes[-1])
507534

508-
added_route_count = 0
509535
for router in (first_router, *other_routers):
536+
route_count_before_include = len(versioned_router.routes)
510537
self.router.versioned_routers[version].include_router(
511538
router,
512539
dependencies=[
@@ -522,9 +549,15 @@ def _add_versioned_routers(
522549
)
523550
],
524551
)
525-
added_route_count += len(router.routes)
526-
527-
added_routes.extend(versioned_router.routes[-added_route_count:])
528-
self.router.routes.extend(added_routes)
552+
newly_added_routes = versioned_router.routes[route_count_before_include:]
553+
if not any(isinstance(route, _IncludedRouter) for route in router.routes):
554+
# There is no nested FastAPI include tree to preserve, so Cadwyn can materialize the plain
555+
# routes here and keep the root router's version bookkeeping stable.
556+
newly_added_routes = _materialize_routes(newly_added_routes)
557+
versioned_router.routes[route_count_before_include:] = newly_added_routes
558+
versioned_router._mark_routes_changed()
559+
added_routes.extend(newly_added_routes)
560+
561+
self.router.extend_routes(added_routes)
529562

530563
return added_routes

cadwyn/route_generation.py

Lines changed: 76 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,9 @@
1111
cast,
1212
)
1313

14-
import fastapi.params
1514
import fastapi.routing
16-
import fastapi.security.base
17-
import fastapi.utils
1815
from fastapi import APIRouter
19-
from fastapi.routing import APIRoute
16+
from fastapi.routing import APIRoute, _EffectiveRouteContext, _iter_routes_with_context
2017
from pydantic import BaseModel
2118
from starlette.routing import BaseRoute
2219
from typing_extensions import TypeVar, assert_never
@@ -104,11 +101,29 @@ def only_exists_in_older_versions(self, endpoint: _Call) -> _Call:
104101

105102
def copy_router(router: _R) -> _R:
106103
router = copy(router)
107-
router.routes = [copy_route(r) for r in router.routes]
104+
router.routes = [
105+
copy_route(route, effective_route_context)
106+
for route, effective_route_context in _iter_routes_with_context(router.routes)
107+
]
108+
router._mark_routes_changed()
108109
return router
109110

110111

111-
def copy_route(route: _RouteT) -> _RouteT:
112+
def copy_route(route: _RouteT, effective_route_context: _EffectiveRouteContext | None = None) -> _RouteT:
113+
"""Copy a route and materialize FastAPI's include-router context into the copy.
114+
115+
FastAPI 0.137+ keeps included routers as tree nodes. Its _EffectiveRouteContext is the merged
116+
route state produced by that tree: prefix, dependencies, tags, responses, schema flags, etc.
117+
Cadwyn mutates copied routes per API version, so it must copy the original route and apply that
118+
effective state to the copy. The original router tree is left intact.
119+
"""
120+
if (
121+
effective_route_context is not None
122+
and not isinstance(route, APIRoute)
123+
and effective_route_context.starlette_route is not None
124+
):
125+
return cast("_RouteT", copy(effective_route_context.starlette_route))
126+
112127
if not isinstance(route, APIRoute):
113128
return copy(route)
114129

@@ -120,19 +135,46 @@ def copy_route(route: _RouteT) -> _RouteT:
120135
# These can hold TypeAdapters for recursive types (e.g. JsonValue) that cause
121136
# infinite recursion during deepcopy.
122137
memo: dict[int, Any] = {}
123-
for attr in ("dependant", "_flat_dependant", "body_field", "response_model"):
138+
for attr in ("dependant", "_flat_dependant", "body_field", "response_model", "dependency_overrides_provider"):
124139
obj = getattr(route, attr, None)
125140
if obj is not None:
126141
memo[id(obj)] = obj
127142
new_route = deepcopy(route, memo)
128-
new_route.dependant = copy(route.dependant)
129-
if getattr(route, "_flat_dependant", None) is not None:
130-
new_route._flat_dependant = copy(route._flat_dependant)
131-
new_route.body_field = route.body_field
132-
new_route.dependencies = copy(route.dependencies)
143+
if effective_route_context is not None:
144+
_apply_effective_route_context_to_route(new_route, effective_route_context)
145+
_refresh_route_app(new_route)
146+
else:
147+
new_route.dependant = copy(route.dependant)
148+
if getattr(route, "_flat_dependant", None) is not None:
149+
new_route._flat_dependant = copy(route._flat_dependant)
150+
new_route.body_field = route.body_field
151+
new_route.dependencies = copy(route.dependencies)
133152
return new_route
134153

135154

155+
def _apply_effective_route_context_to_route(route: APIRoute, effective_route_context: _EffectiveRouteContext) -> None:
156+
for attr_name, attr_value in vars(effective_route_context).items():
157+
if attr_name in {"original_route", "starlette_route"}:
158+
continue
159+
setattr(route, attr_name, _copy_effective_route_context_attr(attr_name, attr_value))
160+
161+
162+
def _copy_effective_route_context_attr(attr_name: str, attr_value: Any) -> Any:
163+
if attr_name in {"dependant", "_flat_dependant"} and attr_value is not None:
164+
return copy(attr_value)
165+
if isinstance(attr_value, (dict, list, set)):
166+
return copy(attr_value)
167+
return attr_value
168+
169+
170+
def _refresh_route_app(route: APIRoute) -> None:
171+
route.app = fastapi.routing.request_response(route.get_route_handler())
172+
173+
174+
def _route_methods(route: APIRoute) -> set[str]:
175+
return route.methods or set()
176+
177+
136178
class _EndpointTransformer(Generic[_R, _WR]):
137179
def __init__(
138180
self,
@@ -149,9 +191,12 @@ def __init__(
149191
self.api_version_parameter_name = api_version_parameter_name
150192
self.api_version_location: APIVersionLocation = api_version_location
151193
self.schema_generators = generate_versioned_models(versions)
194+
self.head_router = copy_router(parent_router)
152195

153196
self.routes_that_never_existed = [
154-
route for route in parent_router.routes if isinstance(route, APIRoute) and _DELETED_ROUTE_TAG in route.tags
197+
route
198+
for route in self.head_router.routes
199+
if isinstance(route, APIRoute) and _DELETED_ROUTE_TAG in route.tags
155200
]
156201

157202
def transform(self) -> GeneratedRouters[_R, _WR]:
@@ -165,7 +210,7 @@ def transform(self) -> GeneratedRouters[_R, _WR]:
165210
self.schema_generators[str(version.value)].annotation_transformer.migrate_router_to_version(router)
166211
self.schema_generators[str(version.value)].annotation_transformer.migrate_router_to_version(webhook_router)
167212

168-
self._attach_routes_to_data_converters(router, self.parent_router, version)
213+
self._attach_routes_to_data_converters(router, self.head_router, version)
169214

170215
routers[version.value] = router
171216
webhook_routers[version.value] = webhook_router
@@ -183,7 +228,7 @@ def transform(self) -> GeneratedRouters[_R, _WR]:
183228
f"{self.routes_that_never_existed}",
184229
)
185230

186-
for route_index, head_route in enumerate(self.parent_router.routes):
231+
for route_index, head_route in enumerate(self.head_router.routes):
187232
if not isinstance(head_route, APIRoute):
188233
continue
189234
_add_request_and_response_params(head_route)
@@ -326,7 +371,7 @@ def _extract_all_routes_identifiers_for_route_to_converter_matching(
326371
annotation = route.body_field.field_info.annotation
327372
if annotation is not None and lenient_issubclass(annotation, BaseModel):
328373
request_bodies.add(annotation)
329-
for method in route.methods:
374+
for method in _route_methods(route):
330375
path_to_route_methods_mapping[route.path][method].add(index)
331376

332377
head_response_models = {model.__cadwyn_original_model__ for model in response_models}
@@ -363,7 +408,7 @@ def _apply_endpoint_changes_to_router( # noqa: C901
363408
if deleted_routes:
364409
method_union = set()
365410
for deleted_route in deleted_routes:
366-
method_union |= deleted_route.methods
411+
method_union |= _route_methods(deleted_route)
367412
raise RouterGenerationError(
368413
f'Endpoint "{list(method_union)} {instruction.endpoint_path}" you tried to delete in '
369414
f'"{version_change.__name__}" was already deleted in a newer version. If you really have '
@@ -372,7 +417,7 @@ def _apply_endpoint_changes_to_router( # noqa: C901
372417
f"{[r.endpoint.__name__ for r in deleted_routes]}",
373418
)
374419
for original_route in original_routes:
375-
methods_to_which_we_applied_changes |= original_route.methods
420+
methods_to_which_we_applied_changes |= _route_methods(original_route)
376421
original_route.tags.append(_DELETED_ROUTE_TAG)
377422
err = (
378423
'Endpoint "{endpoint_methods} {endpoint_path}" you tried to delete in'
@@ -382,7 +427,7 @@ def _apply_endpoint_changes_to_router( # noqa: C901
382427
if original_routes:
383428
method_union = set()
384429
for original_route in original_routes:
385-
method_union |= original_route.methods
430+
method_union |= _route_methods(original_route)
386431
raise RouterGenerationError(
387432
f'Endpoint "{list(method_union)} {instruction.endpoint_path}" you tried to restore in'
388433
f' "{version_change.__name__}" already existed in a newer version. If you really have two '
@@ -408,13 +453,13 @@ def _apply_endpoint_changes_to_router( # noqa: C901
408453
f"endpoints that can be restored: {[r.endpoint.__name__ for r in e.routes]}",
409454
) from e
410455
for deleted_route in deleted_routes:
411-
methods_to_which_we_applied_changes |= deleted_route.methods
456+
methods_to_which_we_applied_changes |= _route_methods(deleted_route)
412457
deleted_route.tags.remove(_DELETED_ROUTE_TAG)
413458

414459
routes_that_never_existed = _get_routes(
415460
self.routes_that_never_existed,
416461
deleted_route.path,
417-
deleted_route.methods,
462+
_route_methods(deleted_route),
418463
deleted_route.endpoint.__name__,
419464
is_deleted=True,
420465
)
@@ -425,7 +470,8 @@ def _apply_endpoint_changes_to_router( # noqa: C901
425470
# to remove it because I like its clarity very much
426471
routes = routes_that_never_existed
427472
raise RouterGenerationError(
428-
f'Endpoint "{list(deleted_route.methods)} {deleted_route.path}" you tried to restore '
473+
f'Endpoint "{list(_route_methods(deleted_route))} {deleted_route.path}" '
474+
"you tried to restore "
429475
f'in "{version_change.__name__}" has {len(routes_that_never_existed)} applicable '
430476
f"routes with the same function name and path that could be restored. This can cause "
431477
f"problems during version generation. Specifically, Cadwyn won't be able to warn "
@@ -439,7 +485,7 @@ def _apply_endpoint_changes_to_router( # noqa: C901
439485
)
440486
elif isinstance(instruction, EndpointHadInstruction):
441487
for original_route in original_routes:
442-
methods_to_which_we_applied_changes |= original_route.methods
488+
methods_to_which_we_applied_changes |= _route_methods(original_route)
443489
_apply_endpoint_had_instruction(
444490
version_change.__name__,
445491
instruction,
@@ -468,7 +514,7 @@ def _validate_no_repetitions_in_routes(routes: list[fastapi.routing.APIRoute]):
468514
route_map = {}
469515

470516
for route in routes:
471-
route_info = _EndpointInfo(route.path, frozenset(route.methods))
517+
route_info = _EndpointInfo(route.path, frozenset(_route_methods(route)))
472518
if route_info in route_map:
473519
raise RouteAlreadyExistsError(route, route_map[route_info])
474520
route_map[route_info] = route
@@ -485,7 +531,8 @@ def _add_data_migrations_to_route(
485531
if not (route.dependant.request_param_name and route.dependant.response_param_name): # pragma: no cover
486532
raise CadwynError(
487533
f"{route.dependant.request_param_name=}, {route.dependant.response_param_name=} "
488-
f"for route {list(route.methods)} {route.path} which should not be possible. Please, contact my author.",
534+
f"for route {list(_route_methods(route))} {route.path} which should not be possible. "
535+
"Please, contact my author.",
489536
)
490537

491538
route.endpoint = versions._versioned(
@@ -514,7 +561,7 @@ def _apply_endpoint_had_instruction(
514561
if getattr(original_route, attr_name) == attr:
515562
raise RouterGenerationError(
516563
f'Expected attribute "{attr_name}" of endpoint'
517-
f' "{list(original_route.methods)} {original_route.path}"'
564+
f' "{list(_route_methods(original_route))} {original_route.path}"'
518565
f' to be different in "{version_change_name}", but it was the same.'
519566
" It means that your version change has no effect on the attribute"
520567
" and can be removed.",
@@ -526,7 +573,7 @@ def _apply_endpoint_had_instruction(
526573
new_path_params.discard(api_version_parameter_name)
527574
if new_path_params != original_path_params:
528575
raise RouterPathParamsModifiedError(
529-
f'When altering the path of "{list(original_route.methods)} {original_route.path}" '
576+
f'When altering the path of "{list(_route_methods(original_route))} {original_route.path}" '
530577
f'in "{version_change_name}", you have tried to change its path params '
531578
f'from "{list(original_path_params)}" to "{list(new_path_params)}". It is not allowed to '
532579
"change the path params of a route because the endpoint was created to handle the old path "
@@ -552,7 +599,7 @@ def _get_routes(
552599
if (
553600
isinstance(route, fastapi.routing.APIRoute)
554601
and route.path.rstrip("/") == endpoint_path
555-
and set(route.methods).issubset(endpoint_methods)
602+
and _route_methods(route).issubset(endpoint_methods)
556603
and (endpoint_func_name is None or route.endpoint.__name__ == endpoint_func_name)
557604
and (_DELETED_ROUTE_TAG in route.tags) == is_deleted
558605
)
@@ -563,7 +610,7 @@ def _get_route_from_func(
563610
routes: Sequence[BaseRoute],
564611
endpoint: Endpoint,
565612
) -> Union[fastapi.routing.APIRoute, None]:
566-
for route in routes:
613+
for route, _effective_route_context in _iter_routes_with_context(routes):
567614
if isinstance(route, fastapi.routing.APIRoute) and (route.endpoint == endpoint):
568615
return route
569616
return None

0 commit comments

Comments
 (0)