Skip to content

Commit 1c19393

Browse files
committed
simplify inspection
1 parent b877dae commit 1c19393

2 files changed

Lines changed: 74 additions & 112 deletions

File tree

src/scikit_opls/_inspection.py

Lines changed: 63 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,9 @@
1-
"""Internal stateless math for OPLS VIP scores and explained-variance metrics.
1+
"""Stateless math helpers for OPLS explained-variance and VIP diagnostics.
22
3-
Private module — not part of the public API. The VIP scores are exposed as lazy
4-
``vip_`` / ``ortho_vip_`` properties on :class:`~scikit_opls.OPLS` and
5-
:class:`~scikit_opls.OPLSDA`; these functions compute them from fitted weights.
6-
7-
VIP (Variable Importance in Projection) is defined in the style of Galindo-Prieto
8-
et al. (2014); these are not intended to reproduce ropls VIP values exactly:
9-
10-
- predictive VIP is the standard PLS VIP of the predictive model fitted on the
11-
orthogonally filtered X, weighting each component by the Y variance it explains;
12-
- orthogonal VIP is an X-variance-weighted score for the removed orthogonal
13-
components, weighting each component by the X variance it explains.
14-
15-
For non-empty blocks with positive explained variance, VIP is normalized so that
16-
sum(vip**2) == n_features. Empty or degenerate blocks return zeros.
3+
Private module — not part of the public API. Used by the fitted attributes of
4+
:class:`~scikit_opls.OPLS` and :class:`~scikit_opls.OPLSDA`. VIP scores are
5+
normalized so ``sum(vip**2) == n_features`` when component importance is
6+
positive; degenerate inputs return zeros.
177
"""
188

199
from __future__ import annotations
@@ -27,15 +17,17 @@
2717
def _safe_total_ss(X: NDArray[np.float64]) -> float:
2818
"""Total sum of squares with a nonzero guard."""
2919
total = float(np.sum(np.asarray(X, dtype=np.float64) ** 2))
30-
return max(total, np.finfo(np.float64).eps)
20+
return max(total, _EPS)
3121

3222

33-
def component_explained_x_variance(
23+
def _validate_x_scores_loadings(
3424
X: NDArray[np.float64],
3525
scores: NDArray[np.float64],
3626
loadings: NDArray[np.float64],
37-
) -> NDArray[np.float64]:
38-
"""Per-component ``SS(t_i @ p_i.T) / SS(X)`` for fitted arrays."""
27+
) -> tuple[NDArray[np.float64], NDArray[np.float64], NDArray[np.float64]]:
28+
X = np.asarray(X, dtype=np.float64)
29+
scores = np.asarray(scores, dtype=np.float64)
30+
loadings = np.asarray(loadings, dtype=np.float64)
3931
if X.ndim != 2 or scores.ndim != 2 or loadings.ndim != 2:
4032
raise ValueError("X, scores and loadings must all be 2D arrays.")
4133
if scores.shape[0] != X.shape[0]:
@@ -44,6 +36,22 @@ def component_explained_x_variance(
4436
raise ValueError("loadings must have one row per feature of X.")
4537
if scores.shape[1] != loadings.shape[1]:
4638
raise ValueError("scores and loadings must have the same number of components.")
39+
if not np.all(np.isfinite(X)):
40+
raise ValueError("X must contain only finite values.")
41+
if not np.all(np.isfinite(scores)):
42+
raise ValueError("scores must contain only finite values.")
43+
if not np.all(np.isfinite(loadings)):
44+
raise ValueError("loadings must contain only finite values.")
45+
return X, scores, loadings
46+
47+
48+
def component_explained_x_variance(
49+
X: NDArray[np.float64],
50+
scores: NDArray[np.float64],
51+
loadings: NDArray[np.float64],
52+
) -> NDArray[np.float64]:
53+
"""Per-component ``SS(t_i @ p_i.T) / SS(X)`` for fitted arrays."""
54+
X, scores, loadings = _validate_x_scores_loadings(X, scores, loadings)
4755
total = _safe_total_ss(X)
4856
out = np.empty(scores.shape[1], dtype=np.float64)
4957
for i in range(scores.shape[1]):
@@ -52,28 +60,6 @@ def component_explained_x_variance(
5260
return out
5361

5462

55-
def cumulative_r2_from_residuals(
56-
original: NDArray[np.float64],
57-
residuals_by_component: list[NDArray[np.float64]],
58-
) -> NDArray[np.float64]:
59-
"""Cumulative R² from a sequence of residual matrices."""
60-
total = _safe_total_ss(original)
61-
return np.asarray(
62-
[1.0 - float(np.sum(resid**2)) / total for resid in residuals_by_component],
63-
dtype=np.float64,
64-
)
65-
66-
67-
def component_r2_from_cumulative(
68-
cumulative: NDArray[np.float64],
69-
) -> NDArray[np.float64]:
70-
"""Convert cumulative R² to per-component increments."""
71-
cumulative = np.asarray(cumulative, dtype=np.float64)
72-
if cumulative.size == 0:
73-
return cumulative
74-
return np.diff(np.r_[0.0, cumulative])
75-
76-
7763
def component_r2y_from_scores(
7864
y: NDArray[np.float64],
7965
scores: NDArray[np.float64],
@@ -87,10 +73,28 @@ def component_r2y_from_scores(
8773
y_arr = np.asarray(y, dtype=np.float64)
8874
if y_arr.ndim == 1:
8975
y_arr = y_arr.reshape(-1, 1)
76+
if y_arr.ndim != 2:
77+
raise ValueError(f"y must be 1D or 2D, got shape {y_arr.shape}.")
9078
T = np.asarray(scores, dtype=np.float64)
79+
if T.ndim != 2:
80+
raise ValueError(f"scores must be 2D, got shape {T.shape}.")
9181
Q = np.asarray(y_loadings, dtype=np.float64)
9282
if Q.ndim == 1:
93-
Q = Q.reshape(-1, 1)
83+
# A 1D y_loadings is one value per component (single target), matching the
84+
# (n_targets, n_components) convention used elsewhere (predictive_vip).
85+
Q = Q.reshape(1, -1)
86+
elif Q.ndim != 2:
87+
raise ValueError(f"y_loadings must be 1D or 2D, got shape {Q.shape}.")
88+
if T.shape[0] != y_arr.shape[0]:
89+
raise ValueError("scores must have one row per sample of y.")
90+
if Q.shape[1] != T.shape[1]:
91+
raise ValueError("y_loadings must have one column per component.")
92+
if not np.all(np.isfinite(y_arr)):
93+
raise ValueError("y must contain only finite values.")
94+
if not np.all(np.isfinite(T)):
95+
raise ValueError("scores must contain only finite values.")
96+
if not np.all(np.isfinite(Q)):
97+
raise ValueError("y_loadings must contain only finite values.")
9498
total = _safe_total_ss(y_arr - y_arr.mean(axis=0, keepdims=True))
9599
out = np.empty(T.shape[1], dtype=np.float64)
96100
for i in range(T.shape[1]):
@@ -105,16 +109,9 @@ def explained_x_variance(
105109
loadings: NDArray[np.float64],
106110
) -> float:
107111
"""Nominal ``SS(T @ P.T) / SS(X)``; not clipped to ``[0, 1]``."""
108-
if X.ndim != 2 or scores.ndim != 2 or loadings.ndim != 2:
109-
raise ValueError("X, scores and loadings must all be 2D arrays.")
110-
if scores.shape[0] != X.shape[0]:
111-
raise ValueError("scores must have one row per sample of X.")
112-
if loadings.shape[0] != X.shape[1]:
113-
raise ValueError("loadings must have one row per feature of X.")
112+
X, scores, loadings = _validate_x_scores_loadings(X, scores, loadings)
114113
if scores.shape[1] == 0:
115114
return 0.0
116-
if scores.shape[1] != loadings.shape[1]:
117-
raise ValueError("scores and loadings must have the same number of components.")
118115
total = float(np.sum(X**2))
119116
if total <= 0.0:
120117
return 0.0
@@ -124,19 +121,9 @@ def explained_x_variance(
124121
def _weighted_vip(
125122
weights: NDArray[np.float64], ss_per_component: NDArray[np.float64]
126123
) -> NDArray[np.float64]:
127-
"""VIP from per-component weight vectors and their importance weights.
128-
129-
Parameters
130-
----------
131-
weights : ndarray of shape (n_features, n_components)
132-
Per-component weight vectors.
133-
ss_per_component : ndarray of shape (n_components,)
134-
Non-negative variance explained by each component.
124+
"""Return VIP scores from component weights and importance values.
135125
136-
Returns
137-
-------
138-
vip : ndarray of shape (n_features,)
139-
VIP scores; all-zero when there are no components or zero total variance.
126+
Zeros for empty components or zero total importance.
140127
"""
141128
if weights.ndim != 2:
142129
raise ValueError(f"weights must be 2D, got shape {weights.shape}.")
@@ -171,26 +158,15 @@ def predictive_vip(
171158
x_scores: NDArray[np.float64],
172159
y_loadings: NDArray[np.float64],
173160
) -> NDArray[np.float64]:
174-
"""Predictive VIP from the engine's weights/scores/Y-loadings.
175-
176-
Parameters
177-
----------
178-
x_weights : ndarray of shape (n_features, n_components)
179-
Predictive weight vectors.
180-
x_scores : ndarray of shape (n_samples, n_components)
181-
Predictive scores.
182-
y_loadings : ndarray of shape (n_components,) or (1, n_components)
183-
Y-loadings of the predictive components.
184-
185-
Returns
186-
-------
187-
vip : ndarray of shape (n_features,)
188-
Predictive VIP scores.
189-
"""
161+
"""Return predictive PLS VIP from weights, scores and Y-loadings."""
190162
if x_weights.ndim != 2:
191163
raise ValueError(f"x_weights must be 2D, got shape {x_weights.shape}.")
192164
if x_scores.ndim != 2:
193165
raise ValueError(f"x_scores must be 2D, got shape {x_scores.shape}.")
166+
if not np.all(np.isfinite(x_weights)):
167+
raise ValueError("x_weights must contain only finite values.")
168+
if not np.all(np.isfinite(x_scores)):
169+
raise ValueError("x_scores must contain only finite values.")
194170

195171
_, n_components = x_weights.shape
196172
if x_scores.shape[1] != n_components:
@@ -213,6 +189,8 @@ def predictive_vip(
213189
y_loadings_2d = y_loadings
214190
else:
215191
raise ValueError(f"y_loadings must be 1D or 2D, got shape {y_loadings.shape}.")
192+
if not np.all(np.isfinite(y_loadings_2d)):
193+
raise ValueError("y_loadings must contain only finite values.")
216194

217195
# Standard PLS VIP weights each component by the Y sum of squares explained by
218196
# that component: loading strength times score energy.
@@ -225,22 +203,7 @@ def orthogonal_vip(
225203
x_ortho_scores: NDArray[np.float64],
226204
x_ortho_loadings: NDArray[np.float64],
227205
) -> NDArray[np.float64]:
228-
"""Orthogonal VIP, each component weighted by the X variance it captures.
229-
230-
Parameters
231-
----------
232-
x_ortho_weights : ndarray of shape (n_features, n_orthogonal)
233-
Orthogonal weight vectors.
234-
x_ortho_scores : ndarray of shape (n_samples, n_orthogonal)
235-
Orthogonal scores.
236-
x_ortho_loadings : ndarray of shape (n_features, n_orthogonal)
237-
Orthogonal loadings.
238-
239-
Returns
240-
-------
241-
vip : ndarray of shape (n_features,)
242-
Orthogonal VIP scores.
243-
"""
206+
"""Return orthogonal VIP weighted by removed X variance."""
244207
if x_ortho_weights.ndim != 2:
245208
raise ValueError(
246209
f"x_ortho_weights must be 2D, got shape {x_ortho_weights.shape}."
@@ -253,6 +216,12 @@ def orthogonal_vip(
253216
raise ValueError(
254217
f"x_ortho_loadings must be 2D, got shape {x_ortho_loadings.shape}."
255218
)
219+
if not np.all(np.isfinite(x_ortho_weights)):
220+
raise ValueError("x_ortho_weights must contain only finite values.")
221+
if not np.all(np.isfinite(x_ortho_scores)):
222+
raise ValueError("x_ortho_scores must contain only finite values.")
223+
if not np.all(np.isfinite(x_ortho_loadings)):
224+
raise ValueError("x_ortho_loadings must contain only finite values.")
256225

257226
n_features, n_components = x_ortho_weights.shape
258227
if x_ortho_scores.shape[1] != n_components:

tests/test_diagnostics.py

Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ def test_pls_engine_exposes_internal_x_mean_for_reconstruction():
257257
X, y = _regression_data()
258258
model = OPLS(n_components=1, n_orthogonal=0, scale="none").fit(X, y)
259259
assert hasattr(model.pls_, "_x_mean")
260-
assert model.pls_._x_mean.shape == (X.shape[1],)
260+
assert getattr(model.pls_, "_x_mean").shape == (X.shape[1],)
261261

262262

263263
# ==============================================================================
@@ -335,22 +335,15 @@ def test_oplsda_diagnostics_expect_raw_x_not_prescaled_x():
335335
)
336336

337337

338-
def test_component_r2_from_cumulative_empty_and_differences():
339-
"""Verify component_r2_from_cumulative with empty and non-empty inputs."""
340-
from scikit_opls._inspection import component_r2_from_cumulative
338+
def test_component_r2y_from_scores_1d_y_loadings_matches_2d():
339+
"""A 1D y_loadings is one value per component, matching the 2D (1, n) form."""
340+
from scikit_opls._inspection import component_r2y_from_scores
341341

342-
assert component_r2_from_cumulative(np.array([])).shape == (0,)
343-
np.testing.assert_allclose(
344-
component_r2_from_cumulative(np.array([0.2, 0.5, 0.8])),
345-
[0.2, 0.3, 0.3],
346-
)
347-
348-
349-
def test_cumulative_r2_from_residuals_decreases_with_smaller_residuals():
350-
"""Verify cumulative_r2_from_residuals decreases as residuals decrease."""
351-
from scikit_opls._inspection import cumulative_r2_from_residuals
342+
rng = np.random.default_rng(0)
343+
y = rng.normal(size=10)
344+
T = rng.normal(size=(10, 3))
345+
q_1d = rng.normal(size=3)
352346

353-
X = np.eye(4)
354-
out = cumulative_r2_from_residuals(X, [0.5 * X, 0.1 * X])
355-
assert out[1] > out[0]
356-
assert np.all((0.0 <= out) & (out <= 1.0))
347+
out_1d = component_r2y_from_scores(y, T, q_1d)
348+
out_2d = component_r2y_from_scores(y, T, q_1d.reshape(1, -1))
349+
np.testing.assert_allclose(out_1d, out_2d)

0 commit comments

Comments
 (0)