Skip to content

Commit de342d4

Browse files
committed
harden O2PLS: guard permutation_test against multi-output y, simplify _stack_columns, cache _ssq in orthogonal extraction, add all-scale predict and multi-output score tests
1 parent e5ab62d commit de342d4

3 files changed

Lines changed: 35 additions & 11 deletions

File tree

src/scikit_opls/_o2pls_core.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,8 @@ def _extract_one_orthogonal_component(
267267
# sequential filter that will later be replayed on new samples.
268268
score = X @ weight
269269
score_ssq = float(score @ score)
270-
block_ssq = max(_ssq(X), 1.0)
270+
x_ssq = _ssq(X)
271+
block_ssq = max(x_ssq, 1.0)
271272
if score_ssq <= tol * block_ssq:
272273
return None
273274

@@ -279,7 +280,7 @@ def _extract_one_orthogonal_component(
279280
if not np.all(np.isfinite(filtered)):
280281
return None
281282
# Refuse components that do not measurably reduce the block sum of squares.
282-
if _ssq(filtered) >= _ssq(X) - tol * max(_ssq(X), 1.0):
283+
if _ssq(filtered) >= x_ssq - tol * block_ssq:
283284
return None
284285

285286
return OrthogonalBlockComponent(
@@ -348,12 +349,12 @@ def _replay_orthogonal_filter(
348349

349350

350351
def _stack_columns(
351-
values: list[NDArray[np.float64]], n_rows: int, n_columns: int
352+
values: list[NDArray[np.float64]], n_rows: int
352353
) -> NDArray[np.float64]:
353354
"""Column-stack values or return a correctly shaped empty matrix."""
354355
if values:
355356
return np.column_stack(values)
356-
return np.zeros((n_rows, n_columns), dtype=np.float64)
357+
return np.zeros((n_rows, 0), dtype=np.float64)
357358

358359

359360
def _reconstruction_r2(
@@ -505,12 +506,12 @@ def o2pls_fit(
505506
B_T = _lstsq_map(T, U)
506507
B_U = _lstsq_map(U, T)
507508

508-
X_orth_weights = _stack_columns(x_weights, n_x_features, 0)
509-
X_orth_scores = _stack_columns(x_scores, n_samples, 0)
510-
X_orth_loadings = _stack_columns(x_loadings, n_x_features, 0)
511-
Y_orth_weights = _stack_columns(y_weights, n_y_features, 0)
512-
Y_orth_scores = _stack_columns(y_scores, n_samples, 0)
513-
Y_orth_loadings = _stack_columns(y_loadings, n_y_features, 0)
509+
X_orth_weights = _stack_columns(x_weights, n_x_features)
510+
X_orth_scores = _stack_columns(x_scores, n_samples)
511+
X_orth_loadings = _stack_columns(x_loadings, n_x_features)
512+
Y_orth_weights = _stack_columns(y_weights, n_y_features)
513+
Y_orth_scores = _stack_columns(y_scores, n_samples)
514+
Y_orth_loadings = _stack_columns(y_loadings, n_y_features)
514515

515516
# Reconstruct nominal joint, orthogonal and residual parts in the original
516517
# preprocessed coordinates for diagnostics and estimator attributes.

src/scikit_opls/validation.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,9 @@ def permutation_test(
180180
raise ValueError(f"n_permutations must be >= 1, got {n_permutations}")
181181

182182
X = check_array(X, dtype=np.float64)
183-
y = np.asarray(y, dtype=np.float64).ravel()
183+
y = np.asarray(y, dtype=np.float64)
184+
if y.ndim != 1:
185+
raise ValueError("permutation_test currently requires a 1D response.")
184186
check_consistent_length(X, y)
185187
if not np.all(np.isfinite(y)):
186188
raise ValueError("y must contain only finite values.")

tests/test_o2pls.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,13 @@ def test_score_matches_r2_score_for_1d_y():
125125
assert model.score(X, y) == pytest.approx(r2_score(y, model.predict(X)))
126126

127127

128+
def test_score_matches_r2_score_uniform_average_for_multi_output_y():
129+
X, Y = _make_o2pls_data(seed=7)
130+
model = O2PLS(n_components=2).fit(X, Y)
131+
132+
assert model.score(X, Y) == pytest.approx(r2_score(Y, model.predict(X)))
133+
134+
128135
@pytest.mark.parametrize(
129136
("method", "arg_name"),
130137
[
@@ -246,6 +253,20 @@ def test_o2pls_predict_unscales_y():
246253
assert_allclose(model.predict(X_raw), expected)
247254

248255

256+
@pytest.mark.parametrize("scale", ["none", "center", "pareto", "standard"])
257+
def test_o2pls_predict_matches_scaled_filtered_coefficient_path_all_scales(scale):
258+
X, y = _regression_data()
259+
rng = np.random.default_rng(0)
260+
Y = np.column_stack([y, y + 0.1 * rng.normal(size=y.shape)])
261+
X_raw = 10.0 + 3.0 * X
262+
Y_raw = -5.0 + 2.0 * Y
263+
model = O2PLS(scale=scale).fit(X_raw, Y_raw)
264+
expected = (
265+
model.filter_transform_x(X_raw) @ model.coef_filtered_
266+
) * model.y_std_ + model.y_mean_
267+
assert_allclose(model.predict(X_raw), expected)
268+
269+
249270
def test_o2pls_predict_x_unscales_x():
250271
X, y = _regression_data()
251272
rng = np.random.default_rng(0)

0 commit comments

Comments
 (0)