Skip to content

Monkey-patch for scikit-learn with accelerated ArrayFire components.

License

Notifications You must be signed in to change notification settings

arrayfire/af-sklearn-monkeypatch

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

33 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

ArrayFire SKlearn MonkeyPatch

MonkeyPatch scikit-learn with ArrayFire accelerated variants. Tested classifiers match sklearn interface and pass sklearn tests. Currently targeting scikit-learn 0.22. Patching sklearn components can be done explicitly with the Patcher class as follows:

import sklearn
from sklearn.neural_network import MLPClassifier

from afsklearn.patcher import Patcher
Patcher.patch("mlp_classifier") # patches MLPClassifier with accelerated variant

clf = MLPClassifier(random_state=1, max_iter=300) # accelerated arrayfire classifier

Patcher.rollback("mlp_classifier")
# returns sklearn package to default state w/o arrayfire

Instead of manually replacing individual classifiers, all possible components can be replaced at once:

import sklearn
from afsklearn.patcher import Patcher

Patcher.patch_all() # patches scikit-learn with all accelerated classifiers
# sklearn functions here
Patcher.rollback_all() #returns sklearn package to default state w/o arrayfire

In the case that no code modification is desired, see installation instructions below for the Autowrapt package.

Installation

pip install -r requirements.txt

The Autowrapt package can be used to globally and automatically replace sklearn during python's import. After installing the linked autowrapt package, set the AUTOWRAPT_BOOTSTRAP=afsklearn environment variable to enable the import hooks.

If packaging the af-sklearn-monkeypatch into a docker image of a larger system such as AutoML(https://gitlab.com/sray/cmu-ta2) or AlphaD3M the following commands can be added to the Dockerfile to include the required dependencies and permannently enable the patch:

RUN pip3 install arrayfire==3.8.0+cu102 -f https://repo.arrayfire.com/python/wheels/3.8.0/

RUN pip3 install git+https://github.com/arrayfire/af-sklearn-monkeypatch && \
    pip3 install git+https://github.com/syurkevi/autowrapt.git

ENV AUTOWRAPT_BOOTSTRAP=afsklearn

Tests

To run all tests

pytest .

To run specific test

pytest tests/test_mlp.py

About

Monkey-patch for scikit-learn with accelerated ArrayFire components.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages