Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support parameter sync with unbalanced pipe stages #106

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions chatlearn/models/base_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -1022,3 +1022,9 @@ def get_data_parallel_rank(self):

def get_data_parallel_size(self):
return self.data_parallel_size

def get_pipeline_stage_layer_num(self):
pass

def get_pipeline_stage_layer_offset(self):
return 0
52 changes: 43 additions & 9 deletions chatlearn/models/megatron_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# ==============================================================================
"""Megatron module"""
import inspect
import re
import torch

try:
from chatlearn.utils.megatron_import_helper import get_args
Expand Down Expand Up @@ -114,6 +116,28 @@ def model_setup(self):
self.runtime_args.bucket_size_mb_in_memory_manager,
)
self.offload()
self.set_pipe_layer_num_offset()

def set_pipe_layer_num_offset(self):
self.stage2layer_num = [None] * self.pipeline_model_parallel_size()
self.stage2offset = [0] * self.pipeline_model_parallel_size()
stage_layer_num = self.get_pipeline_stage_layer_num()
world_size = torch.distributed.get_world_size()
rank_layer_num = torch.tensor([self.pipeline_parallel_rank(), stage_layer_num], device='cuda')
# Gather all tensors to all processes
all_stage_layer_nums = [torch.zeros_like(rank_layer_num, device='cuda') for _ in range(world_size)]
torch.distributed.all_gather(all_stage_layer_nums, rank_layer_num)
for item in all_stage_layer_nums:
rank = item[0].item()
num = item[1].item()
if self.stage2layer_num[rank] is None:
self.stage2layer_num[rank] = num
else:
assert self.stage2layer_num[rank] == num
for i, num in enumerate(self.stage2layer_num):
if i+1 == len(self.stage2offset):
break
self.stage2offset[i+1] = self.stage2offset[i] + num

@property
def megatron_args(self):
Expand Down Expand Up @@ -178,26 +202,21 @@ def megatron_model(self):
model = self.model
return model

def build_pipeline_layer_name_mapping(self, num_target_pipe_stage, target_pipe_rank, requires_grad=True):
def build_pipeline_layer_name_mapping(self, num_target_pipe_stage, target_pipe_rank, tgt_layer_offset, requires_grad=True):
"""
build name mapping from src model to tgt model
Args:
num_target_pipe_stage: number of pipeline stage in target model
target_pipe_rank: target model pipeline rank
tgt_layer_offset: target model pipeline stage layer offset
requires_grad: whether the returned layer requires_grad, as we only need to sync parameters that have changed

:meta private:
"""
src_layers_per_stage = self.num_layers() // self.pipeline_model_parallel_size()
dst_layers_per_stage = self.num_layers() // num_target_pipe_stage
assert dst_layers_per_stage % src_layers_per_stage == 0, \
"We assume pipeline stage of target model is smaller than src model, and is divisible by src model"
mapping_interval = dst_layers_per_stage // src_layers_per_stage
src_rank = mpu.get_pipeline_model_parallel_rank()
self._logger.debug(f"build mapping for rank {src_rank} =========")
src_layer_offset = self.get_pipeline_stage_layer_offset()
model = self.megatron_model()
is_tgt_last_stage = target_pipe_rank == num_target_pipe_stage - 1 and target_pipe_rank != 0
name_mapping = build_pipeline_layer_name_mapping(src_layers_per_stage, src_rank, mapping_interval,
name_mapping = build_pipeline_layer_name_mapping(src_layer_offset, tgt_layer_offset,
is_tgt_last_stage, model, requires_grad)
return name_mapping

Expand Down Expand Up @@ -277,3 +296,18 @@ def build_grad_buffers(self):
"""
if self.module_args.free_grad_buffers:
self._memory_manager.build_grad_buffers()

def get_pipeline_stage_layer_num(self):
if self.stage2layer_num[self.pipeline_parallel_rank()] is not None:
return self.stage2layer_num[self.pipeline_parallel_rank()]
layer_re = re.compile(r'layers\.([0-9]+)')
layer_set = set()
for name in self.named_parameters:
layer_num = re.findall(layer_re, name)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是只匹配layers中的参数?对embedding之类的能否兼容?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

embedding 得到的 layer_num 为空,不会加到layer_set里

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

embedding的后面是怎么索引到?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个 function是计算layer数,embedding 现在没有算一个单独的layer

if layer_num:
layer_set.add(layer_num[0])
stage_layer_num = len(layer_set)
return stage_layer_num

def get_pipeline_stage_layer_offset(self):
return self.stage2offset[self.pipeline_parallel_rank()]
4 changes: 3 additions & 1 deletion chatlearn/runtime/parameter_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,8 +343,10 @@ def _set_sync_param_names(self, send_actor, recv_actor, requires_grad=None):
requires_grad = False
if self.num_src_pipeline_stage > 1:
dst_pipe_rank = self.get_actor_pipe_rank(recv_actor)
dst_layer_offset = self.get_or_cache(recv_actor, "get_pipeline_stage_layer_offset")
dst_src_mappings = future.get(send_actor.build_pipeline_layer_name_mapping.remote(
self.num_dst_pipeline_stage, dst_pipe_rank, requires_grad=requires_grad))
self.num_dst_pipeline_stage, dst_pipe_rank, dst_layer_offset,
requires_grad=requires_grad))
dst_names = list(dst_src_mappings.keys())
src_names = list(dst_src_mappings.values())
else:
Expand Down
22 changes: 12 additions & 10 deletions chatlearn/utils/megatron_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,21 +40,19 @@
layer_re = re.compile(r'layers\.([0-9]+)')


def update_layer_num(layers_per_part, rank, m):
def update_layer_num(start_layer_num, m):
# This assumes no interleaved pipeline execution
layer = int(m.group(1))
layer += rank * layers_per_part
layer += start_layer_num
return f'layers.{layer}'


def build_pipeline_layer_name_mapping(src_layers_per_stage, src_rank, map_interval, tgt_last_stage, model, requires_grad):
def build_pipeline_layer_name_mapping(src_layer_offset, tgt_layer_offset, tgt_last_stage, model, requires_grad):
"""
remap pipeline layer_name. For each pipeline stage, the layer number starts with 0.
Args:
src_layers_per_stage: layer_per_stage in src model
src_rank: src model pipeline rank
map_interval: map interval from tgt to src, i.e. if src_layers_per_stage is 2, and tgt_layers_per_stage is 4,
then the map_iterval is tgt_layers_per_stage/src_layers_per_stage = 2
src_layer_offset: layer offset of src model
tgt_layer_offset: layer offset of target model
tgt_last_stage: is target model in last stage
model: megatron model
requires_grad: whether the layer requires grad
Expand All @@ -66,7 +64,7 @@ def build_pipeline_layer_name_mapping(src_layers_per_stage, src_rank, map_interv
continue
if src_name.endswith("word_embeddings.weight") \
and "language_model" not in src_name \
and hasattr(model, "language_model"):
and hasattr(unwrap_model(model), "language_model"):
# See comment in MegatronModule.initialize_word_embeddings()
if not tgt_last_stage:
tgt_name = src_name.replace("word_embeddings.weight", "language_model.embedding.word_embeddings.weight")
Expand All @@ -75,8 +73,12 @@ def build_pipeline_layer_name_mapping(src_layers_per_stage, src_rank, map_interv
else:
# Translate destination layer number (0-N for each partition)
# to source layer number (single-model layer number)
rank = src_rank % map_interval
_update_layer_num = functools.partial(update_layer_num, src_layers_per_stage, rank)
# e.g. for src model with 8 layers, src_num_stage=4, dst_num_stage=2
# for src_model, stage offsets are [0, 2, 4, 6]. for dst model, stage offsets are [0, 4]
# then the start layer_num of src->dst is as follows:
# stage0 0->0 stage1 0->(2-0) stage2 0->(4-4) stage3 0->(6-4)
start_layer_num = src_layer_offset - tgt_layer_offset
_update_layer_num = functools.partial(update_layer_num, start_layer_num)
tgt_name = re.sub(layer_re, _update_layer_num, src_name)
name_mapping[tgt_name] = src_name
return name_mapping
Expand Down
2 changes: 1 addition & 1 deletion examples/megatron/configs/gpt/policy_shared.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@ hidden_size: ${policy_hidden_size}
num_attention_heads: ${policy_num_attention_heads}
use_distributed_optimizer: ${policy_use_distributed_optimizer:True}
tensor_model_parallel_size: ${policy_tp}
pipeline_model_parallel_size: 1
pipeline_model_parallel_size: ${policy_pp:1}


2 changes: 1 addition & 1 deletion examples/megatron/configs/gpt/reward_shared.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ hidden_size: ${reward_hidden_size}
num_attention_heads: ${reward_num_attention_heads}
use_distributed_optimizer: ${reward_use_distributed_optimizer:False}
tensor_model_parallel_size: ${reward_tp}
pipeline_model_parallel_size: 1
pipeline_model_parallel_size: ${reward_pp:1}
seq_length: ${max_seq_len}
max_position_embeddings: ${max_seq_len}

Expand Down
Loading