diff --git a/tests/recon/torch/modules/test_generator.py b/tests/recon/torch/modules/test_generator.py index 08187efd..d4f735d5 100644 --- a/tests/recon/torch/modules/test_generator.py +++ b/tests/recon/torch/modules/test_generator.py @@ -6,16 +6,56 @@ import torch import torch.nn as nn +from torchvision.models import get_model from bdpy.recon.torch.modules import generator as generator_module class TestResetAllParameters(unittest.TestCase): """Tests for bdpy.recon.torch.modules.generator.reset_all_parameters.""" + def setUp(self): + self.model_ids = [ + "alexnet", + "efficientnet_b0", + "fasterrcnn_resnet50_fpn", + "inception_v3", + "resnet18", + "vgg11", + "vit_b_16", + ] + # NOTE: The following modules are excluded from validation because they + # initialize their parameters as constants every time. + self.excluded_modules = [ + nn.modules.batchnorm._BatchNorm, + nn.LayerNorm, + ] + + def _validate_module(self, module: nn.Module, module_copy: nn.Module, parent_name: str = ""): + if isinstance(module, tuple(self.excluded_modules)): + return + for (name_p1, p1), (_, p2) in zip(module.named_parameters(recurse=False), module_copy.named_parameters(recurse=False)): + # NOTE: skip parameters that are prbably not randomly initialized + if "weight" not in name_p1: + continue + self.assertFalse( + torch.equal(p1, p2), + msg=f"Parameter {parent_name}.{name_p1} does not change after reset_all_parameters." + ) + for (name_m1, m1), (_, m2) in zip(module.named_children(), module_copy.named_children()): + self._validate_module(m1, m2, f"{parent_name}.{name_m1}") def test_reset_all_parameters(self): """Test reset_all_parameters.""" - pass + for model_id in self.model_ids: + model = get_model(model_id) + model_copy = copy.deepcopy(model) + for (name_p1, p1), (_, p2) in zip(model.named_parameters(), model_copy.named_parameters()): + self.assertTrue( + torch.equal(p1, p2), + msg=f"Parameter {name_p1} of {model_id} has been changed by deepcopy." + ) + model.apply(generator_module.reset_all_parameters) + self._validate_module(model, model_copy, model_id) class TestBaseGenerator(unittest.TestCase):