Skip to content

Commit

Permalink
isort + black
Browse files Browse the repository at this point in the history
  • Loading branch information
c-w-feldmann committed Apr 23, 2024
1 parent cdeba6d commit f39d951
Showing 1 changed file with 38 additions and 38 deletions.
76 changes: 38 additions & 38 deletions test_extras/test_chemprop/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,17 @@
import logging
import unittest

from chemprop.nn.loss import LossFunction, BCELoss, MSELoss
from chemprop.nn.loss import BCELoss, LossFunction, MSELoss
from lightning import pytorch as pl
from sklearn.base import clone
from torch import nn
from torch import Tensor
from torch import Tensor, nn

from molpipeline.estimators.chemprop.component_wrapper import (
MPNN,
BinaryClassificationFFN,
RegressionFFN,
BondMessagePassing,
MeanAggregation,
RegressionFFN,
SumAggregation,
)
from molpipeline.estimators.chemprop.models import (
Expand Down Expand Up @@ -48,40 +47,40 @@ def get_model() -> ChempropModel:


DEFAULT_PARAMS = {
"batch_size": 64,
"lightning_trainer": pl.Trainer,
"model": MPNN,
"model__agg__dim": 0,
"model__agg": SumAggregation,
"model__batch_norm": True,
"model__final_lr": 0.0001,
"model__init_lr": 0.0001,
"model__max_lr": 0.001,
"model__message_passing__activation": "relu",
"model__message_passing__bias": False,
"model__message_passing__d_e": 14,
"model__message_passing__d_h": 300,
"model__message_passing__d_v": 72,
"model__message_passing__d_vd": None,
"model__message_passing__depth": 3,
"model__message_passing__dropout_rate": 0.0,
"model__message_passing__undirected": False,
"model__message_passing": BondMessagePassing,
"model__metric_list": None,
"model__predictor__activation": "relu",
"model__warmup_epochs": 2,
"model__predictor": BinaryClassificationFFN,
"model__predictor__criterion": BCELoss,
"model__predictor__dropout": 0,
"model__predictor__hidden_dim": 300,
"model__predictor__input_dim": 300,
"model__predictor__n_layers": 1,
"model__predictor__n_tasks": 1,
"model__predictor__output_transform": nn.Identity,
"model__predictor__task_weights": Tensor([1.0]),
"model__predictor__threshold": None,
"n_jobs": 1,
}
"batch_size": 64,
"lightning_trainer": pl.Trainer,
"model": MPNN,
"model__agg__dim": 0,
"model__agg": SumAggregation,
"model__batch_norm": True,
"model__final_lr": 0.0001,
"model__init_lr": 0.0001,
"model__max_lr": 0.001,
"model__message_passing__activation": "relu",
"model__message_passing__bias": False,
"model__message_passing__d_e": 14,
"model__message_passing__d_h": 300,
"model__message_passing__d_v": 72,
"model__message_passing__d_vd": None,
"model__message_passing__depth": 3,
"model__message_passing__dropout_rate": 0.0,
"model__message_passing__undirected": False,
"model__message_passing": BondMessagePassing,
"model__metric_list": None,
"model__predictor__activation": "relu",
"model__warmup_epochs": 2,
"model__predictor": BinaryClassificationFFN,
"model__predictor__criterion": BCELoss,
"model__predictor__dropout": 0,
"model__predictor__hidden_dim": 300,
"model__predictor__input_dim": 300,
"model__predictor__n_layers": 1,
"model__predictor__n_tasks": 1,
"model__predictor__output_transform": nn.Identity,
"model__predictor__task_weights": Tensor([1.0]),
"model__predictor__threshold": None,
"n_jobs": 1,
}

NO_IDENTITY_CHECK = [
"model__agg",
Expand All @@ -93,6 +92,7 @@ def get_model() -> ChempropModel:
"model__predictor__output_transform",
]


class TestChempropModel(unittest.TestCase):
"""Test the Chemprop model."""

Expand Down

0 comments on commit f39d951

Please sign in to comment.