Skip to content

Commit

Permalink
Fix gpt-neox training accuracy issue.
Browse files Browse the repository at this point in the history
  • Loading branch information
yeonsily committed Oct 4, 2024
1 parent e625fce commit aada5f2
Showing 1 changed file with 3 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -458,8 +458,8 @@ def prepare_inputs_for_generation(
def apply_customized_rope(q, k, cos, sin, position_ids, training=True):
if q.device.type == "hpu" and FusedRoPE:
if training:
rope_q = FusedRoPE.apply(q, cos.unsqueeze(0).unsqueeze(0), sin.unsqueeze(0).unsqueeze(0), position_ids)
rope_k = FusedRoPE.apply(k, cos.unsqueeze(0).unsqueeze(0), sin.unsqueeze(0).unsqueeze(0), position_ids)
rope_q = FusedRoPE.apply(q.to(torch.float), cos.unsqueeze(0).unsqueeze(0), sin.unsqueeze(0).unsqueeze(0), position_ids)
rope_k = FusedRoPE.apply(k.to(torch.float), cos.unsqueeze(0).unsqueeze(0), sin.unsqueeze(0).unsqueeze(0), position_ids)
else:
if q.dtype == torch.bfloat16:
rope_q = FusedRoPE.apply(
Expand All @@ -481,4 +481,4 @@ def apply_customized_rope(q, k, cos, sin, position_ids, training=True):
rope_k = FusedRoPE.apply(k, cos.unsqueeze(0).unsqueeze(0), sin.unsqueeze(0).unsqueeze(0), position_ids)
return rope_q, rope_k
else:
return apply_rotary_pos_emb(q, k, cos[position_ids], sin[position_ids])
return apply_rotary_pos_emb(q.to(torch.float), k.to(torch.float), cos[position_ids], sin[position_ids])

0 comments on commit aada5f2

Please sign in to comment.