Skip to content

Commit 9c7348e

Browse files
committed
fixed name error and plotting in OPLSDA
1 parent 8a5d687 commit 9c7348e

5 files changed

Lines changed: 159 additions & 25 deletions

File tree

src/scikit_opls/_opls.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -493,19 +493,36 @@ def score(self, X: ArrayLike, y: ArrayLike, sample_weight=None) -> float:
493493
"""
494494
return super().score(X, y, sample_weight)
495495

496-
def _filter(self, X: ArrayLike) -> tuple[NDArray[np.float64], NDArray[np.float64]]:
497-
"""Preprocess and orthogonal-filter new ``X`` exactly as at fit time.
496+
def _validate_X_predict(self, X: ArrayLike) -> NDArray[np.float64]: # noqa: N802
497+
"""Validate prediction/projection input against fitted OPLS metadata."""
498+
check_is_fitted(self)
499+
return validate_data(
500+
self,
501+
X,
502+
dtype=np.float64,
503+
copy=self.copy,
504+
reset=False,
505+
)
498506

499-
Returns the filtered ``X`` and the orthogonal scores.
500-
"""
501-
X = validate_data(self, X, reset=False, dtype=np.float64)
507+
def _filter_validated(
508+
self, X: NDArray[np.float64]
509+
) -> tuple[NDArray[np.float64], NDArray[np.float64]]:
510+
"""Filter an already validated dense array without checking names again."""
502511
Xs = apply_scaling(X, self.x_mean_, self.x_std_)
503512
# apply_orthogonal_filter returns both the filtered matrix for prediction
504513
# and the replayed orthogonal scores for transform_orthogonal().
505514
return apply_orthogonal_filter(
506515
Xs, self.x_ortho_weights_, self.x_ortho_loadings_
507516
)
508517

518+
def _filter(self, X: ArrayLike) -> tuple[NDArray[np.float64], NDArray[np.float64]]:
519+
"""Preprocess and orthogonal-filter new ``X`` exactly as at fit time.
520+
521+
Returns the filtered ``X`` and the orthogonal scores.
522+
"""
523+
X_valid = self._validate_X_predict(X)
524+
return self._filter_validated(X_valid)
525+
509526
def __sklearn_tags__(self):
510527
tags = super().__sklearn_tags__()
511528
tags.regressor_tags.poor_score = True

src/scikit_opls/_opls_da.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -156,20 +156,25 @@ def fit(self, X: ArrayLike, y: ArrayLike) -> OPLSDA:
156156
self.n_orthogonal_ = self.opls_.n_orthogonal_
157157
return self
158158

159-
def _validate_x_predict(self, X: ArrayLike) -> NDArray[np.float64]:
159+
def _validate_X_predict(self, X: ArrayLike) -> NDArray[np.float64]: # noqa: N802
160160
"""Validate prediction input against the outer OPLSDA fit contract."""
161161
check_is_fitted(self)
162162

163163
# validate_data(..., reset=False) checks n_features_in_ and feature_names_in_
164164
# against OPLSDA, then returns a nameless ndarray. Passing that ndarray to
165165
# the inner OPLS avoids the spurious "fitted without feature names" warning.
166-
return validate_data(
166+
X_valid = validate_data(
167167
self,
168168
X,
169169
dtype=np.float64,
170170
copy=self.copy,
171171
reset=False,
172172
)
173+
return np.asarray(X_valid, dtype=np.float64)
174+
175+
def _validate_x_predict(self, X: ArrayLike) -> NDArray[np.float64]:
176+
"""Backward-compatible alias for the canonical validation helper."""
177+
return self._validate_X_predict(X)
173178

174179
def decision_function(self, X: ArrayLike) -> NDArray[np.float64]:
175180
"""Raw signed OPLS regression output; positive favours ``classes_[1]``.
@@ -185,8 +190,9 @@ def decision_function(self, X: ArrayLike) -> NDArray[np.float64]:
185190
Signed confidence; ``> 0`` predicts ``classes_[1]``. Scores equal to
186191
zero are assigned to ``classes_[0]`` by :meth:`predict`.
187192
"""
188-
X_valid = self._validate_x_predict(X)
189-
return np.asarray(self.opls_.predict(X_valid), dtype=np.float64).ravel()
193+
X_valid = self._validate_X_predict(X)
194+
X_filtered, _ = self.opls_._filter_validated(X_valid)
195+
return np.asarray(self.opls_.pls_.predict(X_filtered), dtype=np.float64).ravel()
190196

191197
def predict(self, X: ArrayLike) -> NDArray:
192198
"""Predict class labels.

src/scikit_opls/plotting.py

Lines changed: 59 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,13 @@ class with a :meth:`from_estimator` constructor that computes the plotted arrays
88
imported lazily inside :meth:`plot`, so importing this module never requires it.
99
"""
1010

11-
# check_array is under-typed (its dtype kwarg); suppress the resulting
12-
# static-checker false positives.
11+
# The sklearn validation helpers are under-typed; suppress false positives from
12+
# fittedness/metadata validation calls in this plotting adapter.
1313
# pyright: reportArgumentType=false
1414

1515
from __future__ import annotations
1616

17+
import warnings
1718
from numbers import Integral
1819
from typing import TYPE_CHECKING
1920

@@ -22,7 +23,7 @@ class with a :meth:`from_estimator` constructor that computes the plotted arrays
2223
from scipy import sparse
2324
from sklearn.base import BaseEstimator
2425
from sklearn.pipeline import Pipeline
25-
from sklearn.utils.validation import check_array, check_is_fitted
26+
from sklearn.utils.validation import check_is_fitted
2627

2728
from scikit_opls._opls import OPLS
2829
from scikit_opls._opls_da import OPLSDA
@@ -86,11 +87,23 @@ def _unwrap_estimator_and_data(
8687
X = upstream.transform(X)
8788
inner = final_estimator
8889

90+
if sparse.issparse(X):
91+
raise TypeError(
92+
"Input to OPLS plotting is sparse, but plotting requires a dense "
93+
"matrix. If it came from a Pipeline, add a densifying transformer "
94+
"before the final OPLS step."
95+
)
96+
8997
if isinstance(inner, OPLSDA):
9098
check_is_fitted(inner)
99+
# Validate against the outer classifier first: it owns n_features_in_ and
100+
# feature_names_in_ for user-facing OPLSDA calls.
101+
X_checked = inner._validate_X_predict(X)
91102
# OPLSDA is a classifier wrapper; its latent space lives on the inner OPLS.
92103
base = inner.opls_
93104
elif isinstance(inner, OPLS):
105+
check_is_fitted(inner)
106+
X_checked = inner._validate_X_predict(X)
94107
base = inner
95108
else:
96109
raise TypeError(
@@ -103,13 +116,18 @@ def _unwrap_estimator_and_data(
103116
if not isinstance(base, OPLS):
104117
raise TypeError("estimator.opls_ must be a fitted OPLS instance.")
105118
check_is_fitted(base)
106-
if sparse.issparse(X):
119+
if sparse.issparse(X_checked):
107120
raise TypeError(
108121
"Input to OPLS plotting is sparse, but plotting requires a dense "
109122
"matrix. If it came from a Pipeline, add a densifying transformer "
110123
"before the final OPLS step."
111124
)
112-
return base, check_array(X, dtype=np.float64, ensure_min_samples=ensure_min_samples)
125+
if X_checked.shape[0] < ensure_min_samples:
126+
raise ValueError(
127+
f"Found array with {X_checked.shape[0]} sample(s), while a minimum "
128+
f"of {ensure_min_samples} is required."
129+
)
130+
return base, X_checked
113131

114132

115133
class OPLSScoresDisplay:
@@ -237,7 +255,7 @@ def from_estimator(
237255

238256
# Project supplied data through the fitted filter before asking the PLS
239257
# engine for predictive scores.
240-
X_filtered, t_ortho = base._filter(X_trans)
258+
X_filtered, t_ortho = base._filter_validated(X_trans)
241259
scores = base.pls_.transform(X_filtered)
242260
if isinstance(scores, tuple):
243261
t_pred_arr = scores[0]
@@ -353,6 +371,7 @@ def from_estimator(
353371
X: ArrayLike,
354372
*,
355373
component: int = 0,
374+
x_space: str = "centered",
356375
ax: matplotlib.axes.Axes | None = None,
357376
) -> SPlotDisplay:
358377
"""Compute the S-plot arrays from a fitted ``estimator`` and plot them.
@@ -367,6 +386,11 @@ def from_estimator(
367386
Samples to project.
368387
component : int, default=0
369388
The index of the predictive PLS component to plot.
389+
x_space : {"centered", "scaled", "subset-centered"}, default="centered"
390+
Feature space used for the covariance/correlation axes. ``"centered"``
391+
uses original feature units centered by the fitted training mean,
392+
``"scaled"`` uses the model-scaled feature space, and
393+
``"subset-centered"`` centers the provided X subset by its own mean.
370394
ax : matplotlib Axes, default=None
371395
Target axes; a new figure/axes is created when ``None``.
372396
@@ -384,20 +408,42 @@ def from_estimator(
384408
f"component={component} is out of bounds for estimator with "
385409
f"{n_pred} predictive component(s)."
386410
)
411+
if x_space not in {"centered", "scaled", "subset-centered"}:
412+
raise ValueError(
413+
"x_space must be one of {'centered', 'scaled', 'subset-centered'}."
414+
)
415+
if X_trans.shape[0] != base.x_scores_.shape[0]:
416+
warnings.warn(
417+
"SPlotDisplay is usually intended for the training data. "
418+
"The provided X has a different number of samples than the fitted "
419+
"data; covariance and correlation will be computed on this subset.",
420+
UserWarning,
421+
stacklevel=2,
422+
)
387423

388-
# S-plots are computed in the final OPLS input space, after applying the
389-
# same scaling used at fit time and centering the provided sample subset.
390-
Xs = apply_scaling(X_trans, base.x_mean_, base.x_std_)
391-
Xs = Xs - Xs.mean(axis=0)
424+
# Scores always come from the fitted model preprocessing/filtering. The
425+
# S-plot axes can use original-unit or model-scaled feature space.
426+
if x_space == "centered":
427+
X_for_splot = X_trans - base.x_mean_
428+
elif x_space == "scaled":
429+
X_for_splot = apply_scaling(X_trans, base.x_mean_, base.x_std_)
430+
else:
431+
X_for_splot = X_trans - X_trans.mean(axis=0)
392432

393433
# Use the fitted predictive score for the selected component as the common
394434
# reference vector for both covariance and correlation.
395-
t = np.asarray(base.transform(X_trans))[:, component]
435+
X_filtered, _ = base._filter_validated(X_trans)
436+
scores = base.pls_.transform(X_filtered)
437+
if isinstance(scores, tuple):
438+
t_arr = scores[0]
439+
else:
440+
t_arr = scores
441+
t = np.asarray(t_arr)[:, component]
396442
t = t - t.mean()
397443
n = t.shape[0]
398444

399-
covariance = Xs.T @ t / max(n - 1, 1)
400-
x_std = Xs.std(axis=0, ddof=1)
445+
covariance = X_for_splot.T @ t / max(n - 1, 1)
446+
x_std = X_for_splot.std(axis=0, ddof=1)
401447
t_std = float(t.std(ddof=1))
402448
if t_std <= 1e-12:
403449
raise ValueError("Predictive score has zero variance; S-plot is undefined.")
@@ -409,8 +455,6 @@ def from_estimator(
409455
correlation[valid] = covariance[valid] / denom[valid]
410456

411457
if np.any(~valid):
412-
import warnings
413-
414458
warnings.warn(
415459
"Some features have zero variance; their S-plot correlations are NaN.",
416460
RuntimeWarning,

tests/test_opls_da.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,6 @@ def test_oplsda_dataframe_predict_has_no_feature_name_warning():
220220
clf.predict(X)
221221

222222
messages = [str(w.message) for w in record]
223-
print(messages)
224223
assert not any("feature names" in message for message in messages)
225224

226225

tests/test_plotting.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
from __future__ import annotations
44

5+
import warnings
6+
57
import pytest
68

79
# Skip this module if matplotlib is not installed
@@ -55,6 +57,40 @@ def test_scores_plot_classification():
5557
plt.close("all")
5658

5759

60+
def test_scores_display_oplsda_dataframe_validates_feature_names():
61+
pd = pytest.importorskip("pandas")
62+
X, y = _classification_data(n_features=4)
63+
df = pd.DataFrame(X, columns=["a", "b", "c", "d"])
64+
model = OPLSDA(n_components=1, n_orthogonal=1).fit(df, y)
65+
66+
with warnings.catch_warnings(record=True) as record:
67+
disp = OPLSScoresDisplay.from_estimator(model, df, y=y)
68+
69+
messages = [str(w.message) for w in record]
70+
assert not any("feature names" in message for message in messages)
71+
assert disp.t_predictive.shape == (df.shape[0],)
72+
73+
with pytest.raises(ValueError, match="feature names"):
74+
OPLSScoresDisplay.from_estimator(model, df[["b", "a", "c", "d"]], y=y)
75+
76+
plt.close("all")
77+
78+
79+
def test_scores_display_replays_filter_for_new_samples():
80+
X, y = _regression_data(seed=7)
81+
model = OPLS(n_components=1, n_orthogonal=2).fit(X[:35], y[:35])
82+
X_new = X[35:]
83+
84+
disp = OPLSScoresDisplay.from_estimator(model, X_new)
85+
86+
np.testing.assert_allclose(disp.t_predictive, model.transform(X_new)[:, 0])
87+
np.testing.assert_allclose(
88+
disp.t_orthogonal,
89+
model.transform_orthogonal(X_new)[:, 0],
90+
)
91+
plt.close("all")
92+
93+
5894
def test_s_plot_regression_and_classification():
5995
X, y = _regression_data()
6096
disp1 = SPlotDisplay.from_estimator(
@@ -172,6 +208,38 @@ def test_splot_display_nan_correlation():
172208
assert not np.isnan(disp.correlation[1:]).any()
173209

174210

211+
def test_splot_display_x_space_controls_covariance_axis():
212+
X, y = _regression_data(seed=8, n_features=4)
213+
scale = np.array([1.0, 3.0, 10.0, 30.0])
214+
X = X * scale
215+
model = OPLS(n_components=1, n_orthogonal=1, scale="standard").fit(X, y)
216+
217+
centered = SPlotDisplay.from_estimator(model, X, x_space="centered")
218+
scaled = SPlotDisplay.from_estimator(model, X, x_space="scaled")
219+
220+
assert not np.allclose(centered.covariance, scaled.covariance)
221+
assert centered.correlation.shape == scaled.correlation.shape
222+
plt.close("all")
223+
224+
225+
def test_splot_display_invalid_x_space_raises():
226+
X, y = _regression_data()
227+
model = OPLS().fit(X, y)
228+
229+
with pytest.raises(ValueError, match="x_space must be one of"):
230+
SPlotDisplay.from_estimator(model, X, x_space="raw")
231+
232+
233+
def test_splot_display_warns_for_non_training_subset():
234+
X, y = _regression_data(seed=9)
235+
model = OPLS(n_components=1, n_orthogonal=1).fit(X, y)
236+
237+
with pytest.warns(UserWarning, match="usually intended for the training data"):
238+
SPlotDisplay.from_estimator(model, X[:10])
239+
240+
plt.close("all")
241+
242+
175243
def test_plotting_pipeline_ending_in_opls():
176244
X, y = _regression_data()
177245
pipe = Pipeline(

0 commit comments

Comments
 (0)