|
| 1 | +"""Synthetic two-block O2PLS example.""" |
| 2 | + |
| 3 | +from __future__ import annotations |
| 4 | + |
| 5 | +import numpy as np |
| 6 | +from sklearn.metrics import r2_score |
| 7 | + |
| 8 | +from scikit_opls import O2PLS |
| 9 | + |
| 10 | + |
| 11 | +def _make_blocks(seed: int = 0): |
| 12 | + """Generate two blocks with joint and block-specific latent structure.""" |
| 13 | + rng = np.random.default_rng(seed) |
| 14 | + n_samples = 100 |
| 15 | + n_joint = 2 |
| 16 | + latent, _ = np.linalg.qr(rng.normal(size=(n_samples, 4))) |
| 17 | + joint = latent[:, :n_joint] * np.array([5.0, 2.0]) |
| 18 | + x_specific = latent[:, 2:3] * 4.0 |
| 19 | + y_specific = latent[:, 3:4] * 3.0 |
| 20 | + |
| 21 | + x_basis, _ = np.linalg.qr(rng.normal(size=(12, 3))) |
| 22 | + y_basis, _ = np.linalg.qr(rng.normal(size=(8, 3))) |
| 23 | + x_joint = joint @ x_basis[:, :n_joint].T |
| 24 | + y_joint = joint @ y_basis[:, :n_joint].T |
| 25 | + X = x_joint + x_specific @ x_basis[:, 2:3].T |
| 26 | + Y = y_joint + y_specific @ y_basis[:, 2:3].T |
| 27 | + X += 0.01 * rng.normal(size=X.shape) |
| 28 | + Y += 0.01 * rng.normal(size=Y.shape) |
| 29 | + return X, Y, x_joint, y_joint |
| 30 | + |
| 31 | + |
| 32 | +def main() -> None: |
| 33 | + """Fit O2PLS and report joint-structure recovery on synthetic data.""" |
| 34 | + X, Y, x_joint, y_joint = _make_blocks() |
| 35 | + model = O2PLS( |
| 36 | + n_components=2, |
| 37 | + n_x_orthogonal=1, |
| 38 | + n_y_orthogonal=1, |
| 39 | + scale="center", |
| 40 | + ).fit(X, Y) |
| 41 | + |
| 42 | + print(f"X shape: {X.shape}") |
| 43 | + print(f"Y shape: {Y.shape}") |
| 44 | + print(f"X-orthogonal components: {model.n_x_orthogonal_}") |
| 45 | + print(f"Y-orthogonal components: {model.n_y_orthogonal_}") |
| 46 | + print(f"R2X joint / orthogonal: {model.r2x_:.3f} / {model.r2x_ortho_:.3f}") |
| 47 | + print(f"R2Y joint / orthogonal: {model.r2y_:.3f} / {model.r2y_ortho_:.3f}") |
| 48 | + print(f"Predicted joint Y R2: {r2_score(y_joint, model.predict(X)):.3f}") |
| 49 | + print(f"Predicted joint X R2: {r2_score(x_joint, model.predict_x(Y)):.3f}") |
| 50 | + |
| 51 | + |
| 52 | +if __name__ == "__main__": |
| 53 | + main() |
0 commit comments