1111 cast ,
1212)
1313
14- import fastapi .params
1514import fastapi .routing
16- import fastapi .security .base
17- import fastapi .utils
1815from fastapi import APIRouter
19- from fastapi .routing import APIRoute
16+ from fastapi .routing import APIRoute , _EffectiveRouteContext , _iter_routes_with_context
2017from pydantic import BaseModel
2118from starlette .routing import BaseRoute
2219from typing_extensions import TypeVar , assert_never
@@ -104,11 +101,29 @@ def only_exists_in_older_versions(self, endpoint: _Call) -> _Call:
104101
105102def copy_router (router : _R ) -> _R :
106103 router = copy (router )
107- router .routes = [copy_route (r ) for r in router .routes ]
104+ router .routes = [
105+ copy_route (route , effective_route_context )
106+ for route , effective_route_context in _iter_routes_with_context (router .routes )
107+ ]
108+ router ._mark_routes_changed ()
108109 return router
109110
110111
111- def copy_route (route : _RouteT ) -> _RouteT :
112+ def copy_route (route : _RouteT , effective_route_context : _EffectiveRouteContext | None = None ) -> _RouteT :
113+ """Copy a route and materialize FastAPI's include-router context into the copy.
114+
115+ FastAPI 0.137+ keeps included routers as tree nodes. Its _EffectiveRouteContext is the merged
116+ route state produced by that tree: prefix, dependencies, tags, responses, schema flags, etc.
117+ Cadwyn mutates copied routes per API version, so it must copy the original route and apply that
118+ effective state to the copy. The original router tree is left intact.
119+ """
120+ if (
121+ effective_route_context is not None
122+ and not isinstance (route , APIRoute )
123+ and effective_route_context .starlette_route is not None
124+ ):
125+ return cast ("_RouteT" , copy (effective_route_context .starlette_route ))
126+
112127 if not isinstance (route , APIRoute ):
113128 return copy (route )
114129
@@ -120,19 +135,46 @@ def copy_route(route: _RouteT) -> _RouteT:
120135 # These can hold TypeAdapters for recursive types (e.g. JsonValue) that cause
121136 # infinite recursion during deepcopy.
122137 memo : dict [int , Any ] = {}
123- for attr in ("dependant" , "_flat_dependant" , "body_field" , "response_model" ):
138+ for attr in ("dependant" , "_flat_dependant" , "body_field" , "response_model" , "dependency_overrides_provider" ):
124139 obj = getattr (route , attr , None )
125140 if obj is not None :
126141 memo [id (obj )] = obj
127142 new_route = deepcopy (route , memo )
128- new_route .dependant = copy (route .dependant )
129- if getattr (route , "_flat_dependant" , None ) is not None :
130- new_route ._flat_dependant = copy (route ._flat_dependant )
131- new_route .body_field = route .body_field
132- new_route .dependencies = copy (route .dependencies )
143+ if effective_route_context is not None :
144+ _apply_effective_route_context_to_route (new_route , effective_route_context )
145+ _refresh_route_app (new_route )
146+ else :
147+ new_route .dependant = copy (route .dependant )
148+ if getattr (route , "_flat_dependant" , None ) is not None :
149+ new_route ._flat_dependant = copy (route ._flat_dependant )
150+ new_route .body_field = route .body_field
151+ new_route .dependencies = copy (route .dependencies )
133152 return new_route
134153
135154
155+ def _apply_effective_route_context_to_route (route : APIRoute , effective_route_context : _EffectiveRouteContext ) -> None :
156+ for attr_name , attr_value in vars (effective_route_context ).items ():
157+ if attr_name in {"original_route" , "starlette_route" }:
158+ continue
159+ setattr (route , attr_name , _copy_effective_route_context_attr (attr_name , attr_value ))
160+
161+
162+ def _copy_effective_route_context_attr (attr_name : str , attr_value : Any ) -> Any :
163+ if attr_name in {"dependant" , "_flat_dependant" } and attr_value is not None :
164+ return copy (attr_value )
165+ if isinstance (attr_value , (dict , list , set )):
166+ return copy (attr_value )
167+ return attr_value
168+
169+
170+ def _refresh_route_app (route : APIRoute ) -> None :
171+ route .app = fastapi .routing .request_response (route .get_route_handler ())
172+
173+
174+ def _route_methods (route : APIRoute ) -> set [str ]:
175+ return route .methods or set ()
176+
177+
136178class _EndpointTransformer (Generic [_R , _WR ]):
137179 def __init__ (
138180 self ,
@@ -149,9 +191,12 @@ def __init__(
149191 self .api_version_parameter_name = api_version_parameter_name
150192 self .api_version_location : APIVersionLocation = api_version_location
151193 self .schema_generators = generate_versioned_models (versions )
194+ self .head_router = copy_router (parent_router )
152195
153196 self .routes_that_never_existed = [
154- route for route in parent_router .routes if isinstance (route , APIRoute ) and _DELETED_ROUTE_TAG in route .tags
197+ route
198+ for route in self .head_router .routes
199+ if isinstance (route , APIRoute ) and _DELETED_ROUTE_TAG in route .tags
155200 ]
156201
157202 def transform (self ) -> GeneratedRouters [_R , _WR ]:
@@ -165,7 +210,7 @@ def transform(self) -> GeneratedRouters[_R, _WR]:
165210 self .schema_generators [str (version .value )].annotation_transformer .migrate_router_to_version (router )
166211 self .schema_generators [str (version .value )].annotation_transformer .migrate_router_to_version (webhook_router )
167212
168- self ._attach_routes_to_data_converters (router , self .parent_router , version )
213+ self ._attach_routes_to_data_converters (router , self .head_router , version )
169214
170215 routers [version .value ] = router
171216 webhook_routers [version .value ] = webhook_router
@@ -183,7 +228,7 @@ def transform(self) -> GeneratedRouters[_R, _WR]:
183228 f"{ self .routes_that_never_existed } " ,
184229 )
185230
186- for route_index , head_route in enumerate (self .parent_router .routes ):
231+ for route_index , head_route in enumerate (self .head_router .routes ):
187232 if not isinstance (head_route , APIRoute ):
188233 continue
189234 _add_request_and_response_params (head_route )
@@ -326,7 +371,7 @@ def _extract_all_routes_identifiers_for_route_to_converter_matching(
326371 annotation = route .body_field .field_info .annotation
327372 if annotation is not None and lenient_issubclass (annotation , BaseModel ):
328373 request_bodies .add (annotation )
329- for method in route . methods :
374+ for method in _route_methods ( route ) :
330375 path_to_route_methods_mapping [route .path ][method ].add (index )
331376
332377 head_response_models = {model .__cadwyn_original_model__ for model in response_models }
@@ -363,7 +408,7 @@ def _apply_endpoint_changes_to_router( # noqa: C901
363408 if deleted_routes :
364409 method_union = set ()
365410 for deleted_route in deleted_routes :
366- method_union |= deleted_route . methods
411+ method_union |= _route_methods ( deleted_route )
367412 raise RouterGenerationError (
368413 f'Endpoint "{ list (method_union )} { instruction .endpoint_path } " you tried to delete in '
369414 f'"{ version_change .__name__ } " was already deleted in a newer version. If you really have '
@@ -372,7 +417,7 @@ def _apply_endpoint_changes_to_router( # noqa: C901
372417 f"{ [r .endpoint .__name__ for r in deleted_routes ]} " ,
373418 )
374419 for original_route in original_routes :
375- methods_to_which_we_applied_changes |= original_route . methods
420+ methods_to_which_we_applied_changes |= _route_methods ( original_route )
376421 original_route .tags .append (_DELETED_ROUTE_TAG )
377422 err = (
378423 'Endpoint "{endpoint_methods} {endpoint_path}" you tried to delete in'
@@ -382,7 +427,7 @@ def _apply_endpoint_changes_to_router( # noqa: C901
382427 if original_routes :
383428 method_union = set ()
384429 for original_route in original_routes :
385- method_union |= original_route . methods
430+ method_union |= _route_methods ( original_route )
386431 raise RouterGenerationError (
387432 f'Endpoint "{ list (method_union )} { instruction .endpoint_path } " you tried to restore in'
388433 f' "{ version_change .__name__ } " already existed in a newer version. If you really have two '
@@ -408,13 +453,13 @@ def _apply_endpoint_changes_to_router( # noqa: C901
408453 f"endpoints that can be restored: { [r .endpoint .__name__ for r in e .routes ]} " ,
409454 ) from e
410455 for deleted_route in deleted_routes :
411- methods_to_which_we_applied_changes |= deleted_route . methods
456+ methods_to_which_we_applied_changes |= _route_methods ( deleted_route )
412457 deleted_route .tags .remove (_DELETED_ROUTE_TAG )
413458
414459 routes_that_never_existed = _get_routes (
415460 self .routes_that_never_existed ,
416461 deleted_route .path ,
417- deleted_route . methods ,
462+ _route_methods ( deleted_route ) ,
418463 deleted_route .endpoint .__name__ ,
419464 is_deleted = True ,
420465 )
@@ -425,7 +470,8 @@ def _apply_endpoint_changes_to_router( # noqa: C901
425470 # to remove it because I like its clarity very much
426471 routes = routes_that_never_existed
427472 raise RouterGenerationError (
428- f'Endpoint "{ list (deleted_route .methods )} { deleted_route .path } " you tried to restore '
473+ f'Endpoint "{ list (_route_methods (deleted_route ))} { deleted_route .path } " '
474+ "you tried to restore "
429475 f'in "{ version_change .__name__ } " has { len (routes_that_never_existed )} applicable '
430476 f"routes with the same function name and path that could be restored. This can cause "
431477 f"problems during version generation. Specifically, Cadwyn won't be able to warn "
@@ -439,7 +485,7 @@ def _apply_endpoint_changes_to_router( # noqa: C901
439485 )
440486 elif isinstance (instruction , EndpointHadInstruction ):
441487 for original_route in original_routes :
442- methods_to_which_we_applied_changes |= original_route . methods
488+ methods_to_which_we_applied_changes |= _route_methods ( original_route )
443489 _apply_endpoint_had_instruction (
444490 version_change .__name__ ,
445491 instruction ,
@@ -468,7 +514,7 @@ def _validate_no_repetitions_in_routes(routes: list[fastapi.routing.APIRoute]):
468514 route_map = {}
469515
470516 for route in routes :
471- route_info = _EndpointInfo (route .path , frozenset (route . methods ))
517+ route_info = _EndpointInfo (route .path , frozenset (_route_methods ( route ) ))
472518 if route_info in route_map :
473519 raise RouteAlreadyExistsError (route , route_map [route_info ])
474520 route_map [route_info ] = route
@@ -485,7 +531,8 @@ def _add_data_migrations_to_route(
485531 if not (route .dependant .request_param_name and route .dependant .response_param_name ): # pragma: no cover
486532 raise CadwynError (
487533 f"{ route .dependant .request_param_name = } , { route .dependant .response_param_name = } "
488- f"for route { list (route .methods )} { route .path } which should not be possible. Please, contact my author." ,
534+ f"for route { list (_route_methods (route ))} { route .path } which should not be possible. "
535+ "Please, contact my author." ,
489536 )
490537
491538 route .endpoint = versions ._versioned (
@@ -514,7 +561,7 @@ def _apply_endpoint_had_instruction(
514561 if getattr (original_route , attr_name ) == attr :
515562 raise RouterGenerationError (
516563 f'Expected attribute "{ attr_name } " of endpoint'
517- f' "{ list (original_route . methods )} { original_route .path } "'
564+ f' "{ list (_route_methods ( original_route ) )} { original_route .path } "'
518565 f' to be different in "{ version_change_name } ", but it was the same.'
519566 " It means that your version change has no effect on the attribute"
520567 " and can be removed." ,
@@ -526,7 +573,7 @@ def _apply_endpoint_had_instruction(
526573 new_path_params .discard (api_version_parameter_name )
527574 if new_path_params != original_path_params :
528575 raise RouterPathParamsModifiedError (
529- f'When altering the path of "{ list (original_route . methods )} { original_route .path } " '
576+ f'When altering the path of "{ list (_route_methods ( original_route ) )} { original_route .path } " '
530577 f'in "{ version_change_name } ", you have tried to change its path params '
531578 f'from "{ list (original_path_params )} " to "{ list (new_path_params )} ". It is not allowed to '
532579 "change the path params of a route because the endpoint was created to handle the old path "
@@ -552,7 +599,7 @@ def _get_routes(
552599 if (
553600 isinstance (route , fastapi .routing .APIRoute )
554601 and route .path .rstrip ("/" ) == endpoint_path
555- and set (route . methods ).issubset (endpoint_methods )
602+ and _route_methods (route ).issubset (endpoint_methods )
556603 and (endpoint_func_name is None or route .endpoint .__name__ == endpoint_func_name )
557604 and (_DELETED_ROUTE_TAG in route .tags ) == is_deleted
558605 )
@@ -563,7 +610,7 @@ def _get_route_from_func(
563610 routes : Sequence [BaseRoute ],
564611 endpoint : Endpoint ,
565612) -> Union [fastapi .routing .APIRoute , None ]:
566- for route in routes :
613+ for route , _effective_route_context in _iter_routes_with_context ( routes ) :
567614 if isinstance (route , fastapi .routing .APIRoute ) and (route .endpoint == endpoint ):
568615 return route
569616 return None
0 commit comments