Skip to content

Commit 1cf449c

Browse files
authored
[FIX] Further fixes for CV with no df (#1548)
1 parent 0dd6828 commit 1cf449c

2 files changed

Lines changed: 16 additions & 8 deletions

File tree

neuralforecast/core.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1863,6 +1863,11 @@ def _no_refit_cross_validation(
18631863
)
18641864
)
18651865
else:
1866+
id_col, time_col, target_col = (
1867+
self.id_col,
1868+
self.time_col,
1869+
self.target_col,
1870+
)
18661871
if verbose:
18671872
print("Using stored dataset.")
18681873

@@ -1953,13 +1958,14 @@ def _no_refit_cross_validation(
19531958
if df is None:
19541959
# Reconstruct the target from the stored dataset. The dataset's
19551960
# temporal values are scaled, so undo any target scaling.
1956-
target_values = (
1957-
self.dataset.temporal[:, self.dataset.y_idx].clone().numpy()
1958-
)
1961+
target_column = self.dataset.temporal[:, self.dataset.y_idx]
19591962
if self.scalers_:
19601963
target_values = self._scalers_target_inverse_transform(
1961-
target_values.reshape(-1, 1), self.dataset.indptr
1964+
target_column.clone().numpy().reshape(-1, 1),
1965+
self.dataset.indptr,
19621966
).reshape(-1)
1967+
else:
1968+
target_values = target_column.numpy()
19631969
df = type(fcsts_df)(
19641970
{
19651971
id_col: ufp.repeat(self.uids, np.diff(self.dataset.indptr)),

tests/test_core.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -239,10 +239,11 @@ def test_neural_forecast_fit_cross_validation(setup_airplane_data):
239239
pd.testing.assert_frame_equal(after_fcst, init_fcst)
240240

241241

242-
# cross_validation() with no `df` should reuse the stored dataset
242+
# cross_validation() with no `df` should reuse the stored dataset.
243+
@pytest.mark.parametrize("local_scaler_type", [None, "standard"])
243244
@pytest.mark.parametrize("use_polars", [False, True])
244245
def test_cross_validation_without_df_uses_stored_dataset(
245-
use_polars, setup_airplane_data, setup_airplane_data_polars
246+
use_polars, local_scaler_type, setup_airplane_data, setup_airplane_data_polars
246247
):
247248
if use_polars:
248249
df, _ = setup_airplane_data_polars
@@ -256,11 +257,12 @@ def test_cross_validation_without_df_uses_stored_dataset(
256257
assert_frame_equal = pd.testing.assert_frame_equal
257258

258259
models = [NHITS(h=12, input_size=24, max_steps=2, random_seed=0)]
259-
nf = NeuralForecast(models=models, freq=freq)
260+
nf = NeuralForecast(models=models, freq=freq, local_scaler_type=local_scaler_type)
260261
nf.fit(df, **col_kwargs)
261262
# use_init_models resets to the same seeded weights before each run
262263
cv_with_df = nf.cross_validation(df, use_init_models=True, **col_kwargs)
263-
cv_no_df = nf.cross_validation(use_init_models=True, **col_kwargs)
264+
# No column kwargs here: the stored dataset's column names must be preserved.
265+
cv_no_df = nf.cross_validation(use_init_models=True)
264266
assert_frame_equal(cv_no_df, cv_with_df)
265267

266268

0 commit comments

Comments
 (0)