From 942f1ca55647b6119b650856fbac701e052909c6 Mon Sep 17 00:00:00 2001 From: Vasily Shamporov Date: Thu, 6 May 2021 11:18:59 +0300 Subject: [PATCH] Release v1.7.1 of NNCF to master --- ReleaseNotes.md | 5 ++ examples/classification/main.py | 10 ++- examples/common/utils.py | 3 + nncf/dynamic_graph/context.py | 8 +- nncf/dynamic_graph/wrappers.py | 2 +- nncf/initialization.py | 59 +++++++++++---- nncf/model_creation.py | 17 +++++ nncf/nncf_network.py | 13 +++- nncf/quantization/algo.py | 10 +-- nncf/utils.py | 75 ++++++++++++++++--- nncf/version.py | 2 +- .../configs/inception_v3_mock_dataset.json | 16 ++++ tests/helpers.py | 1 + tests/quantization/test_algo_quantization.py | 20 +++++ .../test_saturation_issue_export.py | 16 ++++ tests/test_nncf_utils.py | 21 ++++++ tests/test_sanity_sample.py | 62 ++++++++++----- tests/test_utils.py | 74 ++++++++++++++++++ 18 files changed, 349 insertions(+), 65 deletions(-) create mode 100644 tests/data/configs/inception_v3_mock_dataset.json create mode 100644 tests/test_utils.py diff --git a/ReleaseNotes.md b/ReleaseNotes.md index ab6904acd56..05018974e99 100644 --- a/ReleaseNotes.md +++ b/ReleaseNotes.md @@ -7,6 +7,11 @@ samples distributed with the code. The samples demonstrate the usage of compres public models and datasets for three different use cases: Image Classification, Object Detection, and Semantic Segmentation. +## New in Release 1.7.1: +Bugfixes: +- Fixed a bug with where compressed models that were supposed to return named tuples actually returned regular tuples +- Fixed an issue with batch norm adaptation-enabled compression runs hanging in the DDP scenario + ## New in Release 1.7: - Adjust Padding feature to support accurate execution of U4 on VPU - when setting "target_device" to "VPU", the training-time padding values for quantized convolutions will be adjusted to better reflect VPU inference process. - Weighted layers that are "frozen" (i.e. have requires_grad set to False at compressed model creation time) are no longer considered for compression, to better handle transfer learning cases. diff --git a/examples/classification/main.py b/examples/classification/main.py index fe2a7dda65c..a5340c72e95 100644 --- a/examples/classification/main.py +++ b/examples/classification/main.py @@ -199,6 +199,7 @@ def autoq_eval_fn(model, eval_loader): if config.mode.lower() == 'train': train(config, compression_ctrl, model, criterion, train_criterion_fn, lr_scheduler, model_name, optimizer, train_loader, train_sampler, val_loader, best_acc1) + config.mlflow.end_run() def train(config, compression_ctrl, model, criterion, criterion_fn, lr_scheduler, model_name, optimizer, @@ -267,9 +268,11 @@ def get_dataset(dataset_config, config, transform, is_train): if dataset_config == 'imagenet': prefix = 'train' if is_train else 'val' return datasets.ImageFolder(osp.join(config.dataset_dir, prefix), transform) + # For testing purposes if dataset_config == 'mock_32x32': - # For testing purposes return MockDataset(img_size=(32, 32), transform=transform) + if dataset_config == 'mock_299x299': + return MockDataset(img_size=(299, 299), transform=transform) return create_cifar(config, dataset_config, is_train, transform) @@ -287,7 +290,8 @@ def create_cifar(config, dataset_config, is_train, transform): def create_datasets(config): dataset_config = config.dataset if config.dataset is not None else 'imagenet' dataset_config = dataset_config.lower() - assert dataset_config in ['imagenet', 'cifar100', 'cifar10', 'mock_32x32'], "Unknown dataset option" + assert dataset_config in ['imagenet', 'cifar100', 'cifar10', 'mock_32x32', 'mock_299x299'], \ + "Unknown dataset option" if dataset_config == 'imagenet': normalize = transforms.Normalize(mean=(0.485, 0.456, 0.406), @@ -295,7 +299,7 @@ def create_datasets(config): elif dataset_config == 'cifar100': normalize = transforms.Normalize(mean=(0.5071, 0.4865, 0.4409), std=(0.2673, 0.2564, 0.2761)) - elif dataset_config in ['cifar10', 'mock_32x32']: + elif dataset_config in ['cifar10', 'mock_32x32', 'mock_299x299']: normalize = transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) diff --git a/examples/common/utils.py b/examples/common/utils.py index 56b79c5c924..4d05cf68f0a 100644 --- a/examples/common/utils.py +++ b/examples/common/utils.py @@ -128,6 +128,9 @@ def safe_call(self, func: str, *args, **kwargs) -> Maybe: return Maybe.from_value(self._get_mlflow()).bind( lambda obj: Maybe.from_value(getattr(obj, func)(*args, **kwargs))) + def end_run(self): + self.safe_call('end_run') + def _is_enabled(self): return self.is_suitable_mode and is_main_process() diff --git a/nncf/dynamic_graph/context.py b/nncf/dynamic_graph/context.py index 6379f26c0b9..3067b07f011 100644 --- a/nncf/dynamic_graph/context.py +++ b/nncf/dynamic_graph/context.py @@ -475,11 +475,11 @@ def _init_thread_local(self): tl.operator_counters = {} tl.node_call_tracker = {} - def register_node_call(self, node_key: str): - if node_key in self._thread_local.node_call_tracker: - self._thread_local.node_call_tracker[node_key] += 1 + def register_node_call(self, node: NNCFNode): + if node.node_id in self._thread_local.node_call_tracker: + self._thread_local.node_call_tracker[node.node_id] += 1 else: - self._thread_local.node_call_tracker[node_key] = 1 + self._thread_local.node_call_tracker[node.node_id] = 1 def reset_node_call_counters(self): for k, _ in self._thread_local.node_call_tracker.items(): diff --git a/nncf/dynamic_graph/wrappers.py b/nncf/dynamic_graph/wrappers.py index afbceb90f97..47f272f45f8 100644 --- a/nncf/dynamic_graph/wrappers.py +++ b/nncf/dynamic_graph/wrappers.py @@ -79,7 +79,7 @@ def wrapped(*args, **kwargs): node = ctx.maybe_add_node(processed_input, tensor_metas, ia_op_exec_context, module_attrs) if is_debug(): - ctx.register_node_call(ctx.graph.get_node_key_by_id(node.node_id)) + ctx.register_node_call(node) args = tuple(processed_input.op_args) kwargs = processed_input.op_kwargs diff --git a/nncf/initialization.py b/nncf/initialization.py index 8c5e606fffd..c96b0702013 100644 --- a/nncf/initialization.py +++ b/nncf/initialization.py @@ -17,7 +17,7 @@ from nncf.structures import QuantizationRangeInitArgs from nncf.utils import is_tensor from nncf.utils import objwalk -from nncf.utils import training_mode_switcher +from contextlib import contextmanager class InitializingDataLoader: @@ -164,39 +164,66 @@ def __init__(self, model, init_device: str, num_bn_forget_steps): self.num_bn_forget_steps = num_bn_forget_steps self.momentum_bn_forget = 0.9 self.original_momenta_values = {} + self.original_training_state = {} @staticmethod def _apply_to_batchnorms(func): def func_apply_to_bns(module): - if isinstance(module, torch.nn.modules.batchnorm.BatchNorm2d): + if isinstance(module, (torch.nn.modules.batchnorm.BatchNorm1d, + torch.nn.modules.batchnorm.BatchNorm2d, + torch.nn.modules.batchnorm.BatchNorm3d)): func(module) return func_apply_to_bns - def _run_model_inference(self, data_loader, num_init_steps, device): - num_bn_forget_steps = self.num_bn_forget_steps + @contextmanager + def _bn_training_state_switcher(self) -> None: + def save_original_bn_training_state(module: torch.nn.Module): + self.original_training_state[module] = module.training + + def set_bn_training_state(module: torch.nn.Module, state: Dict[str, bool]): + module.training = state + + def restore_original_bn_training_state(module: torch.nn.Module): + module.training = self.original_training_state[module] + + self.model.apply(self._apply_to_batchnorms(save_original_bn_training_state)) + self.model.apply(self._apply_to_batchnorms(partial(set_bn_training_state, state=True))) + try: + yield + finally: + self.model.apply(self._apply_to_batchnorms(restore_original_bn_training_state)) + @contextmanager + def _bn_momentum_switcher(self) -> None: def set_bn_momentum(module, momentum_value): module.momentum = momentum_value - def save_original_bn_momenta(module): + def save_original_bn_momentum(module: torch.nn.Module): self.original_momenta_values[module] = module.momentum - def restore_original_bn_momenta(module): + def restore_original_bn_momentum(module: torch.nn.Module): module.momentum = self.original_momenta_values[module] - with training_mode_switcher(self.model, is_training=True): - self.model.apply(self._apply_to_batchnorms(save_original_bn_momenta)) - self.model.apply(self._apply_to_batchnorms(partial(set_bn_momentum, - momentum_value=self.momentum_bn_forget))) + self.model.apply(self._apply_to_batchnorms(save_original_bn_momentum)) + self.model.apply(self._apply_to_batchnorms(partial(set_bn_momentum, + momentum_value=self.momentum_bn_forget))) + try: + yield + finally: + self.model.apply(self._apply_to_batchnorms(restore_original_bn_momentum)) - for i, loaded_item in enumerate(data_loader): - if num_bn_forget_steps is not None and i >= num_bn_forget_steps: - break - args_kwargs_tuple = data_loader.get_inputs(loaded_item) - self._infer_batch(args_kwargs_tuple, device) + def _run_model_inference(self, data_loader, num_init_steps, device): + num_bn_forget_steps = self.num_bn_forget_steps - self.model.apply(self._apply_to_batchnorms(restore_original_bn_momenta)) + with self._bn_training_state_switcher(): + if num_bn_forget_steps is not None and num_bn_forget_steps > 0: + with self._bn_momentum_switcher(): + for i, loaded_item in enumerate(data_loader): + if i >= num_bn_forget_steps: + break + args_kwargs_tuple = data_loader.get_inputs(loaded_item) + self._infer_batch(args_kwargs_tuple, device) for i, loaded_item in ProgressBar( enumerate(data_loader), diff --git a/nncf/model_creation.py b/nncf/model_creation.py index 0f662ed6131..caace60879e 100644 --- a/nncf/model_creation.py +++ b/nncf/model_creation.py @@ -14,6 +14,7 @@ from typing import Callable, Any, Tuple, Dict from torch.nn import Module +from torch.distributed import barrier from nncf.checkpoint_loading import load_state from nncf.composite_compression import PTCompositeCompressionAlgorithmBuilder @@ -24,6 +25,7 @@ from nncf.graph.graph_builder import GraphBuilder from nncf.nncf_network import NNCFNetwork from nncf.utils import is_main_process +from nncf.utils import is_dist_avail_and_initialized from nncf.algo_selector import COMPRESSION_ALGORITHMS from nncf.common.utils.logger import logger @@ -141,4 +143,19 @@ def create_compressed_model(model: Module, config: NNCFConfig, graph = compressed_graph_builder.build_graph(compressed_model, compressed_model.get_tracing_context()) graph.visualize_graph(osp.join(config.get("log_dir", "."), "compressed_graph.dot")) + # Synchronize all processes if run in distributed mode + if is_dist_avail_and_initialized(): + try: + barrier() + # Exception can be raised during running barrier + # if the backend not in the supported list https://pytorch.org/docs/stable/distributed.html + except RuntimeError as err: + logger.warning(err) + logger.warning( + "NNCF continues work, while does not guarantee that " + "the processes will finish model's compression at the same time. " + "If your training pipeline demands the processes be synchronized, please, " + "keep attention to that error") + return compression_ctrl, compressed_model + return compression_ctrl, compressed_model diff --git a/nncf/nncf_network.py b/nncf/nncf_network.py index 9f256d00f70..4625be72657 100644 --- a/nncf/nncf_network.py +++ b/nncf/nncf_network.py @@ -504,10 +504,15 @@ def _set_nncf_wrapped_model(self, value): def get_clean_shallow_copy(self) -> 'NNCFNetwork': # WARNING: Will reset pre- and post-ops of the underlying model. Use save_nncf_module_additions # and load_nncf_module_additions to preserve these, or temporary_clean_view(). - return NNCFNetwork(self.get_nncf_wrapped_model(), self.input_infos, - self._user_dummy_forward_fn, self._wrap_inputs_fn, - self.scopes_without_shape_matching, self.ignored_scopes, self.target_scopes, - reset=True) + from nncf.utils import save_module_training_state, load_module_training_state + saved_state = {} + save_module_training_state(self, saved_state) + model_copy = NNCFNetwork(self.get_nncf_wrapped_model(), self.input_infos, + self._user_dummy_forward_fn, self._wrap_inputs_fn, + self.scopes_without_shape_matching, self.ignored_scopes, self.target_scopes, + reset=True) + load_module_training_state(model_copy, saved_state) + return model_copy def get_modules_in_nncf_modules_by_type(self, types) -> Dict['Scope', nn.Module]: nncf_modules = self.get_nncf_modules() diff --git a/nncf/quantization/algo.py b/nncf/quantization/algo.py index 7c3aef0f5c8..0c3efbb9655 100644 --- a/nncf/quantization/algo.py +++ b/nncf/quantization/algo.py @@ -606,7 +606,7 @@ def _get_transformation_layout(self, target_model: NNCFNetwork) -> PTTransformat target_model.register_compression_module_type(ExtraCompressionModuleType.EXTERNAL_QUANTIZER) single_config_quantizer_setup = self._get_quantizer_setup(target_model) minmax_values_for_range_init = {} - if self.should_init: + if is_main_process() and self.should_init: stats_for_range_init = self._get_statistics_for_final_range_init(target_model, single_config_quantizer_setup, self._range_init_params) @@ -1365,6 +1365,9 @@ def __init__(self): self.dump_dir = Path(DEBUG_LOG_DIR) / Path("debug_dumps") self.dump_dir.mkdir(parents=True, exist_ok=True) self.scale_dump_dir = self.dump_dir / Path("scale") + if self.scale_dump_dir.exists(): + shutil.rmtree(str(self.scale_dump_dir)) + self.scale_dump_dir.mkdir(parents=True, exist_ok=True) self.prop_graph_dump_dir = self.dump_dir / Path("quant_prop") if self.prop_graph_dump_dir.exists(): shutil.rmtree(str(self.prop_graph_dump_dir)) @@ -1383,9 +1386,6 @@ def init_actual(self, owner_model: NNCFNetwork): nncf_module_quantizations_id_list) self.call_trackers[self.ACTIVATION_QUANTIZERS_TRACKER_NAME].init_with_key_list( activation_quantizer_id_list) - if self.scale_dump_dir.exists(): - shutil.rmtree(str(self.scale_dump_dir)) - self.scale_dump_dir.mkdir(parents=True, exist_ok=True) self._strict_forward = True def pre_forward_actions(self, module: 'NNCFNetwork'): @@ -1428,7 +1428,7 @@ def dump_scale(self, quantizer_scale_params: Dict[str, torch.Tensor], quantizer_ quantizer_normalized_name = re.sub(r'[^\w\-_\. ]', '_', quantizer_name) for scale_param_name, scale_param in quantizer_scale_params.items(): fname = "{}_{}.txt".format(quantizer_normalized_name, scale_param_name) - with safe_open(self.scale_dump_dir / fname, "ba") as file: + with safe_open(self.scale_dump_dir / fname, "ab") as file: np.savetxt(file, scale_param.cpu().numpy().flatten()) def reset_counters(self): diff --git a/nncf/utils.py b/nncf/utils.py index d2d6f94c418..d9140b8825e 100644 --- a/nncf/utils.py +++ b/nncf/utils.py @@ -12,6 +12,8 @@ """ from collections import OrderedDict from typing import Dict, Callable, Any, Mapping, Sequence, Set, List, Union +from typing import Tuple +from typing import Type import numpy as np import random @@ -20,6 +22,7 @@ from torch import distributed as dist, nn from torch.nn import Module +from nncf.common.utils.logger import logger as nncf_logger from nncf.dynamic_graph.graph_tracer import ModelInputInfo, create_dummy_forward_fn from nncf.dynamic_graph.trace_tensor import TracedTensor from nncf.graph.graph_builder import GraphBuilder @@ -311,14 +314,40 @@ def maybe_get_iterator(obj): return it +def to_tuple(lst: List, + named_tuple_class: Type = None, + named_tuple_fields: List[str] = None) -> Tuple: + # Able to produce namedtuples if a corresponding parameter is given + if named_tuple_fields is None: + return tuple(lst) + return named_tuple_class(*lst) + + +def is_tuple(obj) -> bool: + return isinstance(obj, tuple) + + +def is_named_tuple(obj) -> bool: + return is_tuple(obj) and (obj.__class__ != tuple) + + def objwalk(obj, unary_predicate: Callable[[Any], bool], apply_fn: Callable, memo=None): """Walks through the indexable container hierarchy of obj and replaces all sub-objects matching a criterion with the result of a given function application.""" + #pylint:disable=too-many-nested-blocks + #pylint:disable=too-many-branches if memo is None: memo = set() - is_tuple = isinstance(obj, tuple) - if is_tuple: + named_tuple_class = None + named_tuple_fields = None + if is_named_tuple(obj): + named_tuple_class = obj.__class__ + #pylint:disable=protected-access + named_tuple_fields = obj._fields + + was_tuple = is_tuple(obj) + if was_tuple: obj = list(obj) iterator = maybe_get_iterator(obj) @@ -327,30 +356,34 @@ def objwalk(obj, unary_predicate: Callable[[Any], bool], apply_fn: Callable, mem if id(obj) not in memo: memo.add(id(obj)) indices_to_apply_fn_to = set() - indices_vs_tuples_to_assign = {} # type: Dict[Any, list] + indices_vs_named_tuple_data = {} # type: Dict[Any, Tuple[list, Type, List[str]]] for idx, value in iterator(obj): next_level_it = maybe_get_iterator(value) if next_level_it is None: if unary_predicate(value): indices_to_apply_fn_to.add(idx) else: - if isinstance(value, tuple): + if is_tuple(value): processed_tuple = objwalk(value, unary_predicate, apply_fn, memo) - indices_vs_tuples_to_assign[idx] = processed_tuple + if is_named_tuple(value): + indices_vs_named_tuple_data[idx] = processed_tuple, value.__class__, value._fields + else: + indices_vs_named_tuple_data[idx] = processed_tuple, None, None else: objwalk(value, unary_predicate, apply_fn) for idx in indices_to_apply_fn_to: obj[idx] = apply_fn(obj[idx]) - for idx, tpl in indices_vs_tuples_to_assign.items(): - obj[idx] = tuple(tpl) + for idx, tpl_data in indices_vs_named_tuple_data.items(): + tpl, n_tpl_class, n_tpl_fields = tpl_data + obj[idx] = to_tuple(tpl, n_tpl_class, n_tpl_fields) memo.remove(id(obj)) else: if unary_predicate(obj): return apply_fn(obj) - if is_tuple: - return tuple(obj) + if was_tuple: + return to_tuple(obj, named_tuple_class, named_tuple_fields) return obj @@ -360,14 +393,34 @@ def should_consider_scope(scope_str: str, target_scopes: List[str], ignored_scop and not in_scope_list(scope_str, ignored_scopes) +def save_module_training_state(module: torch.nn.Module, saved_state: Dict[torch.nn.Module, bool]) -> None: + for ch in module.children(): + saved_state[ch] = ch.training + save_module_training_state(ch, saved_state) + + +def load_module_training_state(module: torch.nn.Module, state: Dict[torch.nn.Module, bool], strict=False) -> None: + for ch in module.children(): + try: + ch.train(state[ch]) + except KeyError as err: + # if the modules name changed during forward (e.g. LSTM block in our examples) + if strict: + nncf_logger.error(err) + return + finally: + load_module_training_state(ch, state) + + @contextmanager def training_mode_switcher(model: torch.nn.Module, is_training: bool = True): - is_original_mode_training = model.training + saved_state = {} + save_module_training_state(model, saved_state) model.train(is_training) try: yield finally: - model.train(is_original_mode_training) + load_module_training_state(model, saved_state) def compute_FLOPs_hook(module, input_, output, dict_to_save, ctx: 'TracingContext'): diff --git a/nncf/version.py b/nncf/version.py index cfaff5ba2bc..176b7b6cf51 100644 --- a/nncf/version.py +++ b/nncf/version.py @@ -1,2 +1,2 @@ -__version__ = "1.7.0" +__version__ = "1.7.1" BKC_TORCH_VERSION = "1.8.1" diff --git a/tests/data/configs/inception_v3_mock_dataset.json b/tests/data/configs/inception_v3_mock_dataset.json new file mode 100644 index 00000000000..4d855cfaffd --- /dev/null +++ b/tests/data/configs/inception_v3_mock_dataset.json @@ -0,0 +1,16 @@ +{ + "model": "inception_v3", + "pretrained": false, + "input_info": { + "sample_size": [1, 3, 299, 299] + }, + "num_classes": 10, + "optimizer": { + "type": "sgd", + "base_lr": 1e-2, + "schedule_type": "multistep", + "steps": [ + 1 + ] + } +} diff --git a/tests/helpers.py b/tests/helpers.py index 5ccef88434e..b793647f0c2 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -74,6 +74,7 @@ def create_transpose_conv(in_channels, out_channels, kernel_size, weight_init, b class BasicConvTestModel(nn.Module): + INPUT_SIZE = [1, 1, 4, 4] def __init__(self, in_channels=1, out_channels=2, kernel_size=2, weight_init=-1, bias_init=-2): super().__init__() self.in_channels = in_channels diff --git a/tests/quantization/test_algo_quantization.py b/tests/quantization/test_algo_quantization.py index 16b8b121518..d87c0fbed27 100644 --- a/tests/quantization/test_algo_quantization.py +++ b/tests/quantization/test_algo_quantization.py @@ -10,6 +10,8 @@ See the License for the specific language governing permissions and limitations under the License. """ +import logging +from contextlib import contextmanager from copy import deepcopy from typing import List from typing import Tuple @@ -504,6 +506,7 @@ def test_quantize_outputs(): quantizer = qctrl.non_weight_quantizers[matches[0]].quantizer_module_ref assert isinstance(quantizer, SymmetricQuantizer) + def test_quantize_outputs_with_scope_overrides(): config = get_quantization_config_without_range_init() config["input_info"] = [ @@ -526,3 +529,20 @@ def test_quantize_outputs_with_scope_overrides(): for q in output_quantizers: assert q.num_bits == 4 assert isinstance(q, AsymmetricQuantizer) + + +@contextmanager +def nncf_debug(): + from nncf import set_log_level + set_log_level(logging.DEBUG) + yield + set_log_level(logging.INFO) + + +def test_debug_mode(): + config = get_quantization_config_without_range_init() + model = BasicConvTestModel() + with nncf_debug(): + model, _ = create_compressed_model_and_algo_for_test(model, config) + model.forward(torch.zeros(BasicConvTestModel.INPUT_SIZE, + device=next(model.parameters()).device)) diff --git a/tests/quantization/test_saturation_issue_export.py b/tests/quantization/test_saturation_issue_export.py index a4446e98a97..c1e788ad32a 100644 --- a/tests/quantization/test_saturation_issue_export.py +++ b/tests/quantization/test_saturation_issue_export.py @@ -101,13 +101,29 @@ def __init__(self, in_out_ch=((1, 3), (3, 5), (5, 7), (7, 10))): super().__init__() self.features = [] self.features.append(create_conv(*in_out_ch[0], 2, -1, -2)) + self.features.append(nn.BatchNorm2d(in_out_ch[0][1])) + self.features.append(nn.ReLU()) self.features.append(create_conv(*in_out_ch[1], 5, 1, 1)) + self.features.append(nn.BatchNorm2d(in_out_ch[1][1])) + self.features.append(nn.ReLU()) self.features.append(create_conv(*in_out_ch[2], 1, 2, 2)) + self.features.append(nn.BatchNorm2d(in_out_ch[2][1])) + self.features.append(nn.ReLU()) self.features.append(create_conv(*in_out_ch[3], 9, -1, 0)) + self.features.append(nn.BatchNorm2d(in_out_ch[3][1])) + self.features.append(nn.ReLU()) self.features.append(create_conv(*reversed(in_out_ch[3]), 3, 0, 1)) + self.features.append(nn.BatchNorm2d(in_out_ch[3][0])) + self.features.append(nn.ReLU()) self.features.append(create_conv(*reversed(in_out_ch[2]), 1, -1, 9)) + self.features.append(nn.BatchNorm2d(in_out_ch[2][0])) + self.features.append(nn.ReLU()) self.features.append(create_conv(*reversed(in_out_ch[1]), 2, 10, 1)) + self.features.append(nn.BatchNorm2d(in_out_ch[1][0])) + self.features.append(nn.ReLU()) self.features.append(create_conv(*reversed(in_out_ch[0]), 1, 1, 1)) + self.features.append(nn.BatchNorm2d(in_out_ch[0][0])) + self.features.append(nn.ReLU()) self.features = nn.Sequential(*self.features) def forward(self, x): diff --git a/tests/test_nncf_utils.py b/tests/test_nncf_utils.py index a911455601a..ec666e2b45a 100644 --- a/tests/test_nncf_utils.py +++ b/tests/test_nncf_utils.py @@ -11,6 +11,7 @@ limitations under the License. """ from collections import namedtuple +from typing import Any import pytest from functools import partial @@ -97,3 +98,23 @@ def is_target_class(obj): test_obj = objwalk(start_obj, is_target_class, fn_to_apply) assert test_obj == ref_obj + +def assert_named_tuples_are_equal(ref_named_tuple: tuple, test_obj: Any): + assert test_obj.__class__.__qualname__ == ref_named_tuple.__class__.__qualname__ + assert hasattr(test_obj, "_fields") + assert all([f in test_obj._fields for f in ref_named_tuple._fields]) + assert all([f in ref_named_tuple._fields for f in test_obj._fields]) + + +def test_objwalk_retains_named_tuple(): + named_tuple = NamedTuple(field1=ObjwalkTestClass(OBJWALK_INIT_VAL), + field2=NamedTuple(field1=ObjwalkTestClass(OBJWALK_INIT_VAL), + field2=-8)) + + def is_target_class(obj): + return isinstance(obj, ObjwalkTestClass) + + fn_to_apply = partial(ObjwalkTestClass.member_fn, val=OBJWALK_REF_VAL) + test_obj = objwalk(named_tuple, is_target_class, fn_to_apply) + assert_named_tuples_are_equal(named_tuple, test_obj) + assert_named_tuples_are_equal(named_tuple.field2, test_obj.field2) diff --git a/tests/test_sanity_sample.py b/tests/test_sanity_sample.py index 4ffa1b511dd..0b26420f806 100644 --- a/tests/test_sanity_sample.py +++ b/tests/test_sanity_sample.py @@ -37,6 +37,7 @@ from nncf.common.quantization.structs import QuantizerConfig from nncf.config import NNCFConfig from nncf.hw_config import HWConfigType +from pytest_dependency import depends from tests.conftest import EXAMPLES_DIR from tests.conftest import PROJECT_ROOT from tests.conftest import TEST_ROOT @@ -149,13 +150,14 @@ def create_command_line(args, sample_type): SAMPLE_TYPES = ["classification", "semantic_segmentation", "object_detection"] DATASETS = { - "classification": ["mock_32x32", "mock_32x32", "mock_32x32", "mock_32x32"], + "classification": ["mock_32x32", "mock_299x299", "mock_32x32", "mock_32x32"], "semantic_segmentation": ["camvid", "camvid"], "object_detection": ["voc"], } CONFIGS = { "classification": [TEST_ROOT.joinpath("data", "configs", "squeezenet1_1_cifar10_rb_sparsity_int8.json"), + TEST_ROOT.joinpath("data", "configs", "inception_v3_mock_dataset.json"), TEST_ROOT.joinpath("data", "configs", "resnet18_cifar100_bin_xnor.json"), TEST_ROOT.joinpath("data", "configs", "resnet18_cifar10_staged_quant.json"), TEST_ROOT.joinpath("data", "configs", "resnet18_pruning_magnitude.json")], @@ -165,7 +167,7 @@ def create_command_line(args, sample_type): } BATCHSIZE_PER_GPU = { - "classification": [256, 256, 256, 128], + "classification": [256, 32, 256, 256, 128], "semantic_segmentation": [2, 2], "object_detection": [128], } @@ -201,9 +203,11 @@ def update_compression_algo_dict_with_reduced_bn_adapt_params(algo_dict): algo_dict['initializer'].update({'batchnorm_adaptation': {'num_bn_adaptation_samples': 5, 'num_bn_forget_samples': 5}}) +def _get_test_case_id(p) -> str: + return "-".join([p[0], p[1].name, p[2], str(p[3])]) @pytest.fixture(params=CONFIG_PARAMS, - ids=["-".join([p[0], p[1].name, p[2], str(p[3])]) for p in CONFIG_PARAMS]) + ids=[_get_test_case_id(p) for p in CONFIG_PARAMS]) def config(request, dataset_dir): sample_type, config_path, dataset_name, batch_size = request.param dataset_path = DATASET_PATHS[sample_type][dataset_name](dataset_dir) @@ -232,6 +236,7 @@ def config(request, dataset_dir): "model_name": jconfig["model"], "dataset_path": dataset_path, "batch_size": batch_size, + "test_case_id": _get_test_case_id(request.param) } @@ -266,10 +271,9 @@ def test_pretrained_model_eval(config, tmp_path, multiprocessing_distributed): runner.run() +@pytest.mark.dependency() @pytest.mark.parametrize( - "multiprocessing_distributed", [ - pytest.param(True, marks=pytest.mark.dependency(name="train_distributed")), - pytest.param(False, marks=pytest.mark.dependency(name="train_dataparallel"))], + "multiprocessing_distributed", [True, False], ids=['distributed', 'dataparallel']) def test_pretrained_model_train(config, tmp_path, multiprocessing_distributed, case_common_dirs): checkpoint_save_dir = os.path.join(case_common_dirs["checkpoint_save_dir"], @@ -291,20 +295,34 @@ def test_pretrained_model_train(config, tmp_path, multiprocessing_distributed, c args["--cpu-only"] = True elif multiprocessing_distributed: args["--multiprocessing-distributed"] = True + elif config['nncf_config']["model"] == "inception_v3": + pytest.skip("InceptionV3 may not be trained in DataParallel " + "because it outputs namedtuple, which DP seems to be unable " + "to support even still.") runner = Command(create_command_line(args, config["sample_type"])) runner.run() last_checkpoint_path = os.path.join(checkpoint_save_dir, get_name(config_factory.config) + "_last.pth") assert os.path.exists(last_checkpoint_path) - assert torch.load(last_checkpoint_path)['compression_level'] in (CompressionLevel.FULL, CompressionLevel.PARTIAL) + if 'compression' in config['nncf_config']: + allowed_compression_levels = (CompressionLevel.FULL, CompressionLevel.PARTIAL) + else: + allowed_compression_levels = (CompressionLevel.NONE,) + assert torch.load(last_checkpoint_path)['compression_level'] in allowed_compression_levels + +def depends_on_pretrained_train(request, test_case_id: str, current_multiprocessing_distributed: bool): + full_test_case_id = test_case_id + ('-distributed' if current_multiprocessing_distributed else '-dataparallel') + primary_test_case_name = f'test_pretrained_model_train[{full_test_case_id}]' + depends(request, [primary_test_case_name]) + +@pytest.mark.dependency() @pytest.mark.parametrize( - "multiprocessing_distributed", [ - pytest.param(True, marks=pytest.mark.dependency(depends=["train_distributed"])), - pytest.param(False, marks=pytest.mark.dependency(depends=["train_dataparallel"]))], + "multiprocessing_distributed", [True, False], ids=['distributed', 'dataparallel']) -def test_trained_model_eval(config, tmp_path, multiprocessing_distributed, case_common_dirs): +def test_trained_model_eval(request, config, tmp_path, multiprocessing_distributed, case_common_dirs): + depends_on_pretrained_train(request, config["test_case_id"], multiprocessing_distributed) config_factory = ConfigFactory(config['nncf_config'], tmp_path / 'config.json') ckpt_path = os.path.join(case_common_dirs["checkpoint_save_dir"], "distributed" if multiprocessing_distributed else "data_parallel", @@ -335,12 +353,12 @@ def get_resuming_checkpoint_path(config_factory, multiprocessing_distributed, ch get_name(config_factory.config) + "_last.pth") +@pytest.mark.dependency() @pytest.mark.parametrize( - "multiprocessing_distributed", [ - pytest.param(True, marks=pytest.mark.dependency(depends=["train_distributed"])), - pytest.param(False, marks=pytest.mark.dependency(depends=["train_dataparallel"]))], + "multiprocessing_distributed", [True, False], ids=['distributed', 'dataparallel']) -def test_resume(config, tmp_path, multiprocessing_distributed, case_common_dirs): +def test_resume(request, config, tmp_path, multiprocessing_distributed, case_common_dirs): + depends_on_pretrained_train(request, config["test_case_id"], multiprocessing_distributed) checkpoint_save_dir = os.path.join(str(tmp_path), "models") config_factory = ConfigFactory(config['nncf_config'], tmp_path / 'config.json') ckpt_path = get_resuming_checkpoint_path(config_factory, multiprocessing_distributed, @@ -369,15 +387,19 @@ def test_resume(config, tmp_path, multiprocessing_distributed, case_common_dirs) runner.run() last_checkpoint_path = os.path.join(checkpoint_save_dir, get_name(config_factory.config) + "_last.pth") assert os.path.exists(last_checkpoint_path) - assert torch.load(last_checkpoint_path)['compression_level'] in (CompressionLevel.FULL, CompressionLevel.PARTIAL) + if 'compression' in config['nncf_config']: + allowed_compression_levels = (CompressionLevel.FULL, CompressionLevel.PARTIAL) + else: + allowed_compression_levels = (CompressionLevel.NONE,) + assert torch.load(last_checkpoint_path)['compression_level'] in allowed_compression_levels +@pytest.mark.dependency() @pytest.mark.parametrize( - "multiprocessing_distributed", [ - pytest.param(True, marks=pytest.mark.dependency(depends=["train_distributed"])), - pytest.param(False, marks=pytest.mark.dependency(depends=["train_dataparallel"]))], + "multiprocessing_distributed", [True, False], ids=['distributed', 'dataparallel']) -def test_export_with_resume(config, tmp_path, multiprocessing_distributed, case_common_dirs): +def test_export_with_resume(request, config, tmp_path, multiprocessing_distributed, case_common_dirs): + depends_on_pretrained_train(request, config["test_case_id"], multiprocessing_distributed) config_factory = ConfigFactory(config['nncf_config'], tmp_path / 'config.json') ckpt_path = get_resuming_checkpoint_path(config_factory, multiprocessing_distributed, case_common_dirs["checkpoint_save_dir"]) diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 00000000000..b4dfee8bfe2 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,74 @@ +import pytest +from torch import nn + +from nncf.utils import training_mode_switcher +from nncf.initialization import DataLoaderBNAdaptationRunner + +from tests.helpers import BasicConvTestModel, TwoConvTestModel, MockModel +from tests.quantization.test_saturation_issue_export import DepthWiseConvTestModel, EightConvTestModel +# pylint:disable=unused-import +from tests.modules.test_rnn import _seed + + +def save_model_training_state(module, model_state): + for ch in module.children(): + model_state[ch] = ch.training + save_model_training_state(ch, model_state) + + +def compare_saved_model_state_and_current_model_state(module, model_state): + for ch in module.children(): + assert model_state[ch] == ch.training + compare_saved_model_state_and_current_model_state(ch, model_state) + + +def randomly_change_model_training_state(module): + import random + for ch in module.children(): + if random.uniform(0, 1) > 0.5: + ch.training = False + else: + ch.training = True + randomly_change_model_training_state(ch) + + +@pytest.mark.parametrize('model', [BasicConvTestModel(), TwoConvTestModel(), MockModel(), + DepthWiseConvTestModel(), EightConvTestModel()]) +def test_training_mode_switcher(_seed, model): + randomly_change_model_training_state(model) + + saved_model_state = {} + save_model_training_state(model, saved_model_state) + + with training_mode_switcher(model, True): + # pylint: disable=unnecessary-pass + pass + + compare_saved_model_state_and_current_model_state(model, saved_model_state) + + +@pytest.mark.parametrize('model', [BasicConvTestModel(), TwoConvTestModel(), MockModel(), + DepthWiseConvTestModel(), EightConvTestModel()]) +def test_bn_training_state_switcher(_seed, model): + runner = DataLoaderBNAdaptationRunner(model, 'cuda', 0) + saved_model_state = {} + + def check_were_only_bn_training_state_changed(module, saved_state): + for ch in module.children(): + if isinstance(ch, (nn.BatchNorm1d, + nn.BatchNorm2d, + nn.BatchNorm3d)): + assert ch.training + else: + assert ch.training == saved_state[ch] + check_were_only_bn_training_state_changed(ch, saved_state) + + randomly_change_model_training_state(model) + + save_model_training_state(model, saved_model_state) + + # pylint: disable=protected-access + with runner._bn_training_state_switcher(): + check_were_only_bn_training_state_changed(model, saved_model_state) + + compare_saved_model_state_and_current_model_state(model, saved_model_state)