Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add vacab parallel embedding #315

Merged
merged 3 commits into from
Sep 10, 2024

Conversation

mwiacx
Copy link
Contributor

@mwiacx mwiacx commented Sep 4, 2024

add vacab parallel embedding

@blankde
Copy link
Collaborator

blankde commented Sep 5, 2024

LGTM

assert (
num_embeddings % gpc.weight_parallel_size == 0
), f"{num_embeddings} is not divisible by {gpc.weight_parallel_size}"
self.num_embeddings_per_partition = num_embeddings // gpc.weight_parallel_size
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里默认了ISP模式下就采用了vocab_parallel。是不是可以用vocab_parallel作为一个统一的控制,它只和是否用embedding和head共享权重相关,如果用户在Modeling文件里,需要共享权重,则手动设定vocab_parallel为true即可。其他情况下,默认走之前的切分emb的逻辑。避免之前的代码出现BC,特别是有一些llama模型加载HF权重的设计,都是走的切分emb维度

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

比如CI的那个错误

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

def forward(self, input_: Tensor) -> Tensor:
if self.vocab_parallel:
# Build the mask.
input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

前面is_using_isp切分的是self.num_embeddings_per_partition = num_embeddings // gpc.weight_parallel_size,但是并没有构建vocab_start_index等

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isp是聚合参数,所以不能走vocab_parallel的代码,需要走原来的逻辑

@@ -145,7 +145,7 @@ class EmbeddingWeightParallelCommunicator:

def __init__(self, parallel_mode: ParallelMode) -> None:
self.parallel_mode = parallel_mode
self.emb_column = 1
self.vocab_dim = 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里麻烦的一点是要是ISP模型保持原来默认切分维度,又有了vocab切的可选项,要怎么办

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可能是register_module_hook的时候,判断一下

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

@mwiacx mwiacx force-pushed the feat/add-vocab-parallel-embedding branch from 0b67f09 to 35c53cd Compare September 7, 2024 04:22
@sunpengsdu sunpengsdu merged commit 95dcc04 into InternLM:develop Sep 10, 2024
19 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants