Get optuna results #1418
lakshmitharun
started this conversation in
General
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
My code
"""
AutoTFT Hyperparameter Tuning for Chlorophyll Forecasting
With Selected Exogenous Oceanographic Variables (SST, SLA)
Author: Oceanographic ML Pipeline
Date: 2025-12-08
GPU: NVIDIA A40 (46GB VRAM)
"""
import os
import warnings
import pandas as pd
import numpy as np
import xarray as xr
import traceback
import optuna
from sklearn.metrics import mean_squared_error, mean_absolute_error
from scipy.stats import pearsonr
from neuralforecast import NeuralForecast
from neuralforecast.auto import AutoTFT
from neuralforecast.losses.pytorch import MAE
warnings.filterwarnings("ignore")
=====================================================================
1. CONFIGURATION
=====================================================================
base_dir = "/home/users/ltharun/data/chl/lstm"
TARGET_LAT = 15.0
TARGET_LON = 90.0
NUM_SAMPLES = 200
CPUS_PER_TRIAL = 4
GPUS_PER_TRIAL = 1
horizon = 3
eval_days = 365
paths = {
"CHL": os.path.join(base_dir, "chl.zarr"),
"SST": os.path.join(base_dir, "sst.zarr"),
"ssh": os.path.join(base_dir, "ssh.zarr"), # SLA comes from SSH
}
varnames = {
"CHL": "CHL",
"SST": "analysed_sst",
"sla": "sla",
}
=====================================================================
2. DATA LOADING
=====================================================================
def extract_point(ds_path, varname, lat, lon):
print(f" Loading: {varname} from {os.path.basename(ds_path)}")
ds = xr.open_zarr(ds_path, consolidated=True)
da = ds[varname].sel(latitude=lat, longitude=lon, method="nearest")
if "depth" in da.dims:
da = da.mean(dim="depth")
s = da.to_series()
s.index = pd.to_datetime(s.index)
s = s.sort_index()
return s
=====================================================================
3. LOAD & CLEAN DATA
=====================================================================
series = {}
for key in ["CHL", "SST", "sla"]:
series[key] = extract_point(
paths["CHL"] if key=="CHL" else paths["SST"] if key=="SST" else paths["ssh"],
varnames[key], TARGET_LAT, TARGET_LON
)
Align indices
common_idx = series["CHL"].index
for s in series.values():
common_idx = common_idx.intersection(s.index)
df = pd.DataFrame({k: s.reindex(common_idx) for k, s in series.items()})
Fill CHL NaNs and drop rows missing exogenous data
df["CHL"] = df["CHL"].ffill().bfill()
df = df.interpolate(method="time")
df = df.dropna(subset=["SST", "sla"])
=====================================================================
4. NEURALFORECAST FORMAT
=====================================================================
df = df.copy().reset_index(drop=False)
Ensure 'ds' column exists safely
if "index" in df.columns:
df = df.rename(columns={"index": "ds"})
elif "time" in df.columns:
df = df.rename(columns={"time": "ds"})
elif "date" in df.columns:
df = df.rename(columns={"date": "ds"})
else:
df["ds"] = df.index
df["ds"] = pd.to_datetime(df["ds"], errors="raise")
df["unique_id"] = "point"
df = df.rename(columns={"CHL": "y"})
print(f"Dataset shape: {df.shape}, Date range: {df['ds'].min()} to {df['ds'].max()}")
=====================================================================
5. SPLITS
=====================================================================
df_train_full = df.iloc[:-horizon].reset_index(drop=True)
df_forecast_input = df.iloc[-horizon:].drop(columns=["y"]).reset_index(drop=True)
if len(df_train_full) > eval_days:
df_train_eval = df_train_full.iloc[:-eval_days]
df_val = df_train_full.iloc[-eval_days:]
else:
df_train_eval = df_train_full
df_val = None
=====================================================================
6. METRICS
=====================================================================
def eval_metrics(y_true, y_pred):
mae = mean_absolute_error(y_true, y_pred)
rmse = np.sqrt(mean_squared_error(y_true, y_pred))
corr, _ = pearsonr(y_true, y_pred) if len(y_true) > 1 else (np.nan, None)
return {"MAE": mae, "RMSE": rmse, "PearsonCorr": corr}
=====================================================================
7. MODEL CONFIGURATION (AutoTFT)
=====================================================================
future_exog_cols = ["SST", "sla"]
def tft_config(trial):
return {
"input_size": trial.suggest_categorical("input_size", [7,14,28,56,84,128]),
"hidden_size": trial.suggest_categorical("hidden_size", [32, 64, 128, 256]),
"learning_rate": trial.suggest_categorical("learning_rate", [1e-3, 5e-4, 1e-4]),
"scaler_type": "standard",
"max_steps": trial.suggest_categorical("max_steps", [100,200,500,1000, 1500,2000]),
"futr_exog_list": future_exog_cols,
"stat_exog_list": []
}
auto_tft = AutoTFT(
h=horizon,
loss=MAE(),
config=tft_config,
backend='optuna',
search_alg=optuna.samplers.TPESampler(seed=42),
num_samples=NUM_SAMPLES,
cpus=CPUS_PER_TRIAL,
gpus=GPUS_PER_TRIAL,
verbose=True
)
=====================================================================
8. TRAINING + FORECAST
=====================================================================
results = {}
try:
nf = NeuralForecast(models=[auto_tft], freq='D')
print("Fitting AutoTFT with hyperparameter tuning...")
nf.fit(df=df_train_eval, val_size=30)
except Exception as e:
print(f"ERROR: {e}")
traceback.print_exc()
results["AutoTFT_error"] = str(e)
=====================================================================
9. FINAL RESULTS
=====================================================================
if "AutoTFT_val" in results:
print("\nValidation Metrics:")
print(results["AutoTFT_val"])
if "AutoTFT_forecast" in results:
print("\nForecast Values:")
print(results["AutoTFT_forecast"])
print("\nDone.")
Is there a way we can capture the results for optuna results like the MAE at each trial etc
Beta Was this translation helpful? Give feedback.
All reactions