Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MAINT: utility module for Intel data parallel libs; import checks in one place #1936

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open
13 changes: 3 additions & 10 deletions onedal/_device_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,25 +23,18 @@

from ._config import _get_config
from .utils._array_api import _asarray, _is_numpy_namespace
from .utils._dppy_available import dpctl_available, dpnp_available

try:
if dpctl_available:
from dpctl import SyclQueue
from dpctl.memory import MemoryUSMDevice, as_usm_memory
from dpctl.tensor import usm_ndarray

dpctl_available = True
except ImportError:
dpctl_available = False

try:
if dpnp_available:
import dpnp

from .utils._array_api import _convert_to_dpnp

dpnp_available = True
except ImportError:
dpnp_available = False


class DummySyclQueue:
"""This class is designed to act like dpctl.SyclQueue
Expand Down
9 changes: 4 additions & 5 deletions onedal/datatypes/_data_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,14 @@
from onedal import _backend, _is_dpc_backend

from ..utils import _is_csr
from ..utils._dppy_available import is_dpctl_available

try:
dpctl_available = is_dpctl_available("0.14")

if dpctl_available:
import dpctl
import dpctl.tensor as dpt

dpctl_available = dpctl.__version__ >= "0.14"
except ImportError:
dpctl_available = False


def _apply_and_pass(func, *args):
if len(args) == 1:
Expand Down
9 changes: 4 additions & 5 deletions onedal/datatypes/tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,14 @@
from onedal.datatypes import from_table, to_table
from onedal.primitives import linear_kernel
from onedal.tests.utils._device_selection import get_queues
from onedal.utils._dppy_available import is_dpctl_available

try:
dpctl_available = is_dpctl_available("0.14")

if dpctl_available:
import dpctl
import dpctl.tensor as dpt

dpctl_available = dpctl.__version__ >= "0.14"
except ImportError:
dpctl_available = False


def _test_input_format_c_contiguous_numpy(queue, dtype):
rng = np.random.RandomState(0)
Expand Down
14 changes: 4 additions & 10 deletions onedal/tests/utils/_dataframes_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,14 @@

from sklearnex import get_config

try:
import dpctl.tensor as dpt
from ...utils._dppy_available import dpctl_available, dpnp_available

dpctl_available = True
except ImportError:
dpctl_available = False
if dpctl_available:
import dpctl.tensor as dpt

try:
if dpnp_available:
import dpnp

dpnp_available = True
except ImportError:
dpnp_available = False

try:
# This should be lazy imported in the
# future along with other popular
Expand Down
9 changes: 4 additions & 5 deletions onedal/tests/utils/_device_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

import pytest

from ...utils._dppy_available import dpctl_available


def get_queues(filter_="cpu,gpu"):
"""Get available dpctl.SycQueues for testing.
Expand Down Expand Up @@ -62,9 +64,7 @@ def get_memory_usm():


def is_dpctl_available(targets=None):
try:
import dpctl

if dpctl_available:
samir-nasibli marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if dpctl_available check can be made an import-time check, rather than runtime.

if targets is None:
return True
for device in targets:
Expand All @@ -73,8 +73,7 @@ def is_dpctl_available(targets=None):
if device == "gpu" and not dpctl.has_gpu_devices():
return False
return True
except ImportError:
return False
return False


def device_type_to_str(queue):
Expand Down
16 changes: 3 additions & 13 deletions onedal/utils/_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,10 @@

from collections.abc import Iterable

try:
from dpctl.tensor import usm_ndarray

dpctl_available = True
except ImportError:
dpctl_available = False

try:
import dpnp

dpnp_available = True
except ImportError:
dpnp_available = False
from ._dppy_available import dpctl_available, dpnp_available

if dpctl_available:
from dpctl.tensor import usm_ndarray

if dpnp_available:
import dpnp
Expand Down
48 changes: 48 additions & 0 deletions onedal/utils/_dppy_available.py
samir-nasibli marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# ==============================================================================
# Copyright 2024 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Check availability of DPPY imports in one place"""


def is_dpctl_available(version=""):
"""Checks availability of DPCtl package"""
try:
import dpctl
import dpctl.tensor as dpt

dpctl_available = True
except ImportError:
dpctl_available = False
if dpctl_available and not version == "":
dpctl_available = dpctl.__version__ >= version
samir-nasibli marked this conversation as resolved.
Show resolved Hide resolved
return dpctl_available


def is_dpnp_available(version=""):
"""Checks availability of DPNP package"""
try:
import dpnp

dpnp_available = True
except ImportError:
dpnp_available = False
if dpnp_available and not version == "":
dpnp_available = dpnp.__version__ >= version
samir-nasibli marked this conversation as resolved.
Show resolved Hide resolved
return dpnp_available


dpctl_available = is_dpctl_available()
dpnp_available = is_dpnp_available()
2 changes: 1 addition & 1 deletion tests/run_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

from daal4py import __has_dist__
from daal4py.sklearn._utils import get_daal_version
from onedal._device_offload import dpctl_available
from onedal.utils._dppy_available import dpctl_available

print("Starting examples validation")
# First item is major version - 2021,
Expand Down
Loading