Skip to content

Commit e728456

Browse files
fix(neural): keep inference on the fitted module's device + validate CUDA index
Two related correctness fixes for STATSPAI_TORCH_DEVICE routing: 1. resolve_torch_device() now rejects "cuda:N" with N >= torch.cuda.device_count() instead of silently constructing an out-of-range device. Also tightens the prefix check from startswith("cuda") to "cuda" / startswith("cuda:") so stray strings like "cudafoo" no longer slip through the CUDA branch. 2. DeepIV.effect() and TARNet/CFRNet/DragonNet predict / propensity paths now place the input tensor on next(module.parameters()).device — the device the network was actually fitted on — instead of re-resolving from the env var each call. Previously, flipping STATSPAI_TORCH_DEVICE between fit and effect (or any post-fit device move) raised a cross-device RuntimeError. Tests: explicit device-count guard test in test_torch_device_resolver, and spy-based assertions in test_deepiv / test_neural_causal that effect / predict tensors land on the same device as the fitted parameters.
1 parent 2408818 commit e728456

6 files changed

Lines changed: 70 additions & 14 deletions

File tree

src/statspai/deepiv/deep_iv.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -538,6 +538,7 @@ def fit(self) -> CausalResult:
538538
self._effects = effects
539539
self._x_means = X_means
540540
self._x_stds = X_stds
541+
self._device = device
541542

542543
return CausalResult(
543544
method='DeepIV (Hartford et al. 2017)',
@@ -592,8 +593,7 @@ def effect(self, t0: float, t1: float, X: Optional[np.ndarray] = None) -> np.nda
592593
t1_s = (t1 - self._t_mean) / self._t_std
593594
n = len(X_s)
594595

595-
from ..utils._torch_device import resolve_torch_device
596-
device = resolve_torch_device()
596+
device = next(self._response_net.parameters()).device
597597
X_t = torch.tensor(X_s, dtype=torch.float32, device=device)
598598

599599
with torch.no_grad():

src/statspai/neural_causal/models.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,11 @@ def _build_head(input_dim, hidden_layers):
395395
return nn.Sequential(*layers)
396396

397397

398+
def _module_device(module):
399+
"""Return the device that holds a fitted torch module's parameters."""
400+
return next(module.parameters()).device
401+
402+
398403
# ======================================================================
399404
# TARNet
400405
# ======================================================================
@@ -554,6 +559,7 @@ def fit(self) -> CausalResult:
554559
self._head_0 = head_0
555560
self._head_1 = head_1
556561
self._cate = cate
562+
self._device = device
557563

558564
model_info = self._build_model_info(cate, D, n)
559565

@@ -595,7 +601,8 @@ def effect(self, X_new: Optional[np.ndarray] = None) -> np.ndarray:
595601

596602
X_new = np.asarray(X_new, dtype=np.float32)
597603
X_s = (X_new - self._x_mean) / self._x_std
598-
X_t = torch.tensor(X_s, dtype=torch.float32)
604+
device = _module_device(self._repr_net)
605+
X_t = torch.tensor(X_s, dtype=torch.float32, device=device)
599606

600607
self._repr_net.eval()
601608
self._head_0.eval()
@@ -796,6 +803,7 @@ def fit(self) -> CausalResult:
796803
self._head_0 = head_0
797804
self._head_1 = head_1
798805
self._cate = cate
806+
self._device = device
799807

800808
model_info = {
801809
'architecture': 'CFRNet',
@@ -844,7 +852,8 @@ def effect(self, X_new: Optional[np.ndarray] = None) -> np.ndarray:
844852

845853
X_new = np.asarray(X_new, dtype=np.float32)
846854
X_s = (X_new - self._x_mean) / self._x_std
847-
X_t = torch.tensor(X_s, dtype=torch.float32)
855+
device = _module_device(self._repr_net)
856+
X_t = torch.tensor(X_s, dtype=torch.float32, device=device)
848857

849858
self._repr_net.eval()
850859
self._head_0.eval()
@@ -1085,6 +1094,7 @@ def fit(self) -> CausalResult:
10851094
self._prop_head = prop_head
10861095
self._cate = cate
10871096
self._e_hat = e_hat
1097+
self._device = device
10881098

10891099
model_info = {
10901100
'architecture': 'DragonNet',
@@ -1139,7 +1149,8 @@ def effect(self, X_new: Optional[np.ndarray] = None) -> np.ndarray:
11391149

11401150
X_new = np.asarray(X_new, dtype=np.float32)
11411151
X_s = (X_new - self._x_mean) / self._x_std
1142-
X_t = torch.tensor(X_s, dtype=torch.float32)
1152+
device = _module_device(self._repr_net)
1153+
X_t = torch.tensor(X_s, dtype=torch.float32, device=device)
11431154

11441155
self._repr_net.eval()
11451156
self._head_0.eval()
@@ -1176,7 +1187,8 @@ def propensity(self, X_new: Optional[np.ndarray] = None) -> np.ndarray:
11761187

11771188
X_new = np.asarray(X_new, dtype=np.float32)
11781189
X_s = (X_new - self._x_mean) / self._x_std
1179-
X_t = torch.tensor(X_s, dtype=torch.float32)
1190+
device = _module_device(self._repr_net)
1191+
X_t = torch.tensor(X_s, dtype=torch.float32, device=device)
11801192

11811193
self._repr_net.eval()
11821194
self._prop_head.eval()

src/statspai/utils/_torch_device.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,13 +69,19 @@ def resolve_torch_device(prefer: Optional[str] = None):
6969
return torch.device("mps")
7070
return torch.device("cpu")
7171

72-
if spec.startswith("cuda"):
72+
if spec == "cuda" or spec.startswith("cuda:"):
7373
if not torch.cuda.is_available():
7474
raise RuntimeError(
7575
f"{_ENV_VAR}={raw!r} requested CUDA but torch.cuda.is_available() is False. "
7676
"Install a CUDA-enabled PyTorch build or set STATSPAI_TORCH_DEVICE=cpu."
7777
)
78-
return torch.device(spec)
78+
device = torch.device(spec)
79+
if device.index is not None and device.index >= torch.cuda.device_count():
80+
raise RuntimeError(
81+
f"{_ENV_VAR}={raw!r} requested CUDA device {device.index}, "
82+
f"but only {torch.cuda.device_count()} device(s) are available."
83+
)
84+
return device
7985

8086
if spec == "mps":
8187
if not _mps_available(torch):

tests/test_deepiv.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,17 +101,20 @@ def test_class_interface(self, linear_iv_data):
101101
result = est.fit()
102102
assert isinstance(result, CausalResult)
103103

104-
def test_effect_method(self, linear_iv_data):
104+
def test_effect_method(self, linear_iv_data, monkeypatch):
105105
est = DeepIV(
106106
data=linear_iv_data, y='y', treat='treat',
107107
instruments=['instrument'], covariates=['covar'],
108108
first_stage_epochs=20, second_stage_epochs=20,
109109
hidden_layers=(32,), n_components=3,
110110
)
111111
est.fit()
112+
monkeypatch.setenv("STATSPAI_TORCH_DEVICE", "cuda")
113+
monkeypatch.setattr(torch.cuda, "is_available", lambda: False)
112114
effects = est.effect(t0=0.0, t1=1.0)
113115
assert len(effects) == 2000
114116
assert np.isfinite(effects).all()
117+
assert next(est._response_net.parameters()).device.type == "cpu"
115118

116119

117120
class TestDeepIVValidation:

tests/test_neural_causal.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
import numpy as np
77
import pandas as pd
88

9-
pytest.importorskip("torch", reason="PyTorch required for neural causal tests")
9+
torch = pytest.importorskip(
10+
"torch", reason="PyTorch required for neural causal tests"
11+
)
1012

1113
from statspai.neural_causal import (
1214
tarnet, cfrnet, dragonnet,
@@ -15,6 +17,23 @@
1517
from statspai.core.results import CausalResult
1618

1719

20+
def _spy_tensor_devices(monkeypatch):
21+
"""Capture devices passed to torch.tensor after model fitting."""
22+
devices = []
23+
original_tensor = torch.tensor
24+
25+
def spy_tensor(*args, **kwargs):
26+
devices.append(kwargs.get("device"))
27+
return original_tensor(*args, **kwargs)
28+
29+
monkeypatch.setattr(torch, "tensor", spy_tensor)
30+
return devices
31+
32+
33+
def _module_device(module):
34+
return next(module.parameters()).device
35+
36+
1837
# ======================================================================
1938
# Fixtures: DGPs with known true effects
2039
# ======================================================================
@@ -138,16 +157,18 @@ def test_class_interface(self, small_data):
138157
cate = est.effect()
139158
assert len(cate) == len(small_data)
140159

141-
def test_effect_new_data(self, small_data):
160+
def test_effect_new_data(self, small_data, monkeypatch):
142161
est = TARNet(data=small_data, y='y', treat='d',
143162
covariates=['x1', 'x2'],
144163
epochs=50, repr_layers=(64,),
145164
head_layers=(32,), n_bootstrap=50)
146165
est.fit()
147166

148167
X_new = np.random.randn(10, 2).astype(np.float32)
168+
devices = _spy_tensor_devices(monkeypatch)
149169
cate_new = est.effect(X_new)
150170
assert len(cate_new) == 10
171+
assert devices[-1] == _module_device(est._repr_net)
151172

152173
def test_summary_renders(self, small_data):
153174
result = tarnet(small_data, y='y', treat='d',
@@ -233,7 +254,7 @@ def test_citation(self, small_data):
233254
bib = result.cite()
234255
assert 'shalit2017' in bib
235256

236-
def test_class_effect_method(self, small_data):
257+
def test_class_effect_method(self, small_data, monkeypatch):
237258
est = CFRNet(data=small_data, y='y', treat='d',
238259
covariates=['x1', 'x2'],
239260
epochs=50, repr_layers=(64,),
@@ -243,8 +264,10 @@ def test_class_effect_method(self, small_data):
243264
assert len(cate) == len(small_data)
244265

245266
X_new = np.random.randn(5, 2).astype(np.float32)
267+
devices = _spy_tensor_devices(monkeypatch)
246268
cate_new = est.effect(X_new)
247269
assert len(cate_new) == 5
270+
assert devices[-1] == _module_device(est._repr_net)
248271

249272

250273
# ======================================================================
@@ -285,18 +308,20 @@ def test_propensity_scores(self, small_data):
285308
assert np.all(e >= 0.01)
286309
assert np.all(e <= 0.99)
287310

288-
def test_propensity_new_data(self, small_data):
311+
def test_propensity_new_data(self, small_data, monkeypatch):
289312
est = DragonNet(data=small_data, y='y', treat='d',
290313
covariates=['x1', 'x2'],
291314
epochs=50, repr_layers=(64,),
292315
head_layers=(32,), n_bootstrap=50)
293316
est.fit()
294317

295318
X_new = np.random.randn(10, 2).astype(np.float32)
319+
devices = _spy_tensor_devices(monkeypatch)
296320
e_new = est.propensity(X_new)
297321
assert len(e_new) == 10
298322
assert np.all(e_new >= 0.01)
299323
assert np.all(e_new <= 0.99)
324+
assert devices[-1] == _module_device(est._repr_net)
300325

301326
def test_constant_effect_recovery(self, constant_effect_data):
302327
result = dragonnet(constant_effect_data, y='y', treat='d',
@@ -324,7 +349,7 @@ def test_citation(self, small_data):
324349
bib = result.cite()
325350
assert 'shi2019' in bib
326351

327-
def test_effect_method(self, small_data):
352+
def test_effect_method(self, small_data, monkeypatch):
328353
est = DragonNet(data=small_data, y='y', treat='d',
329354
covariates=['x1', 'x2'],
330355
epochs=50, repr_layers=(64,),
@@ -335,8 +360,10 @@ def test_effect_method(self, small_data):
335360
assert len(cate) == len(small_data)
336361

337362
X_new = np.random.randn(5, 2).astype(np.float32)
363+
devices = _spy_tensor_devices(monkeypatch)
338364
cate_new = est.effect(X_new)
339365
assert len(cate_new) == 5
366+
assert devices[-1] == _module_device(est._repr_net)
340367

341368

342369
# ======================================================================

tests/test_torch_device_resolver.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,14 @@ def test_explicit_cuda_returns_when_available(monkeypatch):
5454
assert dev.type == "cuda"
5555

5656

57+
def test_explicit_cuda_index_checks_device_count(monkeypatch):
58+
monkeypatch.setattr(torch.cuda, "is_available", lambda: True)
59+
monkeypatch.setattr(torch.cuda, "device_count", lambda: 1)
60+
monkeypatch.setenv("STATSPAI_TORCH_DEVICE", "cuda:1")
61+
with pytest.raises(RuntimeError, match="only 1 device"):
62+
resolve_torch_device()
63+
64+
5765
def test_auto_falls_back_to_cpu_when_no_accelerator(monkeypatch):
5866
monkeypatch.setattr(torch.cuda, "is_available", lambda: False)
5967
# Force the MPS probe to return False even on Apple Silicon.

0 commit comments

Comments
 (0)