From 3dfb540b10071ecfde42d6e9505e7f42878d8564 Mon Sep 17 00:00:00 2001 From: huangting4201 <1538303371@qq.com> Date: Fri, 6 Sep 2024 15:30:53 +0800 Subject: [PATCH] fix(shard.py): fix isp unpack data indexes err in rotary emb (#316) --- internlm/core/parallel/comm/isp.py | 46 +++++++++++++++++---- internlm/core/parallel/shard.py | 21 ++++++++-- internlm/core/trainer_builder.py | 16 +++---- internlm/data/utils.py | 8 +++- internlm/eval/evaluation.py | 26 +++++++++++- internlm/model/metrics.py | 18 ++++---- internlm/model/modules/mha.py | 18 +++++--- internlm/model/ops/cross_entropy.py | 19 ++++----- internlm/model/ops/ring_flash_attn/utils.py | 12 ++++-- internlm/model/ops/rotary_emb.py | 1 + internlm/utils/common.py | 4 +- tests/test_model/test_model_internlm.py | 1 + tests/test_training/test_loss.py | 20 ++++----- 13 files changed, 148 insertions(+), 62 deletions(-) diff --git a/internlm/core/parallel/comm/isp.py b/internlm/core/parallel/comm/isp.py index 3fcea13a..8a65052f 100644 --- a/internlm/core/parallel/comm/isp.py +++ b/internlm/core/parallel/comm/isp.py @@ -79,8 +79,10 @@ def __init__( self.weight_process_group = weight_process_group self.seq_process_group = seq_process_group self._seq_parallel_mode = ParallelMode.TENSOR - self._seq_dim = 1 + self._seq_world_size = gpc.get_world_size(ParallelMode.TENSOR) self._retain_out_sharded = retain_out_sharded + self._seq_dim = 1 + self._hid_dim = 2 def communication_mode(self) -> str: return "wp" @@ -120,10 +122,25 @@ def grad_output_hook( """ split grad_output if retain_out_sharded is False. """ - if self._retain_out_sharded or dist.get_world_size(self.seq_process_group) <= 1: - return grad_output, DUMMY_HANDLE_CONST - return _split(grad_output, parallel_mode=self._seq_parallel_mode, dim=self._seq_dim), DUMMY_HANDLE_CONST + # gather hidden_states dim and split seq dim when parallel_output is True + if self._retain_out_sharded: + if self._seq_world_size <= 1: + return grad_output, DUMMY_HANDLE_CONST + else: + _seq_splited_list = [ + t.contiguous() for t in torch.tensor_split(grad_output, self._seq_world_size, dim=self._seq_dim) + ] + output_list = [torch.empty_like(_seq_splited_list[0]) for _ in range(self._seq_world_size)] + dist.all_to_all(output_list, _seq_splited_list, group=self.seq_process_group) + grad_output = torch.cat(output_list, dim=self._hid_dim).contiguous() + return grad_output, DUMMY_HANDLE_CONST + # split seq dim when parallel_output is False + else: + if self._seq_world_size <= 1: + return grad_output, DUMMY_HANDLE_CONST + else: + return _split(grad_output, parallel_mode=self._seq_parallel_mode, dim=self._seq_dim), DUMMY_HANDLE_CONST # rewrite ouput communication hook def output_hook( @@ -132,10 +149,25 @@ def output_hook( """ all gather output for head layer if retain_out_sharded is False. """ - if self._retain_out_sharded or dist.get_world_size(self.seq_process_group) <= 1: - return output, DUMMY_HANDLE_CONST - return _gather(output, parallel_mode=self._seq_parallel_mode, dim=self._seq_dim), DUMMY_HANDLE_CONST + # gather seq dim and split hidden_states dim when parallel_output is True + if self._retain_out_sharded: + if self._seq_world_size <= 1: + return output, DUMMY_HANDLE_CONST + else: + _hid_splited_list = [ + t.contiguous() for t in torch.tensor_split(output, self._seq_world_size, dim=self._hid_dim) + ] + output_list = [torch.empty_like(_hid_splited_list[0]) for _ in range(self._seq_world_size)] + dist.all_to_all(output_list, _hid_splited_list, group=self.seq_process_group) + output = torch.cat(output_list, dim=self._seq_dim).contiguous() + return output, DUMMY_HANDLE_CONST + # gather seq dim when parallel_output is False + else: + if self._seq_world_size <= 1: + return output, DUMMY_HANDLE_CONST + else: + return _gather(output, parallel_mode=self._seq_parallel_mode, dim=self._seq_dim), DUMMY_HANDLE_CONST class EmbeddingWeightParallelCommunicator: diff --git a/internlm/core/parallel/shard.py b/internlm/core/parallel/shard.py index b1afff2e..b27e974b 100644 --- a/internlm/core/parallel/shard.py +++ b/internlm/core/parallel/shard.py @@ -9,7 +9,7 @@ from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc -from internlm.core.parallel.comm.utils import _split +from internlm.core.parallel.comm.utils import _gather, _split from internlm.utils.logger import get_logger from internlm.utils.utils import TensorParallelMode @@ -39,8 +39,8 @@ def _split_data_for_sequence_parallel(data, label): data["input_ids"] = _split(data["input_ids"], ParallelMode.TENSOR, dim=_seq_dim) - if gpc.config.model.parallel_output: - label = _split(label, ParallelMode.TENSOR, dim=_seq_dim) + # if gpc.config.model.parallel_output: + # label = _split(label, ParallelMode.TENSOR, dim=_seq_dim) return data, label @@ -49,7 +49,7 @@ def _split_data_for_2D_sequence_parallel(data, label): if gpc.config.parallel.sequence_2D.enable is False or gpc.get_world_size(ParallelMode.TENSOR) <= 1: return data, label - assert len(data.keys()) == 1 and "input_ids" in data + assert len(data.keys()) == 3 and "input_ids" in data and "indexes" in data and "max_seqlen" in data sp_size = gpc.get_world_size(ParallelMode.TENSOR) hp_size = gpc.get_world_size(ParallelMode.HEAD) @@ -59,6 +59,7 @@ def _split_data_for_2D_sequence_parallel(data, label): stride = 2 assert len(data["input_ids"].shape) == 2 + assert len(data["indexes"].shape) == 1 assert len(label.shape) == 2 seq_dim = 1 data["input_ids"] = data["input_ids"].view( @@ -66,6 +67,11 @@ def _split_data_for_2D_sequence_parallel(data, label): 2 * sp_size, data["input_ids"].shape[seq_dim] // (2 * sp_size), ) + _index_seq_dim = 0 + data["indexes"] = data["indexes"].view( + 2 * sp_size, + data["indexes"].shape[_index_seq_dim] // (2 * sp_size), + ) label = label.view( *label.shape[0:seq_dim], 2 * sp_size, @@ -108,9 +114,16 @@ def _split_data_for_2D_sequence_parallel(data, label): data["input_ids"] = data["input_ids"].view( *data["input_ids"].shape[0:seq_dim], -1, *data["input_ids"].shape[(seq_dim + 2) :] ) + data["indexes"] = data["indexes"].index_select(_index_seq_dim, index) + data["indexes"] = data["indexes"].view( + *data["indexes"].shape[0:_index_seq_dim], -1, *data["indexes"].shape[(_index_seq_dim + 2) :] + ) label = label.index_select(seq_dim, index) label = label.view(*label.shape[0:seq_dim], -1, *label.shape[(seq_dim + 2) :]) + # if gpc.config.model.parallel_output is False: + label = _gather(label, ParallelMode.TENSOR, dim=seq_dim) + return data, label diff --git a/internlm/core/trainer_builder.py b/internlm/core/trainer_builder.py index b94bc0d6..eb319111 100644 --- a/internlm/core/trainer_builder.py +++ b/internlm/core/trainer_builder.py @@ -39,7 +39,7 @@ from internlm.utils.gputest import empty_cache_and_diag from internlm.utils.logger import get_logger from internlm.utils.megatron_timers import megatron_timer as timer -from internlm.utils.parallel import get_parallel_log_file_name, is_using_isp +from internlm.utils.parallel import get_parallel_log_file_name from internlm.utils.simple_memory_profiler import SimpleMemoryProfiler from internlm.utils.utils import DataType from internlm.utils.writer import Writer @@ -205,12 +205,14 @@ def _initialize_writer(self, train_state, config_lines) -> Writer: def _initialize_metric(self, dataset_types) -> AccPerplex: # initialize metric for calculating accuracy and perplexity # if isp mode, head output is parallel in sequence dim, metric dp group should be SP*DP - _dp_pg = ( - gpc.get_group(ParallelMode.ISP_DATA) - if is_using_isp() and gpc.config.model.parallel_output - else gpc.get_group(ParallelMode.DATA) - ) - _tp_pg = dist.new_group([gpc.get_global_rank()]) if is_using_isp() else gpc.get_group(ParallelMode.TENSOR) + # _dp_pg = ( + # gpc.get_group(ParallelMode.ISP_DATA) + # if is_using_isp() and gpc.config.model.parallel_output + # else gpc.get_group(ParallelMode.DATA) + # ) + # _tp_pg = dist.new_group([gpc.get_global_rank()]) if is_using_isp() else gpc.get_group(ParallelMode.TENSOR) + _dp_pg = gpc.get_group(ParallelMode.DATA) + _tp_pg = gpc.get_group(ParallelMode.TENSOR) return AccPerplex( device=get_current_device(), tp_pg=_tp_pg, diff --git a/internlm/data/utils.py b/internlm/data/utils.py index 0686a94a..bd87de4c 100644 --- a/internlm/data/utils.py +++ b/internlm/data/utils.py @@ -52,10 +52,16 @@ def unpack_type_ids(type_ids, cu_seqlens): def unpack_data(data, label): data["input_ids"] = _unpack_data(data["input_ids"], data["cu_seqlens"], padding_v=0).squeeze(0) + data["indexes"] = _unpack_data(data["indexes"], data["cu_seqlens"], padding_v=0).squeeze(0) label = _unpack_data(label, data["cu_seqlens"], padding_v=-100).squeeze(0) + data["max_seqlen"] = gpc.config.data.seq_len + data.pop("cu_seqlens") - data.pop("indexes") + # indexes will be used in rotary emb when using isp and sp_size > 1 + # data.pop("indexes") + # per batch's index should be equal, so we select first batch + data["indexes"] = data["indexes"][0] return data, label diff --git a/internlm/eval/evaluation.py b/internlm/eval/evaluation.py index 71fae100..50d17c01 100644 --- a/internlm/eval/evaluation.py +++ b/internlm/eval/evaluation.py @@ -7,9 +7,11 @@ from internlm.accelerator import get_accelerator from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc +from internlm.core.parallel.shard import split_data_for_sequence_parallel from internlm.core.scheduler.pipeline_scheduler import get_tensor_shape from internlm.model.metrics import AccPerplex, SchedulerMetricHook from internlm.utils.common import get_current_device +from internlm.utils.parallel import is_using_isp internlm_accelerator = get_accelerator() @@ -32,7 +34,29 @@ def switch_evaluation_mode(trainer, metric_hook_list): prev_metric_hooks = trainer.schedule._hooks try: gpc.is_evaluating = True - trainer.schedule.data_process_func = None + + data_fns = [] + + def add_indexes_to_data(data, label): + _indexes = torch.arange(gpc.config.data.seq_len, dtype=torch.int32).to(get_current_device()) + assert "indexes" not in data + data["indexes"] = _indexes + data["max_seqlen"] = gpc.config.data.seq_len + + return data, label + + # support sequence parallel for isp + if is_using_isp(): + data_fns.append(add_indexes_to_data) + data_fns.append(split_data_for_sequence_parallel) + + def _data_preparation_func(_data, _label): + for fn in data_fns: + _data, _label = fn(_data, _label) + + return _data, _label + + trainer.schedule.data_process_func = _data_preparation_func trainer.schedule._hooks = metric_hook_list yield diff --git a/internlm/model/metrics.py b/internlm/model/metrics.py index a20136a6..ff7fcd6e 100644 --- a/internlm/model/metrics.py +++ b/internlm/model/metrics.py @@ -3,13 +3,11 @@ import torch from internlm.accelerator import AcceleratorType, get_accelerator -from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc from internlm.model.ops.cross_entropy import new_cross_entropy from internlm.utils.common import SchedulerHook, get_current_device from internlm.utils.logger import get_logger from internlm.utils.megatron_timers import megatron_timer as timer -from internlm.utils.parallel import is_using_isp try: from torch_scatter import scatter as cuda_scatter @@ -116,10 +114,10 @@ def __init__(self, device, tp_pg, dp_pg, tokenizer=None, dataset_types: List[str def set_current_type_ids(self, type_ids: torch.Tensor): self.batch_shift = 0 - if is_using_isp() and gpc.config.model.parallel_output: - step_seqlen = type_ids.shape[-1] // gpc.get_world_size(ParallelMode.TENSOR) - sp_rank = gpc.get_local_rank(ParallelMode.TENSOR) - type_ids = type_ids[..., step_seqlen * sp_rank : step_seqlen * (sp_rank + 1)] + # if is_using_isp() and gpc.config.model.parallel_output: + # step_seqlen = type_ids.shape[-1] // gpc.get_world_size(ParallelMode.TENSOR) + # sp_rank = gpc.get_local_rank(ParallelMode.TENSOR) + # type_ids = type_ids[..., step_seqlen * sp_rank : step_seqlen * (sp_rank + 1)] self.type_ids = type_ids.to(get_current_device()) def set_cu_seqlens(self, cu_seqlens: List): @@ -302,10 +300,10 @@ def update(self, logits, labels, type_ids=None): loss_list = self.loss_fn(logits, labels) # get current rank part loss_list - if is_using_isp() and gpc.config.model.parallel_output: - step_seqlen = logits.shape[0] - sp_rank = gpc.get_local_rank(ParallelMode.TENSOR) - loss_list = loss_list[step_seqlen * sp_rank : step_seqlen * (sp_rank + 1)] + # if is_using_isp() and gpc.config.model.parallel_output: + # step_seqlen = logits.shape[0] + # sp_rank = gpc.get_local_rank(ParallelMode.TENSOR) + # loss_list = loss_list[step_seqlen * sp_rank : step_seqlen * (sp_rank + 1)] cond = labels != -100 real_loss_list = loss_list[cond] diff --git a/internlm/model/modules/mha.py b/internlm/model/modules/mha.py index 6bb75c52..cd8eaff2 100644 --- a/internlm/model/modules/mha.py +++ b/internlm/model/modules/mha.py @@ -8,6 +8,7 @@ from torch import nn from torch.nn import functional as F +from internlm.core.context import global_context as gpc from internlm.model.modules.embedding import new_rotary_embedding from internlm.model.modules.linear import new_linear from internlm.model.modules.utils import update_kv_cache @@ -24,6 +25,8 @@ def _convert_cu_seqlens_for_qksplited(kwargs: Dict): if cu_seqlens is not None: kwargs["cu_seqlens_q"] = cu_seqlens kwargs["cu_seqlens_k"] = cu_seqlens + + if max_seqlen is not None: kwargs["max_seqlen_q"] = max_seqlen kwargs["max_seqlen_k"] = max_seqlen @@ -153,15 +156,14 @@ def _training(self, x, **kwargs): # rotary embedding indexes = kwargs.pop("indexes", 0) max_seqlen = kwargs.get("max_seqlen", None) - q = self.rotary_emb( - q, offsets=indexes, cache_type="query", interleaved=self.interleaved, max_seqlen=max_seqlen, in_place=True - ) - k = self.rotary_emb( - k, offsets=indexes, cache_type="key", interleaved=self.interleaved, max_seqlen=max_seqlen, in_place=True - ) + q = self.rotary_emb(q, offsets=indexes, cache_type="query", interleaved=self.interleaved, max_seqlen=max_seqlen) + k = self.rotary_emb(k, offsets=indexes, cache_type="key", interleaved=self.interleaved, max_seqlen=max_seqlen) # self attention kwargs = _convert_cu_seqlens_for_qksplited(kwargs) + if gpc.config.data.use_packed_dataset is False: + kwargs.pop("max_seqlen_q", None) + kwargs.pop("max_seqlen_k", None) context = self.inner_attn(q, k, v, **kwargs) # wo @@ -465,6 +467,10 @@ def _training(self, x, **kwargs): kv = torch.concat([k.unsqueeze(2), v.unsqueeze(2)], dim=2) + if gpc.config.data.use_packed_dataset is False: + kwargs.pop("max_seqlen_q", None) + kwargs.pop("max_seqlen_k", None) + # self attention context = self.inner_attn(q, kv, **kwargs) diff --git a/internlm/model/ops/cross_entropy.py b/internlm/model/ops/cross_entropy.py index ecfb93cf..eba7f4dc 100644 --- a/internlm/model/ops/cross_entropy.py +++ b/internlm/model/ops/cross_entropy.py @@ -13,7 +13,6 @@ from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc from internlm.utils.logger import get_logger -from internlm.utils.parallel import is_using_isp try: from flash_attn.losses.cross_entropy import ( @@ -161,15 +160,15 @@ def new_cross_entropy( parallel_output: bool = False, **kwargs, ): - if is_using_isp() and parallel_output: - if gpc.is_rank_for_log(): - logger.warning("Use VocabSequenceParallelCrossEntropyLoss.") - return VocabSequenceParallelCrossEntropyLoss( - ignore_index=ignore_index, - reduction=reduction, - label_smoothing=label_smoothing, - process_group=gpc.get_group(ParallelMode.TENSOR), - ) + # if is_using_isp() and parallel_output: + # if gpc.is_rank_for_log(): + # logger.warning("Use VocabSequenceParallelCrossEntropyLoss.") + # return VocabSequenceParallelCrossEntropyLoss( + # ignore_index=ignore_index, + # reduction=reduction, + # label_smoothing=label_smoothing, + # process_group=gpc.get_group(ParallelMode.TENSOR), + # ) if parallel_output: assert ( diff --git a/internlm/model/ops/ring_flash_attn/utils.py b/internlm/model/ops/ring_flash_attn/utils.py index bf59e00a..c15f074a 100644 --- a/internlm/model/ops/ring_flash_attn/utils.py +++ b/internlm/model/ops/ring_flash_attn/utils.py @@ -4,6 +4,7 @@ import torch import torch.distributed as dist +import torch.nn.functional as F __all__ = ["update_out_and_lse", "RingComm"] @@ -15,14 +16,17 @@ def _update_out_and_lse( block_out: torch.Tensor, block_lse: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: + block_out = block_out.to(torch.float32) block_lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1) - new_lse = lse + torch.log(1 + torch.exp(block_lse - lse)) - - out = torch.exp(lse - new_lse) * out + torch.exp(block_lse - new_lse) * block_out + # new_lse = lse + torch.log(1 + torch.exp(block_lse - lse)) + # torch.exp(lse - new_lse) * out + torch.exp(block_lse - new_lse) * block_out + # For additional context and discussion, please refer to: + # https://github.com/zhuzilin/ring-flash-attention/pull/34#issuecomment-2076126795 + out = out - F.sigmoid(block_lse - lse) * (out - block_out) + lse = lse - F.logsigmoid(lse - block_lse) - lse = new_lse return out, lse diff --git a/internlm/model/ops/rotary_emb.py b/internlm/model/ops/rotary_emb.py index d4627cdb..c7109537 100644 --- a/internlm/model/ops/rotary_emb.py +++ b/internlm/model/ops/rotary_emb.py @@ -107,6 +107,7 @@ def _apply_torch_npu_rotary_mul(x: Tensor, cos: Tensor, sin: Tensor): cos (Tensor): cos, shape is [1, S, 1, D]. sin (Tensor): sin, shape is [1, S, 1, D]. """ + # NOTE: This could probably be moved to Triton. def rotate_half(_x): x1, x2 = _x.chunk(2, dim=-1) diff --git a/internlm/utils/common.py b/internlm/utils/common.py index 956d08a8..5f53a4f9 100644 --- a/internlm/utils/common.py +++ b/internlm/utils/common.py @@ -95,10 +95,10 @@ def check_data_is_packed(data): return False elif isinstance(data, (list, tuple)): if isinstance(data[0], dict): - return "indexes" in data[0] + return "cu_seqlens" in data[0] return False elif isinstance(data, dict): - return "indexes" in data[0] + return "cu_seqlens" in data[0] def filter_kwargs(func, kwargs): diff --git a/tests/test_model/test_model_internlm.py b/tests/test_model/test_model_internlm.py index c33f188c..c62d062f 100644 --- a/tests/test_model/test_model_internlm.py +++ b/tests/test_model/test_model_internlm.py @@ -47,6 +47,7 @@ pack_sample_into_one=False, min_length=0, total_steps=9999, + use_packed_dataset=True, ), model=dict( checkpoint=False, diff --git a/tests/test_training/test_loss.py b/tests/test_training/test_loss.py index 37cf9517..c20d8398 100644 --- a/tests/test_training/test_loss.py +++ b/tests/test_training/test_loss.py @@ -465,16 +465,16 @@ def test_training_with_isp(): global CONFIG_FILE_PATH, BASELINE_LOSS_LIST CONFIG_FILE_PATH = "./configs/7B_isp_sft.py" BASELINE_LOSS_LIST = [ - 10.711931228637695, - 7.549415588378906, - 6.495877742767334, - 5.944756507873535, - 5.246580123901367, - 5.334012031555176, - 4.999225616455078, - 4.70023250579834, - 4.591017723083496, - 4.589826583862305, + 11.595988273620605, + 7.988386154174805, + 6.821506500244141, + 6.2768449783325195, + 5.478013515472412, + 5.4622697830200195, + 5.162247180938721, + 4.854615211486816, + 4.744818210601807, + 4.75523567199707, ] # model training