@@ -215,6 +215,32 @@ def test_neural_forecast_fit_cross_validation(setup_airplane_data):
215215 pd .testing .assert_frame_equal (init_cv , after_cv )
216216 pd .testing .assert_frame_equal (after_fcst , init_fcst )
217217
218+
219+ # cross_validation() with no `df` should reuse the stored dataset
220+ @pytest .mark .parametrize ("use_polars" , [False , True ])
221+ def test_cross_validation_without_df_uses_stored_dataset (
222+ use_polars , setup_airplane_data , setup_airplane_data_polars
223+ ):
224+ if use_polars :
225+ df , _ = setup_airplane_data_polars
226+ freq = "1mo"
227+ col_kwargs = dict (id_col = "uid" , time_col = "time" , target_col = "target" )
228+ assert_frame_equal = polars .testing .assert_frame_equal
229+ else :
230+ df , _ = setup_airplane_data
231+ freq = "M"
232+ col_kwargs = {}
233+ assert_frame_equal = pd .testing .assert_frame_equal
234+
235+ models = [NHITS (h = 12 , input_size = 24 , max_steps = 2 , random_seed = 0 )]
236+ nf = NeuralForecast (models = models , freq = freq )
237+ nf .fit (df , ** col_kwargs )
238+ # use_init_models resets to the same seeded weights before each run
239+ cv_with_df = nf .cross_validation (df , use_init_models = True , ** col_kwargs )
240+ cv_no_df = nf .cross_validation (use_init_models = True , ** col_kwargs )
241+ assert_frame_equal (cv_no_df , cv_with_df )
242+
243+
218244# test cross_validation with refit
219245def test_neural_forecast_refit (setup_airplane_data ):
220246 AirPassengersPanel_train , _ = setup_airplane_data
0 commit comments