From 7940477c587562fa368c417651e021ce9fa3bf5b Mon Sep 17 00:00:00 2001 From: x54-729 Date: Tue, 3 Sep 2024 12:09:13 +0000 Subject: [PATCH 1/3] fix bugs of generation --- internlm/apis/inference.py | 108 +++++++++++++++++++++------- internlm/model/modules/embedding.py | 4 +- internlm/model/modules/mha.py | 4 ++ 3 files changed, 87 insertions(+), 29 deletions(-) diff --git a/internlm/apis/inference.py b/internlm/apis/inference.py index d3b5de87..f78e70d6 100644 --- a/internlm/apis/inference.py +++ b/internlm/apis/inference.py @@ -1,6 +1,4 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - +# Copyright (c) InternLM. All rights reserved. from typing import Dict, List, Tuple, Union import torch @@ -67,6 +65,7 @@ def generate( top_p: float = 1.0, repetition_penalty: float = 1, length_penalty: float = 1.0, + min_new_tokens: int = 1, ): """ Args: @@ -79,6 +78,9 @@ def generate( temperature: it's meaningful when do_sample is True. top_k: sampling from top_k. top_p: sampling from top_p tokens(nucleus sampling). + repetition_penalty: the penalty degree for repetition tokens + length_penalty: the penalty for length. + min_new_tokens: minimum number of generated tokens. Return: the token sequence whose shape is [bsz, num_return_sequences, max_length]. If eos_token_id is not None, @@ -103,6 +105,7 @@ def generate( length_penalty=length_penalty, # the penalty for length. if it > 1, then encourages long sequence. # Otherwise, encourages short sequence. bos_token_id=self.bos_token_id, + min_new_tokens=min_new_tokens, ) else: return greedy_generate( @@ -118,6 +121,7 @@ def generate( repetition_penalty=repetition_penalty, length_penalty=length_penalty, bos_token_id=self.bos_token_id, + min_new_tokens=min_new_tokens, ) @torch.no_grad() @@ -131,6 +135,7 @@ def streaming_generate( top_p: float = 1.0, repetition_penalty: float = 1, length_penalty: float = 1.0, + min_new_tokens: int = 1, ): if not do_sample: temperature = 1 @@ -151,6 +156,7 @@ def streaming_generate( length_penalty=length_penalty, pad_token_id=self.pad_token_id, bos_token_id=self.bos_token_id, + min_new_tokens=min_new_tokens, ) @@ -168,6 +174,7 @@ def greedy_generate( repetition_penalty=1, length_penalty=1.0, bos_token_id=1, + min_new_tokens=1, ): """ Search sequence greedily. @@ -181,6 +188,7 @@ def greedy_generate( pad_token_id: the token id of pad. repetition_penalty: the penalty degree for repetition tokens length_penalty: the penalty for length. + min_new_tokens: minimum number of generated tokens. """ if num_beams == 1: @@ -199,6 +207,7 @@ def greedy_generate( length_penalty=length_penalty, pad_token_id=pad_token_id, bos_token_id=bos_token_id, + min_new_tokens=min_new_tokens, ) else: token_ids = _beam_search_generate( @@ -218,6 +227,7 @@ def greedy_generate( length_penalty=length_penalty, pad_token_id=pad_token_id, bos_token_id=bos_token_id, + min_new_tokens=min_new_tokens, ) return token_ids @@ -240,6 +250,7 @@ def sample_generate( repetition_penalty=1.0, length_penalty=1.0, bos_token_id=1, + min_new_tokens=1, ): """ generate sequence in sampling way. @@ -257,6 +268,7 @@ def sample_generate( pad_token_id: the token id of pad. repetition_penalty: the penalty degree for repetition tokens length_penalty: the penalty for length. + min_new_tokens: minimum number of generated tokens. """ if num_beams == 1: @@ -275,6 +287,7 @@ def sample_generate( length_penalty=length_penalty, pad_token_id=pad_token_id, bos_token_id=bos_token_id, + min_new_tokens=min_new_tokens, ) else: token_ids = _beam_search_generate( @@ -294,6 +307,7 @@ def sample_generate( length_penalty=length_penalty, pad_token_id=pad_token_id, bos_token_id=bos_token_id, + min_new_tokens=min_new_tokens, ) return token_ids @@ -315,8 +329,11 @@ def _streaming_no_beam_search_generate( length_penalty=1.0, pad_token_id=0, bos_token_id=1, + min_new_tokens=1, ): - batch_size = tokens.size(0) + batch_size, cur_len = tokens.shape + real_max_length = max_length + real_min_length = cur_len + min_new_tokens if eos_token_id is not None: if not isinstance(eos_token_id, (List, Tuple)): eos_token_id = [eos_token_id] @@ -326,7 +343,9 @@ def _streaming_no_beam_search_generate( eos_token_id.extend(additional_eos_token_list) eos_token_id = torch.LongTensor(eos_token_id).to(tokens.device) + assert bos_token_id == pad_token_id, "bos_token_id should be equal to left pad_token_id!" has_bos = torch.all(tokens[:, 0].eq(bos_token_id)) + attention_mask = get_attention_mask(tokens, has_bos, bos_token_id=bos_token_id) if inference_params is None: @@ -355,21 +374,21 @@ def _streaming_no_beam_search_generate( scores = scores[0] scores = scores[:, -1].float() inference_params.sequence_len_offset += tokens.size(1) - if eos_token_id is not None: - scores[:, eos_token_id] = -1e12 + if eos_token_id is not None and min_new_tokens > 0: + scores[:, eos_token_id] = -float("inf") # The first token generated. next_tokens = scores.argmax(dim=-1, keepdim=True) - token_ids = torch.cat([tokens, next_tokens], dim=1) - yield token_ids + dones = next_tokens.new_zeros(batch_size, 1).eq(1) + if eos_token_id is not None: + end_mask = torch.any(next_tokens[:, None].eq(eos_token_id), dim=-1) + dones = dones.__or__(end_mask) + token_ids = torch.cat([tokens, next_tokens], dim=1) cur_len = token_ids.size(1) - dones = token_ids.new_zeros(batch_size).eq(1) - - real_max_length = max_length max_lengths = tokens.new_full((tokens.size(0),), fill_value=max_length, dtype=torch.long) - while cur_len < real_max_length: + while cur_len < real_max_length and dones.min() != 1: # batch_size x vocab_size attention_mask = get_attention_mask(token_ids, has_bos, bos_token_id=bos_token_id) @@ -391,6 +410,9 @@ def _streaming_no_beam_search_generate( scores = scores[:, -1].float() inference_params.sequence_len_offset += 1 + if eos_token_id is not None and cur_len < real_min_length: + scores[..., eos_token_id] = -float("inf") + if repetition_penalty != 1.0: token_scores = scores.gather(dim=1, index=token_ids) lt_zero_mask = token_scores.lt(0).float() @@ -422,8 +444,9 @@ def _streaming_no_beam_search_generate( if eos_token_id is not None: # When the generated result exceeds the length, its eos_token_id is set to the most basic terminator. next_tokens = next_tokens.masked_fill(max_lengths.eq(cur_len + 1), eos_token_id[0]) - next_tokens = next_tokens.masked_fill(dones, pad_token_id) + tokens = next_tokens.unsqueeze(1) + tokens = tokens.masked_fill(dones, pad_token_id) token_ids = torch.cat([token_ids, tokens], dim=-1) # batch_size x max_len yield token_ids @@ -461,8 +484,12 @@ def _no_beam_search_generate( length_penalty=1.0, pad_token_id=0, bos_token_id=1, + min_new_tokens=1, ): - batch_size = tokens.size(0) + + batch_size, cur_len = tokens.shape + real_max_length = max_length + real_min_length = cur_len + min_new_tokens if eos_token_id is not None: if not isinstance(eos_token_id, (List, Tuple)): eos_token_id = [eos_token_id] @@ -472,6 +499,7 @@ def _no_beam_search_generate( eos_token_id.extend(additional_eos_token_list) eos_token_id = torch.LongTensor(eos_token_id).to(tokens.device) + assert bos_token_id == pad_token_id, "bos_token_id should be equal to left pad_token_id!" has_bos = torch.all(tokens[:, 0].eq(bos_token_id)) attention_mask = get_attention_mask(tokens, has_bos, bos_token_id) @@ -502,13 +530,14 @@ def _no_beam_search_generate( if isinstance(scores, (list, tuple)): scores = scores[0] scores = scores[:, -1].float() - if eos_token_id is not None: - scores[:, eos_token_id] = -1e12 + if eos_token_id is not None and min_new_tokens > 0: + scores[:, eos_token_id] = -float("inf") # The first token generated. next_tokens = scores.argmax(dim=-1, keepdim=True) else: next_tokens = tokens.new_zeros([batch_size, 1]) + if gpc.is_initialized(ParallelMode.PIPELINE): # broadcast to other rank in PP group torch.distributed.broadcast( @@ -516,16 +545,21 @@ def _no_beam_search_generate( src=gpc.get_ranks_in_group(ParallelMode.PIPELINE)[-1], group=gpc.get_group(ParallelMode.PIPELINE), ) + + dones = next_tokens.new_zeros(batch_size, 1).eq(1) + + if eos_token_id is not None: + end_mask = torch.any(next_tokens[:, None].eq(eos_token_id), dim=-1) + dones = dones.__or__(end_mask) token_ids = torch.cat([tokens, next_tokens], dim=1) cur_len = token_ids.size(1) - dones = token_ids.new_zeros(batch_size).eq(1) inference_params.sequence_len_offset += tokens.size(1) real_max_length = max_length max_lengths = tokens.new_full((tokens.size(0),), fill_value=max_length, dtype=torch.long) - while cur_len < real_max_length: + while cur_len < real_max_length and dones.min() != 1: # batch_size x vocab_size attention_mask = get_attention_mask(token_ids, has_bos, bos_token_id=bos_token_id) @@ -543,6 +577,10 @@ def _no_beam_search_generate( raise NotImplementedError(f"Unsupported decoder type: {type(decoder)}") inference_params.sequence_len_offset += 1 + + if eos_token_id is not None and cur_len < real_min_length: + scores[..., eos_token_id] = -float("inf") + if gpc.is_last_rank(ParallelMode.PIPELINE): if isinstance(scores, (list, tuple)): scores = scores[0] @@ -585,15 +623,17 @@ def _no_beam_search_generate( src=gpc.get_ranks_in_group(ParallelMode.PIPELINE)[-1], group=gpc.get_group(ParallelMode.PIPELINE), ) + if eos_token_id is not None: # When the generated result exceeds the length, its eos_token_id is set to the most basic terminator. next_tokens = next_tokens.masked_fill(max_lengths.eq(cur_len + 1), eos_token_id[0]) - next_tokens = next_tokens.masked_fill(dones, pad_token_id) + tokens = next_tokens.unsqueeze(1) + tokens = tokens.masked_fill(dones, pad_token_id) token_ids = torch.cat([token_ids, tokens], dim=-1) # batch_size x max_len if eos_token_id is not None: - end_mask = torch.any(next_tokens[:, None].eq(eos_token_id), dim=-1) + end_mask = torch.any(tokens[:, None].eq(eos_token_id), dim=-1) dones = dones.__or__(end_mask) cur_len += 1 @@ -630,10 +670,13 @@ def _beam_search_generate( length_penalty=1.0, pad_token_id=0, bos_token_id=1, + min_new_tokens=1, ) -> torch.LongTensor: device = tokens.device - batch_size = tokens.size(0) + batch_size, cur_len = tokens.shape + real_max_length = max_length + real_min_length = cur_len + min_new_tokens if eos_token_id is not None: if not isinstance(eos_token_id, (List, Tuple)): @@ -644,6 +687,7 @@ def _beam_search_generate( eos_token_id.extend(additional_eos_token_list) eos_token_id = torch.LongTensor(eos_token_id).to(tokens.device) + assert bos_token_id == pad_token_id, "bos_token_id should be equal to left pad_token_id!" has_bos = torch.all(tokens[:, 0].eq(bos_token_id)) attention_mask = get_attention_mask(tokens, has_bos, bos_token_id=bos_token_id) @@ -676,8 +720,8 @@ def _beam_search_generate( if isinstance(scores, (list, tuple)): scores = scores[0] scores = scores[:, -1].float() - if eos_token_id is not None: - scores[:, eos_token_id] = -1e12 + if eos_token_id is not None and min_new_tokens > 0: + scores[:, eos_token_id] = -float("inf") vocab_size = scores.size(1) assert vocab_size >= num_beams, "num_beams should be smaller than " "the number of vocabulary size." @@ -724,7 +768,6 @@ def _beam_search_generate( cur_len = token_ids.size(1) - real_max_length = max_length max_lengths = tokens.new_full((tokens.size(0),), fill_value=max_length, dtype=torch.long) hypos = [ BeamHypotheses(num_beams, real_max_length, length_penalty, early_stopping=False) for _ in range(batch_size) @@ -732,10 +775,15 @@ def _beam_search_generate( # 0, num_beams, 2*num_beams, ... batch_inds_with_numbeams_interval = (torch.arange(batch_size) * num_beams).view(-1, 1).to(token_ids) - while cur_len < real_max_length: - attention_mask = get_attention_mask(token_ids, has_bos, bos_token_id=bos_token_id) + for batch_idx in range(batch_size): + dones[batch_idx] = ( + dones[batch_idx] + or hypos[batch_idx].is_done(next_scores[batch_idx, 0].item()) + or max_lengths[batch_idx * num_beams] == cur_len + 1 + ) - # (bsz x num_beams, vocab_size) + while cur_len < real_max_length and not all(dones): + attention_mask = get_attention_mask(token_ids, has_bos, bos_token_id=bos_token_id) if isinstance(decoder, torch.nn.Module): inference_params.attention_mask = attention_mask @@ -757,6 +805,10 @@ def _beam_search_generate( if isinstance(scores, (list, tuple)): scores = scores[0] scores = scores[:, -1].float() + inference_params.sequence_len_offset += 1 + if eos_token_id is not None and cur_len < real_min_length: + scores[..., eos_token_id] = -float("inf") + if repetition_penalty != 1.0: token_scores = scores.gather(dim=1, index=token_ids) lt_zero_mask = token_scores.lt(0).float() @@ -764,6 +816,7 @@ def _beam_search_generate( token_scores = ( lt_zero_mask * repetition_penalty * token_scores + ge_zero_mask / repetition_penalty * token_scores ) + scores.scatter_(dim=1, index=token_ids, src=token_scores) if eos_token_id is not None: @@ -993,6 +1046,7 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("Inf") # scatter sorted tensors to original indexing indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) logits[indices_to_remove] = filter_value + return logits diff --git a/internlm/model/modules/embedding.py b/internlm/model/modules/embedding.py index d65096e2..5c4e9f65 100644 --- a/internlm/model/modules/embedding.py +++ b/internlm/model/modules/embedding.py @@ -146,10 +146,10 @@ def _convert_padding( if convert_type == "left2right": ret[i][: -empties[i]] = x[i][empties[i] :] - ret[i][empties[i] :] = x[i][: -empties[i]] + ret[i][-empties[i] :] = x[i][: empties[i]] else: # right2left ret[i][empties[i] :] = x[i][: -empties[i]] - ret[i][: -empties[i]] = x[i][empties[i] :] + ret[i][: empties[i]] = x[i][-empties[i] :] return ret diff --git a/internlm/model/modules/mha.py b/internlm/model/modules/mha.py index 459efb96..0aa6ed93 100644 --- a/internlm/model/modules/mha.py +++ b/internlm/model/modules/mha.py @@ -271,8 +271,12 @@ def _inference(self, x, inference_params, **kwargs): # pylint: disable=W0613 empties = attention_mask[..., -1].sum(dim=-1) indexes4q = sequence_len_offset * torch.ones(q.size(0), dtype=torch.int, device=q.device) - empties indexes4k = sequence_len_offset * torch.ones(k.size(0), dtype=torch.int, device=k.device) - empties + q = rearrange(q, "b s h d -> s b h d", d=self.head_dim) # pack input + k = rearrange(k, "b s h d -> s b h d", d=self.head_dim) q = self.rotary_emb(q, offsets=indexes4q, cache_type="query", interleaved=self.interleaved) k = self.rotary_emb(k, offsets=indexes4k, cache_type="key", interleaved=self.interleaved) + q = rearrange(q, "s b h d -> b s h d", d=self.head_dim) # unpack + k = rearrange(k, "s b h d -> b s h d", d=self.head_dim) kv = torch.stack([k, v], dim=2) # update kv cache after rotary embedding when disable dynamic ntk rope. From 4fdeb3cba9d7218e2e7134db3128ca2d827fa539 Mon Sep 17 00:00:00 2001 From: x54-729 Date: Tue, 3 Sep 2024 12:37:29 +0000 Subject: [PATCH 2/3] fix for GQA --- internlm/model/modules/mha.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/internlm/model/modules/mha.py b/internlm/model/modules/mha.py index 0aa6ed93..6bb75c52 100644 --- a/internlm/model/modules/mha.py +++ b/internlm/model/modules/mha.py @@ -271,6 +271,8 @@ def _inference(self, x, inference_params, **kwargs): # pylint: disable=W0613 empties = attention_mask[..., -1].sum(dim=-1) indexes4q = sequence_len_offset * torch.ones(q.size(0), dtype=torch.int, device=q.device) - empties indexes4k = sequence_len_offset * torch.ones(k.size(0), dtype=torch.int, device=k.device) - empties + # TODO To fit flash_attn apis, we rearrange q&k to pack them here and + # calculate rope for this batch input. Waiting to be optimized q = rearrange(q, "b s h d -> s b h d", d=self.head_dim) # pack input k = rearrange(k, "b s h d -> s b h d", d=self.head_dim) q = self.rotary_emb(q, offsets=indexes4q, cache_type="query", interleaved=self.interleaved) @@ -538,8 +540,14 @@ def _inference(self, x, inference_params, **kwargs): # pylint: disable=W0613 empties = attention_mask[..., -1].sum(dim=-1) indexes4q = sequence_len_offset * torch.ones(q.size(0), dtype=torch.int, device=q.device) - empties indexes4k = sequence_len_offset * torch.ones(k.size(0), dtype=torch.int, device=k.device) - empties + # TODO To fit flash_attn apis, we rearrange q&k to pack them here and + # calculate rope for this batch input. Waiting to be optimized + q = rearrange(q, "b s h d -> s b h d", d=self.head_dim) # pack input + k = rearrange(k, "b s h d -> s b h d", d=self.head_dim) q = self.rotary_emb(q, offsets=indexes4q, cache_type="query", interleaved=self.interleaved) k = self.rotary_emb(k, offsets=indexes4k, cache_type="key", interleaved=self.interleaved) + q = rearrange(q, "s b h d -> b s h d", d=self.head_dim) # unpack + k = rearrange(k, "s b h d -> b s h d", d=self.head_dim) kv = torch.stack([k, v], dim=2) From 7a5b4f6006df069443217e817ac44e337fdfcc9f Mon Sep 17 00:00:00 2001 From: x54-729 Date: Tue, 3 Sep 2024 12:37:50 +0000 Subject: [PATCH 3/3] change default pad_token_id for generation --- internlm/apis/inference.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/internlm/apis/inference.py b/internlm/apis/inference.py index f78e70d6..a45cf27d 100644 --- a/internlm/apis/inference.py +++ b/internlm/apis/inference.py @@ -170,7 +170,7 @@ def greedy_generate( eos_token_id=None, additional_eos_token_list=None, add_eos_when_return=False, - pad_token_id=0, + pad_token_id=1, repetition_penalty=1, length_penalty=1.0, bos_token_id=1, @@ -246,7 +246,7 @@ def sample_generate( eos_token_id=None, additional_eos_token_list=None, add_eos_when_return=False, - pad_token_id=0, + pad_token_id=1, repetition_penalty=1.0, length_penalty=1.0, bos_token_id=1, @@ -327,7 +327,7 @@ def _streaming_no_beam_search_generate( do_sample=True, repetition_penalty=1.0, length_penalty=1.0, - pad_token_id=0, + pad_token_id=1, bos_token_id=1, min_new_tokens=1, ): @@ -482,7 +482,7 @@ def _no_beam_search_generate( do_sample=True, repetition_penalty=1.0, length_penalty=1.0, - pad_token_id=0, + pad_token_id=1, bos_token_id=1, min_new_tokens=1, ): @@ -668,7 +668,7 @@ def _beam_search_generate( do_sample=True, repetition_penalty=1.0, length_penalty=1.0, - pad_token_id=0, + pad_token_id=1, bos_token_id=1, min_new_tokens=1, ) -> torch.LongTensor: