Skip to content

Commit

Permalink
add vacab parallel embedding (#315)
Browse files Browse the repository at this point in the history
  • Loading branch information
mwiacx authored Sep 10, 2024
1 parent 4452ad6 commit 95dcc04
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 10 deletions.
9 changes: 5 additions & 4 deletions internlm/core/parallel/comm/isp.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ class EmbeddingWeightParallelCommunicator:

def __init__(self, parallel_mode: ParallelMode) -> None:
self.parallel_mode = parallel_mode
self.emb_column = 1
self.gather_dim = 0

self._cur_micro_step = 0
self._num_micro_step = gpc.config.data.micro_num
Expand All @@ -186,6 +186,7 @@ def register_module_hook(self, module: Embedding1D) -> None:
assert isinstance(module, Embedding1D), "Embbeding weight parallel communicator is only support Embedding1D"

module.weight.evo_tensor = None
self.gather_dim = 0 if module.vocab_parallel else 1

class PreModuleWrapper(torch.autograd.Function):
"""
Expand All @@ -197,7 +198,7 @@ def forward(ctx, inputs: torch.Tensor): # pylint: disable=W0613
if module.weight.evo_tensor is None:
module.weight.evo_tensor = module.weight.data

module.weight.data = _gather(module.weight, self.parallel_mode, dim=self.emb_column)
module.weight.data = _gather(module.weight, self.parallel_mode, dim=self.gather_dim)
inputs = inputs.detach()
return inputs

Expand All @@ -220,7 +221,7 @@ def forward(ctx, output: torch.Tensor): # pylint: disable=W0613

@staticmethod
def backward(ctx: Any, grad_output: torch.Tensor) -> torch.Tensor: # pylint: disable=W0613
module.weight.data = _gather(module.weight, self.parallel_mode, dim=self.emb_column)
module.weight.data = _gather(module.weight, self.parallel_mode, dim=self.gather_dim)
return grad_output

def _pre_forward_hook(module, inputs): # pylint: disable=W0613
Expand All @@ -237,7 +238,7 @@ def _post_forward_hook(module, inputs, output): # pylint: disable=W0613
def grad_reduce_hook(self, param: torch.Tensor):

_grad, _ = reduce_scatter_raw(
param.grad, gpc.get_group(self.parallel_mode), op=dist.ReduceOp.AVG, reduce_dim=self.emb_column
param.grad, gpc.get_group(self.parallel_mode), op=dist.ReduceOp.AVG, reduce_dim=self.gather_dim
)
if param.evo_tensor.grad is None:
param.evo_tensor.grad = _grad
Expand Down
15 changes: 13 additions & 2 deletions internlm/core/parallel/comm/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
all_gather_raw,
all_reduce_raw,
gather_forward_split_backward,
reduce_forward,
reduce_scatter_raw,
split_forward_gather_backward,
)
Expand Down Expand Up @@ -341,7 +342,12 @@ def output_hook(self, module: Embedding1D, args: Any, output: Tuple[Any]) -> Tup
"""
_emb_dim = 2 # [bsz, seqlen, emb_dim]

return gather_forward_split_backward(output, self._parallel_mode, dim=_emb_dim)
if module.vocab_parallel:
output = reduce_forward(output, self._parallel_mode)
else:
output = gather_forward_split_backward(output, self._parallel_mode, dim=_emb_dim)

return output


class EmbeddingSequenceParallelCommunicator:
Expand All @@ -363,7 +369,12 @@ def output_hook(self, module: Embedding1D, args: Any, output: Tuple[Any]) -> Tup
"""
_emb_dim, _seq_dim = 2, 1 # [bsz, seqlen, emb_dim]

output = gather_forward_split_backward(output, self._parallel_mode, dim=_emb_dim)
# tp:
if module.vocab_parallel:
output = reduce_forward(output, self._parallel_mode)
else:
output = gather_forward_split_backward(output, self._parallel_mode, dim=_emb_dim)
# sp:
output = split_forward_gather_backward(output, self._parallel_mode, dim=_seq_dim)

return output
37 changes: 37 additions & 0 deletions internlm/core/parallel/comm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,17 @@ def _gather(input_, parallel_mode, dim=-1):
return output


def _reduce(input_, parallel_mode):
# skip if only one rank involved
if gpc.get_world_size(parallel_mode) == 1:
return input_

group = gpc.get_cpu_group(parallel_mode) if input_.device.type == "cpu" else gpc.get_group(parallel_mode)
dist.all_reduce(input_, group=group)

return input_


class _GatherForwardSplitBackward(torch.autograd.Function):
"""Gather the input from model parallel region and concatenate.
Expand Down Expand Up @@ -174,6 +185,32 @@ def split_forward_gather_backward(input_, parallel_mode, dim):
return _SplitForwardGatherBackward.apply(input_, parallel_mode, dim)


class _ReduceForward(torch.autograd.Function):
"""
All-reduce the input from the model parallel region.
Args:
input_: input matrix.
parallel_mode: parallel mode.
"""

@staticmethod
def symbolic(input_):
return _reduce(input_, parallel_mode=None)

@staticmethod
def forward(ctx, input_, parallel_mode): # pylint: disable=W0613
return _reduce(input_, parallel_mode)

@staticmethod
def backward(ctx, grad_output): # pylint: disable=W0613
return grad_output, None


def reduce_forward(input_, parallel_mode):
return _ReduceForward.apply(input_, parallel_mode)


def all_gather_raw(
input_: Tensor,
process_group: ProcessGroup,
Expand Down
41 changes: 37 additions & 4 deletions internlm/model/modules/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from einops import rearrange
from torch import Tensor, nn

from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.model.ops.rotary_emb import apply_rotary_emb
from internlm.utils.parallel import is_using_isp
Expand All @@ -33,6 +34,7 @@ def __init__(
*args,
padding_idx: int = None,
dtype: torch.dtype = None,
vocab_parallel: bool = False,
**kwargs,
):
super().__init__()
Expand All @@ -42,14 +44,45 @@ def __init__(
self.padding_idx = padding_idx
self.embed_args = args
self.embed_kwargs = kwargs
self.vocab_parallel = vocab_parallel

_parallel_size = gpc.weight_parallel_size if is_using_isp() else gpc.tensor_parallel_size
parallel_size = gpc.weight_parallel_size if is_using_isp() else gpc.tensor_parallel_size

embed_dim_per_partition = embedding_dim // _parallel_size
self.weight = nn.Parameter(torch.empty((num_embeddings, embed_dim_per_partition), dtype=dtype))
if vocab_parallel:
assert num_embeddings % parallel_size == 0, f"{num_embeddings} is not divisible by {parallel_size}"

self.num_embeddings_per_partition = num_embeddings // parallel_size
self.embed_dim_per_partition = embedding_dim
self.vocab_start_index = gpc.get_local_rank(ParallelMode.TENSOR) * self.num_embeddings_per_partition
self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition
else:
assert embedding_dim % parallel_size == 0, f"{embedding_dim} is not divisible by {parallel_size}"

self.num_embeddings_per_partition = num_embeddings
self.embed_dim_per_partition = embedding_dim // parallel_size
self.vocab_start_index = 0
self.vocab_end_index = self.num_embeddings_per_partition

self.weight = nn.Parameter(
torch.empty((self.num_embeddings_per_partition, self.embed_dim_per_partition), dtype=dtype)
)

def forward(self, input_: Tensor) -> Tensor:
return F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
if self.vocab_parallel and not is_using_isp():
# Build the mask.
input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index)
# Mask the input.
masked_input = input_.clone() - self.vocab_start_index
masked_input[input_mask] = 0
else:
masked_input = input_

output = F.embedding(masked_input, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)

if self.vocab_parallel and not is_using_isp():
output[input_mask, :] = 0.0

return output


class RotaryEmbedding(torch.nn.Module):
Expand Down

0 comments on commit 95dcc04

Please sign in to comment.