Skip to content

Commit

Permalink
fix(embedding): fix incorrect computing of indexes in _update_cos_sin…
Browse files Browse the repository at this point in the history
…_cache (#311)
  • Loading branch information
li126com authored Sep 2, 2024
1 parent 673f1a2 commit 408be4b
Showing 1 changed file with 6 additions and 8 deletions.
14 changes: 6 additions & 8 deletions internlm/model/modules/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,13 +97,11 @@ def _update_cos_sin_cache(
if max_seqlen is not None:
seqlen = max_seqlen
elif isinstance(indexes, int):
# logic changed temporaryly
# seqlen = indexes + x.shape[1] + 1
seqlen = gpc.config.data.seq_len
seqlen = indexes + x.shape[1]
else:
# Note that this statement may cause synchronization between CPU and GPU,
# so it's best to precompute and pass in max_seqlen ahead of time
seqlen = indexes.max().item() + 1
seqlen = indexes.max().item()

# Reset the tables if the sequence length has changed,
# or if we're on a new device (possibly due to tracing for instance)
Expand Down Expand Up @@ -219,9 +217,9 @@ def __init__(
def _update_cos_sin_cache(self, x, indexes):
"""x: (batch, seqlen, nheads, headdim)"""
if not isinstance(indexes, int):
seqlen = indexes.max().item() + 1
seqlen = indexes.max().item()
else:
seqlen = indexes + x.shape[1] + 1
seqlen = indexes + x.shape[1]

t = torch.arange(seqlen, device=x.device, dtype=self.inv_freq.dtype)
t = t / self.scaling_factor
Expand Down Expand Up @@ -286,9 +284,9 @@ def _update(self, seqlen, x):
def _update_cos_sin_cache(self, x, indexes):
"""x: (batch, seqlen, nheads, headdim)"""
if not isinstance(indexes, int):
seqlen = indexes.max().item() + 1
seqlen = indexes.max().item()
else:
seqlen = indexes + x.shape[1] + 1 # eval_forward
seqlen = indexes + x.shape[1] # eval_forward
if seqlen <= self.max_position_embeddings:
# Reset the tables if the sequence length has changed,
# or if we're on a new device (possibly due to tracing for instance)
Expand Down

0 comments on commit 408be4b

Please sign in to comment.