diff --git a/art/estimators/certification/derandomized_smoothing/vision_transformers/pytorch.py b/art/estimators/certification/derandomized_smoothing/vision_transformers/pytorch.py index ffa604fc63..1675dfa2c2 100644 --- a/art/estimators/certification/derandomized_smoothing/vision_transformers/pytorch.py +++ b/art/estimators/certification/derandomized_smoothing/vision_transformers/pytorch.py @@ -106,6 +106,7 @@ def __init__( :param device_type: Type of device on which the classifier is run, either `gpu` or `cpu`. """ import timm + import torch from timm.models.vision_transformer import VisionTransformer from art.estimators.certification.derandomized_smoothing.vision_transformers.vit import ArtViT @@ -218,6 +219,7 @@ def get_models(cls, generate_from_null: bool = False) -> List[str]: :return: A list of compatible models """ import timm + import torch supported_models = [ "vit_base_patch8_224", @@ -330,7 +332,7 @@ def update_batchnorm(self, x: np.ndarray, batch_size: int, nb_epochs: int = 1) - :param batch_size: Size of batches. :param nb_epochs: How many times to forward pass over the input data """ - + import torch self.model.train() ind = np.arange(len(x))