Skip to content

Commit

Permalink
fix(shard.py): fix isp unpack data indexes err in rotary emb (#316)
Browse files Browse the repository at this point in the history
  • Loading branch information
huangting4201 authored Sep 6, 2024
1 parent 8ab2aff commit 3dfb540
Show file tree
Hide file tree
Showing 13 changed files with 148 additions and 62 deletions.
46 changes: 39 additions & 7 deletions internlm/core/parallel/comm/isp.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,10 @@ def __init__(
self.weight_process_group = weight_process_group
self.seq_process_group = seq_process_group
self._seq_parallel_mode = ParallelMode.TENSOR
self._seq_dim = 1
self._seq_world_size = gpc.get_world_size(ParallelMode.TENSOR)
self._retain_out_sharded = retain_out_sharded
self._seq_dim = 1
self._hid_dim = 2

def communication_mode(self) -> str:
return "wp"
Expand Down Expand Up @@ -120,10 +122,25 @@ def grad_output_hook(
"""
split grad_output if retain_out_sharded is False.
"""
if self._retain_out_sharded or dist.get_world_size(self.seq_process_group) <= 1:
return grad_output, DUMMY_HANDLE_CONST

return _split(grad_output, parallel_mode=self._seq_parallel_mode, dim=self._seq_dim), DUMMY_HANDLE_CONST
# gather hidden_states dim and split seq dim when parallel_output is True
if self._retain_out_sharded:
if self._seq_world_size <= 1:
return grad_output, DUMMY_HANDLE_CONST
else:
_seq_splited_list = [
t.contiguous() for t in torch.tensor_split(grad_output, self._seq_world_size, dim=self._seq_dim)
]
output_list = [torch.empty_like(_seq_splited_list[0]) for _ in range(self._seq_world_size)]
dist.all_to_all(output_list, _seq_splited_list, group=self.seq_process_group)
grad_output = torch.cat(output_list, dim=self._hid_dim).contiguous()
return grad_output, DUMMY_HANDLE_CONST
# split seq dim when parallel_output is False
else:
if self._seq_world_size <= 1:
return grad_output, DUMMY_HANDLE_CONST
else:
return _split(grad_output, parallel_mode=self._seq_parallel_mode, dim=self._seq_dim), DUMMY_HANDLE_CONST

# rewrite ouput communication hook
def output_hook(
Expand All @@ -132,10 +149,25 @@ def output_hook(
"""
all gather output for head layer if retain_out_sharded is False.
"""
if self._retain_out_sharded or dist.get_world_size(self.seq_process_group) <= 1:
return output, DUMMY_HANDLE_CONST

return _gather(output, parallel_mode=self._seq_parallel_mode, dim=self._seq_dim), DUMMY_HANDLE_CONST
# gather seq dim and split hidden_states dim when parallel_output is True
if self._retain_out_sharded:
if self._seq_world_size <= 1:
return output, DUMMY_HANDLE_CONST
else:
_hid_splited_list = [
t.contiguous() for t in torch.tensor_split(output, self._seq_world_size, dim=self._hid_dim)
]
output_list = [torch.empty_like(_hid_splited_list[0]) for _ in range(self._seq_world_size)]
dist.all_to_all(output_list, _hid_splited_list, group=self.seq_process_group)
output = torch.cat(output_list, dim=self._seq_dim).contiguous()
return output, DUMMY_HANDLE_CONST
# gather seq dim when parallel_output is False
else:
if self._seq_world_size <= 1:
return output, DUMMY_HANDLE_CONST
else:
return _gather(output, parallel_mode=self._seq_parallel_mode, dim=self._seq_dim), DUMMY_HANDLE_CONST


class EmbeddingWeightParallelCommunicator:
Expand Down
21 changes: 17 additions & 4 deletions internlm/core/parallel/shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.core.parallel.comm.utils import _split
from internlm.core.parallel.comm.utils import _gather, _split
from internlm.utils.logger import get_logger
from internlm.utils.utils import TensorParallelMode

Expand Down Expand Up @@ -39,8 +39,8 @@ def _split_data_for_sequence_parallel(data, label):

data["input_ids"] = _split(data["input_ids"], ParallelMode.TENSOR, dim=_seq_dim)

if gpc.config.model.parallel_output:
label = _split(label, ParallelMode.TENSOR, dim=_seq_dim)
# if gpc.config.model.parallel_output:
# label = _split(label, ParallelMode.TENSOR, dim=_seq_dim)

return data, label

Expand All @@ -49,7 +49,7 @@ def _split_data_for_2D_sequence_parallel(data, label):
if gpc.config.parallel.sequence_2D.enable is False or gpc.get_world_size(ParallelMode.TENSOR) <= 1:
return data, label

assert len(data.keys()) == 1 and "input_ids" in data
assert len(data.keys()) == 3 and "input_ids" in data and "indexes" in data and "max_seqlen" in data

sp_size = gpc.get_world_size(ParallelMode.TENSOR)
hp_size = gpc.get_world_size(ParallelMode.HEAD)
Expand All @@ -59,13 +59,19 @@ def _split_data_for_2D_sequence_parallel(data, label):
stride = 2

assert len(data["input_ids"].shape) == 2
assert len(data["indexes"].shape) == 1
assert len(label.shape) == 2
seq_dim = 1
data["input_ids"] = data["input_ids"].view(
*data["input_ids"].shape[0:seq_dim],
2 * sp_size,
data["input_ids"].shape[seq_dim] // (2 * sp_size),
)
_index_seq_dim = 0
data["indexes"] = data["indexes"].view(
2 * sp_size,
data["indexes"].shape[_index_seq_dim] // (2 * sp_size),
)
label = label.view(
*label.shape[0:seq_dim],
2 * sp_size,
Expand Down Expand Up @@ -108,9 +114,16 @@ def _split_data_for_2D_sequence_parallel(data, label):
data["input_ids"] = data["input_ids"].view(
*data["input_ids"].shape[0:seq_dim], -1, *data["input_ids"].shape[(seq_dim + 2) :]
)
data["indexes"] = data["indexes"].index_select(_index_seq_dim, index)
data["indexes"] = data["indexes"].view(
*data["indexes"].shape[0:_index_seq_dim], -1, *data["indexes"].shape[(_index_seq_dim + 2) :]
)
label = label.index_select(seq_dim, index)
label = label.view(*label.shape[0:seq_dim], -1, *label.shape[(seq_dim + 2) :])

# if gpc.config.model.parallel_output is False:
label = _gather(label, ParallelMode.TENSOR, dim=seq_dim)

return data, label


Expand Down
16 changes: 9 additions & 7 deletions internlm/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from internlm.utils.gputest import empty_cache_and_diag
from internlm.utils.logger import get_logger
from internlm.utils.megatron_timers import megatron_timer as timer
from internlm.utils.parallel import get_parallel_log_file_name, is_using_isp
from internlm.utils.parallel import get_parallel_log_file_name
from internlm.utils.simple_memory_profiler import SimpleMemoryProfiler
from internlm.utils.utils import DataType
from internlm.utils.writer import Writer
Expand Down Expand Up @@ -205,12 +205,14 @@ def _initialize_writer(self, train_state, config_lines) -> Writer:
def _initialize_metric(self, dataset_types) -> AccPerplex:
# initialize metric for calculating accuracy and perplexity
# if isp mode, head output is parallel in sequence dim, metric dp group should be SP*DP
_dp_pg = (
gpc.get_group(ParallelMode.ISP_DATA)
if is_using_isp() and gpc.config.model.parallel_output
else gpc.get_group(ParallelMode.DATA)
)
_tp_pg = dist.new_group([gpc.get_global_rank()]) if is_using_isp() else gpc.get_group(ParallelMode.TENSOR)
# _dp_pg = (
# gpc.get_group(ParallelMode.ISP_DATA)
# if is_using_isp() and gpc.config.model.parallel_output
# else gpc.get_group(ParallelMode.DATA)
# )
# _tp_pg = dist.new_group([gpc.get_global_rank()]) if is_using_isp() else gpc.get_group(ParallelMode.TENSOR)
_dp_pg = gpc.get_group(ParallelMode.DATA)
_tp_pg = gpc.get_group(ParallelMode.TENSOR)
return AccPerplex(
device=get_current_device(),
tp_pg=_tp_pg,
Expand Down
8 changes: 7 additions & 1 deletion internlm/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,16 @@ def unpack_type_ids(type_ids, cu_seqlens):
def unpack_data(data, label):

data["input_ids"] = _unpack_data(data["input_ids"], data["cu_seqlens"], padding_v=0).squeeze(0)
data["indexes"] = _unpack_data(data["indexes"], data["cu_seqlens"], padding_v=0).squeeze(0)
label = _unpack_data(label, data["cu_seqlens"], padding_v=-100).squeeze(0)

data["max_seqlen"] = gpc.config.data.seq_len

data.pop("cu_seqlens")
data.pop("indexes")
# indexes will be used in rotary emb when using isp and sp_size > 1
# data.pop("indexes")
# per batch's index should be equal, so we select first batch
data["indexes"] = data["indexes"][0]

return data, label

Expand Down
26 changes: 25 additions & 1 deletion internlm/eval/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
from internlm.accelerator import get_accelerator
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.core.parallel.shard import split_data_for_sequence_parallel
from internlm.core.scheduler.pipeline_scheduler import get_tensor_shape
from internlm.model.metrics import AccPerplex, SchedulerMetricHook
from internlm.utils.common import get_current_device
from internlm.utils.parallel import is_using_isp

internlm_accelerator = get_accelerator()

Expand All @@ -32,7 +34,29 @@ def switch_evaluation_mode(trainer, metric_hook_list):
prev_metric_hooks = trainer.schedule._hooks
try:
gpc.is_evaluating = True
trainer.schedule.data_process_func = None

data_fns = []

def add_indexes_to_data(data, label):
_indexes = torch.arange(gpc.config.data.seq_len, dtype=torch.int32).to(get_current_device())
assert "indexes" not in data
data["indexes"] = _indexes
data["max_seqlen"] = gpc.config.data.seq_len

return data, label

# support sequence parallel for isp
if is_using_isp():
data_fns.append(add_indexes_to_data)
data_fns.append(split_data_for_sequence_parallel)

def _data_preparation_func(_data, _label):
for fn in data_fns:
_data, _label = fn(_data, _label)

return _data, _label

trainer.schedule.data_process_func = _data_preparation_func
trainer.schedule._hooks = metric_hook_list

yield
Expand Down
18 changes: 8 additions & 10 deletions internlm/model/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,11 @@
import torch

from internlm.accelerator import AcceleratorType, get_accelerator
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.model.ops.cross_entropy import new_cross_entropy
from internlm.utils.common import SchedulerHook, get_current_device
from internlm.utils.logger import get_logger
from internlm.utils.megatron_timers import megatron_timer as timer
from internlm.utils.parallel import is_using_isp

try:
from torch_scatter import scatter as cuda_scatter
Expand Down Expand Up @@ -116,10 +114,10 @@ def __init__(self, device, tp_pg, dp_pg, tokenizer=None, dataset_types: List[str

def set_current_type_ids(self, type_ids: torch.Tensor):
self.batch_shift = 0
if is_using_isp() and gpc.config.model.parallel_output:
step_seqlen = type_ids.shape[-1] // gpc.get_world_size(ParallelMode.TENSOR)
sp_rank = gpc.get_local_rank(ParallelMode.TENSOR)
type_ids = type_ids[..., step_seqlen * sp_rank : step_seqlen * (sp_rank + 1)]
# if is_using_isp() and gpc.config.model.parallel_output:
# step_seqlen = type_ids.shape[-1] // gpc.get_world_size(ParallelMode.TENSOR)
# sp_rank = gpc.get_local_rank(ParallelMode.TENSOR)
# type_ids = type_ids[..., step_seqlen * sp_rank : step_seqlen * (sp_rank + 1)]
self.type_ids = type_ids.to(get_current_device())

def set_cu_seqlens(self, cu_seqlens: List):
Expand Down Expand Up @@ -302,10 +300,10 @@ def update(self, logits, labels, type_ids=None):
loss_list = self.loss_fn(logits, labels)

# get current rank part loss_list
if is_using_isp() and gpc.config.model.parallel_output:
step_seqlen = logits.shape[0]
sp_rank = gpc.get_local_rank(ParallelMode.TENSOR)
loss_list = loss_list[step_seqlen * sp_rank : step_seqlen * (sp_rank + 1)]
# if is_using_isp() and gpc.config.model.parallel_output:
# step_seqlen = logits.shape[0]
# sp_rank = gpc.get_local_rank(ParallelMode.TENSOR)
# loss_list = loss_list[step_seqlen * sp_rank : step_seqlen * (sp_rank + 1)]

cond = labels != -100
real_loss_list = loss_list[cond]
Expand Down
18 changes: 12 additions & 6 deletions internlm/model/modules/mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from torch import nn
from torch.nn import functional as F

from internlm.core.context import global_context as gpc
from internlm.model.modules.embedding import new_rotary_embedding
from internlm.model.modules.linear import new_linear
from internlm.model.modules.utils import update_kv_cache
Expand All @@ -24,6 +25,8 @@ def _convert_cu_seqlens_for_qksplited(kwargs: Dict):
if cu_seqlens is not None:
kwargs["cu_seqlens_q"] = cu_seqlens
kwargs["cu_seqlens_k"] = cu_seqlens

if max_seqlen is not None:
kwargs["max_seqlen_q"] = max_seqlen
kwargs["max_seqlen_k"] = max_seqlen

Expand Down Expand Up @@ -153,15 +156,14 @@ def _training(self, x, **kwargs):
# rotary embedding
indexes = kwargs.pop("indexes", 0)
max_seqlen = kwargs.get("max_seqlen", None)
q = self.rotary_emb(
q, offsets=indexes, cache_type="query", interleaved=self.interleaved, max_seqlen=max_seqlen, in_place=True
)
k = self.rotary_emb(
k, offsets=indexes, cache_type="key", interleaved=self.interleaved, max_seqlen=max_seqlen, in_place=True
)
q = self.rotary_emb(q, offsets=indexes, cache_type="query", interleaved=self.interleaved, max_seqlen=max_seqlen)
k = self.rotary_emb(k, offsets=indexes, cache_type="key", interleaved=self.interleaved, max_seqlen=max_seqlen)

# self attention
kwargs = _convert_cu_seqlens_for_qksplited(kwargs)
if gpc.config.data.use_packed_dataset is False:
kwargs.pop("max_seqlen_q", None)
kwargs.pop("max_seqlen_k", None)
context = self.inner_attn(q, k, v, **kwargs)

# wo
Expand Down Expand Up @@ -465,6 +467,10 @@ def _training(self, x, **kwargs):

kv = torch.concat([k.unsqueeze(2), v.unsqueeze(2)], dim=2)

if gpc.config.data.use_packed_dataset is False:
kwargs.pop("max_seqlen_q", None)
kwargs.pop("max_seqlen_k", None)

# self attention
context = self.inner_attn(q, kv, **kwargs)

Expand Down
19 changes: 9 additions & 10 deletions internlm/model/ops/cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.utils.logger import get_logger
from internlm.utils.parallel import is_using_isp

try:
from flash_attn.losses.cross_entropy import (
Expand Down Expand Up @@ -161,15 +160,15 @@ def new_cross_entropy(
parallel_output: bool = False,
**kwargs,
):
if is_using_isp() and parallel_output:
if gpc.is_rank_for_log():
logger.warning("Use VocabSequenceParallelCrossEntropyLoss.")
return VocabSequenceParallelCrossEntropyLoss(
ignore_index=ignore_index,
reduction=reduction,
label_smoothing=label_smoothing,
process_group=gpc.get_group(ParallelMode.TENSOR),
)
# if is_using_isp() and parallel_output:
# if gpc.is_rank_for_log():
# logger.warning("Use VocabSequenceParallelCrossEntropyLoss.")
# return VocabSequenceParallelCrossEntropyLoss(
# ignore_index=ignore_index,
# reduction=reduction,
# label_smoothing=label_smoothing,
# process_group=gpc.get_group(ParallelMode.TENSOR),
# )

if parallel_output:
assert (
Expand Down
12 changes: 8 additions & 4 deletions internlm/model/ops/ring_flash_attn/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import torch
import torch.distributed as dist
import torch.nn.functional as F

__all__ = ["update_out_and_lse", "RingComm"]

Expand All @@ -15,14 +16,17 @@ def _update_out_and_lse(
block_out: torch.Tensor,
block_lse: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:

block_out = block_out.to(torch.float32)
block_lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1)

new_lse = lse + torch.log(1 + torch.exp(block_lse - lse))

out = torch.exp(lse - new_lse) * out + torch.exp(block_lse - new_lse) * block_out
# new_lse = lse + torch.log(1 + torch.exp(block_lse - lse))
# torch.exp(lse - new_lse) * out + torch.exp(block_lse - new_lse) * block_out
# For additional context and discussion, please refer to:
# https://github.com/zhuzilin/ring-flash-attention/pull/34#issuecomment-2076126795
out = out - F.sigmoid(block_lse - lse) * (out - block_out)
lse = lse - F.logsigmoid(lse - block_lse)

lse = new_lse
return out, lse


Expand Down
1 change: 1 addition & 0 deletions internlm/model/ops/rotary_emb.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def _apply_torch_npu_rotary_mul(x: Tensor, cos: Tensor, sin: Tensor):
cos (Tensor): cos, shape is [1, S, 1, D].
sin (Tensor): sin, shape is [1, S, 1, D].
"""

# NOTE: This could probably be moved to Triton.
def rotate_half(_x):
x1, x2 = _x.chunk(2, dim=-1)
Expand Down
Loading

0 comments on commit 3dfb540

Please sign in to comment.