Skip to content

Commit 936a3cf

Browse files
authored
[FIX] Error when saving model with custom callbacks (#1487)
1 parent 8afc73f commit 936a3cf

2 files changed

Lines changed: 41 additions & 1 deletion

File tree

neuralforecast/common/_base_model.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -678,9 +678,20 @@ def on_validation_epoch_end(self):
678678
self.validation_step_outputs.clear() # free memory (compute `avg_loss` per epoch)
679679

680680
def save(self, path):
681+
import copy
682+
683+
# Strip callbacks from hparams before saving: callback objects are not
684+
# YAML-serializable, which causes PyTorch Lightning to raise a ValueError
685+
# during predict() on a loaded model. Callbacks can be re-attached after
686+
# loading via `model.trainer_kwargs["callbacks"] = [...]`.
687+
# Note: save_hyperparameters() stores **trainer_kwargs contents flat, so
688+
# `callbacks` is a top-level key in hparams, not nested under trainer_kwargs.
689+
hparams = copy.deepcopy(dict(self.hparams))
690+
if "callbacks" in hparams:
691+
del hparams["callbacks"]
681692
with fsspec.open(path, "wb") as f:
682693
torch.save(
683-
{"hyper_parameters": self.hparams, "state_dict": self.state_dict()},
694+
{"hyper_parameters": hparams, "state_dict": self.state_dict()},
684695
f,
685696
)
686697

tests/test_core.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1093,6 +1093,35 @@ def test_save_load_no_dataset(setup_airplane_data):
10931093
np.testing.assert_allclose(forecasts1["DilatedRNN"], forecasts2["DilatedRNN"])
10941094

10951095

1096+
def test_save_load_with_callbacks(setup_airplane_data, tmp_path):
1097+
"""Saving a model with trainer callbacks should not break predict after reload.
1098+
1099+
Custom callbacks are not YAML-serializable; without the fix this causes a
1100+
ValueError when PyTorch Lightning's logger tries to log hparams during predict.
1101+
"""
1102+
from pytorch_lightning.callbacks import Callback
1103+
1104+
class _NonYamlCallback(Callback):
1105+
# lambda attributes are not YAML-safe
1106+
fn = lambda self: None # noqa: E731
1107+
1108+
AirPassengersPanel_train, _ = setup_airplane_data
1109+
model = NHITS(
1110+
h=12,
1111+
input_size=24,
1112+
max_steps=10,
1113+
callbacks=[_NonYamlCallback()],
1114+
)
1115+
nf = NeuralForecast(models=[model], freq="M")
1116+
nf.fit(AirPassengersPanel_train)
1117+
nf.save(str(tmp_path))
1118+
1119+
nf2 = NeuralForecast.load(str(tmp_path))
1120+
# Should not raise a ValueError from YAML serialization
1121+
preds = nf2.predict(df=AirPassengersPanel_train)
1122+
assert preds is not None
1123+
1124+
10961125
def test_save_skips_nonzero_ddp_rank(monkeypatch, tmp_path):
10971126
"""Only rank 0 should write artifacts when DDP is initialized."""
10981127
dist = pytest.importorskip("torch.distributed")

0 commit comments

Comments
 (0)