Skip to content

Commit

Permalink
[enhancement] Add guard for sklearn fallback in sklearnex testing (#1686
Browse files Browse the repository at this point in the history
)

* added conftest

* formatting

* change to CI

* add mark

* improved text

* thats a wrap for the night

* remove stderr from sklearnex logger during tests

* readd monkeypatch to central test list

* Update run_test.sh

* Update test_memory_usage.py

* Update test_memory_usage.py

* Update test_memory_usage.py

* Update test_parallel.py

* Update run_test.sh

* remove mark due to update in EmpiricalCovariance

* Update _forest.py

* Update conftest.py

* Update test_parallel.py

* Update conftest.py

* isort issues
  • Loading branch information
icfaust authored Feb 27, 2024
1 parent 3529b0d commit cfae3b8
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 1 deletion.
38 changes: 38 additions & 0 deletions sklearnex/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,49 @@
# limitations under the License.
# ==============================================================================

import io
import logging

import pytest

from sklearnex import patch_sklearn, unpatch_sklearn


def pytest_configure(config):
config.addinivalue_line(
"markers", "allow_sklearn_fallback: mark test to not check for sklearnex usage"
)


@pytest.hookimpl(hookwrapper=True)
def pytest_runtest_call(item):
# setup logger to check for sklearn fallback
if not item.get_closest_marker("allow_sklearn_fallback"):

log_stream = io.StringIO()
log_handler = logging.StreamHandler(log_stream)
sklearnex_logger = logging.getLogger("sklearnex")
level = sklearnex_logger.level
sklearnex_stderr_handler = sklearnex_logger.handlers
sklearnex_logger.handlers = []
sklearnex_logger.addHandler(log_handler)
sklearnex_logger.setLevel(logging.INFO)
log_handler.setLevel(logging.INFO)

yield

sklearnex_logger.handlers = sklearnex_stderr_handler
sklearnex_logger.setLevel(level)
sklearnex_logger.removeHandler(log_handler)
text = log_stream.getvalue()
if "fallback to original Scikit-learn" in text:
raise TypeError(
f"test did not properly evaluate sklearnex functionality and fell back to sklearn:\n{text}"
)
else:
yield


@pytest.fixture
def with_sklearnex():
patch_sklearn()
Expand Down
9 changes: 8 additions & 1 deletion sklearnex/tests/test_memory_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.
# ==============================================================================


import gc
import logging
import tracemalloc
Expand Down Expand Up @@ -75,6 +76,8 @@ def fit(self, x, y):


# add all daal4py estimators enabled in patching (except banned)


def get_patched_estimators(ban_list, output_list):
patched_estimators = get_patch_map().values()
for listing in patched_estimators:
Expand Down Expand Up @@ -153,6 +156,7 @@ def split_train_inference(kf, x, y, estimator):
y_train, y_test = y.iloc[train_index], y.iloc[test_index]
# TODO: add parameters for all estimators to prevent
# fallback to stock scikit-learn with default parameters

alg = estimator()
alg.fit(x_train, y_train)
if hasattr(alg, "predict"):
Expand All @@ -163,7 +167,6 @@ def split_train_inference(kf, x, y, estimator):
alg.kneighbors(x_test)
del alg, x_train, x_test, y_train, y_test
mem_tracks.append(tracemalloc.get_traced_memory()[0])

return mem_tracks


Expand Down Expand Up @@ -215,6 +218,10 @@ def _kfold_function_template(estimator, data_transform_function, data_shape):
)


# disable fallback check as logging impacts memory use


@pytest.mark.allow_sklearn_fallback
@pytest.mark.parametrize("data_transform_function", data_transforms)
@pytest.mark.parametrize("estimator", estimators)
@pytest.mark.parametrize("data_shape", data_shapes)
Expand Down

0 comments on commit cfae3b8

Please sign in to comment.