Skip to content

Commit 486ab57

Browse files
authored
[FIX] Allow CV without passing a df (#1545)
1 parent 035a82e commit 486ab57

2 files changed

Lines changed: 43 additions & 0 deletions

File tree

neuralforecast/core.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1939,6 +1939,23 @@ def _no_refit_cross_validation(
19391939
fcsts_df = ufp.horizontal_concat([fcsts_df, fcsts])
19401940

19411941
# Add original input df's y to forecasts DataFrame
1942+
if df is None:
1943+
# Reconstruct the target from the stored dataset. The dataset's
1944+
# temporal values are scaled, so undo any target scaling.
1945+
target_values = (
1946+
self.dataset.temporal[:, self.dataset.y_idx].clone().numpy()
1947+
)
1948+
if self.scalers_:
1949+
target_values = self._scalers_target_inverse_transform(
1950+
target_values.reshape(-1, 1), self.dataset.indptr
1951+
).reshape(-1)
1952+
df = type(fcsts_df)(
1953+
{
1954+
id_col: ufp.repeat(self.uids, np.diff(self.dataset.indptr)),
1955+
time_col: self.ds,
1956+
target_col: target_values,
1957+
}
1958+
)
19421959
return ufp.join(
19431960
fcsts_df,
19441961
df[[id_col, time_col, target_col]],

tests/test_core.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,32 @@ def test_neural_forecast_fit_cross_validation(setup_airplane_data):
215215
pd.testing.assert_frame_equal(init_cv, after_cv)
216216
pd.testing.assert_frame_equal(after_fcst, init_fcst)
217217

218+
219+
# cross_validation() with no `df` should reuse the stored dataset
220+
@pytest.mark.parametrize("use_polars", [False, True])
221+
def test_cross_validation_without_df_uses_stored_dataset(
222+
use_polars, setup_airplane_data, setup_airplane_data_polars
223+
):
224+
if use_polars:
225+
df, _ = setup_airplane_data_polars
226+
freq = "1mo"
227+
col_kwargs = dict(id_col="uid", time_col="time", target_col="target")
228+
assert_frame_equal = polars.testing.assert_frame_equal
229+
else:
230+
df, _ = setup_airplane_data
231+
freq = "M"
232+
col_kwargs = {}
233+
assert_frame_equal = pd.testing.assert_frame_equal
234+
235+
models = [NHITS(h=12, input_size=24, max_steps=2, random_seed=0)]
236+
nf = NeuralForecast(models=models, freq=freq)
237+
nf.fit(df, **col_kwargs)
238+
# use_init_models resets to the same seeded weights before each run
239+
cv_with_df = nf.cross_validation(df, use_init_models=True, **col_kwargs)
240+
cv_no_df = nf.cross_validation(use_init_models=True, **col_kwargs)
241+
assert_frame_equal(cv_no_df, cv_with_df)
242+
243+
218244
# test cross_validation with refit
219245
def test_neural_forecast_refit(setup_airplane_data):
220246
AirPassengersPanel_train, _ = setup_airplane_data

0 commit comments

Comments
 (0)