Skip to content

Commit 3299a3f

Browse files
committed
fixed validation.py
1 parent ab2ef65 commit 3299a3f

2 files changed

Lines changed: 78 additions & 64 deletions

File tree

src/scikit_opls/validation.py

Lines changed: 64 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
"""Permutation testing for OPLS model significance."""
22

3-
# sklearn.base.clone, check_array and joblib.Parallel are under-typed (Parallel is
4-
# annotated as returning Optional); suppress the resulting static-checker false
5-
# positives (the test suite is the real correctness gate).
3+
# sklearn/joblib typing is incomplete for clone, check_array and Parallel.
4+
# Runtime validation and tests are the correctness gate.
65
# pyright: reportAttributeAccessIssue=false, reportArgumentType=false
76
# pyright: reportGeneralTypeIssues=false
87

@@ -31,25 +30,36 @@
3130
_CVType = int | BaseCrossValidator | BaseShuffleSplit | Iterable | None
3231

3332

33+
def _as_univariate_array(name: str, values: ArrayLike) -> NDArray[np.float64]:
34+
"""Return values as a finite 1D float64 array."""
35+
try:
36+
arr = column_or_1d(np.asarray(values, dtype=np.float64), warn=False)
37+
except ValueError as exc:
38+
raise ValueError(
39+
f"{name} must be univariate; multi-output targets are not supported."
40+
) from exc
41+
if not np.all(np.isfinite(arr)):
42+
raise ValueError(f"{name} must contain only finite values.")
43+
return arr
44+
45+
3446
def _safe_r2_score(y_true: ArrayLike, y_pred: ArrayLike) -> float:
35-
y_true_arr = np.asarray(y_true, dtype=np.float64).ravel()
36-
y_pred_arr = np.asarray(y_pred, dtype=np.float64).ravel()
47+
y_true_arr = _as_univariate_array("y_true", y_true)
48+
y_pred_arr = _as_univariate_array("y_pred", y_pred)
3749
if y_true_arr.shape != y_pred_arr.shape:
3850
raise ValueError(
3951
"y_true and y_pred must have the same flattened shape, "
4052
f"got {y_true_arr.shape} and {y_pred_arr.shape}."
4153
)
4254
if not _has_nonzero_variation(y_true_arr):
43-
# sklearn's r2_score defines constant-target cases awkwardly for model
44-
# significance; NaN makes the undefined metric explicit downstream.
45-
return np.nan
55+
return float("nan")
4656
return float(r2_score(y_true_arr, y_pred_arr))
4757

4858

4959
def _cross_val_q2(
5060
estimator: BaseEstimator, X: ArrayLike, y: ArrayLike, cv: _CVType
5161
) -> float:
52-
"""Out-of-fold Q2 of ``estimator`` on ``(X, y)`` using the provided ``cv``."""
62+
"""Out-of-fold Q2 of ``estimator`` on ``(X, y)``."""
5363
y_pred = cross_val_predict(clone(estimator), X, y, cv=cv)
5464
return _safe_r2_score(y, y_pred)
5565

@@ -77,42 +87,71 @@ class PermutationResult:
7787

7888

7989
def _fitted_r2y(fitted: BaseEstimator) -> float:
80-
# GridSearchCV and similar search estimators expose the selected model through
81-
# best_estimator_; recurse until we reach the OPLS-like estimator itself.
8290
if hasattr(fitted, "r2y_"):
8391
return float(getattr(fitted, "r2y_"))
84-
if hasattr(fitted, "cv_results_") and not hasattr(fitted, "best_estimator_"):
92+
93+
best = getattr(fitted, "best_estimator_", None)
94+
if best is not None:
95+
return _fitted_r2y(best)
96+
97+
if hasattr(fitted, "cv_results_"):
8598
raise TypeError(
8699
"Search meta-estimators must use refit=True so permutation_test can "
87100
"access best_estimator_."
88101
)
89-
if hasattr(fitted, "best_estimator_"):
90-
return _fitted_r2y(getattr(fitted, "best_estimator_"))
102+
91103
raise TypeError(
92-
"permutation_test requires an OPLS-like regression estimator exposing r2y_, "
93-
"or a GridSearchCV wrapping one."
104+
"permutation_test requires an OPLS-like regression estimator exposing "
105+
"r2y_, or a refit-enabled search estimator wrapping one."
94106
)
95107

96108

97109
def _permuted_scores(
98110
estimator: BaseEstimator, X: ArrayLike, y_perm: ArrayLike, cv: _CVType
99111
) -> tuple[float, float]:
100-
"""R2Y and out-of-fold Q2 for one permuted target (one parallel task)."""
112+
"""Return R2Y/Q2 for one permuted target."""
101113
fitted = clone(estimator).fit(X, y_perm)
102114
r2y = _fitted_r2y(fitted)
103115
q2 = _cross_val_q2(estimator, X, y_perm, cv=cv)
104116
return r2y, q2
105117

106118

107119
def _contains_classifier(estimator: BaseEstimator) -> bool:
108-
# Walk simple meta-estimators such as CalibratedClassifierCV(estimator=...).
120+
"""Return whether estimator or a simple wrapped estimator is a classifier."""
109121
if is_classifier(estimator):
110122
return True
111-
if hasattr(estimator, "estimator"):
112-
return _contains_classifier(getattr(estimator, "estimator"))
123+
124+
steps = getattr(estimator, "steps", None)
125+
if steps is not None:
126+
return any(_contains_classifier(step) for _, step in steps)
127+
128+
for attr in ("estimator", "base_estimator", "best_estimator_"):
129+
inner = getattr(estimator, attr, None)
130+
if inner is not None and _contains_classifier(inner):
131+
return True
132+
113133
return False
114134

115135

136+
def _resolve_cv(estimator: BaseEstimator, cv: _CVType, y: NDArray[np.float64]):
137+
if cv is None:
138+
estimator_cv = getattr(estimator, "cv", None)
139+
cv = estimator_cv if estimator_cv is not None else min(5, len(y))
140+
141+
# Materialize one-shot split iterables so observed and permuted passes reuse
142+
# the same splits instead of consuming the iterator once.
143+
if cv is not None and not isinstance(cv, Integral) and not hasattr(cv, "split"):
144+
cv = list(cv)
145+
146+
return check_cv(cv, y=y, classifier=False)
147+
148+
149+
def _empirical_p_value(observed: float, permuted: NDArray[np.float64]) -> float:
150+
if np.isnan(observed):
151+
return float("nan")
152+
return float((1 + int(np.sum(permuted >= observed))) / (permuted.size + 1))
153+
154+
116155
def permutation_test(
117156
estimator: BaseEstimator,
118157
X: ArrayLike,
@@ -175,68 +214,33 @@ def permutation_test(
175214
n_permutations = _validate_int("n_permutations", n_permutations, minimum=1)
176215

177216
X = check_array(X, dtype=np.float64)
178-
try:
179-
y = column_or_1d(np.asarray(y, dtype=np.float64), warn=False)
180-
except ValueError as exc:
181-
raise ValueError(
182-
"permutation_test currently requires a univariate response; "
183-
"multi-output targets are not supported."
184-
) from exc
217+
y = _as_univariate_array("y", y)
185218
check_consistent_length(X, y)
186-
if not np.all(np.isfinite(y)):
187-
raise ValueError("y must contain only finite values.")
188219
if len(y) < 3:
189220
raise ValueError(
190221
"permutation_test requires at least 3 samples so each CV training "
191222
"fold can contain at least 2 samples."
192223
)
193224

194-
if cv is None:
195-
estimator_cv = getattr(estimator, "cv", None)
196-
# Prefer an estimator-owned cv setting when present; otherwise keep folds
197-
# valid for small data by capping the default at n_samples.
198-
cv = estimator_cv if estimator_cv is not None else min(5, len(y))
199-
# A one-shot iterable of splits would be consumed by the observed-Q2 pass and
200-
# leave nothing for the permutations; materialise it so every pass sees the
201-
# same splits.
202-
if cv is not None and not isinstance(cv, Integral) and not hasattr(cv, "split"):
203-
cv = list(cv)
204-
cv_checked = check_cv(cv, y=y, classifier=False)
225+
cv_checked = _resolve_cv(estimator, cv, y)
205226

206-
# Fit once on the true labels to establish the observed in-sample R2Y.
207227
fitted = clone(estimator).fit(X, y)
208228
observed_r2y = _fitted_r2y(fitted)
209-
210-
rng = check_random_state(random_state)
211-
# Q2 is always out-of-fold, so compute it through the same CV object used for
212-
# every permutation.
213229
observed_q2 = _cross_val_q2(estimator, X, y, cv=cv_checked)
214230

215-
# Draw all permutations serially from the RNG so the result is independent of
216-
# the execution order the parallel backend chooses.
231+
rng = check_random_state(random_state)
217232
perms = [rng.permutation(y) for _ in range(n_permutations)]
218233
scored = Parallel(n_jobs=n_jobs)(
219234
delayed(_permuted_scores)(estimator, X, y_perm, cv_checked) for y_perm in perms
220235
)
221236
permuted_r2y = np.asarray([r2y for r2y, _ in scored], dtype=np.float64)
222237
permuted_q2 = np.asarray([q2 for _, q2 in scored], dtype=np.float64)
223238

224-
# An undefined observed metric (NaN) must not masquerade as significant.
225-
r2y_p = (
226-
np.nan
227-
if np.isnan(observed_r2y)
228-
else (1 + int(np.sum(permuted_r2y >= observed_r2y))) / (n_permutations + 1)
229-
)
230-
q2_p = (
231-
np.nan
232-
if np.isnan(observed_q2)
233-
else (1 + int(np.sum(permuted_q2 >= observed_q2))) / (n_permutations + 1)
234-
)
235239
return PermutationResult(
236240
r2y=observed_r2y,
237241
q2=observed_q2,
238242
permuted_r2y=permuted_r2y,
239243
permuted_q2=permuted_q2,
240-
r2y_p_value=float(r2y_p),
241-
q2_p_value=float(q2_p),
244+
r2y_p_value=_empirical_p_value(observed_r2y, permuted_r2y),
245+
q2_p_value=_empirical_p_value(observed_q2, permuted_q2),
242246
)

tests/test_preprocessing.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,10 +79,7 @@ def test_apply_scaling_validates_representative_bad_inputs():
7979
# 2. wrong mean_ shape
8080
with pytest.raises(ValueError, match="mean_ must have shape"):
8181
apply_scaling(X, np.zeros(2), np.ones(3))
82-
# 3. zero scale
83-
with pytest.raises(ValueError, match="scale_ must not contain zeros"):
84-
apply_scaling(X, np.zeros(3), np.array([1.0, 0.0, 1.0]))
85-
# 4. nonfinite input
82+
# 3. nonfinite input
8683
with pytest.raises(ValueError, match="finite"):
8784
apply_scaling(np.array([[1.0, np.inf, 1.0]]), np.zeros(3), np.ones(3))
8885

@@ -162,3 +159,16 @@ def test_apply_scaling_rejects_negative_scale():
162159

163160
with pytest.raises(ValueError, match="positive"):
164161
apply_scaling(X, mean, scale)
162+
163+
164+
@pytest.mark.parametrize(
165+
"scale",
166+
[
167+
np.array([1.0, 0.0, 1.0]),
168+
np.array([1.0, -1.0, 1.0]),
169+
],
170+
)
171+
def test_apply_scaling_rejects_non_positive_scale(scale):
172+
X = np.ones((4, 3))
173+
with pytest.raises(ValueError, match="scale_ must contain only positive values"):
174+
apply_scaling(X, np.zeros(3), scale)

0 commit comments

Comments
 (0)