Skip to content

Commit b169878

Browse files
committed
Add LOO integration tests for MCMC results
Add several integration tests covering leave-one-out (LOO) functionality for MCMC results: mechanistic point/curve LOO, ArviZ consistency check, surrogate-mode LOO (including one-step Euler and reuse-of-noise variants), compare() between fits, and LOO plotting/pointwise outputs. Also import numpy for array checks and set matplotlib to Agg in plot tests. Tests use small MCMC configs and a trained NeuralODE surrogate where applicable to validate elpd, pareto_k shapes, and n_data_points behavior.
1 parent 6046f59 commit b169878

1 file changed

Lines changed: 111 additions & 0 deletions

File tree

tests/integration/test_mcmc.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import random
22

33
import jax.numpy as jnp
4+
import numpy as np
45
import optax
56

67
import catalax.mcmc as cmc
@@ -77,6 +78,116 @@ def test_surrogate_model(self, generate_data):
7778
yerrs=1e-5,
7879
)
7980

81+
def test_loo_mechanistic(self, generate_data):
82+
"""Mechanistic LOO returns a valid ELPDData over concentration points."""
83+
model, dataset = generate_data
84+
config = cmm.MCMCConfig(num_warmup=50, num_samples=100, verbose=0)
85+
results = cmm.run_mcmc(model=model, dataset=dataset, config=config, yerrs=1e-2)
86+
87+
# Reusing the inferred noise keeps every (measurement, time, obs) point.
88+
loo_point = results.loo(dataset, leave_out="point")
89+
n_obs = int(loo_point.n_data_points)
90+
assert (
91+
n_obs == dataset.to_jax_arrays(model.get_observable_state_order())[0].size
92+
)
93+
94+
# Leave-one-curve-out collapses each measurement series to one unit.
95+
loo_curve = results.loo(dataset, leave_out="curve")
96+
assert int(loo_curve.n_data_points) == len(dataset.measurements)
97+
98+
def test_loo_consistency_check(self, generate_data):
99+
"""Eval-model reconstruction must match ArviZ native LOO (mechanistic)."""
100+
model, dataset = generate_data
101+
config = cmm.MCMCConfig(num_warmup=50, num_samples=100, num_chains=2, verbose=0)
102+
results = cmm.run_mcmc(model=model, dataset=dataset, config=config, yerrs=1e-2)
103+
104+
check = results.loo_consistency_check(dataset, yerrs=1e-2)
105+
assert check["agree"], check
106+
107+
def test_loo_surrogate(self, generate_data):
108+
"""Surrogate-mode posterior still yields concentration-space LOO."""
109+
model, dataset = generate_data
110+
aug = dataset.augment(n_augmentations=10)
111+
112+
rbf = ctn.RBFLayer(0.2)
113+
neural_ode = ctn.NeuralODE.from_model(
114+
model,
115+
width_size=8,
116+
depth=2,
117+
activation=rbf, # type: ignore
118+
)
119+
strategy = ctn.Strategy()
120+
strategy.add_step(
121+
lr=1e-2, length=0.1, steps=100, batch_size=15, loss=optax.log_cosh
122+
)
123+
neural_ode = ctn.train_neural_ode(
124+
model=neural_ode,
125+
dataset=aug,
126+
strategy=strategy,
127+
print_every=1000,
128+
weight_scale=1e-7,
129+
)
130+
131+
config = cmm.MCMCConfig(num_warmup=50, num_samples=100, verbose=0)
132+
results = cmm.run_mcmc(
133+
model=model,
134+
dataset=aug,
135+
config=config,
136+
surrogate=neural_ode,
137+
yerrs=1e-2,
138+
)
139+
140+
# Reuse the sampled rates, Euler-integrate, and score against the
141+
# *measured* concentrations -- not the surrogate rates. The stored yerrs
142+
# is rate-space for a surrogate fit, so pass a concentration-space one.
143+
loo_res = results.loo(dataset, yerrs=0.5)
144+
assert int(loo_res.n_data_points) > 0
145+
# One Pareto-k per held-out data point (the headline diagnostic).
146+
assert np.asarray(loo_res.pareto_k).shape[0] == int(loo_res.n_data_points)
147+
148+
# One-step-ahead integration is also available.
149+
loo_onestep = results.loo(dataset, yerrs=0.5, integration="euler_onestep")
150+
assert int(loo_onestep.n_data_points) > 0
151+
152+
# The reuse-the-inferred-noise variant is also still available.
153+
loo_reuse = results.loo(dataset, sigma_source="reuse")
154+
assert int(loo_reuse.n_data_points) > 0
155+
156+
def test_loo_compare(self, generate_data):
157+
"""compare() ranks two fits on the same concentration-space footing."""
158+
model, dataset = generate_data
159+
config = cmm.MCMCConfig(num_warmup=50, num_samples=100, verbose=0)
160+
161+
res_a = cmm.run_mcmc(model=model, dataset=dataset, config=config, yerrs=1e-2)
162+
res_b = cmm.run_mcmc(model=model, dataset=dataset, config=config, yerrs=1e-2)
163+
164+
table = res_a.compare({"other": res_b}, dataset)
165+
assert set(table.index) == {"self", "other"}
166+
167+
def test_loo_plots(self, generate_data):
168+
"""Pointwise mapping and both LOO diagnostic plots render."""
169+
import matplotlib
170+
171+
matplotlib.use("Agg")
172+
173+
model, dataset = generate_data
174+
config = cmm.MCMCConfig(num_warmup=50, num_samples=100, verbose=0)
175+
results = cmm.run_mcmc(model=model, dataset=dataset, config=config, yerrs=1e-2)
176+
177+
pw = results.loo_pointwise(dataset, yerrs=0.5)
178+
n_meas = len(dataset.measurements)
179+
n_obs = len(model.get_observable_state_order())
180+
assert pw.elpd.shape[0] == n_meas
181+
assert pw.elpd.shape[2] == n_obs
182+
assert pw.pareto_k.shape == pw.elpd.shape
183+
184+
# Influence overlay (marker size = influence) and both heatmaps.
185+
assert results.plot_loo_influence(dataset, yerrs=0.5) is not None
186+
assert results.plot_loo_heatmap(dataset, metric="elpd", yerrs=0.5) is not None
187+
assert (
188+
results.plot_loo_heatmap(dataset, metric="pareto_k", yerrs=0.5) is not None
189+
)
190+
80191
def test_initial_estimator(self):
81192
# Create a simple Michaelis-Menten model
82193
model = Model(name="test")

0 commit comments

Comments
 (0)