Skip to content

Commit

Permalink
move vit functionality into derandomised smoothing toolset
Browse files Browse the repository at this point in the history
Signed-off-by: GiulioZizzo <[email protected]>
  • Loading branch information
GiulioZizzo committed Jun 27, 2023
1 parent 8f9ebdc commit b154580
Showing 1 changed file with 3 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def __init__(
:param device_type: Type of device on which the classifier is run, either `gpu` or `cpu`.
"""
import timm
import torch
from timm.models.vision_transformer import VisionTransformer
from art.estimators.certification.derandomized_smoothing.vision_transformers.vit import ArtViT

Expand Down Expand Up @@ -218,6 +219,7 @@ def get_models(cls, generate_from_null: bool = False) -> List[str]:
:return: A list of compatible models
"""
import timm
import torch

supported_models = [
"vit_base_patch8_224",
Expand Down Expand Up @@ -330,7 +332,7 @@ def update_batchnorm(self, x: np.ndarray, batch_size: int, nb_epochs: int = 1) -
:param batch_size: Size of batches.
:param nb_epochs: How many times to forward pass over the input data
"""

import torch
self.model.train()

ind = np.arange(len(x))
Expand Down

0 comments on commit b154580

Please sign in to comment.