Skip to content

Commit

Permalink
Fix(mha,linear): fix norm_head and mha inference (#234)
Browse files Browse the repository at this point in the history
Co-authored-by: shidongxing <shidongxing@>
  • Loading branch information
KimmiShi authored May 24, 2024
1 parent c355767 commit 41545ce
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 23 deletions.
Empty file.
8 changes: 4 additions & 4 deletions internlm/core/parallel/comm/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def grad_output_hook(
if self._retain_out_sharded or dist.get_world_size(self._process_group) <= 1:
return grad_output, DUMMY_HANDLE_CONST

return _split(grad_output, parallel_mode=self._parallel_mode, dim=-1)
return _split(grad_output, parallel_mode=self._parallel_mode, dim=-1), DUMMY_HANDLE_CONST

def output_hook(
self, output: torch.Tensor, async_op: bool = False # pylint: disable=W0613
Expand All @@ -244,7 +244,7 @@ def output_hook(
if self._retain_out_sharded or dist.get_world_size(self._process_group) <= 1:
return output, DUMMY_HANDLE_CONST

return _gather(output, parallel_mode=self._parallel_mode, dim=-1)
return _gather(output, parallel_mode=self._parallel_mode, dim=-1), DUMMY_HANDLE_CONST


class HeadSequenceParallelCommunicator(SequenceParallelCommunicator):
Expand Down Expand Up @@ -274,7 +274,7 @@ def grad_output_hook(
if self._retain_out_sharded or dist.get_world_size(self._process_group) <= 1:
return grad_output, DUMMY_HANDLE_CONST

return _split(grad_output, parallel_mode=self._parallel_mode, dim=-1)
return _split(grad_output, parallel_mode=self._parallel_mode, dim=-1), DUMMY_HANDLE_CONST

# rewrite ouput communication hook
def output_hook(
Expand All @@ -286,7 +286,7 @@ def output_hook(
if self._retain_out_sharded or dist.get_world_size(self._process_group) <= 1:
return output, DUMMY_HANDLE_CONST

return _gather(output, parallel_mode=self._parallel_mode, dim=-1)
return _gather(output, parallel_mode=self._parallel_mode, dim=-1), DUMMY_HANDLE_CONST


class MoESequenceParallelCommunicator:
Expand Down
2 changes: 1 addition & 1 deletion internlm/model/modules/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,7 @@ def forward(self, input): # pylint: disable=W0622

return fused_dense_func(
input,
self.weight,
weight,
communicator=self._communicator,
module=self,
bias=self.bias,
Expand Down
38 changes: 23 additions & 15 deletions internlm/model/modules/mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,18 +184,22 @@ def _convert_unpacked_qkv_to_packed(
max_seqlen_q = attention_mask.shape[-1]
max_seqlen_k = attention_mask.shape[-1]

q_packed = q.masked_select(attention_mask.view(batch_size, -1, 1, 1)).view(-1, q.shape[-2], q.shape[-1])
kv_packed = kv.masked_select(attention_mask.view(batch_size, -1, 1, 1, 1)).view(
-1, kv.shape[-3], kv.shape[-2], kv.shape[-1]
q_packed = (
q.masked_select(attention_mask.view(batch_size, -1, 1, 1)).view(-1, q.shape[-2], q.shape[-1]).unsqueeze(0)
)
kv_packed = (
kv.masked_select(attention_mask.view(batch_size, -1, 1, 1, 1))
.view(-1, kv.shape[-3], kv.shape[-2], kv.shape[-1])
.unsqueeze(0)
)

return q_packed, kv_packed, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k

def _inference(self, x, inference_params, **kwargs): # pylint: disable=W0613
assert inference_params is not None, "inference_params is required for inference"
assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
attention_mask = inference_params.get("attention_mask", None)
sequence_len_offset = inference_params.get("sequence_len_offset", 0)
attention_mask = inference_params.attention_mask
sequence_len_offset = inference_params.sequence_len_offset
batch_size = x.shape[0]

# wqkv, output: q, kv
Expand Down Expand Up @@ -230,21 +234,21 @@ def _inference(self, x, inference_params, **kwargs): # pylint: disable=W0613
q = self.rotary_emb(
q, offsets=sequence_len_offset, cache_type="query", interleaved=self.interleaved
)
k = kv[:, :, 0].squeueze(2)
k = kv[:, :, 0].squeeze(2)
self.rotary_emb(
k, offsets=0, cache_type="key", interleaved=self.interleaved, in_place=True
) # in-place is important
else:
if self.rotary_emb_dim > 0:
q = self.rotary_emb(q, offsets=0, cache_type="query", interleaved=self.interleaved)
k = kv[:, :, 0].squeueze(2)
k = kv[:, :, 0].squeeze(2)
self.rotary_emb(
k, offsets=0, cache_type="key", interleaved=self.interleaved, in_place=True
) # in-place is important
else:
assert self.rotary_emb_dim > 0, "You should use rotary_emb."

k, v = kv[:, :, 0].squeueze(2), kv[:, :, 1].squeueze(2)
k, v = kv[:, :, 0].squeeze(2), kv[:, :, 1].squeeze(2)

if attention_mask is None:
q = self.rotary_emb(q, offsets=sequence_len_offset, cache_type="query", interleaved=self.interleaved)
Expand Down Expand Up @@ -474,27 +478,31 @@ def _convert_unpacked_qkv_to_packed(
max_seqlen_q = attention_mask.shape[-1]
max_seqlen_k = attention_mask.shape[-1]

q_packed = q.masked_select(attention_mask.view(batch_size, -1, 1, 1)).view(-1, q.shape[-2], q.shape[-1])
kv_packed = kv.masked_select(attention_mask.view(batch_size, -1, 1, 1, 1)).view(
-1, kv.shape[-3], kv.shape[-2], kv.shape[-1]
q_packed = (
q.masked_select(attention_mask.view(batch_size, -1, 1, 1)).view(-1, q.shape[-2], q.shape[-1]).unsqueeze(0)
)
kv_packed = (
kv.masked_select(attention_mask.view(batch_size, -1, 1, 1, 1))
.view(-1, kv.shape[-3], kv.shape[-2], kv.shape[-1])
.unsqueeze(0)
)

return q_packed, kv_packed, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k

def _inference(self, x, inference_params, **kwargs): # pylint: disable=W0613
assert inference_params is not None, "inference_params is required for inference"
assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
attention_mask = inference_params.get("attention_mask", None)
sequence_len_offset = inference_params.get("sequence_len_offset", 0)
window_size = inference_params.get("window_size", None)
attention_mask = inference_params.attention_mask
sequence_len_offset = inference_params.sequence_len_offset
window_size = inference_params.window_size

batch_size = x.shape[0]

# wqkv, output: q, k, v
if self.enable_qkv_fusion:
qkv = self.wqkv(x)
qkv = rearrange(qkv, "b s (h gs d) -> b s h gs d", gs=self.q_per_kv + 2, d=self.head_dim)
q, k, v = (qkv[..., : self.q_per_kv, :], qkv[..., -2, :].unsqueeze(-2), qkv[..., -1, :].unsqueeze(-2))
q, k, v = (qkv[..., : self.q_per_kv, :], qkv[..., -2, :], qkv[..., -1, :])
q = rearrange(q, "b s h gs d -> b s (h gs) d")
else:
q, k, v = self.wq(x), self.wk(x), self.wv(x)
Expand Down
7 changes: 4 additions & 3 deletions internlm/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,15 @@ def __kv_checker(num_args: int):
# kv: [batch, seqlen, 3, n_head, headdim]
return len(args[2].shape) == 5

def __cu_seqlens_checker(num_args: int, check_idx: int):
def __cu_seqlens_checker(args, check_idx: int):
num_args = len(args)
if num_args < (check_idx + 1):
if check_idx == 2:
return "cu_seqlens" in kwargs and kwargs["cu_seqlens"] is not None
else:
return "cu_seqlens_q" in kwargs and kwargs["cu_seqlens_q"] is not None
else:
return isinstance(num_args[check_idx], torch.Tensor)
return isinstance(args[check_idx], torch.Tensor)

if __qkv_checker(len(args)):
# qkv packed, and we should check cu_seqlens with index 2
Expand All @@ -81,7 +82,7 @@ def __cu_seqlens_checker(num_args: int, check_idx: int):
# qkv splited, and we should check cu_seqlens with index 4
qkv_pack_type = int(QKVPackType.QKVSPLITED)

with_cu_seqlens = __cu_seqlens_checker(len(args), qkv_pack_type)
with_cu_seqlens = __cu_seqlens_checker(args, qkv_pack_type)

return str(qkv_pack_type), str(with_cu_seqlens)

Expand Down

0 comments on commit 41545ce

Please sign in to comment.