@@ -497,7 +497,9 @@ def fit(
497497 `val_df` can be temporally independent (no requirement that it starts immediately after `df`).
498498 Cannot be used together with `val_size`. Only supported when `df` is a pandas or polars DataFrame.
499499 All series in `val_df` must have the same length.
500- use_init_models (bool, optional): Use initial model passed when NeuralForecast object was instantiated.
500+ use_init_models (bool, optional): If True, discards any previously fitted weights
501+ and reinitializes the models from the configs passed at `NeuralForecast(__init__)`.
502+ Use this to start training from scratch. Defaults to False.
501503 verbose (bool): Print processing steps.
502504 id_col (str): Column that identifies each serie.
503505 time_col (str): Column that identifies each timestep, its values can be timestamps or integers.
@@ -1766,7 +1768,12 @@ def explain(
17661768 def _reset_models (self ):
17671769 self .models = [deepcopy (model ) for model in self .models_init ]
17681770 if self ._fitted :
1769- print ("WARNING: Deleting previously fitted models." )
1771+ warnings .warn (
1772+ "Deleting previously fitted models because `use_init_models=True` "
1773+ "was passed; fitted weights will be discarded and models reinitialized "
1774+ "from the configs given at `NeuralForecast(__init__)`." ,
1775+ stacklevel = 2 ,
1776+ )
17701777
17711778 def _no_refit_cross_validation (
17721779 self ,
@@ -1776,6 +1783,7 @@ def _no_refit_cross_validation(
17761783 step_size : int ,
17771784 val_size : Optional [int ],
17781785 test_size : int ,
1786+ use_fitted : bool ,
17791787 verbose : bool ,
17801788 id_col : str ,
17811789 time_col : str ,
@@ -1786,96 +1794,137 @@ def _no_refit_cross_validation(
17861794 if (df is None ) and not (hasattr (self , "dataset" )):
17871795 raise Exception ("You must pass a DataFrame or have one stored." )
17881796
1789- # Process and save new dataset (in self)
1790- if df is not None :
1791- validate_freq (df [time_col ], self .freq )
1792- self .dataset , self .uids , self .last_dates , self .ds = self ._prepare_fit (
1793- df = df ,
1794- static_df = static_df ,
1795- id_col = id_col ,
1796- time_col = time_col ,
1797- target_col = target_col ,
1798- )
1799- else :
1800- if verbose :
1801- print ("Using stored dataset." )
1802-
1803- if val_size is not None :
1804- if self .dataset .min_size < (val_size + test_size ):
1805- warnings .warn (
1806- "Validation and test sets are larger than the shorter time-series."
1797+ # When use_fitted=True we evaluate the already-fitted model on a new
1798+ # holdout `df` without retraining.
1799+ restore_fitted_state = use_fitted and df is not None
1800+ _snapshot : Dict [str , object ] = {}
1801+ if restore_fitted_state :
1802+ _snapshot = {
1803+ attr : getattr (self , attr )
1804+ for attr in (
1805+ "scalers_" ,
1806+ "static_scalers_" ,
1807+ "dataset" ,
1808+ "uids" ,
1809+ "last_dates" ,
1810+ "ds" ,
1811+ "id_col" ,
1812+ "time_col" ,
1813+ "target_col" ,
18071814 )
1815+ }
18081816
1809- fcsts_df = ufp .cv_times (
1810- times = self .ds ,
1811- uids = self .uids ,
1812- indptr = self .dataset .indptr ,
1813- h = h ,
1814- test_size = test_size ,
1815- step_size = step_size ,
1816- id_col = id_col ,
1817- time_col = time_col ,
1818- )
1819- # the cv_times is sorted by window and then id
1820- fcsts_df = ufp .sort (fcsts_df , [id_col , "cutoff" , time_col ])
1821-
1822- fcsts_list : List = []
1823- for model in self .models :
1824- if self ._add_level and (
1825- model .loss .outputsize_multiplier > 1
1826- or isinstance (model .loss , (IQLoss , HuberIQLoss ))
1827- ):
1828- continue
1817+ try :
1818+ # Process and save new dataset in self
1819+ if df is not None :
1820+ validate_freq (df [time_col ], self .freq )
1821+ self .dataset , self .uids , self .last_dates , self .ds = (
1822+ self ._prepare_fit (
1823+ df = df ,
1824+ static_df = static_df ,
1825+ id_col = id_col ,
1826+ time_col = time_col ,
1827+ target_col = target_col ,
1828+ )
1829+ )
1830+ else :
1831+ if verbose :
1832+ print ("Using stored dataset." )
18291833
1830- model .fit (dataset = self .dataset , val_size = val_size , test_size = test_size )
1831- model_fcsts = model .predict (
1832- self .dataset , step_size = step_size , h = h , ** data_kwargs
1833- )
1834- # Append predictions in memory placeholder
1835- fcsts_list .append (model_fcsts )
1834+ if val_size is not None :
1835+ if self .dataset .min_size < (val_size + test_size ):
1836+ warnings .warn (
1837+ "Validation and test sets are larger than the shorter time-series."
1838+ )
18361839
1837- fcsts = np .concatenate (fcsts_list , axis = - 1 )
1838- # we may have allocated more space than needed
1839- # each serie can produce at most (serie.size - 1) // self.h CV windows
1840- effective_sizes = ufp .counts_by_id (fcsts_df , id_col )["counts" ].to_numpy ()
1841- needs_trim = effective_sizes .sum () != fcsts .shape [0 ]
1842- if self .scalers_ or needs_trim :
1843- indptr = np .arange (
1844- 0 ,
1845- n_windows * h * (self .dataset .n_groups + 1 ),
1846- n_windows * h ,
1847- dtype = np .int32 ,
1840+ fcsts_df = ufp .cv_times (
1841+ times = self .ds ,
1842+ uids = self .uids ,
1843+ indptr = self .dataset .indptr ,
1844+ h = h ,
1845+ test_size = test_size ,
1846+ step_size = step_size ,
1847+ id_col = id_col ,
1848+ time_col = time_col ,
18481849 )
1849- if self .scalers_ :
1850- fcsts = self ._scalers_target_inverse_transform (fcsts , indptr )
1851- if needs_trim :
1852- # we keep only the effective samples of each serie from the cv results
1853- trimmed = np .empty_like (
1854- fcsts , shape = (effective_sizes .sum (), fcsts .shape [1 ])
1855- )
1856- cv_indptr = np .append (0 , effective_sizes ).cumsum (dtype = np .int32 )
1857- for i in range (fcsts .shape [1 ]):
1858- ga = GroupedArray (fcsts [:, i ], indptr )
1859- trimmed [:, i ] = ga ._tails (cv_indptr )
1860- fcsts = trimmed
1861-
1862- self ._fitted = True
1850+ # the cv_times is sorted by window and then id
1851+ fcsts_df = ufp .sort (fcsts_df , [id_col , "cutoff" , time_col ])
1852+
1853+ fcsts_list : List = []
1854+ for model in self .models :
1855+ if self ._add_level and (
1856+ model .loss .outputsize_multiplier > 1
1857+ or isinstance (model .loss , (IQLoss , HuberIQLoss ))
1858+ ):
1859+ continue
18631860
1864- # Add predictions to forecasts DataFrame
1865- cols = self ._get_model_names (add_level = self ._add_level )
1866- if isinstance (self .uids , pl_Series ):
1867- fcsts = pl_DataFrame (dict (zip (cols , fcsts .T )))
1868- else :
1869- fcsts = pd .DataFrame (fcsts , columns = cols )
1870- fcsts_df = ufp .horizontal_concat ([fcsts_df , fcsts ])
1861+ if use_fitted :
1862+ _saved_model_test_size = model .get_test_size ()
1863+ model .set_test_size (test_size )
1864+ try :
1865+ model_fcsts = model .predict (
1866+ self .dataset , step_size = step_size , h = h , ** data_kwargs
1867+ )
1868+ finally :
1869+ model .set_test_size (_saved_model_test_size )
1870+ else :
1871+ model .fit (
1872+ dataset = self .dataset ,
1873+ val_size = val_size ,
1874+ test_size = test_size ,
1875+ )
1876+ model_fcsts = model .predict (
1877+ self .dataset , step_size = step_size , h = h , ** data_kwargs
1878+ )
1879+ # Append predictions in memory placeholder
1880+ fcsts_list .append (model_fcsts )
18711881
1872- # Add original input df's y to forecasts DataFrame
1873- return ufp .join (
1874- fcsts_df ,
1875- df [[id_col , time_col , target_col ]],
1876- how = "left" ,
1877- on = [id_col , time_col ],
1878- )
1882+ fcsts = np .concatenate (fcsts_list , axis = - 1 )
1883+ # we may have allocated more space than needed
1884+ # each serie can produce at most (serie.size - 1) // self.h CV windows
1885+ effective_sizes = ufp .counts_by_id (fcsts_df , id_col )["counts" ].to_numpy ()
1886+ needs_trim = effective_sizes .sum () != fcsts .shape [0 ]
1887+ if self .scalers_ or needs_trim :
1888+ indptr = np .arange (
1889+ 0 ,
1890+ n_windows * h * (self .dataset .n_groups + 1 ),
1891+ n_windows * h ,
1892+ dtype = np .int32 ,
1893+ )
1894+ if self .scalers_ :
1895+ fcsts = self ._scalers_target_inverse_transform (fcsts , indptr )
1896+ if needs_trim :
1897+ # we keep only the effective samples of each serie from the cv results
1898+ trimmed = np .empty_like (
1899+ fcsts , shape = (effective_sizes .sum (), fcsts .shape [1 ])
1900+ )
1901+ cv_indptr = np .append (0 , effective_sizes ).cumsum (dtype = np .int32 )
1902+ for i in range (fcsts .shape [1 ]):
1903+ ga = GroupedArray (fcsts [:, i ], indptr )
1904+ trimmed [:, i ] = ga ._tails (cv_indptr )
1905+ fcsts = trimmed
1906+
1907+ self ._fitted = True
1908+
1909+ # Add predictions to forecasts DataFrame
1910+ cols = self ._get_model_names (add_level = self ._add_level )
1911+ if isinstance (self .uids , pl_Series ):
1912+ fcsts = pl_DataFrame (dict (zip (cols , fcsts .T )))
1913+ else :
1914+ fcsts = pd .DataFrame (fcsts , columns = cols )
1915+ fcsts_df = ufp .horizontal_concat ([fcsts_df , fcsts ])
1916+
1917+ # Add original input df's y to forecasts DataFrame
1918+ return ufp .join (
1919+ fcsts_df ,
1920+ df [[id_col , time_col , target_col ]],
1921+ how = "left" ,
1922+ on = [id_col , time_col ],
1923+ )
1924+ finally :
1925+ if restore_fitted_state :
1926+ for attr , value in _snapshot .items ():
1927+ setattr (self , attr , value )
18791928
18801929 def cross_validation (
18811930 self ,
@@ -1886,6 +1935,7 @@ def cross_validation(
18861935 val_size : Optional [int ] = 0 ,
18871936 test_size : Optional [int ] = None ,
18881937 use_init_models : bool = False ,
1938+ use_fitted : bool = False ,
18891939 verbose : bool = False ,
18901940 refit : Union [bool , int ] = False ,
18911941 id_col : str = "unique_id" ,
@@ -1910,7 +1960,14 @@ def cross_validation(
19101960 step_size (int): Step size between each window.
19111961 val_size (int, optional): Length of validation size. If passed, set `n_windows=None`. Defaults to 0.
19121962 test_size (int, optional): Length of test size. If passed, set `n_windows=None`.
1913- use_init_models (bool, optional): Use initial model passed when object was instantiated.
1963+ use_init_models (bool, optional): If True, discards any previously fitted weights
1964+ and reinitializes the models from the configs passed at `NeuralForecast(__init__)`.
1965+ Use this to start cross-validation from scratch. Defaults to False.
1966+ use_fitted (bool, optional): Evaluate the already-fitted model on `df` without retraining
1967+ (transfer-learning cross-validation). Requires a previous `fit` call, `refit=False`,
1968+ `use_init_models=False`, and `prediction_intervals=None`. Local scalers, if any, are
1969+ refit per series on `df` and the fitted state (model weights, stored dataset, scalers)
1970+ is restored after CV completes. Defaults to False.
19141971 verbose (bool): Print processing steps.
19151972 refit (bool or int): Retrain model for each cross validation window.
19161973 If False, the models are trained at the beginning and then used to predict each window.
@@ -1976,6 +2033,30 @@ def cross_validation(
19762033 assert n_windows is not None
19772034 assert test_size is not None
19782035
2036+ if use_fitted :
2037+ if not self ._fitted :
2038+ raise ValueError (
2039+ "`use_fitted=True` requires a model previously fitted with `fit`."
2040+ )
2041+ if refit :
2042+ raise ValueError (
2043+ "`use_fitted=True` is only supported with `refit=False`."
2044+ )
2045+ if use_init_models :
2046+ raise ValueError (
2047+ "`use_fitted=True` cannot be combined with `use_init_models=True`; "
2048+ "`use_init_models` discards the fitted weights that `use_fitted` relies on."
2049+ )
2050+ if (
2051+ prediction_intervals is not None
2052+ or getattr (self , "prediction_intervals" , None ) is not None
2053+ ):
2054+ raise ValueError (
2055+ "`use_fitted=True` is not supported with `prediction_intervals` "
2056+ "(calibration requires retraining). This applies whether the "
2057+ "intervals were passed here or configured during the prior `fit` call."
2058+ )
2059+
19792060 # Recover initial model if use_init_models.
19802061 if use_init_models :
19812062 self ._reset_models ()
@@ -2003,6 +2084,7 @@ def cross_validation(
20032084 step_size = step_size ,
20042085 val_size = val_size ,
20052086 test_size = test_size ,
2087+ use_fitted = use_fitted ,
20062088 verbose = verbose ,
20072089 id_col = id_col ,
20082090 time_col = time_col ,
0 commit comments