Skip to content

Commit

Permalink
update test case for reset_all_parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
ganow committed Dec 21, 2023
1 parent c3babf4 commit 9458842
Showing 1 changed file with 41 additions and 1 deletion.
42 changes: 41 additions & 1 deletion tests/recon/torch/modules/test_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,56 @@

import torch
import torch.nn as nn
from torchvision.models import get_model

from bdpy.recon.torch.modules import generator as generator_module


class TestResetAllParameters(unittest.TestCase):
"""Tests for bdpy.recon.torch.modules.generator.reset_all_parameters."""
def setUp(self):
self.model_ids = [
"alexnet",
"efficientnet_b0",
"fasterrcnn_resnet50_fpn",
"inception_v3",
"resnet18",
"vgg11",
"vit_b_16",
]
# NOTE: The following modules are excluded from validation because they
# initialize their parameters as constants every time.
self.excluded_modules = [
nn.modules.batchnorm._BatchNorm,
nn.LayerNorm,
]

def _validate_module(self, module: nn.Module, module_copy: nn.Module, parent_name: str = ""):
if isinstance(module, tuple(self.excluded_modules)):
return
for (name_p1, p1), (_, p2) in zip(module.named_parameters(recurse=False), module_copy.named_parameters(recurse=False)):
# NOTE: skip parameters that are prbably not randomly initialized
if "weight" not in name_p1:
continue
self.assertFalse(
torch.equal(p1, p2),
msg=f"Parameter {parent_name}.{name_p1} does not change after reset_all_parameters."
)
for (name_m1, m1), (_, m2) in zip(module.named_children(), module_copy.named_children()):
self._validate_module(m1, m2, f"{parent_name}.{name_m1}")

def test_reset_all_parameters(self):
"""Test reset_all_parameters."""
pass
for model_id in self.model_ids:
model = get_model(model_id)
model_copy = copy.deepcopy(model)
for (name_p1, p1), (_, p2) in zip(model.named_parameters(), model_copy.named_parameters()):
self.assertTrue(
torch.equal(p1, p2),
msg=f"Parameter {name_p1} of {model_id} has been changed by deepcopy."
)
model.apply(generator_module.reset_all_parameters)
self._validate_module(model, model_copy, model_id)


class TestBaseGenerator(unittest.TestCase):
Expand Down

0 comments on commit 9458842

Please sign in to comment.