Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions optbinning/_sklearn_compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"""Compatibility with multiple sklearn versions."""

import functools

from packaging.version import parse, Version
import sklearn
from sklearn.utils import check_array as _check_array


@functools.lru_cache(maxsize=0)
def _sklearn_version() -> Version:
"""Return the version of sklearn as a version tuple.

This function is cached to avoid calling it multiple times.
"""
return parse(sklearn.__version__)


@functools.lru_cache(maxsize=0)
def _new_check_array_api():
"""Sklearn uses the new api after version 1.6.0+

https://github.com/scikit-learn/scikit-learn/issues/29262
"""
return _sklearn_version() >= parse("1.6.0")

def check_array(*args, **kwargs):
"""Wrapper around check array to preserve backwards compatibility.

https://github.com/scikit-learn/scikit-learn/issues/29262
"""
if _new_check_array_api() is False:
return _check_array(*args, **kwargs)

# Only replace if it's in the kwargs
finite = kwargs.pop("force_all_finite", None)
if finite is not None:
kwargs["ensure_all_finite"] = finite
return _check_array(*args, **kwargs)
2 changes: 1 addition & 1 deletion optbinning/binning/binning.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@

import numpy as np

from sklearn.utils import check_array

import json

from ..information import solver_statistics
from ..logging import Logger
from .._sklearn_compat import check_array
from .auto_monotonic import auto_monotonic
from .auto_monotonic import peak_valley_trend_change_heuristic
from .base import BaseOptimalBinning
Expand Down
2 changes: 1 addition & 1 deletion optbinning/binning/binning_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@
from joblib import Parallel, delayed, effective_n_jobs
from sklearn.base import BaseEstimator
from sklearn.exceptions import NotFittedError
from sklearn.utils import check_array
from sklearn.utils import check_consistent_length
from sklearn.utils.multiclass import type_of_target

from ..logging import Logger
from .._sklearn_compat import check_array
from .base import Base
from .binning import OptimalBinning
from .binning_process_information import print_binning_process_information
Expand Down
2 changes: 1 addition & 1 deletion optbinning/binning/continuous_binning.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
import time
import json

from sklearn.utils import check_array

import numpy as np

from ..information import solver_statistics
from ..logging import Logger
from .._sklearn_compat import check_array
from .auto_monotonic import auto_monotonic_continuous
from .auto_monotonic import peak_valley_trend_change_heuristic
from .binning import OptimalBinning
Expand Down
3 changes: 2 additions & 1 deletion optbinning/binning/mdlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
from scipy import special
from sklearn.base import BaseEstimator
from sklearn.exceptions import NotFittedError
from sklearn.utils import check_array

from .._sklearn_compat import check_array


def _check_parameters(min_samples_split, min_samples_leaf, max_candidates):
Expand Down
3 changes: 2 additions & 1 deletion optbinning/binning/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@

from scipy import special
from scipy import stats
from sklearn.utils import check_array
from sklearn.utils import check_consistent_length

from .._sklearn_compat import check_array


def _check_x_y(x, y):
x = check_array(x, ensure_2d=False, force_all_finite=True)
Expand Down
2 changes: 1 addition & 1 deletion optbinning/binning/multiclass_binning.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@

import numpy as np

from sklearn.utils import check_array

from ..information import solver_statistics
from ..logging import Logger
from .._sklearn_compat import check_array
from .auto_monotonic import auto_monotonic
from .auto_monotonic import peak_valley_trend_change_heuristic
from .binning import OptimalBinning
Expand Down
2 changes: 1 addition & 1 deletion optbinning/binning/multidimensional/preprocessing_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
import numpy as np
import pandas as pd

from sklearn.utils import check_array
from sklearn.utils import check_consistent_length

from ..._sklearn_compat import check_array
from ..preprocessing import categorical_transform


Expand Down
3 changes: 2 additions & 1 deletion optbinning/binning/multidimensional/transformations_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
import numpy as np
import pandas as pd

from sklearn.utils import check_array


from ..._sklearn_compat import check_array
from ..transformations import _check_metric_special_missing
from ..transformations import _check_show_digits
from .binning_statistics_2d import bin_categorical
Expand Down
2 changes: 1 addition & 1 deletion optbinning/binning/piecewise/transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
import numpy as np
import pandas as pd

from sklearn.utils import check_array

from ...binning.transformations import transform_event_rate_to_woe
from ...binning.transformations import _check_metric_special_missing
from ...binning.transformations import _mask_special_missing
from ..._sklearn_compat import check_array


def _apply_transform(x, c, lb, ub, special_codes, metric_special,
Expand Down
2 changes: 1 addition & 1 deletion optbinning/binning/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
import pandas as pd

from sklearn.preprocessing import LabelEncoder
from sklearn.utils import check_array
from sklearn.utils import check_consistent_length
from sklearn.utils import compute_class_weight
from sklearn.utils.validation import _check_sample_weight

from .._sklearn_compat import check_array
from .outlier import ModifiedZScoreDetector
from .outlier import RangeDetector
from .outlier import YQuantileDetector
Expand Down
3 changes: 1 addition & 2 deletions optbinning/binning/transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@
import numpy as np
import pandas as pd

from sklearn.utils import check_array

from .._sklearn_compat import check_array
from .binning_statistics import bin_categorical
from .binning_statistics import bin_str_format

Expand Down
2 changes: 1 addition & 1 deletion optbinning/binning/uncertainty/binning_scenarios.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@

import numpy as np

from sklearn.utils import check_array

from ...information import solver_statistics
from ...logging import Logger
from ...binning.preprocessing import split_data_scenarios
from ..._sklearn_compat import check_array
from ..binning import OptimalBinning
from ..binning_statistics import bin_info
from ..binning_statistics import BinningTable
Expand Down
3 changes: 2 additions & 1 deletion optbinning/scorecard/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@
import matplotlib.pyplot as plt

from sklearn.metrics import roc_curve, roc_auc_score
from sklearn.utils import check_array
from sklearn.utils import check_consistent_length

from .._sklearn_compat import check_array


def _check_arrays(y, y_pred):
y = check_array(y, ensure_2d=False, force_all_finite=True)
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def run(self):
'matplotlib',
'numpy>=1.16.1',
'ortools>=9.4,<9.12',
'packaging',
'pandas',
'ropwr>=1.0.0',
'scikit-learn>=1.0.2',
Expand Down
51 changes: 51 additions & 0 deletions tests/test_sklearn_compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""Tests for sklearn compat."""

from unittest import mock

import pytest

from optbinning import _sklearn_compat


@pytest.fixture()
def _cache_clear_version():
"""Clear the cache each time it's called so multiple tests can run."""
_sklearn_compat._sklearn_version.cache_clear()
yield
_sklearn_compat._sklearn_version.cache_clear()


@pytest.fixture()
def _cache_clear_new_check_array_api():
"""Clear the cache each time it's called so multiple tests can run."""
_sklearn_compat._new_check_array_api.cache_clear()
yield
_sklearn_compat._new_check_array_api.cache_clear()


@pytest.mark.parametrize(
("sklearn_version", "want"),
[
("1.0.2", {"force_all_finite": True}),
# post releases
("1.0.2.post", {"force_all_finite": True}),
# dev releases
("1.0.dev0", {"force_all_finite": True}),
("1.6.0", {"ensure_all_finite": True}),
("1.7.0", {"ensure_all_finite": True}),
# release candidate
("1.7.0rc1", {"ensure_all_finite": True}),
],
)
def test__check_array_ensure_finite_kwargs(
sklearn_version, want, _cache_clear_version, _cache_clear_new_check_array_api
):
with mock.patch.object(
_sklearn_compat.sklearn, "__version__", new=sklearn_version
), mock.patch.object(
_sklearn_compat, "_check_array", return_value=""
) as mock_check_array:
_sklearn_compat.check_array(force_all_finite=True)
kwargs = mock_check_array.call_args.kwargs
assert kwargs == want

Loading