Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add vacab parallel embedding #315

Merged
merged 3 commits into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions internlm/core/parallel/comm/isp.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ class EmbeddingWeightParallelCommunicator:

def __init__(self, parallel_mode: ParallelMode) -> None:
self.parallel_mode = parallel_mode
self.emb_column = 1
self.vocab_dim = 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里麻烦的一点是要是ISP模型保持原来默认切分维度,又有了vocab切的可选项,要怎么办

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可能是register_module_hook的时候,判断一下

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok


self._cur_micro_step = 0
self._num_micro_step = gpc.config.data.micro_num
Expand All @@ -165,7 +165,7 @@ def forward(ctx, inputs: torch.Tensor): # pylint: disable=W0613
if module.weight.evo_tensor is None:
module.weight.evo_tensor = module.weight.data

module.weight.data = _gather(module.weight, self.parallel_mode, dim=self.emb_column)
module.weight.data = _gather(module.weight, self.parallel_mode, dim=self.vocab_dim)
inputs = inputs.detach()
return inputs

Expand All @@ -188,7 +188,7 @@ def forward(ctx, output: torch.Tensor): # pylint: disable=W0613

@staticmethod
def backward(ctx: Any, grad_output: torch.Tensor) -> torch.Tensor: # pylint: disable=W0613
module.weight.data = _gather(module.weight, self.parallel_mode, dim=self.emb_column)
module.weight.data = _gather(module.weight, self.parallel_mode, dim=self.vocab_dim)
return grad_output

def _pre_forward_hook(module, inputs): # pylint: disable=W0613
Expand All @@ -205,7 +205,7 @@ def _post_forward_hook(module, inputs, output): # pylint: disable=W0613
def grad_reduce_hook(self, param: torch.Tensor):

_grad, _ = reduce_scatter_raw(
param.grad, gpc.get_group(self.parallel_mode), op=dist.ReduceOp.AVG, reduce_dim=self.emb_column
param.grad, gpc.get_group(self.parallel_mode), op=dist.ReduceOp.AVG, reduce_dim=self.vocab_dim
)
if param.evo_tensor.grad is None:
param.evo_tensor.grad = _grad
Expand Down
15 changes: 13 additions & 2 deletions internlm/core/parallel/comm/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
all_gather_raw,
all_reduce_raw,
gather_forward_split_backward,
reduce_forward,
reduce_scatter_raw,
split_forward_gather_backward,
)
Expand Down Expand Up @@ -341,7 +342,12 @@ def output_hook(self, module: Embedding1D, args: Any, output: Tuple[Any]) -> Tup
"""
_emb_dim = 2 # [bsz, seqlen, emb_dim]

return gather_forward_split_backward(output, self._parallel_mode, dim=_emb_dim)
if module.vocab_parallel:
output = reduce_forward(output, self._parallel_mode)
else:
output = gather_forward_split_backward(output, self._parallel_mode, dim=_emb_dim)

return output


class EmbeddingSequenceParallelCommunicator:
Expand All @@ -363,7 +369,12 @@ def output_hook(self, module: Embedding1D, args: Any, output: Tuple[Any]) -> Tup
"""
_emb_dim, _seq_dim = 2, 1 # [bsz, seqlen, emb_dim]

output = gather_forward_split_backward(output, self._parallel_mode, dim=_emb_dim)
# tp:
if module.vocab_parallel:
output = reduce_forward(output, self._parallel_mode)
else:
output = gather_forward_split_backward(output, self._parallel_mode, dim=_emb_dim)
# sp:
output = split_forward_gather_backward(output, self._parallel_mode, dim=_seq_dim)

return output
37 changes: 37 additions & 0 deletions internlm/core/parallel/comm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,17 @@ def _gather(input_, parallel_mode, dim=-1):
return output


def _reduce(input_, parallel_mode):
# skip if only one rank involved
if gpc.get_world_size(parallel_mode) == 1:
return input_

group = gpc.get_cpu_group(parallel_mode) if input_.device.type == "cpu" else gpc.get_group(parallel_mode)
dist.all_reduce(input_, group=group)

return input_


class _GatherForwardSplitBackward(torch.autograd.Function):
"""Gather the input from model parallel region and concatenate.

Expand Down Expand Up @@ -174,6 +185,32 @@ def split_forward_gather_backward(input_, parallel_mode, dim):
return _SplitForwardGatherBackward.apply(input_, parallel_mode, dim)


class _ReduceForward(torch.autograd.Function):
"""
All-reduce the input from the model parallel region.

Args:
input_: input matrix.
parallel_mode: parallel mode.
"""

@staticmethod
def symbolic(input_):
return _reduce(input_, parallel_mode=None)

@staticmethod
def forward(ctx, input_, parallel_mode): # pylint: disable=W0613
return _reduce(input_, parallel_mode)

@staticmethod
def backward(ctx, grad_output): # pylint: disable=W0613
return grad_output, None


def reduce_forward(input_, parallel_mode):
return _ReduceForward.apply(input_, parallel_mode)


def all_gather_raw(
input_: Tensor,
process_group: ProcessGroup,
Expand Down
52 changes: 47 additions & 5 deletions internlm/model/modules/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from einops import rearrange
from torch import Tensor, nn

from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.model.ops.rotary_emb import apply_rotary_emb
from internlm.utils.parallel import is_using_isp
Expand All @@ -33,6 +34,7 @@ def __init__(
*args,
padding_idx: int = None,
dtype: torch.dtype = None,
vocab_parallel: bool = False,
**kwargs,
):
super().__init__()
Expand All @@ -42,14 +44,54 @@ def __init__(
self.padding_idx = padding_idx
self.embed_args = args
self.embed_kwargs = kwargs
self.vocab_parallel = vocab_parallel

if is_using_isp():
# isp: split vocab_size to support the sharing of parameters between embedding and head.
assert (
num_embeddings % gpc.weight_parallel_size == 0
), f"{num_embeddings} is not divisible by {gpc.weight_parallel_size}"
self.num_embeddings_per_partition = num_embeddings // gpc.weight_parallel_size
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里默认了ISP模式下就采用了vocab_parallel。是不是可以用vocab_parallel作为一个统一的控制,它只和是否用embedding和head共享权重相关,如果用户在Modeling文件里,需要共享权重,则手动设定vocab_parallel为true即可。其他情况下,默认走之前的切分emb的逻辑。避免之前的代码出现BC,特别是有一些llama模型加载HF权重的设计,都是走的切分emb维度

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

比如CI的那个错误

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

self.embed_dim_per_partition = embedding_dim
elif vocab_parallel:

assert (
num_embeddings % gpc.tensor_parallel_size == 0
), f"{num_embeddings} is not divisible by {gpc.tensor_parallel_size}"

self.num_embeddings_per_partition = num_embeddings // gpc.tensor_parallel_size
self.embed_dim_per_partition = embedding_dim
self.vocab_start_index = gpc.get_local_rank(ParallelMode.TENSOR) * self.num_embeddings_per_partition
self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition
else:
# mtp/msp/fsp: do not support the sharing of parameters between embedding and head,
# use VocabParallelEmbedding1D instead.
assert (
embedding_dim % gpc.tensor_parallel_size == 0
), f"{embedding_dim} is not divisible by {gpc.tensor_parallel_size}"
self.num_embeddings_per_partition = num_embeddings
self.embed_dim_per_partition = embedding_dim // gpc.tensor_parallel_size

self.weight = nn.Parameter(
torch.empty((self.num_embeddings_per_partition, self.embed_dim_per_partition), dtype=dtype)
)

def forward(self, input_: Tensor) -> Tensor:
if self.vocab_parallel:
# Build the mask.
input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

前面is_using_isp切分的是self.num_embeddings_per_partition = num_embeddings // gpc.weight_parallel_size,但是并没有构建vocab_start_index等

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isp是聚合参数,所以不能走vocab_parallel的代码,需要走原来的逻辑

# Mask the input.
masked_input = input_.clone() - self.vocab_start_index
masked_input[input_mask] = 0
else:
masked_input = input_

_parallel_size = gpc.weight_parallel_size if is_using_isp() else gpc.tensor_parallel_size
output = F.embedding(masked_input, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)

embed_dim_per_partition = embedding_dim // _parallel_size
self.weight = nn.Parameter(torch.empty((num_embeddings, embed_dim_per_partition), dtype=dtype))
if self.vocab_parallel:
output[input_mask, :] = 0.0

def forward(self, input_: Tensor) -> Tensor:
return F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
return output


class RotaryEmbedding(torch.nn.Module):
Expand Down
Loading