@@ -8,12 +8,13 @@ class with a :meth:`from_estimator` constructor that computes the plotted arrays
88imported lazily inside :meth:`plot`, so importing this module never requires it.
99"""
1010
11- # check_array is under-typed (its dtype kwarg) ; suppress the resulting
12- # static-checker false positives .
11+ # The sklearn validation helpers are under-typed ; suppress false positives from
12+ # fittedness/metadata validation calls in this plotting adapter .
1313# pyright: reportArgumentType=false
1414
1515from __future__ import annotations
1616
17+ import warnings
1718from numbers import Integral
1819from typing import TYPE_CHECKING
1920
@@ -22,7 +23,7 @@ class with a :meth:`from_estimator` constructor that computes the plotted arrays
2223from scipy import sparse
2324from sklearn .base import BaseEstimator
2425from sklearn .pipeline import Pipeline
25- from sklearn .utils .validation import check_array , check_is_fitted
26+ from sklearn .utils .validation import check_is_fitted
2627
2728from scikit_opls ._opls import OPLS
2829from scikit_opls ._opls_da import OPLSDA
@@ -86,11 +87,23 @@ def _unwrap_estimator_and_data(
8687 X = upstream .transform (X )
8788 inner = final_estimator
8889
90+ if sparse .issparse (X ):
91+ raise TypeError (
92+ "Input to OPLS plotting is sparse, but plotting requires a dense "
93+ "matrix. If it came from a Pipeline, add a densifying transformer "
94+ "before the final OPLS step."
95+ )
96+
8997 if isinstance (inner , OPLSDA ):
9098 check_is_fitted (inner )
99+ # Validate against the outer classifier first: it owns n_features_in_ and
100+ # feature_names_in_ for user-facing OPLSDA calls.
101+ X_checked = inner ._validate_X_predict (X )
91102 # OPLSDA is a classifier wrapper; its latent space lives on the inner OPLS.
92103 base = inner .opls_
93104 elif isinstance (inner , OPLS ):
105+ check_is_fitted (inner )
106+ X_checked = inner ._validate_X_predict (X )
94107 base = inner
95108 else :
96109 raise TypeError (
@@ -103,13 +116,18 @@ def _unwrap_estimator_and_data(
103116 if not isinstance (base , OPLS ):
104117 raise TypeError ("estimator.opls_ must be a fitted OPLS instance." )
105118 check_is_fitted (base )
106- if sparse .issparse (X ):
119+ if sparse .issparse (X_checked ):
107120 raise TypeError (
108121 "Input to OPLS plotting is sparse, but plotting requires a dense "
109122 "matrix. If it came from a Pipeline, add a densifying transformer "
110123 "before the final OPLS step."
111124 )
112- return base , check_array (X , dtype = np .float64 , ensure_min_samples = ensure_min_samples )
125+ if X_checked .shape [0 ] < ensure_min_samples :
126+ raise ValueError (
127+ f"Found array with { X_checked .shape [0 ]} sample(s), while a minimum "
128+ f"of { ensure_min_samples } is required."
129+ )
130+ return base , X_checked
113131
114132
115133class OPLSScoresDisplay :
@@ -237,7 +255,7 @@ def from_estimator(
237255
238256 # Project supplied data through the fitted filter before asking the PLS
239257 # engine for predictive scores.
240- X_filtered , t_ortho = base ._filter (X_trans )
258+ X_filtered , t_ortho = base ._filter_validated (X_trans )
241259 scores = base .pls_ .transform (X_filtered )
242260 if isinstance (scores , tuple ):
243261 t_pred_arr = scores [0 ]
@@ -353,6 +371,7 @@ def from_estimator(
353371 X : ArrayLike ,
354372 * ,
355373 component : int = 0 ,
374+ x_space : str = "centered" ,
356375 ax : matplotlib .axes .Axes | None = None ,
357376 ) -> SPlotDisplay :
358377 """Compute the S-plot arrays from a fitted ``estimator`` and plot them.
@@ -367,6 +386,11 @@ def from_estimator(
367386 Samples to project.
368387 component : int, default=0
369388 The index of the predictive PLS component to plot.
389+ x_space : {"centered", "scaled", "subset-centered"}, default="centered"
390+ Feature space used for the covariance/correlation axes. ``"centered"``
391+ uses original feature units centered by the fitted training mean,
392+ ``"scaled"`` uses the model-scaled feature space, and
393+ ``"subset-centered"`` centers the provided X subset by its own mean.
370394 ax : matplotlib Axes, default=None
371395 Target axes; a new figure/axes is created when ``None``.
372396
@@ -384,20 +408,42 @@ def from_estimator(
384408 f"component={ component } is out of bounds for estimator with "
385409 f"{ n_pred } predictive component(s)."
386410 )
411+ if x_space not in {"centered" , "scaled" , "subset-centered" }:
412+ raise ValueError (
413+ "x_space must be one of {'centered', 'scaled', 'subset-centered'}."
414+ )
415+ if X_trans .shape [0 ] != base .x_scores_ .shape [0 ]:
416+ warnings .warn (
417+ "SPlotDisplay is usually intended for the training data. "
418+ "The provided X has a different number of samples than the fitted "
419+ "data; covariance and correlation will be computed on this subset." ,
420+ UserWarning ,
421+ stacklevel = 2 ,
422+ )
387423
388- # S-plots are computed in the final OPLS input space, after applying the
389- # same scaling used at fit time and centering the provided sample subset.
390- Xs = apply_scaling (X_trans , base .x_mean_ , base .x_std_ )
391- Xs = Xs - Xs .mean (axis = 0 )
424+ # Scores always come from the fitted model preprocessing/filtering. The
425+ # S-plot axes can use original-unit or model-scaled feature space.
426+ if x_space == "centered" :
427+ X_for_splot = X_trans - base .x_mean_
428+ elif x_space == "scaled" :
429+ X_for_splot = apply_scaling (X_trans , base .x_mean_ , base .x_std_ )
430+ else :
431+ X_for_splot = X_trans - X_trans .mean (axis = 0 )
392432
393433 # Use the fitted predictive score for the selected component as the common
394434 # reference vector for both covariance and correlation.
395- t = np .asarray (base .transform (X_trans ))[:, component ]
435+ X_filtered , _ = base ._filter_validated (X_trans )
436+ scores = base .pls_ .transform (X_filtered )
437+ if isinstance (scores , tuple ):
438+ t_arr = scores [0 ]
439+ else :
440+ t_arr = scores
441+ t = np .asarray (t_arr )[:, component ]
396442 t = t - t .mean ()
397443 n = t .shape [0 ]
398444
399- covariance = Xs .T @ t / max (n - 1 , 1 )
400- x_std = Xs .std (axis = 0 , ddof = 1 )
445+ covariance = X_for_splot .T @ t / max (n - 1 , 1 )
446+ x_std = X_for_splot .std (axis = 0 , ddof = 1 )
401447 t_std = float (t .std (ddof = 1 ))
402448 if t_std <= 1e-12 :
403449 raise ValueError ("Predictive score has zero variance; S-plot is undefined." )
@@ -409,8 +455,6 @@ def from_estimator(
409455 correlation [valid ] = covariance [valid ] / denom [valid ]
410456
411457 if np .any (~ valid ):
412- import warnings
413-
414458 warnings .warn (
415459 "Some features have zero variance; their S-plot correlations are NaN." ,
416460 RuntimeWarning ,
0 commit comments