From 57c7990030e6ff50972831429ff0f5468da8a7bc Mon Sep 17 00:00:00 2001 From: "chenxun.p" <759046501@qq.com> Date: Wed, 4 Sep 2024 19:29:16 +0800 Subject: [PATCH 1/3] add vacab parallel embedding --- internlm/core/parallel/comm/isp.py | 8 ++--- internlm/core/parallel/comm/tensor.py | 15 ++++++-- internlm/core/parallel/comm/utils.py | 37 +++++++++++++++++++ internlm/model/modules/embedding.py | 52 ++++++++++++++++++++++++--- 4 files changed, 101 insertions(+), 11 deletions(-) diff --git a/internlm/core/parallel/comm/isp.py b/internlm/core/parallel/comm/isp.py index 3fcea13a..54e158a5 100644 --- a/internlm/core/parallel/comm/isp.py +++ b/internlm/core/parallel/comm/isp.py @@ -145,7 +145,7 @@ class EmbeddingWeightParallelCommunicator: def __init__(self, parallel_mode: ParallelMode) -> None: self.parallel_mode = parallel_mode - self.emb_column = 1 + self.vocab_dim = 0 self._cur_micro_step = 0 self._num_micro_step = gpc.config.data.micro_num @@ -165,7 +165,7 @@ def forward(ctx, inputs: torch.Tensor): # pylint: disable=W0613 if module.weight.evo_tensor is None: module.weight.evo_tensor = module.weight.data - module.weight.data = _gather(module.weight, self.parallel_mode, dim=self.emb_column) + module.weight.data = _gather(module.weight, self.parallel_mode, dim=self.vocab_dim) inputs = inputs.detach() return inputs @@ -188,7 +188,7 @@ def forward(ctx, output: torch.Tensor): # pylint: disable=W0613 @staticmethod def backward(ctx: Any, grad_output: torch.Tensor) -> torch.Tensor: # pylint: disable=W0613 - module.weight.data = _gather(module.weight, self.parallel_mode, dim=self.emb_column) + module.weight.data = _gather(module.weight, self.parallel_mode, dim=self.vocab_dim) return grad_output def _pre_forward_hook(module, inputs): # pylint: disable=W0613 @@ -205,7 +205,7 @@ def _post_forward_hook(module, inputs, output): # pylint: disable=W0613 def grad_reduce_hook(self, param: torch.Tensor): _grad, _ = reduce_scatter_raw( - param.grad, gpc.get_group(self.parallel_mode), op=dist.ReduceOp.AVG, reduce_dim=self.emb_column + param.grad, gpc.get_group(self.parallel_mode), op=dist.ReduceOp.AVG, reduce_dim=self.vocab_dim ) if param.evo_tensor.grad is None: param.evo_tensor.grad = _grad diff --git a/internlm/core/parallel/comm/tensor.py b/internlm/core/parallel/comm/tensor.py index 2dfc8bd2..453b33f1 100644 --- a/internlm/core/parallel/comm/tensor.py +++ b/internlm/core/parallel/comm/tensor.py @@ -19,6 +19,7 @@ all_gather_raw, all_reduce_raw, gather_forward_split_backward, + reduce_forward, reduce_scatter_raw, split_forward_gather_backward, ) @@ -341,7 +342,12 @@ def output_hook(self, module: Embedding1D, args: Any, output: Tuple[Any]) -> Tup """ _emb_dim = 2 # [bsz, seqlen, emb_dim] - return gather_forward_split_backward(output, self._parallel_mode, dim=_emb_dim) + if module.vocab_parallel: + output = reduce_forward(output, self._parallel_mode) + else: + output = gather_forward_split_backward(output, self._parallel_mode, dim=_emb_dim) + + return output class EmbeddingSequenceParallelCommunicator: @@ -363,7 +369,12 @@ def output_hook(self, module: Embedding1D, args: Any, output: Tuple[Any]) -> Tup """ _emb_dim, _seq_dim = 2, 1 # [bsz, seqlen, emb_dim] - output = gather_forward_split_backward(output, self._parallel_mode, dim=_emb_dim) + # tp: + if module.vocab_parallel: + output = reduce_forward(output, self._parallel_mode) + else: + output = gather_forward_split_backward(output, self._parallel_mode, dim=_emb_dim) + # sp: output = split_forward_gather_backward(output, self._parallel_mode, dim=_seq_dim) return output diff --git a/internlm/core/parallel/comm/utils.py b/internlm/core/parallel/comm/utils.py index abd291cf..aead5fbe 100644 --- a/internlm/core/parallel/comm/utils.py +++ b/internlm/core/parallel/comm/utils.py @@ -117,6 +117,17 @@ def _gather(input_, parallel_mode, dim=-1): return output +def _reduce(input_, parallel_mode): + # skip if only one rank involved + if gpc.get_world_size(parallel_mode) == 1: + return input_ + + group = gpc.get_cpu_group(parallel_mode) if input_.device.type == "cpu" else gpc.get_group(parallel_mode) + dist.all_reduce(input_, group=group) + + return input_ + + class _GatherForwardSplitBackward(torch.autograd.Function): """Gather the input from model parallel region and concatenate. @@ -174,6 +185,32 @@ def split_forward_gather_backward(input_, parallel_mode, dim): return _SplitForwardGatherBackward.apply(input_, parallel_mode, dim) +class _ReduceForward(torch.autograd.Function): + """ + All-reduce the input from the model parallel region. + + Args: + input_: input matrix. + parallel_mode: parallel mode. + """ + + @staticmethod + def symbolic(input_): + return _reduce(input_, parallel_mode=None) + + @staticmethod + def forward(ctx, input_, parallel_mode): + return _reduce(input_, parallel_mode) + + @staticmethod + def backward(ctx, grad_output): + return grad_output, None + + +def reduce_forward(input_, parallel_mode): + return _ReduceForward.apply(input_, parallel_mode) + + def all_gather_raw( input_: Tensor, process_group: ProcessGroup, diff --git a/internlm/model/modules/embedding.py b/internlm/model/modules/embedding.py index d65096e2..363f3e2b 100644 --- a/internlm/model/modules/embedding.py +++ b/internlm/model/modules/embedding.py @@ -8,6 +8,7 @@ from einops import rearrange from torch import Tensor, nn +from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc from internlm.model.ops.rotary_emb import apply_rotary_emb from internlm.utils.parallel import is_using_isp @@ -33,6 +34,7 @@ def __init__( *args, padding_idx: int = None, dtype: torch.dtype = None, + vocab_parallel: bool = False, **kwargs, ): super().__init__() @@ -42,14 +44,54 @@ def __init__( self.padding_idx = padding_idx self.embed_args = args self.embed_kwargs = kwargs + self.vocab_parallel = vocab_parallel + + if is_using_isp(): + # isp: split vocab_size to support the sharing of parameters between embedding and head. + assert ( + num_embeddings % gpc.weight_parallel_size == 0 + ), f"{num_embeddings} is not divisible by {gpc.weight_parallel_size}" + self.num_embeddings_per_partition = num_embeddings // gpc.weight_parallel_size + self.embed_dim_per_partition = embedding_dim + elif vocab_parallel: + + assert ( + num_embeddings % gpc.tensor_parallel_size == 0 + ), f"{num_embeddings} is not divisible by {gpc.tensor_parallel_size}" + + self.num_embeddings_per_partition = num_embeddings // gpc.tensor_parallel_size + self.embed_dim_per_partition = embedding_dim + self.vocab_start_index = gpc.get_local_rank(ParallelMode.TENSOR) * self.num_embeddings_per_partition + self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition + else: + # mtp/msp/fsp: do not support the sharing of parameters between embedding and head, + # use VocabParallelEmbedding1D instead. + assert ( + embedding_dim % gpc.tensor_parallel_size == 0 + ), f"{embedding_dim} is not divisible by {gpc.tensor_parallel_size}" + self.num_embeddings_per_partition = num_embeddings + self.embed_dim_per_partition = embedding_dim // gpc.tensor_parallel_size + + self.weight = nn.Parameter( + torch.empty((self.num_embeddings_per_partition, self.embed_dim_per_partition), dtype=dtype) + ) + + def forward(self, input_: Tensor) -> Tensor: + if self.vocab_parallel: + # Build the mask. + input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index) + # Mask the input. + masked_input = input_.clone() - self.vocab_start_index + masked_input[input_mask] = 0 + else: + masked_input = input_ - _parallel_size = gpc.weight_parallel_size if is_using_isp() else gpc.tensor_parallel_size + output = F.embedding(masked_input, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) - embed_dim_per_partition = embedding_dim // _parallel_size - self.weight = nn.Parameter(torch.empty((num_embeddings, embed_dim_per_partition), dtype=dtype)) + if self.vocab_parallel: + output[input_mask, :] = 0.0 - def forward(self, input_: Tensor) -> Tensor: - return F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) + return output class RotaryEmbedding(torch.nn.Module): From 93767df39e10a9fc00405f8b9c375e724da2e233 Mon Sep 17 00:00:00 2001 From: "chenxun.p" <759046501@qq.com> Date: Thu, 5 Sep 2024 14:06:51 +0800 Subject: [PATCH 2/3] fix lint --- internlm/core/parallel/comm/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/internlm/core/parallel/comm/utils.py b/internlm/core/parallel/comm/utils.py index aead5fbe..a7f93c3b 100644 --- a/internlm/core/parallel/comm/utils.py +++ b/internlm/core/parallel/comm/utils.py @@ -199,11 +199,11 @@ def symbolic(input_): return _reduce(input_, parallel_mode=None) @staticmethod - def forward(ctx, input_, parallel_mode): + def forward(ctx, input_, parallel_mode): # pylint: disable=W0613 return _reduce(input_, parallel_mode) @staticmethod - def backward(ctx, grad_output): + def backward(ctx, grad_output): # pylint: disable=W0613 return grad_output, None From 35c53cd787cafc7e1133ec75fadf9621e84f1ccd Mon Sep 17 00:00:00 2001 From: "chenxun.p" <759046501@qq.com> Date: Sat, 7 Sep 2024 12:22:13 +0800 Subject: [PATCH 3/3] restore isp embedding default split dim --- internlm/core/parallel/comm/isp.py | 9 +++++---- internlm/model/modules/embedding.py | 31 ++++++++++------------------- 2 files changed, 16 insertions(+), 24 deletions(-) diff --git a/internlm/core/parallel/comm/isp.py b/internlm/core/parallel/comm/isp.py index 54e158a5..93078fa0 100644 --- a/internlm/core/parallel/comm/isp.py +++ b/internlm/core/parallel/comm/isp.py @@ -145,7 +145,7 @@ class EmbeddingWeightParallelCommunicator: def __init__(self, parallel_mode: ParallelMode) -> None: self.parallel_mode = parallel_mode - self.vocab_dim = 0 + self.gather_dim = 0 self._cur_micro_step = 0 self._num_micro_step = gpc.config.data.micro_num @@ -154,6 +154,7 @@ def register_module_hook(self, module: Embedding1D) -> None: assert isinstance(module, Embedding1D), "Embbeding weight parallel communicator is only support Embedding1D" module.weight.evo_tensor = None + self.gather_dim = 0 if module.vocab_parallel else 1 class PreModuleWrapper(torch.autograd.Function): """ @@ -165,7 +166,7 @@ def forward(ctx, inputs: torch.Tensor): # pylint: disable=W0613 if module.weight.evo_tensor is None: module.weight.evo_tensor = module.weight.data - module.weight.data = _gather(module.weight, self.parallel_mode, dim=self.vocab_dim) + module.weight.data = _gather(module.weight, self.parallel_mode, dim=self.gather_dim) inputs = inputs.detach() return inputs @@ -188,7 +189,7 @@ def forward(ctx, output: torch.Tensor): # pylint: disable=W0613 @staticmethod def backward(ctx: Any, grad_output: torch.Tensor) -> torch.Tensor: # pylint: disable=W0613 - module.weight.data = _gather(module.weight, self.parallel_mode, dim=self.vocab_dim) + module.weight.data = _gather(module.weight, self.parallel_mode, dim=self.gather_dim) return grad_output def _pre_forward_hook(module, inputs): # pylint: disable=W0613 @@ -205,7 +206,7 @@ def _post_forward_hook(module, inputs, output): # pylint: disable=W0613 def grad_reduce_hook(self, param: torch.Tensor): _grad, _ = reduce_scatter_raw( - param.grad, gpc.get_group(self.parallel_mode), op=dist.ReduceOp.AVG, reduce_dim=self.vocab_dim + param.grad, gpc.get_group(self.parallel_mode), op=dist.ReduceOp.AVG, reduce_dim=self.gather_dim ) if param.evo_tensor.grad is None: param.evo_tensor.grad = _grad diff --git a/internlm/model/modules/embedding.py b/internlm/model/modules/embedding.py index 363f3e2b..4a24172f 100644 --- a/internlm/model/modules/embedding.py +++ b/internlm/model/modules/embedding.py @@ -46,38 +46,29 @@ def __init__( self.embed_kwargs = kwargs self.vocab_parallel = vocab_parallel - if is_using_isp(): - # isp: split vocab_size to support the sharing of parameters between embedding and head. - assert ( - num_embeddings % gpc.weight_parallel_size == 0 - ), f"{num_embeddings} is not divisible by {gpc.weight_parallel_size}" - self.num_embeddings_per_partition = num_embeddings // gpc.weight_parallel_size - self.embed_dim_per_partition = embedding_dim - elif vocab_parallel: + parallel_size = gpc.weight_parallel_size if is_using_isp() else gpc.tensor_parallel_size - assert ( - num_embeddings % gpc.tensor_parallel_size == 0 - ), f"{num_embeddings} is not divisible by {gpc.tensor_parallel_size}" + if vocab_parallel: + assert num_embeddings % parallel_size == 0, f"{num_embeddings} is not divisible by {parallel_size}" - self.num_embeddings_per_partition = num_embeddings // gpc.tensor_parallel_size + self.num_embeddings_per_partition = num_embeddings // parallel_size self.embed_dim_per_partition = embedding_dim self.vocab_start_index = gpc.get_local_rank(ParallelMode.TENSOR) * self.num_embeddings_per_partition self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition else: - # mtp/msp/fsp: do not support the sharing of parameters between embedding and head, - # use VocabParallelEmbedding1D instead. - assert ( - embedding_dim % gpc.tensor_parallel_size == 0 - ), f"{embedding_dim} is not divisible by {gpc.tensor_parallel_size}" + assert embedding_dim % parallel_size == 0, f"{embedding_dim} is not divisible by {parallel_size}" + self.num_embeddings_per_partition = num_embeddings - self.embed_dim_per_partition = embedding_dim // gpc.tensor_parallel_size + self.embed_dim_per_partition = embedding_dim // parallel_size + self.vocab_start_index = 0 + self.vocab_end_index = self.num_embeddings_per_partition self.weight = nn.Parameter( torch.empty((self.num_embeddings_per_partition, self.embed_dim_per_partition), dtype=dtype) ) def forward(self, input_: Tensor) -> Tensor: - if self.vocab_parallel: + if self.vocab_parallel and not is_using_isp(): # Build the mask. input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index) # Mask the input. @@ -88,7 +79,7 @@ def forward(self, input_: Tensor) -> Tensor: output = F.embedding(masked_input, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) - if self.vocab_parallel: + if self.vocab_parallel and not is_using_isp(): output[input_mask, :] = 0.0 return output