Skip to content

Commit 4f4ca4c

Browse files
authored
[FIX] ResetEarlyStopping state in cross-validation when refit=True (#1529)
1 parent 7958cd4 commit 4f4ca4c

2 files changed

Lines changed: 78 additions & 6 deletions

File tree

neuralforecast/common/_base_model.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from contextlib import contextmanager
99
from copy import deepcopy
1010
from dataclasses import dataclass
11-
from typing import Dict, List, Union
11+
from typing import Dict, List, Optional, Union
1212

1313
import fsspec
1414
import numpy as np
@@ -295,6 +295,8 @@ def __init__(
295295
raise Exception("max_epochs is deprecated, use max_steps instead.")
296296

297297
# Callbacks
298+
self._early_stop_kwargs: Optional[dict] = None
299+
self._early_stop_cb: Optional[EarlyStopping] = None
298300
if early_stop_patience_steps > 0:
299301
valid_monitors = ["ptl/val_loss", "valid_loss", "train_loss"]
300302
if val_monitor not in valid_monitors:
@@ -304,11 +306,12 @@ def __init__(
304306
)
305307
if "callbacks" not in trainer_kwargs:
306308
trainer_kwargs["callbacks"] = []
307-
trainer_kwargs["callbacks"].append(
308-
EarlyStopping(
309-
monitor=val_monitor, patience=early_stop_patience_steps
310-
)
311-
)
309+
self._early_stop_kwargs = {
310+
"monitor": val_monitor,
311+
"patience": early_stop_patience_steps,
312+
}
313+
self._early_stop_cb = EarlyStopping(**self._early_stop_kwargs)
314+
trainer_kwargs["callbacks"].append(self._early_stop_cb)
312315

313316
# Add GPU accelerator if available
314317
if trainer_kwargs.get("accelerator", None) is None:
@@ -604,6 +607,17 @@ def _fit(
604607
self.trainer_kwargs["val_check_interval"] = int(val_check_interval)
605608
self.trainer_kwargs["check_val_every_n_epoch"] = None
606609

610+
# The EarlyStopping callback we add in __init__ accumulates state across fits.
611+
# Rebuild it from the stored kwargs.
612+
# We match by identity, so user-supplied EarlyStopping callbacks are left untouched.
613+
if self._early_stop_cb is not None and self._early_stop_kwargs is not None:
614+
callbacks = self.trainer_kwargs.get("callbacks") or []
615+
for i, cb in enumerate(callbacks):
616+
if cb is self._early_stop_cb:
617+
self._early_stop_cb = EarlyStopping(**self._early_stop_kwargs)
618+
callbacks[i] = self._early_stop_cb
619+
break
620+
607621
if is_local:
608622
model = self
609623
trainer = pl.Trainer(**model.trainer_kwargs)

tests/test_core.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,64 @@ def test_neural_forecast_refit(setup_airplane_data):
253253
)
254254

255255

256+
# Cross_validation(refit=True, val_size=...) must give each refit window a fresh
257+
# EarlyStopping state. Otherwise the prior window's wait_count and best_score
258+
# carry over and subsequent refits stop on the first validation check.
259+
def test_cross_validation_refit_resets_early_stopping(setup_airplane_data):
260+
import pytorch_lightning as pl
261+
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
262+
263+
AirPassengersPanel_train, _ = setup_airplane_data
264+
265+
class _CaptureEarlyStoppingState(pl.Callback):
266+
def __init__(self):
267+
self.snapshots = []
268+
269+
def on_train_start(self, trainer, pl_module):
270+
es = next(
271+
(cb for cb in trainer.callbacks if isinstance(cb, EarlyStopping)),
272+
None,
273+
)
274+
assert es is not None, "EarlyStopping callback missing from trainer"
275+
self.snapshots.append(
276+
{
277+
"wait_count": es.wait_count,
278+
"stopped_epoch": es.stopped_epoch,
279+
"best_score": float(es.best_score),
280+
}
281+
)
282+
283+
model = NHITS(
284+
h=12,
285+
input_size=24,
286+
max_steps=4,
287+
val_check_steps=1,
288+
early_stop_patience_steps=1,
289+
callbacks=[_CaptureEarlyStoppingState()],
290+
enable_progress_bar=False,
291+
)
292+
nf = NeuralForecast(models=[model], freq="M")
293+
capture = next(
294+
cb
295+
for cb in nf.models[0].trainer_kwargs["callbacks"]
296+
if isinstance(cb, _CaptureEarlyStoppingState)
297+
)
298+
299+
nf.cross_validation(
300+
df=AirPassengersPanel_train,
301+
val_size=12,
302+
n_windows=3,
303+
refit=True,
304+
)
305+
306+
# One fit per refit window.
307+
assert len(capture.snapshots) == 3, capture.snapshots
308+
for i, snap in enumerate(capture.snapshots):
309+
assert snap["wait_count"] == 0, (i, snap)
310+
assert snap["stopped_epoch"] == 0, (i, snap)
311+
assert snap["best_score"] == float("inf"), (i, snap)
312+
313+
256314
def test_neural_forecast_scaling(setup_airplane_data):
257315
"""Test scaling functionality for NeuralForecast models."""
258316
AirPassengersPanel_train, AirPassengersPanel_test = setup_airplane_data

0 commit comments

Comments
 (0)