88from contextlib import contextmanager
99from copy import deepcopy
1010from dataclasses import dataclass
11- from typing import Dict , List , Union
11+ from typing import Dict , List , Optional , Union
1212
1313import fsspec
1414import numpy as np
@@ -295,6 +295,8 @@ def __init__(
295295 raise Exception ("max_epochs is deprecated, use max_steps instead." )
296296
297297 # Callbacks
298+ self ._early_stop_kwargs : Optional [dict ] = None
299+ self ._early_stop_cb : Optional [EarlyStopping ] = None
298300 if early_stop_patience_steps > 0 :
299301 valid_monitors = ["ptl/val_loss" , "valid_loss" , "train_loss" ]
300302 if val_monitor not in valid_monitors :
@@ -304,11 +306,12 @@ def __init__(
304306 )
305307 if "callbacks" not in trainer_kwargs :
306308 trainer_kwargs ["callbacks" ] = []
307- trainer_kwargs ["callbacks" ].append (
308- EarlyStopping (
309- monitor = val_monitor , patience = early_stop_patience_steps
310- )
311- )
309+ self ._early_stop_kwargs = {
310+ "monitor" : val_monitor ,
311+ "patience" : early_stop_patience_steps ,
312+ }
313+ self ._early_stop_cb = EarlyStopping (** self ._early_stop_kwargs )
314+ trainer_kwargs ["callbacks" ].append (self ._early_stop_cb )
312315
313316 # Add GPU accelerator if available
314317 if trainer_kwargs .get ("accelerator" , None ) is None :
@@ -604,6 +607,17 @@ def _fit(
604607 self .trainer_kwargs ["val_check_interval" ] = int (val_check_interval )
605608 self .trainer_kwargs ["check_val_every_n_epoch" ] = None
606609
610+ # The EarlyStopping callback we add in __init__ accumulates state across fits.
611+ # Rebuild it from the stored kwargs.
612+ # We match by identity, so user-supplied EarlyStopping callbacks are left untouched.
613+ if self ._early_stop_cb is not None and self ._early_stop_kwargs is not None :
614+ callbacks = self .trainer_kwargs .get ("callbacks" ) or []
615+ for i , cb in enumerate (callbacks ):
616+ if cb is self ._early_stop_cb :
617+ self ._early_stop_cb = EarlyStopping (** self ._early_stop_kwargs )
618+ callbacks [i ] = self ._early_stop_cb
619+ break
620+
607621 if is_local :
608622 model = self
609623 trainer = pl .Trainer (** model .trainer_kwargs )
0 commit comments