Solar Flare is a machine learning pipeline for forecasting X-class solar flares from SDO/HMI SHARP magnetogram time series.
The project treats flare forecasting as a rare-event classification problem. A transformer encoder reads sliding windows of SHARP features and predicts both flare occurrence and extreme-tail behavior, combining binary classification with Generalized Pareto Distribution modeling for peaks-over-threshold exceedances.
The model has two prediction heads:
- Class logits — predicts whether an X-class flare will occur in the forecast window.
- GPD parameters — estimates Generalized Pareto tail parameters for extreme exceedances.
Training combines cross-entropy with GPD negative log-likelihood. This lets the model learn both event occurrence and the behavior of rare high-magnitude signals.
The live data path uses NASA JSOC/DRMS SHARP records from:
hmi.sharp_720s
The default SHARP features are:
TOTUSJH
TOTUSJZ
R_VALUE
USFLUX
AREA_ACR
Labels are built from GOES/HEK flare events. A window is positive when an X-class flare peak occurs in:
(window_end + 24h, window_end + 72h]
Each example uses a 72-hour SHARP context window. At the native 720-second cadence, this gives 360 time steps per window.
Patch-wise train/validation/test splits are done by HARPNAME to avoid leakage across active-region patches.
Synthetic data is included for local smoke tests without JSOC credentials or network access.
.
├── data_loader.py # SHARP/DRMS loading, labels, windows, splits, synthetic data
├── model.py # Flax transformer model and prediction heads
├── loss.py # Cross-entropy and GPD tail losses
├── train.py # Training loop, validation, metrics, checkpoint output
├── inference.py # Streaming inference and extreme-alert utilities
├── visualize.py # Time series, tail-density, and attention plots
├── notebooks/
│ └── colab_train.ipynb # Colab workflow for GPU training
├── tests/ # Unit tests
├── pyproject.toml # Dependencies and pytest config
└── README.md
Use Python 3.10 or newer.
python3 -m venv .venv
source .venv/bin/activate
pip install -e .Editable install pulls JAX, Flax, Optax, DRMS, SunPy, scikit-learn, matplotlib, seaborn, scipy, and pytest from pyproject.toml.
source .venv/bin/activate
PYTHONPATH=. pytest tests/Or without activating the venv:
PYTHONPATH=. python3 -m pytest tests/No JSOC credentials or network required.
Training smoke run (patch-wise split on synthetic HARP IDs, prints validation TSS/HSS each epoch):
PYTHONPATH=. python train.pyDiagnostic figures (precursor series, tail density, attention heatmap):
PYTHONPATH=. python visualize.pyPNG files are written under figures/.
Programmatic check (same path as train.py):
from train import TrainConfig, run_training_synthetic
cfg = TrainConfig(epochs=3, metrics_csv="runs/smoke/metrics.csv", checkpoint_dir="runs/smoke/ckpt")
run_training_synthetic(cfg)Real SHARP data requires JSOC/DRMS access and a registered JSOC email.
export JSOC_EMAIL="your-email@example.com"The data pipeline uses SharpSegmentSpec entries to define HARP/time-range segments:
from data_loader import SharpSegmentSpec, build_labeled_arrays_from_drms, dataset_class_balance
from train import TrainConfig, run_training_from_arrays
segments = [
SharpSegmentSpec("harp.4698", "2014.10.20_00:00:00_TAI", "2014.10.28_00:00:00_TAI"),
# Add more segments with distinct harpname values for train/val/test splits.
]
x, y, exceed, harps, evt = build_labeled_arrays_from_drms(segments)
print(dataset_class_balance(y))
cfg = TrainConfig(
epochs=50,
metrics_csv="runs/metrics.csv",
checkpoint_dir="runs/ckpt",
)
run_training_from_arrays(x, y, exceed, evt, harps, cfg)Use several distinct harpname values across segments. A single HARP cannot support a leakage-safe patch-wise train/validation/test split (patch_wise_train_val_test_split needs enough unique HARP groups).
Use notebooks/colab_train.ipynb for GPU training.
Before running the notebook:
- Sync the updated repository to Google Drive.
- Set
JSOC_EMAILin Colab Secrets. - Restart the runtime after Drive sync.
- Run the synthetic smoke section first.
- Set
RUN_REAL_DRMS = Trueand add multipleSharpSegmentSpecentries before DRMS training.
The notebook has no saved outputs.
Implemented in train.py:
| Metric | Description |
|---|---|
| TSS | True Skill Statistic: sensitivity + specificity − 1 |
| HSS | Heidke Skill Score from the confusion matrix |
| Validation threshold | Grid search (default 201 points on [0, 1]) maximizing TSS on the validation set |
| Class balance | dataset_class_balance() — counts and positive rate |
| Loss | alpha * CE + (1 - alpha) * GPD_NLL (default alpha=0.5) |
Final test metrics use the validation-tuned threshold (no threshold tuning on the test set).
| Path | Contents |
|---|---|
runs/ |
Checkpoints (flax.training.checkpoints) and optional metrics.csv per epoch |
figures/ |
precursor_timeseries.png, tail_density.png, attention_heatmap.png from visualize.py |
Training prints epoch validation TSS, HSS, and threshold to stdout. When TrainConfig.metrics_csv is set, each epoch appends epoch, val_tss, val_hss, val_threshold, val_tss_grid.
- Synthetic data is for pipeline checks, not scientific validation.
- Real model quality depends on HARP selection, time coverage, class balance, and label quality.
- A single HARP cannot support a real patch-wise train/validation/test split.
- DRMS availability and JSOC credentials are external requirements.
- X-class flare forecasting is highly imbalanced; accuracy alone is not a sufficient metric.
Generated run outputs, local environments, caches, notebook checkpoints, and editor metadata should not be committed. See .gitignore for runs/, figures/, .venv/, .pytest_cache/, .cursor/, and related paths.