You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
fix(neural): keep inference on the fitted module's device + validate CUDA index
Two related correctness fixes for STATSPAI_TORCH_DEVICE routing:
1. resolve_torch_device() now rejects "cuda:N" with N >= torch.cuda.device_count()
instead of silently constructing an out-of-range device. Also tightens the
prefix check from startswith("cuda") to "cuda" / startswith("cuda:") so
stray strings like "cudafoo" no longer slip through the CUDA branch.
2. DeepIV.effect() and TARNet/CFRNet/DragonNet predict / propensity paths now
place the input tensor on next(module.parameters()).device — the device the
network was actually fitted on — instead of re-resolving from the env var
each call. Previously, flipping STATSPAI_TORCH_DEVICE between fit and effect
(or any post-fit device move) raised a cross-device RuntimeError.
Tests: explicit device-count guard test in test_torch_device_resolver, and
spy-based assertions in test_deepiv / test_neural_causal that effect / predict
tensors land on the same device as the fitted parameters.
0 commit comments