-
Notifications
You must be signed in to change notification settings - Fork 47
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add vacab parallel embedding #315
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,6 +8,7 @@ | |
from einops import rearrange | ||
from torch import Tensor, nn | ||
|
||
from internlm.core.context import ParallelMode | ||
from internlm.core.context import global_context as gpc | ||
from internlm.model.ops.rotary_emb import apply_rotary_emb | ||
from internlm.utils.parallel import is_using_isp | ||
|
@@ -33,6 +34,7 @@ def __init__( | |
*args, | ||
padding_idx: int = None, | ||
dtype: torch.dtype = None, | ||
vocab_parallel: bool = False, | ||
**kwargs, | ||
): | ||
super().__init__() | ||
|
@@ -42,14 +44,54 @@ def __init__( | |
self.padding_idx = padding_idx | ||
self.embed_args = args | ||
self.embed_kwargs = kwargs | ||
self.vocab_parallel = vocab_parallel | ||
|
||
if is_using_isp(): | ||
# isp: split vocab_size to support the sharing of parameters between embedding and head. | ||
assert ( | ||
num_embeddings % gpc.weight_parallel_size == 0 | ||
), f"{num_embeddings} is not divisible by {gpc.weight_parallel_size}" | ||
self.num_embeddings_per_partition = num_embeddings // gpc.weight_parallel_size | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里默认了ISP模式下就采用了vocab_parallel。是不是可以用vocab_parallel作为一个统一的控制,它只和是否用embedding和head共享权重相关,如果用户在Modeling文件里,需要共享权重,则手动设定vocab_parallel为true即可。其他情况下,默认走之前的切分emb的逻辑。避免之前的代码出现BC,特别是有一些llama模型加载HF权重的设计,都是走的切分emb维度 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 比如CI的那个错误 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok |
||
self.embed_dim_per_partition = embedding_dim | ||
elif vocab_parallel: | ||
|
||
assert ( | ||
num_embeddings % gpc.tensor_parallel_size == 0 | ||
), f"{num_embeddings} is not divisible by {gpc.tensor_parallel_size}" | ||
|
||
self.num_embeddings_per_partition = num_embeddings // gpc.tensor_parallel_size | ||
self.embed_dim_per_partition = embedding_dim | ||
self.vocab_start_index = gpc.get_local_rank(ParallelMode.TENSOR) * self.num_embeddings_per_partition | ||
self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition | ||
else: | ||
# mtp/msp/fsp: do not support the sharing of parameters between embedding and head, | ||
# use VocabParallelEmbedding1D instead. | ||
assert ( | ||
embedding_dim % gpc.tensor_parallel_size == 0 | ||
), f"{embedding_dim} is not divisible by {gpc.tensor_parallel_size}" | ||
self.num_embeddings_per_partition = num_embeddings | ||
self.embed_dim_per_partition = embedding_dim // gpc.tensor_parallel_size | ||
|
||
self.weight = nn.Parameter( | ||
torch.empty((self.num_embeddings_per_partition, self.embed_dim_per_partition), dtype=dtype) | ||
) | ||
|
||
def forward(self, input_: Tensor) -> Tensor: | ||
if self.vocab_parallel: | ||
# Build the mask. | ||
input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 前面is_using_isp切分的是self.num_embeddings_per_partition = num_embeddings // gpc.weight_parallel_size,但是并没有构建vocab_start_index等 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. isp是聚合参数,所以不能走vocab_parallel的代码,需要走原来的逻辑 |
||
# Mask the input. | ||
masked_input = input_.clone() - self.vocab_start_index | ||
masked_input[input_mask] = 0 | ||
else: | ||
masked_input = input_ | ||
|
||
_parallel_size = gpc.weight_parallel_size if is_using_isp() else gpc.tensor_parallel_size | ||
output = F.embedding(masked_input, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) | ||
|
||
embed_dim_per_partition = embedding_dim // _parallel_size | ||
self.weight = nn.Parameter(torch.empty((num_embeddings, embed_dim_per_partition), dtype=dtype)) | ||
if self.vocab_parallel: | ||
output[input_mask, :] = 0.0 | ||
|
||
def forward(self, input_: Tensor) -> Tensor: | ||
return F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) | ||
return output | ||
|
||
|
||
class RotaryEmbedding(torch.nn.Module): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里麻烦的一点是要是ISP模型保持原来默认切分维度,又有了vocab切的可选项,要怎么办
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可能是register_module_hook的时候,判断一下
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok