diff --git a/bdpy/dl/torch/torch.py b/bdpy/dl/torch/torch.py index 5b927d30..66825d42 100644 --- a/bdpy/dl/torch/torch.py +++ b/bdpy/dl/torch/torch.py @@ -1,7 +1,7 @@ '''PyTorch module.''' from typing import Iterable, List, Dict, Union, Tuple, Any, Callable, Optional - +from collections import OrderedDict import os import numpy as np @@ -93,6 +93,19 @@ def run(self, x: _tensor_t) -> Dict[str, _tensor_t]: } return features + + def __del__(self): + ''' + Remove forward hooks for the FeatureExtractor while keeping + other forward hooks in the model. + ''' + for layer in self.__layers: + if self.__layer_map is not None: + layer = self.__layer_map[layer] + layer_object = models._parse_layer_name(self._encoder, layer) + for key, hook in layer_object._forward_hooks.items(): + if hook == self._extractor: + del layer_object._forward_hooks[key] class FeatureExtractorHandle(object):