diff --git a/internlm/train/pipeline.py b/internlm/train/pipeline.py index 8676dff8..12080c08 100644 --- a/internlm/train/pipeline.py +++ b/internlm/train/pipeline.py @@ -906,7 +906,16 @@ def inject_config(model: nn.Module) -> None: def inject_model_helper(model: Union[nn.Module, nn.ModuleList], inject_info: Optional[Dict] = None) -> None: - # get inject_info + ''' + Inject model helper functions. + + Args: + model (Union[nn.Module, nn.ModuleList]): + For built-in models, it is nn.Module for no pp and nn.ModuleList for pp. + For injected models, it is nn.Module. + inject_info (Optional[Dict]): configurations for injected_models. + ''' + # parse inject_info if inject_info is not None: inject = inject_info.get("inject", False) interactive = inject_info.get("interactive", False) @@ -928,33 +937,37 @@ def inject_model_helper(model: Union[nn.Module, nn.ModuleList], inject_info: Opt "norm": inject_norm, } + # Special case for pure dp mode: do nothing + if ( + isinstance(gpc.config.parallel["tensor"], dict) + and gpc.config.parallel["tensor"].get("mode", TensorParallelMode.mtp.name) == TensorParallelMode.mtp.name + and gpc.get_world_size(ParallelMode.DATA) == gpc.get_world_size(ParallelMode.GLOBAL) + ): + return + + # inject config + if inject: + inject_config(model) + if not isinstance(model, nn.ModuleList): model = [model] - - # inject modules for _chunk in model: - if ( - isinstance(gpc.config.parallel["tensor"], dict) - and gpc.config.parallel["tensor"].get("mode", TensorParallelMode.mtp.name) == TensorParallelMode.mtp.name - and gpc.get_world_size(ParallelMode.DATA) == gpc.get_world_size(ParallelMode.GLOBAL) - ): - continue + # In-place replacement or check for modules: "embed", "linear", "norm" + # (1) If inject=True, in-place replacement + # (2) If inject=False, check for mod in modules: inject_funcs[mod](_chunk, inject, interactive) - - # reset parameters and move model to device - for _chunk in model: + # reset parameters if needed, model should have reset_parameters() method + if reset_params: + _chunk.reset_parameters() + # If inject=True, model is initialized on cpu, and should be moved to cuda device after injection if inject: - if reset_params: - _chunk.reset_parameters() _chunk.to(get_current_device()) - # inject configs - if inject: - inject_config(model[0]) - if gpc.is_rank_for_log(): - logger.info( - f"inject is enabled, please check the model carefully, " - f"if there are any problems, please report issue to us. " - f"The injected model is \n {model}" - ) + # print injected model + if inject and gpc.is_rank_for_log(): + logger.info( + f"inject is enabled, please check the model carefully, " + f"if there are any problems, please report issue to us. " + f"The injected model is \n {model}" + )