Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(model): fix bugs of batch generation & support min_new_tokens for inference #313

Merged
merged 3 commits into from
Sep 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 86 additions & 32 deletions internlm/apis/inference.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-

# Copyright (c) InternLM. All rights reserved.
from typing import Dict, List, Tuple, Union

import torch
Expand Down Expand Up @@ -67,6 +65,7 @@ def generate(
top_p: float = 1.0,
repetition_penalty: float = 1,
length_penalty: float = 1.0,
min_new_tokens: int = 1,
):
"""
Args:
Expand All @@ -79,6 +78,9 @@ def generate(
temperature: it's meaningful when do_sample is True.
top_k: sampling from top_k.
top_p: sampling from top_p tokens(nucleus sampling).
repetition_penalty: the penalty degree for repetition tokens
length_penalty: the penalty for length.
min_new_tokens: minimum number of generated tokens.

Return:
the token sequence whose shape is [bsz, num_return_sequences, max_length]. If eos_token_id is not None,
Expand All @@ -103,6 +105,7 @@ def generate(
length_penalty=length_penalty, # the penalty for length. if it > 1, then encourages long sequence.
# Otherwise, encourages short sequence.
bos_token_id=self.bos_token_id,
min_new_tokens=min_new_tokens,
)
else:
return greedy_generate(
Expand All @@ -118,6 +121,7 @@ def generate(
repetition_penalty=repetition_penalty,
length_penalty=length_penalty,
bos_token_id=self.bos_token_id,
min_new_tokens=min_new_tokens,
)

@torch.no_grad()
Expand All @@ -131,6 +135,7 @@ def streaming_generate(
top_p: float = 1.0,
repetition_penalty: float = 1,
length_penalty: float = 1.0,
min_new_tokens: int = 1,
):
if not do_sample:
temperature = 1
Expand All @@ -151,6 +156,7 @@ def streaming_generate(
length_penalty=length_penalty,
pad_token_id=self.pad_token_id,
bos_token_id=self.bos_token_id,
min_new_tokens=min_new_tokens,
)


Expand All @@ -164,10 +170,11 @@ def greedy_generate(
eos_token_id=None,
additional_eos_token_list=None,
add_eos_when_return=False,
pad_token_id=0,
pad_token_id=1,
repetition_penalty=1,
length_penalty=1.0,
bos_token_id=1,
min_new_tokens=1,
):
"""
Search sequence greedily.
Expand All @@ -181,6 +188,7 @@ def greedy_generate(
pad_token_id: the token id of pad.
repetition_penalty: the penalty degree for repetition tokens
length_penalty: the penalty for length.
min_new_tokens: minimum number of generated tokens.

"""
if num_beams == 1:
Expand All @@ -199,6 +207,7 @@ def greedy_generate(
length_penalty=length_penalty,
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
min_new_tokens=min_new_tokens,
)
else:
token_ids = _beam_search_generate(
Expand All @@ -218,6 +227,7 @@ def greedy_generate(
length_penalty=length_penalty,
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
min_new_tokens=min_new_tokens,
)

return token_ids
Expand All @@ -236,10 +246,11 @@ def sample_generate(
eos_token_id=None,
additional_eos_token_list=None,
add_eos_when_return=False,
pad_token_id=0,
pad_token_id=1,
repetition_penalty=1.0,
length_penalty=1.0,
bos_token_id=1,
min_new_tokens=1,
):
"""
generate sequence in sampling way.
Expand All @@ -257,6 +268,7 @@ def sample_generate(
pad_token_id: the token id of pad.
repetition_penalty: the penalty degree for repetition tokens
length_penalty: the penalty for length.
min_new_tokens: minimum number of generated tokens.

"""
if num_beams == 1:
Expand All @@ -275,6 +287,7 @@ def sample_generate(
length_penalty=length_penalty,
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
min_new_tokens=min_new_tokens,
)
else:
token_ids = _beam_search_generate(
Expand All @@ -294,6 +307,7 @@ def sample_generate(
length_penalty=length_penalty,
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
min_new_tokens=min_new_tokens,
)
return token_ids

Expand All @@ -313,10 +327,13 @@ def _streaming_no_beam_search_generate(
do_sample=True,
repetition_penalty=1.0,
length_penalty=1.0,
pad_token_id=0,
pad_token_id=1,
bos_token_id=1,
min_new_tokens=1,
):
batch_size = tokens.size(0)
batch_size, cur_len = tokens.shape
real_max_length = max_length
real_min_length = cur_len + min_new_tokens
if eos_token_id is not None:
if not isinstance(eos_token_id, (List, Tuple)):
eos_token_id = [eos_token_id]
Expand All @@ -326,7 +343,9 @@ def _streaming_no_beam_search_generate(
eos_token_id.extend(additional_eos_token_list)
eos_token_id = torch.LongTensor(eos_token_id).to(tokens.device)

assert bos_token_id == pad_token_id, "bos_token_id should be equal to left pad_token_id!"
has_bos = torch.all(tokens[:, 0].eq(bos_token_id))

attention_mask = get_attention_mask(tokens, has_bos, bos_token_id=bos_token_id)

if inference_params is None:
Expand Down Expand Up @@ -355,21 +374,21 @@ def _streaming_no_beam_search_generate(
scores = scores[0]
scores = scores[:, -1].float()
inference_params.sequence_len_offset += tokens.size(1)
if eos_token_id is not None:
scores[:, eos_token_id] = -1e12
if eos_token_id is not None and min_new_tokens > 0:
scores[:, eos_token_id] = -float("inf")

# The first token generated.
next_tokens = scores.argmax(dim=-1, keepdim=True)
token_ids = torch.cat([tokens, next_tokens], dim=1)
yield token_ids
dones = next_tokens.new_zeros(batch_size, 1).eq(1)

if eos_token_id is not None:
end_mask = torch.any(next_tokens[:, None].eq(eos_token_id), dim=-1)
dones = dones.__or__(end_mask)
token_ids = torch.cat([tokens, next_tokens], dim=1)
cur_len = token_ids.size(1)
dones = token_ids.new_zeros(batch_size).eq(1)

real_max_length = max_length
max_lengths = tokens.new_full((tokens.size(0),), fill_value=max_length, dtype=torch.long)

while cur_len < real_max_length:
while cur_len < real_max_length and dones.min() != 1:
# batch_size x vocab_size
attention_mask = get_attention_mask(token_ids, has_bos, bos_token_id=bos_token_id)

Expand All @@ -391,6 +410,9 @@ def _streaming_no_beam_search_generate(
scores = scores[:, -1].float()
inference_params.sequence_len_offset += 1

if eos_token_id is not None and cur_len < real_min_length:
scores[..., eos_token_id] = -float("inf")

if repetition_penalty != 1.0:
token_scores = scores.gather(dim=1, index=token_ids)
lt_zero_mask = token_scores.lt(0).float()
Expand Down Expand Up @@ -422,8 +444,9 @@ def _streaming_no_beam_search_generate(
if eos_token_id is not None:
# When the generated result exceeds the length, its eos_token_id is set to the most basic terminator.
next_tokens = next_tokens.masked_fill(max_lengths.eq(cur_len + 1), eos_token_id[0])
next_tokens = next_tokens.masked_fill(dones, pad_token_id)

tokens = next_tokens.unsqueeze(1)
tokens = tokens.masked_fill(dones, pad_token_id)
token_ids = torch.cat([token_ids, tokens], dim=-1) # batch_size x max_len

yield token_ids
Expand Down Expand Up @@ -459,10 +482,14 @@ def _no_beam_search_generate(
do_sample=True,
repetition_penalty=1.0,
length_penalty=1.0,
pad_token_id=0,
pad_token_id=1,
bos_token_id=1,
min_new_tokens=1,
):
batch_size = tokens.size(0)

batch_size, cur_len = tokens.shape
real_max_length = max_length
real_min_length = cur_len + min_new_tokens
if eos_token_id is not None:
if not isinstance(eos_token_id, (List, Tuple)):
eos_token_id = [eos_token_id]
Expand All @@ -472,6 +499,7 @@ def _no_beam_search_generate(
eos_token_id.extend(additional_eos_token_list)
eos_token_id = torch.LongTensor(eos_token_id).to(tokens.device)

assert bos_token_id == pad_token_id, "bos_token_id should be equal to left pad_token_id!"
has_bos = torch.all(tokens[:, 0].eq(bos_token_id))

attention_mask = get_attention_mask(tokens, has_bos, bos_token_id)
Expand Down Expand Up @@ -502,30 +530,36 @@ def _no_beam_search_generate(
if isinstance(scores, (list, tuple)):
scores = scores[0]
scores = scores[:, -1].float()
if eos_token_id is not None:
scores[:, eos_token_id] = -1e12
if eos_token_id is not None and min_new_tokens > 0:
scores[:, eos_token_id] = -float("inf")

# The first token generated.
next_tokens = scores.argmax(dim=-1, keepdim=True)
else:
next_tokens = tokens.new_zeros([batch_size, 1])

if gpc.is_initialized(ParallelMode.PIPELINE):
# broadcast to other rank in PP group
torch.distributed.broadcast(
next_tokens,
src=gpc.get_ranks_in_group(ParallelMode.PIPELINE)[-1],
group=gpc.get_group(ParallelMode.PIPELINE),
)

dones = next_tokens.new_zeros(batch_size, 1).eq(1)

if eos_token_id is not None:
end_mask = torch.any(next_tokens[:, None].eq(eos_token_id), dim=-1)
dones = dones.__or__(end_mask)
token_ids = torch.cat([tokens, next_tokens], dim=1)
cur_len = token_ids.size(1)
dones = token_ids.new_zeros(batch_size).eq(1)

inference_params.sequence_len_offset += tokens.size(1)

real_max_length = max_length
max_lengths = tokens.new_full((tokens.size(0),), fill_value=max_length, dtype=torch.long)

while cur_len < real_max_length:
while cur_len < real_max_length and dones.min() != 1:
# batch_size x vocab_size
attention_mask = get_attention_mask(token_ids, has_bos, bos_token_id=bos_token_id)

Expand All @@ -543,6 +577,10 @@ def _no_beam_search_generate(
raise NotImplementedError(f"Unsupported decoder type: {type(decoder)}")

inference_params.sequence_len_offset += 1

if eos_token_id is not None and cur_len < real_min_length:
scores[..., eos_token_id] = -float("inf")

if gpc.is_last_rank(ParallelMode.PIPELINE):
if isinstance(scores, (list, tuple)):
scores = scores[0]
Expand Down Expand Up @@ -585,15 +623,17 @@ def _no_beam_search_generate(
src=gpc.get_ranks_in_group(ParallelMode.PIPELINE)[-1],
group=gpc.get_group(ParallelMode.PIPELINE),
)

if eos_token_id is not None:
# When the generated result exceeds the length, its eos_token_id is set to the most basic terminator.
next_tokens = next_tokens.masked_fill(max_lengths.eq(cur_len + 1), eos_token_id[0])
next_tokens = next_tokens.masked_fill(dones, pad_token_id)

tokens = next_tokens.unsqueeze(1)
tokens = tokens.masked_fill(dones, pad_token_id)
token_ids = torch.cat([token_ids, tokens], dim=-1) # batch_size x max_len

if eos_token_id is not None:
end_mask = torch.any(next_tokens[:, None].eq(eos_token_id), dim=-1)
end_mask = torch.any(tokens[:, None].eq(eos_token_id), dim=-1)
dones = dones.__or__(end_mask)

cur_len += 1
Expand Down Expand Up @@ -628,12 +668,15 @@ def _beam_search_generate(
do_sample=True,
repetition_penalty=1.0,
length_penalty=1.0,
pad_token_id=0,
pad_token_id=1,
bos_token_id=1,
min_new_tokens=1,
) -> torch.LongTensor:

device = tokens.device
batch_size = tokens.size(0)
batch_size, cur_len = tokens.shape
real_max_length = max_length
real_min_length = cur_len + min_new_tokens

if eos_token_id is not None:
if not isinstance(eos_token_id, (List, Tuple)):
Expand All @@ -644,6 +687,7 @@ def _beam_search_generate(
eos_token_id.extend(additional_eos_token_list)
eos_token_id = torch.LongTensor(eos_token_id).to(tokens.device)

assert bos_token_id == pad_token_id, "bos_token_id should be equal to left pad_token_id!"
has_bos = torch.all(tokens[:, 0].eq(bos_token_id))

attention_mask = get_attention_mask(tokens, has_bos, bos_token_id=bos_token_id)
Expand Down Expand Up @@ -676,8 +720,8 @@ def _beam_search_generate(
if isinstance(scores, (list, tuple)):
scores = scores[0]
scores = scores[:, -1].float()
if eos_token_id is not None:
scores[:, eos_token_id] = -1e12
if eos_token_id is not None and min_new_tokens > 0:
scores[:, eos_token_id] = -float("inf")
vocab_size = scores.size(1)
assert vocab_size >= num_beams, "num_beams should be smaller than " "the number of vocabulary size."

Expand Down Expand Up @@ -724,18 +768,22 @@ def _beam_search_generate(

cur_len = token_ids.size(1)

real_max_length = max_length
max_lengths = tokens.new_full((tokens.size(0),), fill_value=max_length, dtype=torch.long)
hypos = [
BeamHypotheses(num_beams, real_max_length, length_penalty, early_stopping=False) for _ in range(batch_size)
]
# 0, num_beams, 2*num_beams, ...
batch_inds_with_numbeams_interval = (torch.arange(batch_size) * num_beams).view(-1, 1).to(token_ids)

while cur_len < real_max_length:
attention_mask = get_attention_mask(token_ids, has_bos, bos_token_id=bos_token_id)
for batch_idx in range(batch_size):
dones[batch_idx] = (
dones[batch_idx]
or hypos[batch_idx].is_done(next_scores[batch_idx, 0].item())
or max_lengths[batch_idx * num_beams] == cur_len + 1
)

# (bsz x num_beams, vocab_size)
while cur_len < real_max_length and not all(dones):
attention_mask = get_attention_mask(token_ids, has_bos, bos_token_id=bos_token_id)

if isinstance(decoder, torch.nn.Module):
inference_params.attention_mask = attention_mask
Expand All @@ -757,13 +805,18 @@ def _beam_search_generate(
if isinstance(scores, (list, tuple)):
scores = scores[0]
scores = scores[:, -1].float()
inference_params.sequence_len_offset += 1
if eos_token_id is not None and cur_len < real_min_length:
scores[..., eos_token_id] = -float("inf")

if repetition_penalty != 1.0:
token_scores = scores.gather(dim=1, index=token_ids)
lt_zero_mask = token_scores.lt(0).float()
ge_zero_mask = lt_zero_mask.eq(0).float()
token_scores = (
lt_zero_mask * repetition_penalty * token_scores + ge_zero_mask / repetition_penalty * token_scores
)

scores.scatter_(dim=1, index=token_ids, src=token_scores)

if eos_token_id is not None:
Expand Down Expand Up @@ -993,6 +1046,7 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("Inf")
# scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
logits[indices_to_remove] = filter_value

return logits


Expand Down
Loading
Loading