Skip to content

Commit 7910386

Browse files
authored
[FEAT] Transfer learning with cross-validation (#1530)
1 parent 4f4ca4c commit 7910386

2 files changed

Lines changed: 344 additions & 86 deletions

File tree

neuralforecast/core.py

Lines changed: 168 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -497,7 +497,9 @@ def fit(
497497
`val_df` can be temporally independent (no requirement that it starts immediately after `df`).
498498
Cannot be used together with `val_size`. Only supported when `df` is a pandas or polars DataFrame.
499499
All series in `val_df` must have the same length.
500-
use_init_models (bool, optional): Use initial model passed when NeuralForecast object was instantiated.
500+
use_init_models (bool, optional): If True, discards any previously fitted weights
501+
and reinitializes the models from the configs passed at `NeuralForecast(__init__)`.
502+
Use this to start training from scratch. Defaults to False.
501503
verbose (bool): Print processing steps.
502504
id_col (str): Column that identifies each serie.
503505
time_col (str): Column that identifies each timestep, its values can be timestamps or integers.
@@ -1766,7 +1768,12 @@ def explain(
17661768
def _reset_models(self):
17671769
self.models = [deepcopy(model) for model in self.models_init]
17681770
if self._fitted:
1769-
print("WARNING: Deleting previously fitted models.")
1771+
warnings.warn(
1772+
"Deleting previously fitted models because `use_init_models=True` "
1773+
"was passed; fitted weights will be discarded and models reinitialized "
1774+
"from the configs given at `NeuralForecast(__init__)`.",
1775+
stacklevel=2,
1776+
)
17701777

17711778
def _no_refit_cross_validation(
17721779
self,
@@ -1776,6 +1783,7 @@ def _no_refit_cross_validation(
17761783
step_size: int,
17771784
val_size: Optional[int],
17781785
test_size: int,
1786+
use_fitted: bool,
17791787
verbose: bool,
17801788
id_col: str,
17811789
time_col: str,
@@ -1786,96 +1794,137 @@ def _no_refit_cross_validation(
17861794
if (df is None) and not (hasattr(self, "dataset")):
17871795
raise Exception("You must pass a DataFrame or have one stored.")
17881796

1789-
# Process and save new dataset (in self)
1790-
if df is not None:
1791-
validate_freq(df[time_col], self.freq)
1792-
self.dataset, self.uids, self.last_dates, self.ds = self._prepare_fit(
1793-
df=df,
1794-
static_df=static_df,
1795-
id_col=id_col,
1796-
time_col=time_col,
1797-
target_col=target_col,
1798-
)
1799-
else:
1800-
if verbose:
1801-
print("Using stored dataset.")
1802-
1803-
if val_size is not None:
1804-
if self.dataset.min_size < (val_size + test_size):
1805-
warnings.warn(
1806-
"Validation and test sets are larger than the shorter time-series."
1797+
# When use_fitted=True we evaluate the already-fitted model on a new
1798+
# holdout `df` without retraining.
1799+
restore_fitted_state = use_fitted and df is not None
1800+
_snapshot: Dict[str, object] = {}
1801+
if restore_fitted_state:
1802+
_snapshot = {
1803+
attr: getattr(self, attr)
1804+
for attr in (
1805+
"scalers_",
1806+
"static_scalers_",
1807+
"dataset",
1808+
"uids",
1809+
"last_dates",
1810+
"ds",
1811+
"id_col",
1812+
"time_col",
1813+
"target_col",
18071814
)
1815+
}
18081816

1809-
fcsts_df = ufp.cv_times(
1810-
times=self.ds,
1811-
uids=self.uids,
1812-
indptr=self.dataset.indptr,
1813-
h=h,
1814-
test_size=test_size,
1815-
step_size=step_size,
1816-
id_col=id_col,
1817-
time_col=time_col,
1818-
)
1819-
# the cv_times is sorted by window and then id
1820-
fcsts_df = ufp.sort(fcsts_df, [id_col, "cutoff", time_col])
1821-
1822-
fcsts_list: List = []
1823-
for model in self.models:
1824-
if self._add_level and (
1825-
model.loss.outputsize_multiplier > 1
1826-
or isinstance(model.loss, (IQLoss, HuberIQLoss))
1827-
):
1828-
continue
1817+
try:
1818+
# Process and save new dataset in self
1819+
if df is not None:
1820+
validate_freq(df[time_col], self.freq)
1821+
self.dataset, self.uids, self.last_dates, self.ds = (
1822+
self._prepare_fit(
1823+
df=df,
1824+
static_df=static_df,
1825+
id_col=id_col,
1826+
time_col=time_col,
1827+
target_col=target_col,
1828+
)
1829+
)
1830+
else:
1831+
if verbose:
1832+
print("Using stored dataset.")
18291833

1830-
model.fit(dataset=self.dataset, val_size=val_size, test_size=test_size)
1831-
model_fcsts = model.predict(
1832-
self.dataset, step_size=step_size, h=h, **data_kwargs
1833-
)
1834-
# Append predictions in memory placeholder
1835-
fcsts_list.append(model_fcsts)
1834+
if val_size is not None:
1835+
if self.dataset.min_size < (val_size + test_size):
1836+
warnings.warn(
1837+
"Validation and test sets are larger than the shorter time-series."
1838+
)
18361839

1837-
fcsts = np.concatenate(fcsts_list, axis=-1)
1838-
# we may have allocated more space than needed
1839-
# each serie can produce at most (serie.size - 1) // self.h CV windows
1840-
effective_sizes = ufp.counts_by_id(fcsts_df, id_col)["counts"].to_numpy()
1841-
needs_trim = effective_sizes.sum() != fcsts.shape[0]
1842-
if self.scalers_ or needs_trim:
1843-
indptr = np.arange(
1844-
0,
1845-
n_windows * h * (self.dataset.n_groups + 1),
1846-
n_windows * h,
1847-
dtype=np.int32,
1840+
fcsts_df = ufp.cv_times(
1841+
times=self.ds,
1842+
uids=self.uids,
1843+
indptr=self.dataset.indptr,
1844+
h=h,
1845+
test_size=test_size,
1846+
step_size=step_size,
1847+
id_col=id_col,
1848+
time_col=time_col,
18481849
)
1849-
if self.scalers_:
1850-
fcsts = self._scalers_target_inverse_transform(fcsts, indptr)
1851-
if needs_trim:
1852-
# we keep only the effective samples of each serie from the cv results
1853-
trimmed = np.empty_like(
1854-
fcsts, shape=(effective_sizes.sum(), fcsts.shape[1])
1855-
)
1856-
cv_indptr = np.append(0, effective_sizes).cumsum(dtype=np.int32)
1857-
for i in range(fcsts.shape[1]):
1858-
ga = GroupedArray(fcsts[:, i], indptr)
1859-
trimmed[:, i] = ga._tails(cv_indptr)
1860-
fcsts = trimmed
1861-
1862-
self._fitted = True
1850+
# the cv_times is sorted by window and then id
1851+
fcsts_df = ufp.sort(fcsts_df, [id_col, "cutoff", time_col])
1852+
1853+
fcsts_list: List = []
1854+
for model in self.models:
1855+
if self._add_level and (
1856+
model.loss.outputsize_multiplier > 1
1857+
or isinstance(model.loss, (IQLoss, HuberIQLoss))
1858+
):
1859+
continue
18631860

1864-
# Add predictions to forecasts DataFrame
1865-
cols = self._get_model_names(add_level=self._add_level)
1866-
if isinstance(self.uids, pl_Series):
1867-
fcsts = pl_DataFrame(dict(zip(cols, fcsts.T)))
1868-
else:
1869-
fcsts = pd.DataFrame(fcsts, columns=cols)
1870-
fcsts_df = ufp.horizontal_concat([fcsts_df, fcsts])
1861+
if use_fitted:
1862+
_saved_model_test_size = model.get_test_size()
1863+
model.set_test_size(test_size)
1864+
try:
1865+
model_fcsts = model.predict(
1866+
self.dataset, step_size=step_size, h=h, **data_kwargs
1867+
)
1868+
finally:
1869+
model.set_test_size(_saved_model_test_size)
1870+
else:
1871+
model.fit(
1872+
dataset=self.dataset,
1873+
val_size=val_size,
1874+
test_size=test_size,
1875+
)
1876+
model_fcsts = model.predict(
1877+
self.dataset, step_size=step_size, h=h, **data_kwargs
1878+
)
1879+
# Append predictions in memory placeholder
1880+
fcsts_list.append(model_fcsts)
18711881

1872-
# Add original input df's y to forecasts DataFrame
1873-
return ufp.join(
1874-
fcsts_df,
1875-
df[[id_col, time_col, target_col]],
1876-
how="left",
1877-
on=[id_col, time_col],
1878-
)
1882+
fcsts = np.concatenate(fcsts_list, axis=-1)
1883+
# we may have allocated more space than needed
1884+
# each serie can produce at most (serie.size - 1) // self.h CV windows
1885+
effective_sizes = ufp.counts_by_id(fcsts_df, id_col)["counts"].to_numpy()
1886+
needs_trim = effective_sizes.sum() != fcsts.shape[0]
1887+
if self.scalers_ or needs_trim:
1888+
indptr = np.arange(
1889+
0,
1890+
n_windows * h * (self.dataset.n_groups + 1),
1891+
n_windows * h,
1892+
dtype=np.int32,
1893+
)
1894+
if self.scalers_:
1895+
fcsts = self._scalers_target_inverse_transform(fcsts, indptr)
1896+
if needs_trim:
1897+
# we keep only the effective samples of each serie from the cv results
1898+
trimmed = np.empty_like(
1899+
fcsts, shape=(effective_sizes.sum(), fcsts.shape[1])
1900+
)
1901+
cv_indptr = np.append(0, effective_sizes).cumsum(dtype=np.int32)
1902+
for i in range(fcsts.shape[1]):
1903+
ga = GroupedArray(fcsts[:, i], indptr)
1904+
trimmed[:, i] = ga._tails(cv_indptr)
1905+
fcsts = trimmed
1906+
1907+
self._fitted = True
1908+
1909+
# Add predictions to forecasts DataFrame
1910+
cols = self._get_model_names(add_level=self._add_level)
1911+
if isinstance(self.uids, pl_Series):
1912+
fcsts = pl_DataFrame(dict(zip(cols, fcsts.T)))
1913+
else:
1914+
fcsts = pd.DataFrame(fcsts, columns=cols)
1915+
fcsts_df = ufp.horizontal_concat([fcsts_df, fcsts])
1916+
1917+
# Add original input df's y to forecasts DataFrame
1918+
return ufp.join(
1919+
fcsts_df,
1920+
df[[id_col, time_col, target_col]],
1921+
how="left",
1922+
on=[id_col, time_col],
1923+
)
1924+
finally:
1925+
if restore_fitted_state:
1926+
for attr, value in _snapshot.items():
1927+
setattr(self, attr, value)
18791928

18801929
def cross_validation(
18811930
self,
@@ -1886,6 +1935,7 @@ def cross_validation(
18861935
val_size: Optional[int] = 0,
18871936
test_size: Optional[int] = None,
18881937
use_init_models: bool = False,
1938+
use_fitted: bool = False,
18891939
verbose: bool = False,
18901940
refit: Union[bool, int] = False,
18911941
id_col: str = "unique_id",
@@ -1910,7 +1960,14 @@ def cross_validation(
19101960
step_size (int): Step size between each window.
19111961
val_size (int, optional): Length of validation size. If passed, set `n_windows=None`. Defaults to 0.
19121962
test_size (int, optional): Length of test size. If passed, set `n_windows=None`.
1913-
use_init_models (bool, optional): Use initial model passed when object was instantiated.
1963+
use_init_models (bool, optional): If True, discards any previously fitted weights
1964+
and reinitializes the models from the configs passed at `NeuralForecast(__init__)`.
1965+
Use this to start cross-validation from scratch. Defaults to False.
1966+
use_fitted (bool, optional): Evaluate the already-fitted model on `df` without retraining
1967+
(transfer-learning cross-validation). Requires a previous `fit` call, `refit=False`,
1968+
`use_init_models=False`, and `prediction_intervals=None`. Local scalers, if any, are
1969+
refit per series on `df` and the fitted state (model weights, stored dataset, scalers)
1970+
is restored after CV completes. Defaults to False.
19141971
verbose (bool): Print processing steps.
19151972
refit (bool or int): Retrain model for each cross validation window.
19161973
If False, the models are trained at the beginning and then used to predict each window.
@@ -1976,6 +2033,30 @@ def cross_validation(
19762033
assert n_windows is not None
19772034
assert test_size is not None
19782035

2036+
if use_fitted:
2037+
if not self._fitted:
2038+
raise ValueError(
2039+
"`use_fitted=True` requires a model previously fitted with `fit`."
2040+
)
2041+
if refit:
2042+
raise ValueError(
2043+
"`use_fitted=True` is only supported with `refit=False`."
2044+
)
2045+
if use_init_models:
2046+
raise ValueError(
2047+
"`use_fitted=True` cannot be combined with `use_init_models=True`; "
2048+
"`use_init_models` discards the fitted weights that `use_fitted` relies on."
2049+
)
2050+
if (
2051+
prediction_intervals is not None
2052+
or getattr(self, "prediction_intervals", None) is not None
2053+
):
2054+
raise ValueError(
2055+
"`use_fitted=True` is not supported with `prediction_intervals` "
2056+
"(calibration requires retraining). This applies whether the "
2057+
"intervals were passed here or configured during the prior `fit` call."
2058+
)
2059+
19792060
# Recover initial model if use_init_models.
19802061
if use_init_models:
19812062
self._reset_models()
@@ -2003,6 +2084,7 @@ def cross_validation(
20032084
step_size=step_size,
20042085
val_size=val_size,
20052086
test_size=test_size,
2087+
use_fitted=use_fitted,
20062088
verbose=verbose,
20072089
id_col=id_col,
20082090
time_col=time_col,

0 commit comments

Comments
 (0)