|
1 | 1 | """Permutation testing for OPLS model significance.""" |
2 | 2 |
|
3 | | -# sklearn.base.clone, check_array and joblib.Parallel are under-typed (Parallel is |
4 | | -# annotated as returning Optional); suppress the resulting static-checker false |
5 | | -# positives (the test suite is the real correctness gate). |
| 3 | +# sklearn/joblib typing is incomplete for clone, check_array and Parallel. |
| 4 | +# Runtime validation and tests are the correctness gate. |
6 | 5 | # pyright: reportAttributeAccessIssue=false, reportArgumentType=false |
7 | 6 | # pyright: reportGeneralTypeIssues=false |
8 | 7 |
|
|
31 | 30 | _CVType = int | BaseCrossValidator | BaseShuffleSplit | Iterable | None |
32 | 31 |
|
33 | 32 |
|
| 33 | +def _as_univariate_array(name: str, values: ArrayLike) -> NDArray[np.float64]: |
| 34 | + """Return values as a finite 1D float64 array.""" |
| 35 | + try: |
| 36 | + arr = column_or_1d(np.asarray(values, dtype=np.float64), warn=False) |
| 37 | + except ValueError as exc: |
| 38 | + raise ValueError( |
| 39 | + f"{name} must be univariate; multi-output targets are not supported." |
| 40 | + ) from exc |
| 41 | + if not np.all(np.isfinite(arr)): |
| 42 | + raise ValueError(f"{name} must contain only finite values.") |
| 43 | + return arr |
| 44 | + |
| 45 | + |
34 | 46 | def _safe_r2_score(y_true: ArrayLike, y_pred: ArrayLike) -> float: |
35 | | - y_true_arr = np.asarray(y_true, dtype=np.float64).ravel() |
36 | | - y_pred_arr = np.asarray(y_pred, dtype=np.float64).ravel() |
| 47 | + y_true_arr = _as_univariate_array("y_true", y_true) |
| 48 | + y_pred_arr = _as_univariate_array("y_pred", y_pred) |
37 | 49 | if y_true_arr.shape != y_pred_arr.shape: |
38 | 50 | raise ValueError( |
39 | 51 | "y_true and y_pred must have the same flattened shape, " |
40 | 52 | f"got {y_true_arr.shape} and {y_pred_arr.shape}." |
41 | 53 | ) |
42 | 54 | if not _has_nonzero_variation(y_true_arr): |
43 | | - # sklearn's r2_score defines constant-target cases awkwardly for model |
44 | | - # significance; NaN makes the undefined metric explicit downstream. |
45 | | - return np.nan |
| 55 | + return float("nan") |
46 | 56 | return float(r2_score(y_true_arr, y_pred_arr)) |
47 | 57 |
|
48 | 58 |
|
49 | 59 | def _cross_val_q2( |
50 | 60 | estimator: BaseEstimator, X: ArrayLike, y: ArrayLike, cv: _CVType |
51 | 61 | ) -> float: |
52 | | - """Out-of-fold Q2 of ``estimator`` on ``(X, y)`` using the provided ``cv``.""" |
| 62 | + """Out-of-fold Q2 of ``estimator`` on ``(X, y)``.""" |
53 | 63 | y_pred = cross_val_predict(clone(estimator), X, y, cv=cv) |
54 | 64 | return _safe_r2_score(y, y_pred) |
55 | 65 |
|
@@ -77,42 +87,71 @@ class PermutationResult: |
77 | 87 |
|
78 | 88 |
|
79 | 89 | def _fitted_r2y(fitted: BaseEstimator) -> float: |
80 | | - # GridSearchCV and similar search estimators expose the selected model through |
81 | | - # best_estimator_; recurse until we reach the OPLS-like estimator itself. |
82 | 90 | if hasattr(fitted, "r2y_"): |
83 | 91 | return float(getattr(fitted, "r2y_")) |
84 | | - if hasattr(fitted, "cv_results_") and not hasattr(fitted, "best_estimator_"): |
| 92 | + |
| 93 | + best = getattr(fitted, "best_estimator_", None) |
| 94 | + if best is not None: |
| 95 | + return _fitted_r2y(best) |
| 96 | + |
| 97 | + if hasattr(fitted, "cv_results_"): |
85 | 98 | raise TypeError( |
86 | 99 | "Search meta-estimators must use refit=True so permutation_test can " |
87 | 100 | "access best_estimator_." |
88 | 101 | ) |
89 | | - if hasattr(fitted, "best_estimator_"): |
90 | | - return _fitted_r2y(getattr(fitted, "best_estimator_")) |
| 102 | + |
91 | 103 | raise TypeError( |
92 | | - "permutation_test requires an OPLS-like regression estimator exposing r2y_, " |
93 | | - "or a GridSearchCV wrapping one." |
| 104 | + "permutation_test requires an OPLS-like regression estimator exposing " |
| 105 | + "r2y_, or a refit-enabled search estimator wrapping one." |
94 | 106 | ) |
95 | 107 |
|
96 | 108 |
|
97 | 109 | def _permuted_scores( |
98 | 110 | estimator: BaseEstimator, X: ArrayLike, y_perm: ArrayLike, cv: _CVType |
99 | 111 | ) -> tuple[float, float]: |
100 | | - """R2Y and out-of-fold Q2 for one permuted target (one parallel task).""" |
| 112 | + """Return R2Y/Q2 for one permuted target.""" |
101 | 113 | fitted = clone(estimator).fit(X, y_perm) |
102 | 114 | r2y = _fitted_r2y(fitted) |
103 | 115 | q2 = _cross_val_q2(estimator, X, y_perm, cv=cv) |
104 | 116 | return r2y, q2 |
105 | 117 |
|
106 | 118 |
|
107 | 119 | def _contains_classifier(estimator: BaseEstimator) -> bool: |
108 | | - # Walk simple meta-estimators such as CalibratedClassifierCV(estimator=...). |
| 120 | + """Return whether estimator or a simple wrapped estimator is a classifier.""" |
109 | 121 | if is_classifier(estimator): |
110 | 122 | return True |
111 | | - if hasattr(estimator, "estimator"): |
112 | | - return _contains_classifier(getattr(estimator, "estimator")) |
| 123 | + |
| 124 | + steps = getattr(estimator, "steps", None) |
| 125 | + if steps is not None: |
| 126 | + return any(_contains_classifier(step) for _, step in steps) |
| 127 | + |
| 128 | + for attr in ("estimator", "base_estimator", "best_estimator_"): |
| 129 | + inner = getattr(estimator, attr, None) |
| 130 | + if inner is not None and _contains_classifier(inner): |
| 131 | + return True |
| 132 | + |
113 | 133 | return False |
114 | 134 |
|
115 | 135 |
|
| 136 | +def _resolve_cv(estimator: BaseEstimator, cv: _CVType, y: NDArray[np.float64]): |
| 137 | + if cv is None: |
| 138 | + estimator_cv = getattr(estimator, "cv", None) |
| 139 | + cv = estimator_cv if estimator_cv is not None else min(5, len(y)) |
| 140 | + |
| 141 | + # Materialize one-shot split iterables so observed and permuted passes reuse |
| 142 | + # the same splits instead of consuming the iterator once. |
| 143 | + if cv is not None and not isinstance(cv, Integral) and not hasattr(cv, "split"): |
| 144 | + cv = list(cv) |
| 145 | + |
| 146 | + return check_cv(cv, y=y, classifier=False) |
| 147 | + |
| 148 | + |
| 149 | +def _empirical_p_value(observed: float, permuted: NDArray[np.float64]) -> float: |
| 150 | + if np.isnan(observed): |
| 151 | + return float("nan") |
| 152 | + return float((1 + int(np.sum(permuted >= observed))) / (permuted.size + 1)) |
| 153 | + |
| 154 | + |
116 | 155 | def permutation_test( |
117 | 156 | estimator: BaseEstimator, |
118 | 157 | X: ArrayLike, |
@@ -175,68 +214,33 @@ def permutation_test( |
175 | 214 | n_permutations = _validate_int("n_permutations", n_permutations, minimum=1) |
176 | 215 |
|
177 | 216 | X = check_array(X, dtype=np.float64) |
178 | | - try: |
179 | | - y = column_or_1d(np.asarray(y, dtype=np.float64), warn=False) |
180 | | - except ValueError as exc: |
181 | | - raise ValueError( |
182 | | - "permutation_test currently requires a univariate response; " |
183 | | - "multi-output targets are not supported." |
184 | | - ) from exc |
| 217 | + y = _as_univariate_array("y", y) |
185 | 218 | check_consistent_length(X, y) |
186 | | - if not np.all(np.isfinite(y)): |
187 | | - raise ValueError("y must contain only finite values.") |
188 | 219 | if len(y) < 3: |
189 | 220 | raise ValueError( |
190 | 221 | "permutation_test requires at least 3 samples so each CV training " |
191 | 222 | "fold can contain at least 2 samples." |
192 | 223 | ) |
193 | 224 |
|
194 | | - if cv is None: |
195 | | - estimator_cv = getattr(estimator, "cv", None) |
196 | | - # Prefer an estimator-owned cv setting when present; otherwise keep folds |
197 | | - # valid for small data by capping the default at n_samples. |
198 | | - cv = estimator_cv if estimator_cv is not None else min(5, len(y)) |
199 | | - # A one-shot iterable of splits would be consumed by the observed-Q2 pass and |
200 | | - # leave nothing for the permutations; materialise it so every pass sees the |
201 | | - # same splits. |
202 | | - if cv is not None and not isinstance(cv, Integral) and not hasattr(cv, "split"): |
203 | | - cv = list(cv) |
204 | | - cv_checked = check_cv(cv, y=y, classifier=False) |
| 225 | + cv_checked = _resolve_cv(estimator, cv, y) |
205 | 226 |
|
206 | | - # Fit once on the true labels to establish the observed in-sample R2Y. |
207 | 227 | fitted = clone(estimator).fit(X, y) |
208 | 228 | observed_r2y = _fitted_r2y(fitted) |
209 | | - |
210 | | - rng = check_random_state(random_state) |
211 | | - # Q2 is always out-of-fold, so compute it through the same CV object used for |
212 | | - # every permutation. |
213 | 229 | observed_q2 = _cross_val_q2(estimator, X, y, cv=cv_checked) |
214 | 230 |
|
215 | | - # Draw all permutations serially from the RNG so the result is independent of |
216 | | - # the execution order the parallel backend chooses. |
| 231 | + rng = check_random_state(random_state) |
217 | 232 | perms = [rng.permutation(y) for _ in range(n_permutations)] |
218 | 233 | scored = Parallel(n_jobs=n_jobs)( |
219 | 234 | delayed(_permuted_scores)(estimator, X, y_perm, cv_checked) for y_perm in perms |
220 | 235 | ) |
221 | 236 | permuted_r2y = np.asarray([r2y for r2y, _ in scored], dtype=np.float64) |
222 | 237 | permuted_q2 = np.asarray([q2 for _, q2 in scored], dtype=np.float64) |
223 | 238 |
|
224 | | - # An undefined observed metric (NaN) must not masquerade as significant. |
225 | | - r2y_p = ( |
226 | | - np.nan |
227 | | - if np.isnan(observed_r2y) |
228 | | - else (1 + int(np.sum(permuted_r2y >= observed_r2y))) / (n_permutations + 1) |
229 | | - ) |
230 | | - q2_p = ( |
231 | | - np.nan |
232 | | - if np.isnan(observed_q2) |
233 | | - else (1 + int(np.sum(permuted_q2 >= observed_q2))) / (n_permutations + 1) |
234 | | - ) |
235 | 239 | return PermutationResult( |
236 | 240 | r2y=observed_r2y, |
237 | 241 | q2=observed_q2, |
238 | 242 | permuted_r2y=permuted_r2y, |
239 | 243 | permuted_q2=permuted_q2, |
240 | | - r2y_p_value=float(r2y_p), |
241 | | - q2_p_value=float(q2_p), |
| 244 | + r2y_p_value=_empirical_p_value(observed_r2y, permuted_r2y), |
| 245 | + q2_p_value=_empirical_p_value(observed_q2, permuted_q2), |
242 | 246 | ) |
0 commit comments