Skip to content

Commit

Permalink
fix mypy errors
Browse files Browse the repository at this point in the history
Signed-off-by: GiulioZizzo <[email protected]>
  • Loading branch information
GiulioZizzo committed Jun 27, 2023
1 parent 80ff747 commit 06178f4
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 292 deletions.
177 changes: 42 additions & 135 deletions art/estimators/certification/derandomized_smoothing/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@
if TYPE_CHECKING:
# pylint: disable=C0412
import torch

import torchvision
from timm.models.vision_transformer import VisionTransformer
from art.utils import CLIP_VALUES_TYPE, PREPROCESSING_TYPE
from art.defences.preprocessor import Preprocessor
from art.defences.postprocessor import Postprocessor
Expand Down Expand Up @@ -149,149 +150,39 @@ def _fit_classifier(self, x: np.ndarray, y: np.ndarray, batch_size: int, nb_epoc
x = x.astype(ART_NUMPY_DTYPE)
return PyTorchClassifier.fit(self, x, y, batch_size=batch_size, nb_epochs=nb_epochs, **kwargs)

def fit_old( # pylint: disable=W0221
self,
x: np.ndarray,
y: np.ndarray,
batch_size: int = 128,
nb_epochs: int = 10,
training_mode: bool = True,
drop_last: bool = False,
scheduler: Optional[Any] = None,
update_batchnorm: bool = True,
batchnorm_update_epochs: int = 1,
transform: Optional["torchvision.transforms.transforms.Compose"] = None,
verbose: bool = True,
**kwargs,
) -> None:
"""
Fit the classifier on the training set `(x, y)`.
:param x: Training data.
:param y: Target values (class labels) one-hot-encoded of shape (nb_samples, nb_classes) or index labels of
shape (nb_samples,).
:param batch_size: Size of batches.
:param nb_epochs: Number of epochs to use for training.
:param training_mode: `True` for model set to training mode and `'False` for model set to evaluation mode.
:param drop_last: Set to ``True`` to drop the last incomplete batch, if the dataset size is not divisible by
the batch size. If ``False`` and the size of dataset is not divisible by the batch size, then
the last batch will be smaller. (default: ``False``)
:param scheduler: Learning rate scheduler to run at the start of every epoch.
:param update_batchnorm: ViT Specific Arg. If to run the training data through the model to update any batch norm statistics prior
to training. Useful on small datasets when using pre-trained ViTs.
:param batchnorm_update_epochs: ViT Specific Arg. How many times to forward pass over the training data
to pre-adjust the batchnorm statistics.
:param transform: ViT Specific Arg. Torchvision compose of relevant augmentation transformations to apply.
:param verbose: if to display training progress bars
:param kwargs: Dictionary of framework-specific arguments. This parameter is not currently supported for PyTorch
and providing it takes no effect.
"""
import torch

# Check if we have a VIT

# Set model mode
self._model.train(mode=training_mode)

if self._optimizer is None: # pragma: no cover
raise ValueError("An optimizer is needed to train the model, but none for provided.")

y = check_and_transform_label_format(y, nb_classes=self.nb_classes)

# Apply preprocessing
x_preprocessed, y_preprocessed = self._apply_preprocessing(x, y, fit=True)

if update_batchnorm: # VIT specific
self.update_batchnorm(x_preprocessed, batch_size, nb_epochs=batchnorm_update_epochs)

# Check label shape
y_preprocessed = self.reduce_labels(y_preprocessed)

num_batch = len(x_preprocessed) / float(batch_size)
if drop_last:
num_batch = int(np.floor(num_batch))
else:
num_batch = int(np.ceil(num_batch))
ind = np.arange(len(x_preprocessed))

# Start training
for _ in tqdm(range(nb_epochs)):
# Shuffle the examples
random.shuffle(ind)

epoch_acc = []
epoch_loss = []
epoch_batch_sizes = []

pbar = tqdm(range(num_batch), disable=not verbose)

# Train for one epoch
for m in pbar:
i_batch = np.copy(x_preprocessed[ind[m * batch_size : (m + 1) * batch_size]])
i_batch = self.ablator.forward(i_batch)

if transform is not None: # VIT specific
i_batch = transform(i_batch)

i_batch = torch.from_numpy(i_batch).to(self._device)
o_batch = torch.from_numpy(y_preprocessed[ind[m * batch_size : (m + 1) * batch_size]]).to(self._device)

# Zero the parameter gradients
self._optimizer.zero_grad()

# Perform prediction
try:
model_outputs = self.model(i_batch)
except ValueError as err:
if "Expected more than 1 value per channel when training" in str(err):
logger.exception(
"Try dropping the last incomplete batch by setting drop_last=True in "
"method PyTorchClassifier.fit."
)
raise err

loss = self.loss(model_outputs, o_batch)
acc = self.get_accuracy(preds=model_outputs, labels=o_batch)

# Do training
if self._use_amp: # pragma: no cover
from apex import amp # pylint: disable=E0611

with amp.scale_loss(loss, self._optimizer) as scaled_loss:
scaled_loss.backward()

else:
loss.backward()

self.optimizer.step()

epoch_acc.append(acc)
epoch_loss.append(loss.cpu().detach().numpy())
epoch_batch_sizes.append(len(i_batch))

if verbose:
pbar.set_description(
f"Loss {np.average(epoch_loss, weights=epoch_batch_sizes):.3f} "
f"Acc {np.average(epoch_acc, weights=epoch_batch_sizes):.3f} "
)

if scheduler is not None:
scheduler.step()
class PyTorchDeRandomizedSmoothing(PyTorchDeRandomizedSmoothingCNN, PyTorchSmoothedViT):
"""
Interface class for the two De-randomized smoothing approaches supported by ART for pytorch.
If a regular pytorch neural network is fed in then (De)Randomized Smoothing as introduced in Levine et al. (2020)
is used.
class PyTorchDeRandomizedSmoothing(PyTorchDeRandomizedSmoothingCNN, PyTorchSmoothedViT):
Otherwise, if a timm vision transfomer is fed in then Certified Patch Robustness via Smoothed Vision Transformers
as introduced in Salman et al. (2021) is used.
"""
def __init__(self, model: Union[str, "VisionTransformer", "torch.nn.Module"], **kwargs):
import torch

if isinstance(model, torch.nn.Module):
PyTorchDeRandomizedSmoothingCNN.__init__(self, model, **kwargs)
self.mode = "CNN"
self.mode = None
if importlib.util.find_spec("timm") is not None:
from timm.models.vision_transformer import VisionTransformer

if isinstance(model, VisionTransformer) or isinstance(model, str):
if isinstance(model, (VisionTransformer, str)):
PyTorchSmoothedViT.__init__(self, model, **kwargs)
self.mode = "ViT"
else:
if isinstance(model, torch.nn.Module):
PyTorchDeRandomizedSmoothingCNN.__init__(self, model, **kwargs)
self.mode = "CNN"

elif isinstance(model, torch.nn.Module):
PyTorchDeRandomizedSmoothingCNN.__init__(self, model, **kwargs)
self.mode = "CNN"

if self.mode is None:
raise ValueError("Model type not recognized.")


def fit( # pylint: disable=W0221
self,
Expand Down Expand Up @@ -373,10 +264,11 @@ def fit( # pylint: disable=W0221
i_batch = np.copy(x_preprocessed[ind[m * batch_size : (m + 1) * batch_size]])
i_batch = self.ablator.forward(i_batch)

if transform is not None and self.mode == "ViT": # VIT specific
if transform is not None and self.mode == "ViT": # VIT specific
i_batch = transform(i_batch)

i_batch = torch.from_numpy(i_batch).to(self._device)
if isinstance(i_batch, np.ndarray):
i_batch = torch.from_numpy(i_batch).to(self._device)
o_batch = torch.from_numpy(y_preprocessed[ind[m * batch_size : (m + 1) * batch_size]]).to(self._device)

# Zero the parameter gradients
Expand Down Expand Up @@ -421,4 +313,19 @@ def fit( # pylint: disable=W0221
if scheduler is not None:
scheduler.step()

@staticmethod
def get_accuracy(preds: Union[np.ndarray, "torch.Tensor"], labels: Union[np.ndarray, "torch.Tensor"]) -> np.ndarray:
"""
Helper function to get the accuracy during training.
:param preds: model predictions.
:param labels: ground truth labels (not one hot).
:return: prediction accuracy.
"""
if not isinstance(preds, np.ndarray):
preds = preds.detach().cpu().numpy()

if not isinstance(labels, np.ndarray):
labels = labels.detach().cpu().numpy()

return np.sum(np.argmax(preds, axis=1) == labels) / len(labels)
Loading

0 comments on commit 06178f4

Please sign in to comment.