Skip to content

Commit

Permalink
oneDAL LinReg and Covariance hyperparameters API (#1594)
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexsandruss authored Dec 7, 2023
1 parent d8ed189 commit b85de64
Show file tree
Hide file tree
Showing 12 changed files with 358 additions and 35 deletions.
58 changes: 58 additions & 0 deletions onedal/common/dispatch_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

#include <pybind11/pybind11.h>

#include "onedal/version.hpp"

#include "oneapi/dal/train.hpp"
#include "oneapi/dal/infer.hpp"
#include "oneapi/dal/compute.hpp"
Expand Down Expand Up @@ -68,6 +70,34 @@ struct compute_ops {
Ops ops;
};

#if defined(ONEDAL_VERSION) && ONEDAL_VERSION >= 20240000

template <typename Policy, typename Input, typename Ops, typename Hyperparams>
struct compute_ops_with_hyperparams {
using Task = typename Input::task_t;

compute_ops_with_hyperparams(
const Policy& policy, const Input& input,
const Ops& ops, const Hyperparams& hyperparams)
: policy(policy),
input(input),
ops(ops),
hyperparams(hyperparams) {}

template <typename Float, typename Method, typename... Args>
auto operator()(const pybind11::dict& params) {
auto desc = ops.template operator()<Float, Method, Task, Args...>(params);
return dal::compute(policy, desc, hyperparams, input);
}

Policy policy;
Input input;
Ops ops;
Hyperparams hyperparams;
};

#endif // defined(ONEDAL_VERSION) && ONEDAL_VERSION >= 20240000

template <typename Policy, typename Input, typename Ops>
struct train_ops {
using Task = typename Input::task_t;
Expand All @@ -88,6 +118,34 @@ struct train_ops {
Ops ops;
};

#if defined(ONEDAL_VERSION) && ONEDAL_VERSION >= 20240000

template <typename Policy, typename Input, typename Ops, typename Hyperparams>
struct train_ops_with_hyperparams {
using Task = typename Input::task_t;

train_ops_with_hyperparams(
const Policy& policy, const Input& input,
const Ops& ops, const Hyperparams& hyperparams)
: policy(policy),
input(input),
ops(ops),
hyperparams(hyperparams) {}

template <typename Float, typename Method, typename... Args>
auto operator()(const pybind11::dict& params) {
auto desc = ops.template operator()<Float, Method, Task, Args...>(params);
return dal::train(policy, desc, hyperparams, input);
}

Policy policy;
Input input;
Ops ops;
Hyperparams hyperparams;
};

#endif // defined(ONEDAL_VERSION) && ONEDAL_VERSION >= 20240000

template <typename Policy, typename Input, typename Ops>
struct infer_ops {
using Task = typename Input::task_t;
Expand Down
116 changes: 116 additions & 0 deletions onedal/common/hyperparameters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# ==============================================================================
# Copyright 2023 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import logging
from warnings import warn

from daal4py.sklearn._utils import daal_check_version
from onedal import _backend

if daal_check_version((2024, "P", 0)):
_hparams_reserved_words = [
"algorithm",
"op",
"setters",
"getters",
"backend",
"is_default",
"to_dict",
]

class HyperParameters:
"""Class for simplified interaction with oneDAL hyperparameters.
Overrides `__getattribute__` and `__setattr__` to utilize getters and setters
of hyperparameter class from onedal backend.
"""

def __init__(self, algorithm, op, setters, getters, backend):
self.algorithm = algorithm
self.op = op
self.setters = setters
self.getters = getters
self.backend = backend
self.is_default = True

def __getattribute__(self, __name):
if __name in _hparams_reserved_words:
if __name == "backend":
# `backend` attribute accessed only for oneDAL kernel calls
logging.getLogger("sklearnex").debug(
"Using next hyperparameters for "
f"'{self.algorithm}.{self.op}': {self.to_dict()}"
)
return super().__getattribute__(__name)
elif __name in self.getters.keys():
return self.getters[__name]()
else:
raise ValueError(
f"Unknown '{__name}' name in "
f"'{self.algorithm}.{self.op}' hyperparameters"
)

def __setattr__(self, __name, __value):
if __name in _hparams_reserved_words:
super().__setattr__(__name, __value)
elif __name in self.setters.keys():
self.is_default = False
self.setters[__name](__value)
else:
raise ValueError(
f"Unknown '{__name}' name in "
f"'{self.algorithm}.{self.op}' hyperparameters"
)

def to_dict(self):
return {name: getter() for name, getter in self.getters.items()}

def get_methods_with_prefix(obj, prefix):
return {
method.replace(prefix, ""): getattr(obj, method)
for method in filter(lambda f: f.startswith(prefix), dir(obj))
}

hyperparameters_backend = {
(
"linear_regression",
"train",
): _backend.linear_model.regression.train_hyperparameters(),
("covariance", "compute"): _backend.covariance.compute_hyperparameters(),
}
hyperparameters_map = {}

for (algorithm, op), hyperparameters in hyperparameters_backend.items():
setters = get_methods_with_prefix(hyperparameters, "set_")
getters = get_methods_with_prefix(hyperparameters, "get_")

if set(setters.keys()) != set(getters.keys()):
raise ValueError(
f"Setters and getters in '{algorithm}.{op}' "
"hyperparameters wrapper do not correspond."
)

hyperparameters_map[(algorithm, op)] = HyperParameters(
algorithm, op, setters, getters, hyperparameters
)

def get_hyperparameters(algorithm, op):
return hyperparameters_map[(algorithm, op)]

else:

def get_hyperparameters(algorithm, op):
warn("Hyperparameters are supported in oneDAL starting from 2024.0.0 version.")
return None
20 changes: 16 additions & 4 deletions onedal/decomposition/pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@

import numpy as np

from daal4py.sklearn._utils import sklearn_check_version
from daal4py.sklearn._utils import daal_check_version, sklearn_check_version
from onedal import _backend

from ..common._policy import _get_policy
from ..common.hyperparameters import get_hyperparameters
from ..datatypes import _convert_to_supported, from_table, to_table


Expand Down Expand Up @@ -54,9 +55,20 @@ def fit(self, X, queue):
X = _convert_to_supported(policy, X)

params = self.get_onedal_params(X)
cov_result = _backend.covariance.compute(
policy, {"fptype": params["fptype"], "method": "dense"}, to_table(X)
)
hparams = get_hyperparameters("covariance", "compute")
if hparams is not None and not hparams.is_default:
cov_result = _backend.covariance.compute(
policy,
{"fptype": params["fptype"], "method": "dense"},
hparams.backend,
to_table(X),
)
else:
cov_result = _backend.covariance.compute(
policy,
{"fptype": params["fptype"], "method": "dense"},
to_table(X),
)
covariance_matrix = from_table(cov_result.cov_matrix)
self.mean_ = from_table(cov_result.means)
result = _backend.decomposition.dim_reduction.train(
Expand Down
70 changes: 59 additions & 11 deletions onedal/linear_model/linear_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,17 +113,34 @@ template <typename Policy>
struct init_train_ops_dispatcher<Policy, linear_regression::task::regression> {
void operator()(py::module_& m) {
using Task = linear_regression::task::regression;
m.def("train",
[](const Policy& policy,
const py::dict& params,
const table& data,
const table& responses) {
using namespace dal::linear_regression;
using input_t = train_input<Task>;

train_ops ops(policy, input_t{ data, responses }, params2desc{});
return fptype2t{ method2t{ Task{}, ops } }(params);
});

#if defined(ONEDAL_VERSION) && ONEDAL_VERSION >= 20240000
using train_hyperparams_t = dal::linear_regression::detail::train_parameters<Task>;
m.def("train", [](
const Policy& policy,
const py::dict& params,
const train_hyperparams_t& hyperparams,
const table& data,
const table& responses) {
using namespace dal::linear_regression;
using input_t = train_input<Task>;
train_ops_with_hyperparams ops(
policy, input_t{ data, responses }, params2desc{}, hyperparams);
return fptype2t{ method2t{ Task{}, ops } }(params);
}
);
#endif // defined(ONEDAL_VERSION) && ONEDAL_VERSION >= 20240000
m.def("train", [](
const Policy& policy,
const py::dict& params,
const table& data,
const table& responses) {
using namespace dal::linear_regression;
using input_t = train_input<Task>;
train_ops ops(policy, input_t{ data, responses }, params2desc{});
return fptype2t{ method2t{ Task{}, ops } }(params);
}
);
}
};

Expand Down Expand Up @@ -188,11 +205,39 @@ void init_infer_result(py::module_& m) {
.DEF_ONEDAL_PY_PROPERTY(responses, result_t);
}

#if defined(ONEDAL_VERSION) && ONEDAL_VERSION >= 20240000

template <typename Task>
void init_train_hyperparameters(py::module_& m) {
using namespace dal::linear_regression::detail;
using train_hyperparams_t = train_parameters<Task>;

auto cls = py::class_<train_hyperparams_t>(m, "train_hyperparameters")
.def(py::init())
.def("set_cpu_macro_block", [](train_hyperparams_t& self, int64_t cpu_macro_block) {
self.set_cpu_macro_block(cpu_macro_block);
})
.def("set_gpu_macro_block", [](train_hyperparams_t& self, int64_t gpu_macro_block) {
self.set_gpu_macro_block(gpu_macro_block);
})
.def("get_cpu_macro_block", [](const train_hyperparams_t& self) {
return self.get_cpu_macro_block();
})
.def("get_gpu_macro_block", [](const train_hyperparams_t& self) {
return self.get_gpu_macro_block();
});
}

#endif // defined(ONEDAL_VERSION) && ONEDAL_VERSION >= 20240000

ONEDAL_PY_DECLARE_INSTANTIATOR(init_model);
ONEDAL_PY_DECLARE_INSTANTIATOR(init_train_result);
ONEDAL_PY_DECLARE_INSTANTIATOR(init_infer_result);
ONEDAL_PY_DECLARE_INSTANTIATOR(init_train_ops);
ONEDAL_PY_DECLARE_INSTANTIATOR(init_infer_ops);
#if defined(ONEDAL_VERSION) && ONEDAL_VERSION >= 20240000
ONEDAL_PY_DECLARE_INSTANTIATOR(init_train_hyperparameters);
#endif // defined(ONEDAL_VERSION) && ONEDAL_VERSION >= 20240000

} // namespace linear_model

Expand All @@ -215,6 +260,9 @@ ONEDAL_PY_INIT_MODULE(linear_model) {
ONEDAL_PY_INSTANTIATE(init_model, sub, task_list);
ONEDAL_PY_INSTANTIATE(init_train_result, sub, task_list);
ONEDAL_PY_INSTANTIATE(init_infer_result, sub, task_list);
#if defined(ONEDAL_VERSION) && ONEDAL_VERSION >= 20240000
ONEDAL_PY_INSTANTIATE(init_train_hyperparameters, sub, task_list);
#endif // defined(ONEDAL_VERSION) && ONEDAL_VERSION >= 20240000
}

ONEDAL_PY_TYPE2STR(dal::linear_regression::task::regression, "regression");
Expand Down
9 changes: 7 additions & 2 deletions onedal/linear_model/linear_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,13 @@
import numpy as np
from sklearn.base import BaseEstimator

from daal4py.sklearn._utils import get_dtype, make2d
from daal4py.sklearn._utils import daal_check_version, get_dtype, make2d
from onedal import _backend

from ..common._estimator_checks import _check_is_fitted
from ..common._mixin import RegressorMixin
from ..common._policy import _get_policy
from ..common.hyperparameters import get_hyperparameters
from ..datatypes import _convert_to_supported, from_table, to_table
from ..utils import _check_array, _check_n_features, _check_X_y, _num_features

Expand Down Expand Up @@ -72,7 +73,11 @@ def _fit(self, X, y, module, queue):
params = self._get_onedal_params(get_dtype(X_loc))
X_table, y_table = to_table(X_loc, y_loc)

result = module.train(policy, params, X_table, y_table)
hparams = get_hyperparameters("linear_regression", "train")
if hparams is not None and not hparams.is_default:
result = module.train(policy, params, hparams.backend, X_table, y_table)
else:
result = module.train(policy, params, X_table, y_table)

self._onedal_model = result.model

Expand Down
Loading

0 comments on commit b85de64

Please sign in to comment.