File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -678,9 +678,20 @@ def on_validation_epoch_end(self):
678678 self .validation_step_outputs .clear () # free memory (compute `avg_loss` per epoch)
679679
680680 def save (self , path ):
681+ import copy
682+
683+ # Strip callbacks from hparams before saving: callback objects are not
684+ # YAML-serializable, which causes PyTorch Lightning to raise a ValueError
685+ # during predict() on a loaded model. Callbacks can be re-attached after
686+ # loading via `model.trainer_kwargs["callbacks"] = [...]`.
687+ # Note: save_hyperparameters() stores **trainer_kwargs contents flat, so
688+ # `callbacks` is a top-level key in hparams, not nested under trainer_kwargs.
689+ hparams = copy .deepcopy (dict (self .hparams ))
690+ if "callbacks" in hparams :
691+ del hparams ["callbacks" ]
681692 with fsspec .open (path , "wb" ) as f :
682693 torch .save (
683- {"hyper_parameters" : self . hparams , "state_dict" : self .state_dict ()},
694+ {"hyper_parameters" : hparams , "state_dict" : self .state_dict ()},
684695 f ,
685696 )
686697
Original file line number Diff line number Diff line change @@ -1093,6 +1093,35 @@ def test_save_load_no_dataset(setup_airplane_data):
10931093 np .testing .assert_allclose (forecasts1 ["DilatedRNN" ], forecasts2 ["DilatedRNN" ])
10941094
10951095
1096+ def test_save_load_with_callbacks (setup_airplane_data , tmp_path ):
1097+ """Saving a model with trainer callbacks should not break predict after reload.
1098+
1099+ Custom callbacks are not YAML-serializable; without the fix this causes a
1100+ ValueError when PyTorch Lightning's logger tries to log hparams during predict.
1101+ """
1102+ from pytorch_lightning .callbacks import Callback
1103+
1104+ class _NonYamlCallback (Callback ):
1105+ # lambda attributes are not YAML-safe
1106+ fn = lambda self : None # noqa: E731
1107+
1108+ AirPassengersPanel_train , _ = setup_airplane_data
1109+ model = NHITS (
1110+ h = 12 ,
1111+ input_size = 24 ,
1112+ max_steps = 10 ,
1113+ callbacks = [_NonYamlCallback ()],
1114+ )
1115+ nf = NeuralForecast (models = [model ], freq = "M" )
1116+ nf .fit (AirPassengersPanel_train )
1117+ nf .save (str (tmp_path ))
1118+
1119+ nf2 = NeuralForecast .load (str (tmp_path ))
1120+ # Should not raise a ValueError from YAML serialization
1121+ preds = nf2 .predict (df = AirPassengersPanel_train )
1122+ assert preds is not None
1123+
1124+
10961125def test_save_skips_nonzero_ddp_rank (monkeypatch , tmp_path ):
10971126 """Only rank 0 should write artifacts when DDP is initialized."""
10981127 dist = pytest .importorskip ("torch.distributed" )
You can’t perform that action at this time.
0 commit comments