diff --git a/bdpy/dl/torch/models.py b/bdpy/dl/torch/models.py index 8e5954c1..5ab0d61a 100644 --- a/bdpy/dl/torch/models.py +++ b/bdpy/dl/torch/models.py @@ -1,5 +1,6 @@ """Model definitions.""" +from __future__ import annotations from typing import Dict, Union, Optional, Sequence @@ -93,7 +94,7 @@ def _parse_layer_name(model: nn.Module, layer_name: str) -> nn.Module: Network model. layer_name : str Layer name. It accepts the following formats: 'layer_name', - 'layer_name[index]', 'parent_name.child_name', and combinations of them. + '[index]', 'parent_name.child_name', and combinations of them. Returns ------- @@ -123,18 +124,19 @@ def _get_value_by_indices(array, indices): model = _parse_layer_name(model, top_most_layer_name) return _parse_layer_name(model, child_layer_name) - # parse layer name having index (e.g., 'features[0]', 'backbone[0][1]') - pattern = re.compile(r'^(?P\w+)(?P(\[(\d+)\])+)$') + # parse layer name having index (e.g., '[0]', 'features[0]', 'backbone[0][1]') + pattern = re.compile(r'^(?P[a-zA-Z_]+[a-zA-Z0-9_]*)?(?P(\[(\d+)\])+)$') m = pattern.match(layer_name) if m is not None: - layer_name = m.group('layer_name') + layer_name: str | None = m.group('layer_name') # NOTE: layer_name can be None index_str = m.group('index') indeces = re.findall(r'\[(\d+)\]', index_str) indeces = [int(i) for i in indeces] - if hasattr(model, layer_name): - return _get_value_by_indices(getattr(model, layer_name), indeces) + if isinstance(layer_name, str) and hasattr(model, layer_name): + model = getattr(model, layer_name) + return _get_value_by_indices(model, indeces) raise ValueError( f"Invalid layer name: '{layer_name}'. Either the syntax of '{layer_name}' is not supported, " diff --git a/test/dl/torch/test_models.py b/test/dl/torch/test_models.py index 6e217d88..dbb33520 100644 --- a/test/dl/torch/test_models.py +++ b/test/dl/torch/test_models.py @@ -6,6 +6,13 @@ from bdpy.dl.torch import models +def _removeprefix(text: str, prefix: str) -> str: + """Remove prefix from text. (Workaround for Python 3.8)""" + if text.startswith(prefix): + return text[len(prefix):] + return text + + class MockModule(nn.Module): def __init__(self): super(MockModule, self).__init__() @@ -69,6 +76,22 @@ def test_parse_layer_name(self): self.assertRaises( ValueError, models._parse_layer_name, self.mock, 'layers["key"]') + def test_parse_layer_name_for_sequential(self): + """Test _parse_layer_name for nn.Sequential. + + nn.Sequential is a special case because the submodules are directly + accessible like a list. For example, `model[0]` will return the first + module in the model. + """ + sequential_module = self.mock.layers + accessors = [accessor for accessor in self.accessors if accessor['name'].startswith('layers')] + for accessor in accessors: + accsessor_key = _removeprefix(accessor['name'], 'layers') + layer = models._parse_layer_name(sequential_module, accsessor_key) + self.assertIsInstance(layer, accessor['type']) + for attr, value in accessor['attrs'].items(): + self.assertEqual(getattr(layer, attr), value) + class TestVGG19(unittest.TestCase): def setUp(self):