Skip to content

Commit

Permalink
linting
Browse files Browse the repository at this point in the history
  • Loading branch information
JochenSiegWork committed Aug 30, 2024
1 parent 2f14913 commit 07befa2
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 10 deletions.
12 changes: 6 additions & 6 deletions molpipeline/explainability/explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,12 +117,12 @@ def _convert_shap_feature_weights_to_atom_weights(


# pylint: disable=R0903
class AbstractExplainer(abc.ABC):
class AbstractSHAPExplainer(abc.ABC):
"""Abstract class for explainer objects."""

# pylint: disable=C0103,W0613
@abc.abstractmethod
def explain(self, X: Any, **kwargs: Any) -> list[Explanation]:
def explain(self, X: Any, **kwargs: Any) -> list[SHAPExplanation]:
"""Explain the predictions for the input data.
Parameters
Expand All @@ -134,13 +134,13 @@ def explain(self, X: Any, **kwargs: Any) -> list[Explanation]:
Returns
-------
list[Explanation]
list[Explanation] | list[SHAPExplanation]
List of explanations corresponding to the input samples.
"""


# pylint: disable=R0903
class SHAPTreeExplainer(AbstractExplainer):
class SHAPTreeExplainer(AbstractSHAPExplainer):
"""Class for SHAP's TreeExplainer wrapper."""

def __init__(self, pipeline: Pipeline, **kwargs: Any) -> None:
Expand Down Expand Up @@ -228,7 +228,7 @@ def explain(self, X: Any, **kwargs: Any) -> list[SHAPExplanation]:
Returns
-------
list[Explanation]
list[SHAPExplanation]
List of explanations corresponding to the input data.
"""
featurization_element = self.featurization_subpipeline.steps[-1][1] # type: ignore[union-attr]
Expand All @@ -242,7 +242,7 @@ def explain(self, X: Any, **kwargs: Any) -> list[SHAPExplanation]:
prediction = _get_predictions(self.pipeline, input_sample)
if not self._prediction_is_valid(prediction):
# we use the prediction to check if the input is valid. If not, we cannot explain it.
explanation_results.append(Explanation())
explanation_results.append(SHAPExplanation())
continue

if prediction.ndim > 1:
Expand Down
6 changes: 3 additions & 3 deletions molpipeline/explainability/visualization/heatmaps.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import numpy as np
import numpy.typing as npt
from matplotlib import cm, colors
from matplotlib import colors
from rdkit.Chem import Draw
from rdkit.Geometry.rdGeometry import Point2D

Expand Down Expand Up @@ -218,8 +218,8 @@ def map2color(
----------
c_map: colors.Colormap
Colormap to be used for mapping values to colors.
v_lim: Optional[Tuple[float, float]]
Limits for the colormap. If not given, the maximum absolute value of `self.values` is used as limit.
normalizer: colors.Normalize
Normalizer to be used for mapping values to colors.
Returns
-------
Expand Down
2 changes: 1 addition & 1 deletion molpipeline/explainability/visualization/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,7 @@ def structure_heatmap_shap(
f"$P(y=1|X) = {explanation.prediction[1]:.2f}$ ="
"\n"
"\n"
f" $expected \ value={explanation.expected_value[1]:.2f}$ + "
f" $expected \ value={explanation.expected_value[1]:.2f}$ + " # noqa: W605
f"$features_{{present}}= {sum_present_shap:.2f}$ + "
f"$features_{{absent}}={sum_absent_shap:.2f}$"
)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Test visualization methods for explanations."""

import unittest
from typing import ClassVar

import numpy as np
from rdkit import Chem
Expand Down Expand Up @@ -54,6 +55,8 @@ def _get_test_shap_explanations() -> list[SHAPExplanation]:
class TestExplainabilityVisualization(unittest.TestCase):
"""Test the public interface of the visualization methods for explanations."""

explanations: ClassVar[list[SHAPExplanation]]

@classmethod
def setUpClass(cls) -> None:
"""Set up the tests."""
Expand Down Expand Up @@ -90,6 +93,8 @@ def test_structure_heatmap_shap_explanation(self) -> None:
class TestSumOfGaussiansGrid(unittest.TestCase):
"""Test visualization methods for explanations."""

explanations: ClassVar[list[SHAPExplanation]]

@classmethod
def setUpClass(cls) -> None:
"""Set up the tests."""
Expand Down

0 comments on commit 07befa2

Please sign in to comment.