Skip to content

Commit b877dae

Browse files
committed
small validation fix to remove contains_classifier
1 parent 3299a3f commit b877dae

2 files changed

Lines changed: 29 additions & 20 deletions

File tree

src/scikit_opls/validation.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -117,20 +117,8 @@ def _permuted_scores(
117117

118118

119119
def _contains_classifier(estimator: BaseEstimator) -> bool:
120-
"""Return whether estimator or a simple wrapped estimator is a classifier."""
121-
if is_classifier(estimator):
122-
return True
123-
124-
steps = getattr(estimator, "steps", None)
125-
if steps is not None:
126-
return any(_contains_classifier(step) for _, step in steps)
127-
128-
for attr in ("estimator", "base_estimator", "best_estimator_"):
129-
inner = getattr(estimator, attr, None)
130-
if inner is not None and _contains_classifier(inner):
131-
return True
132-
133-
return False
120+
"""Return whether estimator is tagged as a classifier."""
121+
return is_classifier(estimator)
134122

135123

136124
def _resolve_cv(estimator: BaseEstimator, cv: _CVType, y: NDArray[np.float64]):

tests/test_validation.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,12 @@
66
import pytest
77
from sklearn.utils._testing import assert_allclose
88

9-
from scikit_opls import OPLS
10-
from scikit_opls.validation import _safe_r2_score, permutation_test
9+
from scikit_opls import OPLS, OPLSDA
10+
from scikit_opls.validation import (
11+
_contains_classifier,
12+
_safe_r2_score,
13+
permutation_test,
14+
)
1115

1216
from ._data import make_regression_data as _regression_data
1317

@@ -91,8 +95,6 @@ def test_permutation_test_n_jobs_is_reproducible():
9195

9296

9397
def test_permutation_test_non_regression_estimator_raises():
94-
from scikit_opls import OPLSDA
95-
9698
X, y = _regression_data(seed=10)
9799
labels = np.where(y > 0.0, "hi", "lo")
98100
# OPLSDA is a classifier, raising a clean TypeError immediately
@@ -224,8 +226,6 @@ def test_permutation_test_classifier_wrapped_raises():
224226
from sklearn.model_selection import GridSearchCV
225227
from sklearn.pipeline import Pipeline
226228

227-
from scikit_opls import OPLSDA
228-
229229
X, y = _regression_data(seed=42)
230230
# 1. Pipeline containing a classifier
231231
pipe = Pipeline([("clf", OPLSDA())])
@@ -236,3 +236,24 @@ def test_permutation_test_classifier_wrapped_raises():
236236
gs = GridSearchCV(pipe, {"clf__n_components": [1]})
237237
with pytest.raises(TypeError, match="classifiers like OPLSDA are not supported"):
238238
permutation_test(gs, X, y)
239+
240+
241+
def test_contains_classifier_detects_oplsda():
242+
assert _contains_classifier(OPLSDA())
243+
assert not _contains_classifier(OPLS())
244+
245+
246+
def test_contains_classifier_detects_pipeline_ending_in_oplsda():
247+
from sklearn.pipeline import Pipeline
248+
249+
assert _contains_classifier(Pipeline([("model", OPLSDA())]))
250+
assert not _contains_classifier(Pipeline([("model", OPLS())]))
251+
252+
253+
def test_contains_classifier_detects_grid_search_over_oplsda():
254+
from sklearn.model_selection import GridSearchCV
255+
256+
search = GridSearchCV(OPLSDA(), {"n_orthogonal": [0, 1]}, cv=2)
257+
assert _contains_classifier(search)
258+
other = GridSearchCV(OPLS(), {"n_orthogonal": [0, 1]}, cv=2)
259+
assert not _contains_classifier(other)

0 commit comments

Comments
 (0)