diff --git a/optbinning/_sklearn_compat.py b/optbinning/_sklearn_compat.py new file mode 100644 index 0000000..ec0eb4f --- /dev/null +++ b/optbinning/_sklearn_compat.py @@ -0,0 +1,39 @@ +"""Compatibility with multiple sklearn versions.""" + +import functools + +from packaging.version import parse, Version +import sklearn +from sklearn.utils import check_array as _check_array + + +@functools.lru_cache(maxsize=0) +def _sklearn_version() -> Version: + """Return the version of sklearn as a version tuple. + + This function is cached to avoid calling it multiple times. + """ + return parse(sklearn.__version__) + + +@functools.lru_cache(maxsize=0) +def _new_check_array_api(): + """Sklearn uses the new api after version 1.6.0+ + + https://github.com/scikit-learn/scikit-learn/issues/29262 + """ + return _sklearn_version() >= parse("1.6.0") + +def check_array(*args, **kwargs): + """Wrapper around check array to preserve backwards compatibility. + + https://github.com/scikit-learn/scikit-learn/issues/29262 + """ + if _new_check_array_api() is False: + return _check_array(*args, **kwargs) + + # Only replace if it's in the kwargs + finite = kwargs.pop("force_all_finite", None) + if finite is not None: + kwargs["ensure_all_finite"] = finite + return _check_array(*args, **kwargs) diff --git a/optbinning/binning/binning.py b/optbinning/binning/binning.py index e2c1899..d9ed7c5 100644 --- a/optbinning/binning/binning.py +++ b/optbinning/binning/binning.py @@ -10,12 +10,12 @@ import numpy as np -from sklearn.utils import check_array import json from ..information import solver_statistics from ..logging import Logger +from .._sklearn_compat import check_array from .auto_monotonic import auto_monotonic from .auto_monotonic import peak_valley_trend_change_heuristic from .base import BaseOptimalBinning diff --git a/optbinning/binning/binning_process.py b/optbinning/binning/binning_process.py index 4c53022..b144ce3 100644 --- a/optbinning/binning/binning_process.py +++ b/optbinning/binning/binning_process.py @@ -17,11 +17,11 @@ from joblib import Parallel, delayed, effective_n_jobs from sklearn.base import BaseEstimator from sklearn.exceptions import NotFittedError -from sklearn.utils import check_array from sklearn.utils import check_consistent_length from sklearn.utils.multiclass import type_of_target from ..logging import Logger +from .._sklearn_compat import check_array from .base import Base from .binning import OptimalBinning from .binning_process_information import print_binning_process_information diff --git a/optbinning/binning/continuous_binning.py b/optbinning/binning/continuous_binning.py index 0b6d3f3..cfe2e0b 100644 --- a/optbinning/binning/continuous_binning.py +++ b/optbinning/binning/continuous_binning.py @@ -9,12 +9,12 @@ import time import json -from sklearn.utils import check_array import numpy as np from ..information import solver_statistics from ..logging import Logger +from .._sklearn_compat import check_array from .auto_monotonic import auto_monotonic_continuous from .auto_monotonic import peak_valley_trend_change_heuristic from .binning import OptimalBinning diff --git a/optbinning/binning/mdlp.py b/optbinning/binning/mdlp.py index 835ec7b..f43e242 100644 --- a/optbinning/binning/mdlp.py +++ b/optbinning/binning/mdlp.py @@ -12,7 +12,8 @@ from scipy import special from sklearn.base import BaseEstimator from sklearn.exceptions import NotFittedError -from sklearn.utils import check_array + +from .._sklearn_compat import check_array def _check_parameters(min_samples_split, min_samples_leaf, max_candidates): diff --git a/optbinning/binning/metrics.py b/optbinning/binning/metrics.py index e0220ce..58451f8 100644 --- a/optbinning/binning/metrics.py +++ b/optbinning/binning/metrics.py @@ -9,9 +9,10 @@ from scipy import special from scipy import stats -from sklearn.utils import check_array from sklearn.utils import check_consistent_length +from .._sklearn_compat import check_array + def _check_x_y(x, y): x = check_array(x, ensure_2d=False, force_all_finite=True) diff --git a/optbinning/binning/multiclass_binning.py b/optbinning/binning/multiclass_binning.py index cc80add..2c9f090 100644 --- a/optbinning/binning/multiclass_binning.py +++ b/optbinning/binning/multiclass_binning.py @@ -11,10 +11,10 @@ import numpy as np -from sklearn.utils import check_array from ..information import solver_statistics from ..logging import Logger +from .._sklearn_compat import check_array from .auto_monotonic import auto_monotonic from .auto_monotonic import peak_valley_trend_change_heuristic from .binning import OptimalBinning diff --git a/optbinning/binning/multidimensional/preprocessing_2d.py b/optbinning/binning/multidimensional/preprocessing_2d.py index 11d4b74..de5f50c 100644 --- a/optbinning/binning/multidimensional/preprocessing_2d.py +++ b/optbinning/binning/multidimensional/preprocessing_2d.py @@ -8,9 +8,9 @@ import numpy as np import pandas as pd -from sklearn.utils import check_array from sklearn.utils import check_consistent_length +from ..._sklearn_compat import check_array from ..preprocessing import categorical_transform diff --git a/optbinning/binning/multidimensional/transformations_2d.py b/optbinning/binning/multidimensional/transformations_2d.py index c3ff7c4..d4b499a 100644 --- a/optbinning/binning/multidimensional/transformations_2d.py +++ b/optbinning/binning/multidimensional/transformations_2d.py @@ -8,8 +8,9 @@ import numpy as np import pandas as pd -from sklearn.utils import check_array + +from ..._sklearn_compat import check_array from ..transformations import _check_metric_special_missing from ..transformations import _check_show_digits from .binning_statistics_2d import bin_categorical diff --git a/optbinning/binning/piecewise/transformations.py b/optbinning/binning/piecewise/transformations.py index 9ead4b1..1ed06de 100644 --- a/optbinning/binning/piecewise/transformations.py +++ b/optbinning/binning/piecewise/transformations.py @@ -8,11 +8,11 @@ import numpy as np import pandas as pd -from sklearn.utils import check_array from ...binning.transformations import transform_event_rate_to_woe from ...binning.transformations import _check_metric_special_missing from ...binning.transformations import _mask_special_missing +from ..._sklearn_compat import check_array def _apply_transform(x, c, lb, ub, special_codes, metric_special, diff --git a/optbinning/binning/preprocessing.py b/optbinning/binning/preprocessing.py index fc299dc..7c12f34 100644 --- a/optbinning/binning/preprocessing.py +++ b/optbinning/binning/preprocessing.py @@ -11,11 +11,11 @@ import pandas as pd from sklearn.preprocessing import LabelEncoder -from sklearn.utils import check_array from sklearn.utils import check_consistent_length from sklearn.utils import compute_class_weight from sklearn.utils.validation import _check_sample_weight +from .._sklearn_compat import check_array from .outlier import ModifiedZScoreDetector from .outlier import RangeDetector from .outlier import YQuantileDetector diff --git a/optbinning/binning/transformations.py b/optbinning/binning/transformations.py index 5c35d7d..59cac9c 100644 --- a/optbinning/binning/transformations.py +++ b/optbinning/binning/transformations.py @@ -10,8 +10,7 @@ import numpy as np import pandas as pd -from sklearn.utils import check_array - +from .._sklearn_compat import check_array from .binning_statistics import bin_categorical from .binning_statistics import bin_str_format diff --git a/optbinning/binning/uncertainty/binning_scenarios.py b/optbinning/binning/uncertainty/binning_scenarios.py index 8b5c8d3..9455805 100644 --- a/optbinning/binning/uncertainty/binning_scenarios.py +++ b/optbinning/binning/uncertainty/binning_scenarios.py @@ -11,11 +11,11 @@ import numpy as np -from sklearn.utils import check_array from ...information import solver_statistics from ...logging import Logger from ...binning.preprocessing import split_data_scenarios +from ..._sklearn_compat import check_array from ..binning import OptimalBinning from ..binning_statistics import bin_info from ..binning_statistics import BinningTable diff --git a/optbinning/scorecard/plots.py b/optbinning/scorecard/plots.py index 1873ac5..c766fcb 100644 --- a/optbinning/scorecard/plots.py +++ b/optbinning/scorecard/plots.py @@ -11,9 +11,10 @@ import matplotlib.pyplot as plt from sklearn.metrics import roc_curve, roc_auc_score -from sklearn.utils import check_array from sklearn.utils import check_consistent_length +from .._sklearn_compat import check_array + def _check_arrays(y, y_pred): y = check_array(y, ensure_2d=False, force_all_finite=True) diff --git a/setup.py b/setup.py index 3c90bbc..2acf3d4 100644 --- a/setup.py +++ b/setup.py @@ -36,6 +36,7 @@ def run(self): 'matplotlib', 'numpy>=1.16.1', 'ortools>=9.4,<9.12', + 'packaging', 'pandas', 'ropwr>=1.0.0', 'scikit-learn>=1.0.2', diff --git a/tests/test_sklearn_compat.py b/tests/test_sklearn_compat.py new file mode 100644 index 0000000..45c8315 --- /dev/null +++ b/tests/test_sklearn_compat.py @@ -0,0 +1,51 @@ +"""Tests for sklearn compat.""" + +from unittest import mock + +import pytest + +from optbinning import _sklearn_compat + + +@pytest.fixture() +def _cache_clear_version(): + """Clear the cache each time it's called so multiple tests can run.""" + _sklearn_compat._sklearn_version.cache_clear() + yield + _sklearn_compat._sklearn_version.cache_clear() + + +@pytest.fixture() +def _cache_clear_new_check_array_api(): + """Clear the cache each time it's called so multiple tests can run.""" + _sklearn_compat._new_check_array_api.cache_clear() + yield + _sklearn_compat._new_check_array_api.cache_clear() + + +@pytest.mark.parametrize( + ("sklearn_version", "want"), + [ + ("1.0.2", {"force_all_finite": True}), + # post releases + ("1.0.2.post", {"force_all_finite": True}), + # dev releases + ("1.0.dev0", {"force_all_finite": True}), + ("1.6.0", {"ensure_all_finite": True}), + ("1.7.0", {"ensure_all_finite": True}), + # release candidate + ("1.7.0rc1", {"ensure_all_finite": True}), + ], +) +def test__check_array_ensure_finite_kwargs( + sklearn_version, want, _cache_clear_version, _cache_clear_new_check_array_api +): + with mock.patch.object( + _sklearn_compat.sklearn, "__version__", new=sklearn_version + ), mock.patch.object( + _sklearn_compat, "_check_array", return_value="" + ) as mock_check_array: + _sklearn_compat.check_array(force_all_finite=True) + kwargs = mock_check_array.call_args.kwargs + assert kwargs == want +