Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: DBSCAN via Array API #2014

Draft
wants to merge 114 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
114 commits
Select commit Hold shift + click to select a range
bec3ebb
BUG: fixing circular import in daal4py/sklearnex device_offloading
samir-nasibli May 13, 2024
9052a73
removed onedal4py sklearnex dependence
samir-nasibli May 15, 2024
b070c89
minor fix
samir-nasibli May 15, 2024
0fbc680
Merge branch 'intel:main' into fix/device_offload
samir-nasibli May 16, 2024
8a6c4da
Merge branch 'intel:main' into fix/device_offload
samir-nasibli Jun 1, 2024
89ac903
Merge branch 'intel:main' into fix/device_offload
samir-nasibli Jun 6, 2024
01d73da
minor update
samir-nasibli Jun 8, 2024
d4587aa
Merge branch 'intel:main' into fix/device_offload
samir-nasibli Jun 8, 2024
b5f8921
added daal4py.sklearn._config for exposing sklearnex settings
samir-nasibli Jun 8, 2024
46dea79
Merge branch 'intel:main' into fix/device_offload
samir-nasibli Jun 11, 2024
8ce81ea
removed daal4py device_offloading
samir-nasibli Jun 11, 2024
fc01bae
integrating changes of device offloading for sklearnex primitives/est…
samir-nasibli Jun 11, 2024
18e6599
minor fixes
samir-nasibli Jun 11, 2024
4318064
minor fix for daal4py/sklearn/linear_model/_coordinate_descent.py
samir-nasibli Jun 11, 2024
864c1a1
minor fix for daal4py/sklearn/linear_model/_linear.py
samir-nasibli Jun 11, 2024
4142bf1
fix for sklearnex/_device_offload.py
samir-nasibli Jun 11, 2024
25d87bc
fix for onedal._config
samir-nasibli Jun 11, 2024
778b88d
wrapping daal4py.sklearne Kmeans with onedal4py's support_usm_ndarray
samir-nasibli Jun 11, 2024
ecd731d
ENH: functional support for Array API
samir-nasibli Jun 12, 2024
f18070f
minor update for support_usm_ndarray decorator
samir-nasibli Jun 14, 2024
df682e6
update sklearnex/dispatcher.py
samir-nasibli Jun 14, 2024
23c84f4
fixed dispatcher
samir-nasibli Jun 14, 2024
0483c5c
fixed decorator name
samir-nasibli Jun 14, 2024
459e638
minor update for onedal/_device_offload.py
samir-nasibli Jun 15, 2024
75f9c4b
minor updates
samir-nasibli Jun 18, 2024
27b7a7c
update docstrings for onedal._config._get_config
samir-nasibli Jun 18, 2024
dec939e
reverted changes for LogReg refactoring
samir-nasibli Jun 19, 2024
38cc61e
using _get_config instead of _get_onedal_threadlocal_config
samir-nasibli Jun 19, 2024
cea22ac
minor fix of _get_config
samir-nasibli Jun 19, 2024
1bb3697
removed TODO, that covered by ticket
samir-nasibli Jun 20, 2024
9ec11b9
Merge branch 'intel:main' into enh/functional_array_api
samir-nasibli Jun 20, 2024
4179f22
added todo comments
samir-nasibli Jun 21, 2024
694d94b
moved out from _DataParallelInteropPolicy init import of DummySyclQueue
samir-nasibli Jun 20, 2024
8bf2585
removed outdated comment; will be covered on #1813
samir-nasibli Jun 20, 2024
e903fbb
removed TODO comment from ridge.py
samir-nasibli Jun 20, 2024
1491b38
Added ElasticNet, Lasso, Ridge into sklearnex patching map
samir-nasibli Jun 21, 2024
81de19e
Merge branch 'main' into enh/functional_array_api
samir-nasibli Jun 27, 2024
3dd5000
lint
samir-nasibli Jun 27, 2024
adc6e68
removed debug print
samir-nasibli Jun 27, 2024
652d014
enabled more array api test
samir-nasibli Jun 27, 2024
d9f2ef8
Merge branch 'main' into enh/functional_array_api
samir-nasibli Jul 5, 2024
85c279c
Merge branch 'intel:main' into enh/functional_array_api
samir-nasibli Jul 9, 2024
14c675d
Merge branch 'intel:main' into enh/functional_array_api
samir-nasibli Jul 11, 2024
669d68f
Merge branch 'intel:main' into enh/functional_array_api
samir-nasibli Jul 12, 2024
c6e158c
Merge branch 'intel:main' into enh/functional_array_api
samir-nasibli Jul 15, 2024
ad42750
created seperate array api module
samir-nasibli Jul 16, 2024
7bef6c4
currently disabled array_api for test_memory_leaks
samir-nasibli Jul 17, 2024
88e2a22
update _convert_to_dataframe
samir-nasibli Jul 17, 2024
383ce0d
Merge branch 'main' into enh/functional_array_api
samir-nasibli Jul 17, 2024
d8d0dc4
linting
samir-nasibli Jul 17, 2024
45b920f
update condition for _transfer_to_host
samir-nasibli Jul 18, 2024
84da80a
Merge branch 'intel:main' into enh/functional_array_api
samir-nasibli Jul 22, 2024
eda98a4
Merge branch 'intel:main' into enh/functional_array_api
samir-nasibli Jul 26, 2024
f74ff13
fixed for bs and ridge
samir-nasibli Jul 26, 2024
7e83e4a
Merge branch 'intel:main' into enh/functional_array_api
samir-nasibli Aug 5, 2024
4b86143
Merge branch 'intel:main' into enh/functional_array_api
samir-nasibli Aug 6, 2024
b8348ee
Merge branch 'intel:main' into enh/functional_array_api
samir-nasibli Aug 7, 2024
0300f72
update fallback when array_api_dispatch enabled
samir-nasibli Aug 8, 2024
5cae80b
refactor sklearnex get_namespace usage
samir-nasibli Aug 8, 2024
ac3aacb
updated array apiconditions for get_dataframes_and_queues
samir-nasibli Aug 8, 2024
80da711
fix import in test_memory_usage.py
samir-nasibli Aug 8, 2024
1de8fc9
first temp commit address py312 fails
samir-nasibli Aug 9, 2024
621dfc2
small tmp workaround for py312
samir-nasibli Aug 9, 2024
d2f4383
removed from_dlpack from the non-zero support logic
samir-nasibli Aug 10, 2024
16ddb69
fixing tests for incremenatal estimators
samir-nasibli Aug 11, 2024
9be33c2
using asarray instead of dlpack conversions
samir-nasibli Aug 12, 2024
e2fa37a
FIX: fixing spmd tests utilities for dpctl inputs
samir-nasibli Aug 12, 2024
016f550
Deselect LOF stability test with array api
samir-nasibli Aug 12, 2024
3e833d7
Revert "Deselect LOF stability test with array api"
samir-nasibli Aug 12, 2024
d0fc95d
MAINT: minor refactoring and docstrings added
samir-nasibli Aug 12, 2024
2ff418f
Merge branch 'intel:main' into enh/functional_array_api
samir-nasibli Aug 26, 2024
77b8231
ENH: DBSCAN via Array API
samir-nasibli Aug 27, 2024
df1efc6
Merge branch 'intel:main' into enh/functional_array_api
samir-nasibli Aug 30, 2024
2094fad
refactoring
samir-nasibli Aug 30, 2024
3004faa
Merge branch 'main' into enh/functional_array_api
samir-nasibli Sep 2, 2024
0fcaafe
minor update for _extract_array_attr
samir-nasibli Sep 2, 2024
5521720
minor update
samir-nasibli Sep 2, 2024
7bd3df0
update conditions for support_array_api
samir-nasibli Sep 2, 2024
19df8f1
update dispatch for array api
samir-nasibli Sep 2, 2024
442826e
covered by tickets TODOs
samir-nasibli Sep 2, 2024
fc34a6c
Merge branch 'main' into enh/functional_array_api
samir-nasibli Sep 4, 2024
ec634c1
refactor _asarray
samir-nasibli Sep 4, 2024
844745f
update a bit logic
samir-nasibli Sep 4, 2024
752df22
addressed test failes
samir-nasibli Sep 4, 2024
fe38790
update docstring
samir-nasibli Sep 4, 2024
3dd9521
Merge branch 'intel:main' into enh/functional_array_api
samir-nasibli Sep 4, 2024
6989049
Update _array_api.py
samir-nasibli Sep 4, 2024
047d698
minor update
samir-nasibli Sep 4, 2024
c5f8281
minor updatte try
samir-nasibli Sep 4, 2024
1b45a24
renamed wrapper for inputs support
samir-nasibli Sep 4, 2024
ac7288f
minor refactoring
samir-nasibli Sep 5, 2024
d5f3a35
Merge branch 'main' into enh/functional_array_api
samir-nasibli Sep 5, 2024
3915066
fix refactoring
samir-nasibli Sep 5, 2024
95fa66f
addressed test fails
samir-nasibli Sep 5, 2024
8a9e497
remove unnecessary comments
samir-nasibli Sep 5, 2024
c66836f
Merge branch 'intel:main' into enh/functional_array_api
samir-nasibli Sep 5, 2024
3ac9c82
enabled transform_output check
samir-nasibli Sep 6, 2024
819aa8f
Merge branch 'intel:main' into enh/functional_array_api
samir-nasibli Sep 6, 2024
aa2bf82
update use of transform_output flag for input handlers
samir-nasibli Sep 6, 2024
3539ee3
reverted changes for test_incremental_pca.py
samir-nasibli Sep 6, 2024
ad71cd0
Revert "update use of transform_output flag for input handlers"
samir-nasibli Sep 6, 2024
4634c73
Merge branch 'main' into enh/functional_array_api
samir-nasibli Sep 6, 2024
870b0be
ENH: DBSCAN via Array API
samir-nasibli Aug 27, 2024
288ae25
Merge branch 'enh/dbscan_array_api' of https://github.com/samir-nasib…
samir-nasibli Sep 7, 2024
9bcbd4e
backup local changes
samir-nasibli Sep 9, 2024
41d1efc
minor refactoring
samir-nasibli Sep 9, 2024
52e9257
fixing
samir-nasibli Sep 9, 2024
dd78da4
Merge branch 'intel:main' into enh/functional_array_api
samir-nasibli Sep 10, 2024
019ff56
Merge remote-tracking branch 'origin/enh/functional_array_api' into e…
samir-nasibli Sep 10, 2024
a09c3d4
minor fix
samir-nasibli Sep 10, 2024
8a43f8f
Merge branch 'main' into enh/dbscan_array_api
samir-nasibli Sep 10, 2024
be8fce8
Merge branch 'intel:main' into enh/dbscan_array_api
samir-nasibli Sep 24, 2024
d9b8cc7
Merge branch 'main' into enh/dbscan_array_api
samir-nasibli Oct 1, 2024
f95197f
added array-api-compat as test dep
samir-nasibli Oct 1, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 35 additions & 20 deletions onedal/cluster/dbscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@
# limitations under the License.
# ===============================================================================

import numpy as np
from sklearn.utils import check_array

from daal4py.sklearn._utils import get_dtype, make2d
from onedal.datatypes._data_conversion import get_dtype, make2d

from ..common._base import BaseEstimator
from ..common._mixin import ClusterMixin
from ..datatypes import _convert_to_supported, from_table, to_table
from ..utils import _check_array
from ..utils._array_api import get_namespace


class BaseDBSCAN(BaseEstimator, ClusterMixin):
Expand All @@ -46,38 +46,54 @@ def __init__(
self.p = p
self.n_jobs = n_jobs

def _get_onedal_params(self, dtype=np.float32):
def _get_onedal_params(self, xp, dtype):
return {
"fptype": "float" if dtype == np.float32 else "double",
"fptype": "float" if dtype == xp.float32 else "double",
"method": "by_default",
"min_observations": int(self.min_samples),
"epsilon": float(self.eps),
"mem_save_mode": False,
"result_options": "core_observation_indices|responses",
}

def _fit(self, X, y, sample_weight, module, queue):
def _fit(self, X, xp, is_array_api_compliant, y, sample_weight, queue):
policy = self._get_policy(queue, X)
X = _check_array(X, accept_sparse="csr", dtype=[np.float64, np.float32])
X = check_array(X, accept_sparse="csr", dtype=[xp.float64, xp.float32])
sample_weight = make2d(sample_weight) if sample_weight is not None else None
X = make2d(X)
if xp:
X_device = X.device

types = [np.float32, np.float64]
# TODO:
# revice once again the flow.
types = [xp.float32, xp.float64]
if get_dtype(X) not in types:
X = X.astype(np.float64)
X = _convert_to_supported(policy, X)
X = X.astype(xp.float64)
X = _convert_to_supported(policy, X, xp)
sample_weight = (
_convert_to_supported(policy, sample_weight, xp)
if sample_weight is not None
else None
)
dtype = get_dtype(X)
params = self._get_onedal_params(dtype)
result = module.compute(policy, params, to_table(X), to_table(sample_weight))
params = self._get_onedal_params(xp, dtype)
result = self._get_backend("dbscan", "clustering", None).compute(
policy, params, to_table(X, xp), to_table(sample_weight, xp)
)

self.labels_ = from_table(result.responses).ravel()
self.labels_ = from_table(
result.responses, xp=xp, queue=queue, array_api_device=X_device
).reshape(-1)
if result.core_observation_indices is not None:
self.core_sample_indices_ = from_table(
result.core_observation_indices
).ravel()
result.core_observation_indices,
xp=xp,
queue=queue,
array_api_device=X_device,
).reshape(-1)
else:
self.core_sample_indices_ = np.array([], dtype=np.intc)
self.components_ = np.take(X, self.core_sample_indices_, axis=0)
self.core_sample_indices_ = xp.array([], dtype=xp.int32)
self.components_ = xp.take(X, self.core_sample_indices_, axis=0)
self.n_features_in_ = X.shape[1]
return self

Expand Down Expand Up @@ -105,6 +121,5 @@ def __init__(
self.n_jobs = n_jobs

def fit(self, X, y=None, sample_weight=None, queue=None):
return super()._fit(
X, y, sample_weight, self._get_backend("dbscan", "clustering", None), queue
)
xp, is_array_api_compliant = get_namespace(X)
return super()._fit(X, xp, is_array_api_compliant, y, sample_weight, queue)
72 changes: 50 additions & 22 deletions onedal/datatypes/_data_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,48 +18,76 @@

import numpy as np

from daal4py.sklearn._utils import make2d
from daal4py.sklearn._utils import get_dtype
from daal4py.sklearn._utils import make2d as d4p_make2d
from onedal import _backend, _is_dpc_backend

from ..utils import _is_csr

try:
import dpctl
import dpctl.tensor as dpt
# from ..utils._array_api import get_namespace
from ..utils._array_api import _is_numpy_namespace

dpctl_available = dpctl.__version__ >= "0.14"
except ImportError:
dpctl_available = False

# TODO:
# move to proper module.
# TODO
# def make2d(arg, xp=None, is_array_api_compliant=None):
def make2d(arg, xp=None):
if xp and not _is_numpy_namespace(xp) and arg.ndim == 1:
return xp.reshape(arg, (arg.size, 1)) if arg.ndim == 1 else arg
# TODO:
# reimpl via is_array_api_compliant usage.
return d4p_make2d(arg)

def _apply_and_pass(func, *args):

# TODO:
# remove such kind of func calls
def _apply_and_pass(func, *args, **kwargs):
# TODO:
# refactor.
if len(args) == 1:
return func(args[0])
return tuple(map(func, args))

return tuple(map(func, args, kwargs))

def from_table(*args):
return _apply_and_pass(_backend.from_table, *args)

def convert_one_from_table(arg, xp=None, queue=None, array_api_device=None):
# TODO:
# use `array_api_device`.
result = _backend.from_table(arg)
if xp:
if xp.__name__ in {"dpctl", "dpctl.tensor"}:
result = xp.asarray(arg, sycl_queue=queue) if queue else xp.asarray(arg)
elif not _is_numpy_namespace(xp):
results = xp.asarray(result)
return result

def convert_one_to_table(arg):
if dpctl_available:
if isinstance(arg, dpt.usm_ndarray):
return _backend.dpctl_to_table(arg)

def convert_one_to_table(arg, xp=None):
if not _is_csr(arg):
if xp and not _is_numpy_namespace(xp):
arg = np.asarray(arg)
# TODO:
# Check. Probably should be removed from here
# Not really realted with converting to table.
arg = make2d(arg)
return _backend.to_table(arg)


def to_table(*args):
return _apply_and_pass(convert_one_to_table, *args)
def from_table(*args, xp=None, queue=None, array_api_device=None):
return _apply_and_pass(convert_one_from_table, *args)


def to_table(*args, xp=None):
return _apply_and_pass(convert_one_to_table, *args, xp=xp)


if _is_dpc_backend:
from ..common._policy import _HostInteropPolicy

def _convert_to_supported(policy, *data):
def _convert_to_supported(policy, *data, xp=None):
if xp is None:
xp = np

def func(x):
return x

Expand All @@ -71,13 +99,13 @@ def func(x):
device = policy._queue.sycl_device

def convert_or_pass(x):
if (x is not None) and (x.dtype == np.float64):
if (x is not None) and (x.dtype == xp.float64):
warnings.warn(
"Data will be converted into float32 from "
"float64 because device does not support it",
RuntimeWarning,
)
return x.astype(np.float32)
return x.astype(xp.float32)
else:
return x

Expand All @@ -88,7 +116,7 @@ def convert_or_pass(x):

else:

def _convert_to_supported(policy, *data):
def _convert_to_supported(policy, *data, xp=None):
def func(x):
return x

Expand Down
25 changes: 25 additions & 0 deletions onedal/utils/_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@

from collections.abc import Iterable

import numpy as np


try:
from dpctl.tensor import usm_ndarray

Expand Down Expand Up @@ -89,3 +92,25 @@ def _get_sycl_namespace(*arrays):
raise ValueError(f"SYCL type not recognized: {sycl_type}")

return sycl_type, None, False


def get_namespace(*arrays):
"""Get namespace of arrays.
TBD.
Parameters
----------
*arrays : array objects
Array objects.
Returns
-------
namespace : module
Namespace shared by array objects.
is_array_api : bool
True of the arrays are containers that implement the Array API spec.
"""
sycl_type, xp, is_array_api_compliant = _get_sycl_namespace(*arrays)

if sycl_type:
return xp, is_array_api_compliant
else:
return np, True
1 change: 1 addition & 0 deletions requirements-test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ lightgbm==4.5.0
catboost==1.2.7 ; python_version < '3.11' # TODO: Remove 3.11 condition when catboost supports numpy 2.0
shap==0.46.0
array-api-strict==2.0.1
array-api-compat==1.8.0
34 changes: 34 additions & 0 deletions sklearnex/_device_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,40 @@ def _get_backend(obj, queue, method_name, *data):
raise RuntimeError("Device support is not implemented")


# def dispatch_with_array_api(obj, method_name, branches, xp, is_array_api_compliant, *args, **kwargs):
def dispatch_with_array_api(
obj, method_name, branches, xp, is_array_api_compliant, *args, **kwargs
):
q = _get_global_queue()
# if "array_api_support_sklearnex" in obj._get_tags() and obj._get_tags()["array_api_support_sklearnex"]:

backend, q, patching_status = _get_backend(obj, q, method_name, *args)

if backend == "onedal":
patching_status.write_log(queue=q)
return branches[backend](obj, *args, **kwargs, queue=q)
if backend == "sklearn":
if (
"array_api_dispatch" in get_config()
and get_config()["array_api_dispatch"]
and "array_api_support" in obj._get_tags()
and obj._get_tags()["array_api_support"]
):
# If `array_api_dispatch` enabled and array api is supported for the stock scikit-learn,
# then raw inputs are used for the fallback.
patching_status.write_log()
return branches[backend](obj, *args, **kwargs)
else:
patching_status.write_log()
_, hostargs = _transfer_to_host(q, *args)
_, hostvalues = _transfer_to_host(q, *kwargs.values())
hostkwargs = dict(zip(kwargs.keys(), hostvalues))
return branches[backend](obj, *hostargs, **hostkwargs)
raise RuntimeError(
f"Undefined backend {backend} in " f"{obj.__class__.__name__}.{method_name}"
)


def dispatch(obj, method_name, branches, *args, **kwargs):
q = _get_global_queue()
q, hostargs = _transfer_to_host(q, *args)
Expand Down
15 changes: 12 additions & 3 deletions sklearnex/cluster/dbscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@
from daal4py.sklearn._utils import sklearn_check_version
from onedal.cluster import DBSCAN as onedal_DBSCAN

from .._device_offload import dispatch
from .._device_offload import dispatch, dispatch_with_array_api
from .._utils import PatchingConditionsChain
from ..utils._array_api import get_namespace

if sklearn_check_version("1.1") and not sklearn_check_version("1.2"):
from sklearn.utils import check_scalar
Expand Down Expand Up @@ -91,6 +92,7 @@ def __init__(
def _onedal_fit(self, X, y, sample_weight=None, queue=None):
if sklearn_check_version("1.0"):
X = validate_data(self, X, force_all_finite=False)
xp, is_array_api_compliant = get_namespace(X)

onedal_params = {
"eps": self.eps,
Expand All @@ -104,7 +106,9 @@ def _onedal_fit(self, X, y, sample_weight=None, queue=None):
}
self._onedal_estimator = self._onedal_dbscan(**onedal_params)

self._onedal_estimator.fit(X, y=y, sample_weight=sample_weight, queue=queue)
self._onedal_estimator._fit(
X, xp, is_array_api_compliant, y, sample_weight, queue=queue
)
self._save_attributes()

def _onedal_supported(self, method_name, *data):
Expand Down Expand Up @@ -178,9 +182,11 @@ def fit(self, X, y=None, sample_weight=None):
if self.eps <= 0.0:
raise ValueError(f"eps == {self.eps}, must be > 0.0.")

# TODO:
# should be checked for Array API inputs.
if sample_weight is not None:
sample_weight = _check_sample_weight(sample_weight, X)
dispatch(
dispatch_with_array_api(
self,
"fit",
{
Expand All @@ -194,4 +200,7 @@ def fit(self, X, y=None, sample_weight=None):

return self

def _more_tags(self):
return {"array_api_support_sklearnex": True}

fit.__doc__ = sklearn_DBSCAN.fit.__doc__
Loading
Loading