Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: functional support for Array API #1861

Merged
merged 105 commits into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from 97 commits
Commits
Show all changes
105 commits
Select commit Hold shift + click to select a range
bec3ebb
BUG: fixing circular import in daal4py/sklearnex device_offloading
samir-nasibli May 13, 2024
9052a73
removed onedal4py sklearnex dependence
samir-nasibli May 15, 2024
b070c89
minor fix
samir-nasibli May 15, 2024
0fbc680
Merge branch 'intel:main' into fix/device_offload
samir-nasibli May 16, 2024
8a6c4da
Merge branch 'intel:main' into fix/device_offload
samir-nasibli Jun 1, 2024
89ac903
Merge branch 'intel:main' into fix/device_offload
samir-nasibli Jun 6, 2024
01d73da
minor update
samir-nasibli Jun 8, 2024
d4587aa
Merge branch 'intel:main' into fix/device_offload
samir-nasibli Jun 8, 2024
b5f8921
added daal4py.sklearn._config for exposing sklearnex settings
samir-nasibli Jun 8, 2024
46dea79
Merge branch 'intel:main' into fix/device_offload
samir-nasibli Jun 11, 2024
8ce81ea
removed daal4py device_offloading
samir-nasibli Jun 11, 2024
fc01bae
integrating changes of device offloading for sklearnex primitives/est…
samir-nasibli Jun 11, 2024
18e6599
minor fixes
samir-nasibli Jun 11, 2024
4318064
minor fix for daal4py/sklearn/linear_model/_coordinate_descent.py
samir-nasibli Jun 11, 2024
864c1a1
minor fix for daal4py/sklearn/linear_model/_linear.py
samir-nasibli Jun 11, 2024
4142bf1
fix for sklearnex/_device_offload.py
samir-nasibli Jun 11, 2024
25d87bc
fix for onedal._config
samir-nasibli Jun 11, 2024
778b88d
wrapping daal4py.sklearne Kmeans with onedal4py's support_usm_ndarray
samir-nasibli Jun 11, 2024
ecd731d
ENH: functional support for Array API
samir-nasibli Jun 12, 2024
f18070f
minor update for support_usm_ndarray decorator
samir-nasibli Jun 14, 2024
df682e6
update sklearnex/dispatcher.py
samir-nasibli Jun 14, 2024
23c84f4
fixed dispatcher
samir-nasibli Jun 14, 2024
0483c5c
fixed decorator name
samir-nasibli Jun 14, 2024
459e638
minor update for onedal/_device_offload.py
samir-nasibli Jun 15, 2024
75f9c4b
minor updates
samir-nasibli Jun 18, 2024
27b7a7c
update docstrings for onedal._config._get_config
samir-nasibli Jun 18, 2024
dec939e
reverted changes for LogReg refactoring
samir-nasibli Jun 19, 2024
38cc61e
using _get_config instead of _get_onedal_threadlocal_config
samir-nasibli Jun 19, 2024
cea22ac
minor fix of _get_config
samir-nasibli Jun 19, 2024
1bb3697
removed TODO, that covered by ticket
samir-nasibli Jun 20, 2024
9ec11b9
Merge branch 'intel:main' into enh/functional_array_api
samir-nasibli Jun 20, 2024
4179f22
added todo comments
samir-nasibli Jun 21, 2024
694d94b
moved out from _DataParallelInteropPolicy init import of DummySyclQueue
samir-nasibli Jun 20, 2024
8bf2585
removed outdated comment; will be covered on #1813
samir-nasibli Jun 20, 2024
e903fbb
removed TODO comment from ridge.py
samir-nasibli Jun 20, 2024
1491b38
Added ElasticNet, Lasso, Ridge into sklearnex patching map
samir-nasibli Jun 21, 2024
81de19e
Merge branch 'main' into enh/functional_array_api
samir-nasibli Jun 27, 2024
3dd5000
lint
samir-nasibli Jun 27, 2024
adc6e68
removed debug print
samir-nasibli Jun 27, 2024
652d014
enabled more array api test
samir-nasibli Jun 27, 2024
d9f2ef8
Merge branch 'main' into enh/functional_array_api
samir-nasibli Jul 5, 2024
85c279c
Merge branch 'intel:main' into enh/functional_array_api
samir-nasibli Jul 9, 2024
14c675d
Merge branch 'intel:main' into enh/functional_array_api
samir-nasibli Jul 11, 2024
669d68f
Merge branch 'intel:main' into enh/functional_array_api
samir-nasibli Jul 12, 2024
c6e158c
Merge branch 'intel:main' into enh/functional_array_api
samir-nasibli Jul 15, 2024
ad42750
created seperate array api module
samir-nasibli Jul 16, 2024
7bef6c4
currently disabled array_api for test_memory_leaks
samir-nasibli Jul 17, 2024
88e2a22
update _convert_to_dataframe
samir-nasibli Jul 17, 2024
383ce0d
Merge branch 'main' into enh/functional_array_api
samir-nasibli Jul 17, 2024
d8d0dc4
linting
samir-nasibli Jul 17, 2024
45b920f
update condition for _transfer_to_host
samir-nasibli Jul 18, 2024
84da80a
Merge branch 'intel:main' into enh/functional_array_api
samir-nasibli Jul 22, 2024
eda98a4
Merge branch 'intel:main' into enh/functional_array_api
samir-nasibli Jul 26, 2024
f74ff13
fixed for bs and ridge
samir-nasibli Jul 26, 2024
7e83e4a
Merge branch 'intel:main' into enh/functional_array_api
samir-nasibli Aug 5, 2024
4b86143
Merge branch 'intel:main' into enh/functional_array_api
samir-nasibli Aug 6, 2024
b8348ee
Merge branch 'intel:main' into enh/functional_array_api
samir-nasibli Aug 7, 2024
0300f72
update fallback when array_api_dispatch enabled
samir-nasibli Aug 8, 2024
5cae80b
refactor sklearnex get_namespace usage
samir-nasibli Aug 8, 2024
ac3aacb
updated array apiconditions for get_dataframes_and_queues
samir-nasibli Aug 8, 2024
80da711
fix import in test_memory_usage.py
samir-nasibli Aug 8, 2024
1de8fc9
first temp commit address py312 fails
samir-nasibli Aug 9, 2024
621dfc2
small tmp workaround for py312
samir-nasibli Aug 9, 2024
d2f4383
removed from_dlpack from the non-zero support logic
samir-nasibli Aug 10, 2024
16ddb69
fixing tests for incremenatal estimators
samir-nasibli Aug 11, 2024
9be33c2
using asarray instead of dlpack conversions
samir-nasibli Aug 12, 2024
e2fa37a
FIX: fixing spmd tests utilities for dpctl inputs
samir-nasibli Aug 12, 2024
016f550
Deselect LOF stability test with array api
samir-nasibli Aug 12, 2024
3e833d7
Revert "Deselect LOF stability test with array api"
samir-nasibli Aug 12, 2024
d0fc95d
MAINT: minor refactoring and docstrings added
samir-nasibli Aug 12, 2024
2ff418f
Merge branch 'intel:main' into enh/functional_array_api
samir-nasibli Aug 26, 2024
df1efc6
Merge branch 'intel:main' into enh/functional_array_api
samir-nasibli Aug 30, 2024
2094fad
refactoring
samir-nasibli Aug 30, 2024
3004faa
Merge branch 'main' into enh/functional_array_api
samir-nasibli Sep 2, 2024
0fcaafe
minor update for _extract_array_attr
samir-nasibli Sep 2, 2024
5521720
minor update
samir-nasibli Sep 2, 2024
7bd3df0
update conditions for support_array_api
samir-nasibli Sep 2, 2024
19df8f1
update dispatch for array api
samir-nasibli Sep 2, 2024
442826e
covered by tickets TODOs
samir-nasibli Sep 2, 2024
fc34a6c
Merge branch 'main' into enh/functional_array_api
samir-nasibli Sep 4, 2024
ec634c1
refactor _asarray
samir-nasibli Sep 4, 2024
844745f
update a bit logic
samir-nasibli Sep 4, 2024
752df22
addressed test failes
samir-nasibli Sep 4, 2024
fe38790
update docstring
samir-nasibli Sep 4, 2024
3dd9521
Merge branch 'intel:main' into enh/functional_array_api
samir-nasibli Sep 4, 2024
6989049
Update _array_api.py
samir-nasibli Sep 4, 2024
047d698
minor update
samir-nasibli Sep 4, 2024
c5f8281
minor updatte try
samir-nasibli Sep 4, 2024
1b45a24
renamed wrapper for inputs support
samir-nasibli Sep 4, 2024
ac7288f
minor refactoring
samir-nasibli Sep 5, 2024
d5f3a35
Merge branch 'main' into enh/functional_array_api
samir-nasibli Sep 5, 2024
3915066
fix refactoring
samir-nasibli Sep 5, 2024
95fa66f
addressed test fails
samir-nasibli Sep 5, 2024
8a9e497
remove unnecessary comments
samir-nasibli Sep 5, 2024
c66836f
Merge branch 'intel:main' into enh/functional_array_api
samir-nasibli Sep 5, 2024
3ac9c82
enabled transform_output check
samir-nasibli Sep 6, 2024
819aa8f
Merge branch 'intel:main' into enh/functional_array_api
samir-nasibli Sep 6, 2024
aa2bf82
update use of transform_output flag for input handlers
samir-nasibli Sep 6, 2024
3539ee3
reverted changes for test_incremental_pca.py
samir-nasibli Sep 6, 2024
ad71cd0
Revert "update use of transform_output flag for input handlers"
samir-nasibli Sep 6, 2024
4634c73
Merge branch 'main' into enh/functional_array_api
samir-nasibli Sep 6, 2024
41d1efc
minor refactoring
samir-nasibli Sep 9, 2024
52e9257
fixing
samir-nasibli Sep 9, 2024
dd78da4
Merge branch 'intel:main' into enh/functional_array_api
samir-nasibli Sep 10, 2024
4a44e12
minor refactoring
samir-nasibli Sep 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 32 additions & 24 deletions onedal/_device_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@
from functools import wraps

import numpy as np
from sklearn import get_config

from ._config import _get_config
from .utils._array_api import _asarray, _is_numpy_namespace

try:
from dpctl import SyclQueue
Expand All @@ -34,6 +36,8 @@
try:
import dpnp

from .utils._array_api import _convert_to_dpnp

dpnp_available = True
except ImportError:
dpnp_available = False
Expand Down Expand Up @@ -94,6 +98,9 @@ def _transfer_to_host(queue, *data):
host_data = []
for item in data:
usm_iface = getattr(item, "__sycl_usm_array_interface__", None)
array_api = getattr(item, "__array_namespace__", None)
if array_api:
array_api = array_api()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
array_api = getattr(item, "__array_namespace__", None)
if array_api:
array_api = array_api()
array_api = getattr(item, "__array_namespace__", print)()

Simple way of one-lining it and removes the branching. You could also replace print with lambda : None if you wanted. Ideally this should only be run if usm_iface is None.

Copy link
Contributor Author

@samir-nasibli samir-nasibli Sep 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This print actually cause some test fails, I can not apply suggestion. Generally the namespace getter will be updated and used here. This will be done as a follow up out of the scope current PR. Ticket is created.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you send me the fails, I am curious. The suggestion should only yield a self-contained return of a None in the case __array_namespace doesn't exit.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://github.com/intel/scikit-learn-intelex/actions/runs/10721335064/job/29729604965
Print there caused some test fails. The example of the job is shared.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

try array_api = getattr(item, "__array_namespace", lambda : None)() instead

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There were some issue with this as well. Let me check it once again

if usm_iface is not None:
if not dpctl_available:
raise RuntimeError(
Expand All @@ -120,6 +127,11 @@ def _transfer_to_host(queue, *data):
order=order,
)
has_usm_data = True
elif array_api and not _is_numpy_namespace(array_api):
# `copy`` param for the `asarray`` is not setted.
# The object is copied only if needed.
item = np.asarray(item)
has_host_data = True
else:
has_host_data = True

Expand Down Expand Up @@ -153,34 +165,17 @@ def _get_host_inputs(*args, **kwargs):
return q, hostargs, hostkwargs


def _extract_usm_iface(*args, **kwargs):
allargs = (*args, *kwargs.values())
if len(allargs) == 0:
return None
return getattr(allargs[0], "__sycl_usm_array_interface__", None)


def _run_on_device(func, obj=None, *args, **kwargs):
if obj is not None:
return func(obj, *args, **kwargs)
return func(*args, **kwargs)


if dpnp_available:

def _convert_to_dpnp(array):
if isinstance(array, usm_ndarray):
return dpnp.array(array, copy=False)
elif isinstance(array, Iterable):
for i in range(len(array)):
array[i] = _convert_to_dpnp(array[i])
return array


def support_usm_ndarray(freefunc=False, queue_param=True):
def support_input_format(freefunc=False, queue_param=True):
"""
Handles USMArray input. Puts SYCLQueue from data to decorated function arguments.
Converts output of decorated function to dpctl.tensor/dpnp.ndarray if input was of this type.
Converts and moves the output arrays of the decorated function
to match the input array type and device.
Puts SYCLQueue from data to decorated function arguments.

Parameters
----------
Expand All @@ -194,17 +189,30 @@ def support_usm_ndarray(freefunc=False, queue_param=True):

def decorator(func):
def wrapper_impl(obj, *args, **kwargs):
usm_iface = _extract_usm_iface(*args, **kwargs)
data = (*args, *kwargs.values())
samir-nasibli marked this conversation as resolved.
Show resolved Hide resolved
if len(data) == 0:
return _run_on_device(func, obj, *args, **kwargs)
samir-nasibli marked this conversation as resolved.
Show resolved Hide resolved
data_queue, hostargs, hostkwargs = _get_host_inputs(*args, **kwargs)
if queue_param and not (
"queue" in hostkwargs and hostkwargs["queue"] is not None
):
hostkwargs["queue"] = data_queue
result = _run_on_device(func, obj, *hostargs, **hostkwargs)
if usm_iface is not None and hasattr(result, "__array_interface__"):
usm_iface = getattr(data[0], "__sycl_usm_array_interface__", None)
if usm_iface is not None:
result = _copy_to_usm(data_queue, result)
if dpnp_available and len(args) > 0 and isinstance(args[0], dpnp.ndarray):
if dpnp_available and isinstance(args[0], dpnp.ndarray):
result = _convert_to_dpnp(result)
return result
config = get_config()
if not ("transform_output" in config and config["transform_output"]):
samir-nasibli marked this conversation as resolved.
Show resolved Hide resolved
input_array_api = getattr(data[0], "__array_namespace__", None)
if input_array_api:
input_array_api = input_array_api()
input_array_api_device = data[0].device
result = _asarray(
result, input_array_api, device=input_array_api_device
)
return result

if freefunc:
Expand Down
6 changes: 3 additions & 3 deletions onedal/spmd/basic_statistics/basic_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@

from onedal.basic_statistics import BasicStatistics as BasicStatistics_Batch

from ..._device_offload import support_usm_ndarray
from ..._device_offload import support_input_format
from .._base import BaseEstimatorSPMD


class BasicStatistics(BaseEstimatorSPMD, BasicStatistics_Batch):
@support_usm_ndarray()
@support_input_format()
def compute(self, data, weights=None, queue=None):
return super().compute(data, weights=weights, queue=queue)

@support_usm_ndarray()
@support_input_format()
def fit(self, data, sample_weight=None, queue=None):
return super().fit(data, sample_weight=sample_weight, queue=queue)
8 changes: 4 additions & 4 deletions onedal/spmd/cluster/kmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from onedal.cluster import KMeansInit as KMeansInit_Batch
from onedal.spmd.basic_statistics import BasicStatistics

from ..._device_offload import support_usm_ndarray
from ..._device_offload import support_input_format
from .._base import BaseEstimatorSPMD


Expand All @@ -37,15 +37,15 @@ def _get_basic_statistics_backend(self, result_options):
def _get_kmeans_init(self, cluster_count, seed, algorithm):
return KMeansInit(cluster_count=cluster_count, seed=seed, algorithm=algorithm)

@support_usm_ndarray()
@support_input_format()
def fit(self, X, y=None, queue=None):
return super().fit(X, queue=queue)

@support_usm_ndarray()
@support_input_format()
def predict(self, X, queue=None):
return super().predict(X, queue=queue)

@support_usm_ndarray()
@support_input_format()
def fit_predict(self, X, y=None, queue=None):
return super().fit_predict(X, queue=queue)

Expand Down
4 changes: 2 additions & 2 deletions onedal/spmd/covariance/covariance.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@

from onedal.covariance import EmpiricalCovariance as EmpiricalCovariance_Batch

from ..._device_offload import support_usm_ndarray
from ..._device_offload import support_input_format
from .._base import BaseEstimatorSPMD


class EmpiricalCovariance(BaseEstimatorSPMD, EmpiricalCovariance_Batch):
@support_usm_ndarray()
@support_input_format()
def fit(self, X, y=None, queue=None):
return super().fit(X, queue=queue)
4 changes: 2 additions & 2 deletions onedal/spmd/decomposition/pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@

from onedal.decomposition.pca import PCA as PCABatch

from ..._device_offload import support_usm_ndarray
from ..._device_offload import support_input_format
from .._base import BaseEstimatorSPMD


class PCA(BaseEstimatorSPMD, PCABatch):
@support_usm_ndarray()
@support_input_format()
def fit(self, X, y=None, queue=None):
return super().fit(X, queue=queue)
6 changes: 3 additions & 3 deletions onedal/spmd/linear_model/linear_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@

from onedal.linear_model import LinearRegression as LinearRegression_Batch

from ..._device_offload import support_usm_ndarray
from ..._device_offload import support_input_format
from .._base import BaseEstimatorSPMD


class LinearRegression(BaseEstimatorSPMD, LinearRegression_Batch):
@support_usm_ndarray()
@support_input_format()
def fit(self, X, y, queue=None):
return super().fit(X, y, queue=queue)

@support_usm_ndarray()
@support_input_format()
def predict(self, X, queue=None):
return super().predict(X, queue=queue)
10 changes: 5 additions & 5 deletions onedal/spmd/linear_model/logistic_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,23 @@

from onedal.linear_model import LogisticRegression as LogisticRegression_Batch

from ..._device_offload import support_usm_ndarray
from ..._device_offload import support_input_format
from .._base import BaseEstimatorSPMD


class LogisticRegression(BaseEstimatorSPMD, LogisticRegression_Batch):
@support_usm_ndarray()
@support_input_format()
def fit(self, X, y, queue=None):
return super().fit(X, y, queue=queue)

@support_usm_ndarray()
@support_input_format()
def predict(self, X, queue=None):
return super().predict(X, queue=queue)

@support_usm_ndarray()
@support_input_format()
def predict_proba(self, X, queue=None):
return super().predict_proba(X, queue=queue)

@support_usm_ndarray()
@support_input_format()
def predict_log_proba(self, X, queue=None):
return super().predict_log_proba(X, queue=queue)
20 changes: 10 additions & 10 deletions onedal/spmd/neighbors/neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,30 +17,30 @@
from onedal.neighbors import KNeighborsClassifier as KNeighborsClassifier_Batch
from onedal.neighbors import KNeighborsRegressor as KNeighborsRegressor_Batch

from ..._device_offload import support_usm_ndarray
from ..._device_offload import support_input_format
from .._base import BaseEstimatorSPMD


class KNeighborsClassifier(BaseEstimatorSPMD, KNeighborsClassifier_Batch):
@support_usm_ndarray()
@support_input_format()
def fit(self, X, y, queue=None):
return super().fit(X, y, queue=queue)

@support_usm_ndarray()
@support_input_format()
def predict(self, X, queue=None):
return super().predict(X, queue=queue)

@support_usm_ndarray()
@support_input_format()
def predict_proba(self, X, queue=None):
raise NotImplementedError("predict_proba not supported in distributed mode.")

@support_usm_ndarray()
@support_input_format()
def kneighbors(self, X=None, n_neighbors=None, return_distance=True, queue=None):
return super().kneighbors(X, n_neighbors, return_distance, queue=queue)


class KNeighborsRegressor(BaseEstimatorSPMD, KNeighborsRegressor_Batch):
@support_usm_ndarray()
@support_input_format()
def fit(self, X, y, queue=None):
if queue is not None and queue.sycl_device.is_gpu:
return super()._fit(X, y, queue=queue)
Expand All @@ -50,11 +50,11 @@ def fit(self, X, y, queue=None):
"CPU. Consider running on it on GPU."
)

@support_usm_ndarray()
@support_input_format()
def kneighbors(self, X=None, n_neighbors=None, return_distance=True, queue=None):
return super().kneighbors(X, n_neighbors, return_distance, queue=queue)

@support_usm_ndarray()
@support_input_format()
def predict(self, X, queue=None):
return self._predict_gpu(X, queue=queue)

Expand All @@ -66,10 +66,10 @@ def _get_onedal_params(self, X, y=None):


class NearestNeighbors(BaseEstimatorSPMD):
@support_usm_ndarray()
@support_input_format()
def fit(self, X, y, queue=None):
return super().fit(X, y, queue=queue)

@support_usm_ndarray()
@support_input_format()
def kneighbors(self, X=None, n_neighbors=None, return_distance=True, queue=None):
return super().kneighbors(X, n_neighbors, return_distance, queue=queue)
29 changes: 13 additions & 16 deletions onedal/tests/utils/_dataframes_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@

import pytest
import scipy.sparse as sp
from sklearn import get_config

from sklearnex import get_config

try:
import dpctl
import dpctl.tensor as dpt

dpctl_available = True
Expand All @@ -40,7 +40,6 @@
# GPU-no-copy.
import array_api_strict

# Run check if "array_api_dispatch" is configurable
array_api_enabled = lambda: get_config()["array_api_dispatch"]
array_api_enabled()
array_api_modules = {"array_api": array_api_strict}
Expand All @@ -58,7 +57,7 @@


def get_dataframes_and_queues(
dataframe_filter_="numpy,pandas,dpnp,dpctl", device_filter_="cpu,gpu"
dataframe_filter_="numpy,pandas,dpnp,dpctl,array_api", device_filter_="cpu,gpu"
):
"""Get supported dataframes for testing.

Expand Down Expand Up @@ -107,13 +106,18 @@ def get_df_and_q(dataframe: str):
dataframes_and_queues.extend(get_df_and_q("dpctl"))
if dpnp_available and "dpnp" in dataframe_filter_:
dataframes_and_queues.extend(get_df_and_q("dpnp"))
if "array_api" in dataframe_filter_ or array_api_enabled():
if (
"array_api" in dataframe_filter_
and "array_api" in array_api_modules
or array_api_enabled()
):
dataframes_and_queues.append(pytest.param("array_api", None, id="array_api"))

return dataframes_and_queues


def _as_numpy(obj, *args, **kwargs):
"""Converted input object to numpy.ndarray format."""
if dpnp_available and isinstance(obj, dpnp.ndarray):
return obj.asnumpy(*args, **kwargs)
if dpctl_available and isinstance(obj, dpt.usm_ndarray):
Expand Down Expand Up @@ -155,17 +159,10 @@ def _convert_to_dataframe(obj, sycl_queue=None, target_df=None, *args, **kwargs)
# DPCtl tensor.
return dpt.asarray(obj, usm_type="device", sycl_queue=sycl_queue, *args, **kwargs)
elif target_df in array_api_modules:
# use dpctl to define gpu devices via queues and
# move data to the device. This is necessary as
# the standard for defining devices is
# purposefully not defined in the array_api
# standard, but maintaining data on a device
# using the method `from_dlpack` is.
# Array API input other than DPNP ndarray, DPCtl tensor or
# Numpy ndarray.

xp = array_api_modules[target_df]
return xp.from_dlpack(
_convert_to_dataframe(
obj, sycl_queue=sycl_queue, target_df="dpctl", *args, **kwargs
)
)
return xp.asarray(obj)

raise RuntimeError("Unsupported dataframe conversion")
Loading
Loading