Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add vacab parallel embedding #315

Merged
merged 3 commits into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions internlm/core/parallel/comm/isp.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ class EmbeddingWeightParallelCommunicator:

def __init__(self, parallel_mode: ParallelMode) -> None:
self.parallel_mode = parallel_mode
self.emb_column = 1
self.gather_dim = 0

self._cur_micro_step = 0
self._num_micro_step = gpc.config.data.micro_num
Expand All @@ -154,6 +154,7 @@ def register_module_hook(self, module: Embedding1D) -> None:
assert isinstance(module, Embedding1D), "Embbeding weight parallel communicator is only support Embedding1D"

module.weight.evo_tensor = None
self.gather_dim = 0 if module.vocab_parallel else 1

class PreModuleWrapper(torch.autograd.Function):
"""
Expand All @@ -165,7 +166,7 @@ def forward(ctx, inputs: torch.Tensor): # pylint: disable=W0613
if module.weight.evo_tensor is None:
module.weight.evo_tensor = module.weight.data

module.weight.data = _gather(module.weight, self.parallel_mode, dim=self.emb_column)
module.weight.data = _gather(module.weight, self.parallel_mode, dim=self.gather_dim)
inputs = inputs.detach()
return inputs

Expand All @@ -188,7 +189,7 @@ def forward(ctx, output: torch.Tensor): # pylint: disable=W0613

@staticmethod
def backward(ctx: Any, grad_output: torch.Tensor) -> torch.Tensor: # pylint: disable=W0613
module.weight.data = _gather(module.weight, self.parallel_mode, dim=self.emb_column)
module.weight.data = _gather(module.weight, self.parallel_mode, dim=self.gather_dim)
return grad_output

def _pre_forward_hook(module, inputs): # pylint: disable=W0613
Expand All @@ -205,7 +206,7 @@ def _post_forward_hook(module, inputs, output): # pylint: disable=W0613
def grad_reduce_hook(self, param: torch.Tensor):

_grad, _ = reduce_scatter_raw(
param.grad, gpc.get_group(self.parallel_mode), op=dist.ReduceOp.AVG, reduce_dim=self.emb_column
param.grad, gpc.get_group(self.parallel_mode), op=dist.ReduceOp.AVG, reduce_dim=self.gather_dim
)
if param.evo_tensor.grad is None:
param.evo_tensor.grad = _grad
Expand Down
15 changes: 13 additions & 2 deletions internlm/core/parallel/comm/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
all_gather_raw,
all_reduce_raw,
gather_forward_split_backward,
reduce_forward,
reduce_scatter_raw,
split_forward_gather_backward,
)
Expand Down Expand Up @@ -341,7 +342,12 @@ def output_hook(self, module: Embedding1D, args: Any, output: Tuple[Any]) -> Tup
"""
_emb_dim = 2 # [bsz, seqlen, emb_dim]

return gather_forward_split_backward(output, self._parallel_mode, dim=_emb_dim)
if module.vocab_parallel:
output = reduce_forward(output, self._parallel_mode)
else:
output = gather_forward_split_backward(output, self._parallel_mode, dim=_emb_dim)

return output


class EmbeddingSequenceParallelCommunicator:
Expand All @@ -363,7 +369,12 @@ def output_hook(self, module: Embedding1D, args: Any, output: Tuple[Any]) -> Tup
"""
_emb_dim, _seq_dim = 2, 1 # [bsz, seqlen, emb_dim]

output = gather_forward_split_backward(output, self._parallel_mode, dim=_emb_dim)
# tp:
if module.vocab_parallel:
output = reduce_forward(output, self._parallel_mode)
else:
output = gather_forward_split_backward(output, self._parallel_mode, dim=_emb_dim)
# sp:
output = split_forward_gather_backward(output, self._parallel_mode, dim=_seq_dim)

return output
37 changes: 37 additions & 0 deletions internlm/core/parallel/comm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,17 @@ def _gather(input_, parallel_mode, dim=-1):
return output


def _reduce(input_, parallel_mode):
# skip if only one rank involved
if gpc.get_world_size(parallel_mode) == 1:
return input_

group = gpc.get_cpu_group(parallel_mode) if input_.device.type == "cpu" else gpc.get_group(parallel_mode)
dist.all_reduce(input_, group=group)

return input_


class _GatherForwardSplitBackward(torch.autograd.Function):
"""Gather the input from model parallel region and concatenate.

Expand Down Expand Up @@ -174,6 +185,32 @@ def split_forward_gather_backward(input_, parallel_mode, dim):
return _SplitForwardGatherBackward.apply(input_, parallel_mode, dim)


class _ReduceForward(torch.autograd.Function):
"""
All-reduce the input from the model parallel region.

Args:
input_: input matrix.
parallel_mode: parallel mode.
"""

@staticmethod
def symbolic(input_):
return _reduce(input_, parallel_mode=None)

@staticmethod
def forward(ctx, input_, parallel_mode): # pylint: disable=W0613
return _reduce(input_, parallel_mode)

@staticmethod
def backward(ctx, grad_output): # pylint: disable=W0613
return grad_output, None


def reduce_forward(input_, parallel_mode):
return _ReduceForward.apply(input_, parallel_mode)


def all_gather_raw(
input_: Tensor,
process_group: ProcessGroup,
Expand Down
41 changes: 37 additions & 4 deletions internlm/model/modules/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from einops import rearrange
from torch import Tensor, nn

from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.model.ops.rotary_emb import apply_rotary_emb
from internlm.utils.parallel import is_using_isp
Expand All @@ -33,6 +34,7 @@ def __init__(
*args,
padding_idx: int = None,
dtype: torch.dtype = None,
vocab_parallel: bool = False,
**kwargs,
):
super().__init__()
Expand All @@ -42,14 +44,45 @@ def __init__(
self.padding_idx = padding_idx
self.embed_args = args
self.embed_kwargs = kwargs
self.vocab_parallel = vocab_parallel

_parallel_size = gpc.weight_parallel_size if is_using_isp() else gpc.tensor_parallel_size
parallel_size = gpc.weight_parallel_size if is_using_isp() else gpc.tensor_parallel_size

embed_dim_per_partition = embedding_dim // _parallel_size
self.weight = nn.Parameter(torch.empty((num_embeddings, embed_dim_per_partition), dtype=dtype))
if vocab_parallel:
assert num_embeddings % parallel_size == 0, f"{num_embeddings} is not divisible by {parallel_size}"

self.num_embeddings_per_partition = num_embeddings // parallel_size
self.embed_dim_per_partition = embedding_dim
self.vocab_start_index = gpc.get_local_rank(ParallelMode.TENSOR) * self.num_embeddings_per_partition
self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition
else:
assert embedding_dim % parallel_size == 0, f"{embedding_dim} is not divisible by {parallel_size}"

self.num_embeddings_per_partition = num_embeddings
self.embed_dim_per_partition = embedding_dim // parallel_size
self.vocab_start_index = 0
self.vocab_end_index = self.num_embeddings_per_partition

self.weight = nn.Parameter(
torch.empty((self.num_embeddings_per_partition, self.embed_dim_per_partition), dtype=dtype)
)

def forward(self, input_: Tensor) -> Tensor:
return F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
if self.vocab_parallel and not is_using_isp():
# Build the mask.
input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

前面is_using_isp切分的是self.num_embeddings_per_partition = num_embeddings // gpc.weight_parallel_size,但是并没有构建vocab_start_index等

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isp是聚合参数,所以不能走vocab_parallel的代码,需要走原来的逻辑

# Mask the input.
masked_input = input_.clone() - self.vocab_start_index
masked_input[input_mask] = 0
else:
masked_input = input_

output = F.embedding(masked_input, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)

if self.vocab_parallel and not is_using_isp():
output[input_mask, :] = 0.0

return output


class RotaryEmbedding(torch.nn.Module):
Expand Down
Loading