Skip to content

marco-hening-tallarico/solar-flare-tail-transformer

Repository files navigation

Solar Flare Precursor Modeling

Solar Flare is a machine learning pipeline for forecasting X-class solar flares from SDO/HMI SHARP magnetogram time series.

The project treats flare forecasting as a rare-event classification problem. A transformer encoder reads sliding windows of SHARP features and predicts both flare occurrence and extreme-tail behavior, combining binary classification with Generalized Pareto Distribution modeling for peaks-over-threshold exceedances.

What the model does

The model has two prediction heads:

  1. Class logits — predicts whether an X-class flare will occur in the forecast window.
  2. GPD parameters — estimates Generalized Pareto tail parameters for extreme exceedances.

Training combines cross-entropy with GPD negative log-likelihood. This lets the model learn both event occurrence and the behavior of rare high-magnitude signals.

Data

The live data path uses NASA JSOC/DRMS SHARP records from:

hmi.sharp_720s

The default SHARP features are:

TOTUSJH
TOTUSJZ
R_VALUE
USFLUX
AREA_ACR

Labels are built from GOES/HEK flare events. A window is positive when an X-class flare peak occurs in:

(window_end + 24h, window_end + 72h]

Each example uses a 72-hour SHARP context window. At the native 720-second cadence, this gives 360 time steps per window.

Patch-wise train/validation/test splits are done by HARPNAME to avoid leakage across active-region patches.

Synthetic data is included for local smoke tests without JSOC credentials or network access.

Project layout

.
├── data_loader.py              # SHARP/DRMS loading, labels, windows, splits, synthetic data
├── model.py                    # Flax transformer model and prediction heads
├── loss.py                     # Cross-entropy and GPD tail losses
├── train.py                    # Training loop, validation, metrics, checkpoint output
├── inference.py                # Streaming inference and extreme-alert utilities
├── visualize.py                # Time series, tail-density, and attention plots
├── notebooks/
│   └── colab_train.ipynb       # Colab workflow for GPU training
├── tests/                      # Unit tests
├── pyproject.toml              # Dependencies and pytest config
└── README.md

Installation

Use Python 3.10 or newer.

python3 -m venv .venv
source .venv/bin/activate
pip install -e .

Editable install pulls JAX, Flax, Optax, DRMS, SunPy, scikit-learn, matplotlib, seaborn, scipy, and pytest from pyproject.toml.

Running tests

source .venv/bin/activate
PYTHONPATH=. pytest tests/

Or without activating the venv:

PYTHONPATH=. python3 -m pytest tests/

Synthetic smoke test

No JSOC credentials or network required.

Training smoke run (patch-wise split on synthetic HARP IDs, prints validation TSS/HSS each epoch):

PYTHONPATH=. python train.py

Diagnostic figures (precursor series, tail density, attention heatmap):

PYTHONPATH=. python visualize.py

PNG files are written under figures/.

Programmatic check (same path as train.py):

from train import TrainConfig, run_training_synthetic

cfg = TrainConfig(epochs=3, metrics_csv="runs/smoke/metrics.csv", checkpoint_dir="runs/smoke/ckpt")
run_training_synthetic(cfg)

Real DRMS data

Real SHARP data requires JSOC/DRMS access and a registered JSOC email.

export JSOC_EMAIL="your-email@example.com"

The data pipeline uses SharpSegmentSpec entries to define HARP/time-range segments:

from data_loader import SharpSegmentSpec, build_labeled_arrays_from_drms, dataset_class_balance
from train import TrainConfig, run_training_from_arrays

segments = [
    SharpSegmentSpec("harp.4698", "2014.10.20_00:00:00_TAI", "2014.10.28_00:00:00_TAI"),
    # Add more segments with distinct harpname values for train/val/test splits.
]
x, y, exceed, harps, evt = build_labeled_arrays_from_drms(segments)
print(dataset_class_balance(y))

cfg = TrainConfig(
    epochs=50,
    metrics_csv="runs/metrics.csv",
    checkpoint_dir="runs/ckpt",
)
run_training_from_arrays(x, y, exceed, evt, harps, cfg)

Use several distinct harpname values across segments. A single HARP cannot support a leakage-safe patch-wise train/validation/test split (patch_wise_train_val_test_split needs enough unique HARP groups).

Colab

Use notebooks/colab_train.ipynb for GPU training.

Before running the notebook:

  1. Sync the updated repository to Google Drive.
  2. Set JSOC_EMAIL in Colab Secrets.
  3. Restart the runtime after Drive sync.
  4. Run the synthetic smoke section first.
  5. Set RUN_REAL_DRMS = True and add multiple SharpSegmentSpec entries before DRMS training.

The notebook has no saved outputs.

Metrics

Implemented in train.py:

Metric Description
TSS True Skill Statistic: sensitivity + specificity − 1
HSS Heidke Skill Score from the confusion matrix
Validation threshold Grid search (default 201 points on [0, 1]) maximizing TSS on the validation set
Class balance dataset_class_balance() — counts and positive rate
Loss alpha * CE + (1 - alpha) * GPD_NLL (default alpha=0.5)

Final test metrics use the validation-tuned threshold (no threshold tuning on the test set).

Outputs

Path Contents
runs/ Checkpoints (flax.training.checkpoints) and optional metrics.csv per epoch
figures/ precursor_timeseries.png, tail_density.png, attention_heatmap.png from visualize.py

Training prints epoch validation TSS, HSS, and threshold to stdout. When TrainConfig.metrics_csv is set, each epoch appends epoch, val_tss, val_hss, val_threshold, val_tss_grid.

Limitations

  • Synthetic data is for pipeline checks, not scientific validation.
  • Real model quality depends on HARP selection, time coverage, class balance, and label quality.
  • A single HARP cannot support a real patch-wise train/validation/test split.
  • DRMS availability and JSOC credentials are external requirements.
  • X-class flare forecasting is highly imbalanced; accuracy alone is not a sufficient metric.

Git hygiene

Generated run outputs, local environments, caches, notebook checkpoints, and editor metadata should not be committed. See .gitignore for runs/, figures/, .venv/, .pytest_cache/, .cursor/, and related paths.

About

Tail-aware JAX/Flax transformer for X-class solar flare forecasting from SHARP.

Topics

Resources

Contributing

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors