66import pytest
77from sklearn .utils ._testing import assert_allclose
88
9- from scikit_opls import OPLS
10- from scikit_opls .validation import _safe_r2_score , permutation_test
9+ from scikit_opls import OPLS , OPLSDA
10+ from scikit_opls .validation import (
11+ _contains_classifier ,
12+ _safe_r2_score ,
13+ permutation_test ,
14+ )
1115
1216from ._data import make_regression_data as _regression_data
1317
@@ -91,8 +95,6 @@ def test_permutation_test_n_jobs_is_reproducible():
9195
9296
9397def test_permutation_test_non_regression_estimator_raises ():
94- from scikit_opls import OPLSDA
95-
9698 X , y = _regression_data (seed = 10 )
9799 labels = np .where (y > 0.0 , "hi" , "lo" )
98100 # OPLSDA is a classifier, raising a clean TypeError immediately
@@ -224,8 +226,6 @@ def test_permutation_test_classifier_wrapped_raises():
224226 from sklearn .model_selection import GridSearchCV
225227 from sklearn .pipeline import Pipeline
226228
227- from scikit_opls import OPLSDA
228-
229229 X , y = _regression_data (seed = 42 )
230230 # 1. Pipeline containing a classifier
231231 pipe = Pipeline ([("clf" , OPLSDA ())])
@@ -236,3 +236,24 @@ def test_permutation_test_classifier_wrapped_raises():
236236 gs = GridSearchCV (pipe , {"clf__n_components" : [1 ]})
237237 with pytest .raises (TypeError , match = "classifiers like OPLSDA are not supported" ):
238238 permutation_test (gs , X , y )
239+
240+
241+ def test_contains_classifier_detects_oplsda ():
242+ assert _contains_classifier (OPLSDA ())
243+ assert not _contains_classifier (OPLS ())
244+
245+
246+ def test_contains_classifier_detects_pipeline_ending_in_oplsda ():
247+ from sklearn .pipeline import Pipeline
248+
249+ assert _contains_classifier (Pipeline ([("model" , OPLSDA ())]))
250+ assert not _contains_classifier (Pipeline ([("model" , OPLS ())]))
251+
252+
253+ def test_contains_classifier_detects_grid_search_over_oplsda ():
254+ from sklearn .model_selection import GridSearchCV
255+
256+ search = GridSearchCV (OPLSDA (), {"n_orthogonal" : [0 , 1 ]}, cv = 2 )
257+ assert _contains_classifier (search )
258+ other = GridSearchCV (OPLS (), {"n_orthogonal" : [0 , 1 ]}, cv = 2 )
259+ assert not _contains_classifier (other )
0 commit comments