@@ -118,6 +118,7 @@ def __init__(
118118 step_size : int = 1 ,
119119 num_lr_decays : int = 0 ,
120120 early_stop_patience_steps : int = - 1 ,
121+ val_monitor : str = "ptl/val_loss" ,
121122 scaler_type : str = "identity" ,
122123 futr_exog_list : Union [List , None ] = None ,
123124 hist_exog_list : Union [List , None ] = None ,
@@ -295,11 +296,17 @@ def __init__(
295296
296297 # Callbacks
297298 if early_stop_patience_steps > 0 :
299+ valid_monitors = ["ptl/val_loss" , "valid_loss" , "train_loss" ]
300+ if val_monitor not in valid_monitors :
301+ raise ValueError (
302+ f"val_monitor='{ val_monitor } ' is not supported. "
303+ f"Valid options are: { valid_monitors } ."
304+ )
298305 if "callbacks" not in trainer_kwargs :
299306 trainer_kwargs ["callbacks" ] = []
300307 trainer_kwargs ["callbacks" ].append (
301308 EarlyStopping (
302- monitor = "ptl/val_loss" , patience = early_stop_patience_steps
309+ monitor = val_monitor , patience = early_stop_patience_steps
303310 )
304311 )
305312
@@ -398,6 +405,7 @@ def __init__(
398405 max (max_steps // self .num_lr_decays , 1 ) if self .num_lr_decays > 0 else 10e7
399406 )
400407 self .early_stop_patience_steps = early_stop_patience_steps
408+ self .val_monitor = val_monitor
401409 self .val_check_steps = val_check_steps
402410 self .windows_batch_size = windows_batch_size
403411 self .step_size = step_size
0 commit comments