Skip to content

Commit

Permalink
Fix: remove only added forward hooks
Browse files Browse the repository at this point in the history
  • Loading branch information
ShunsukeOnoo committed Jul 26, 2024
1 parent 8021eef commit 3560c80
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions bdpy/dl/torch/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,17 @@ def run(self, x: _tensor_t) -> Dict[str, _tensor_t]:
return features

def __del__(self):
'''Remove forward hooks from the encoder.'''
'''
Remove forward hooks for the FeatureExtractor while keeping
other forward hooks in the model.
'''
for layer in self.__layers:
if self.__layer_map is not None:
layer = self.__layer_map[layer]
layer_object = models._parse_layer_name(self._encoder, layer)
layer_object._forward_hooks = OrderedDict()
for key, hook in layer_object._forward_hooks.items():
if hook == self._extractor:
del layer_object._forward_hooks[key]


class FeatureExtractorHandle(object):
Expand Down

0 comments on commit 3560c80

Please sign in to comment.