From 35c79968e7e510c432b9f903a375d55ec116e364 Mon Sep 17 00:00:00 2001 From: Bowen Bao Date: Thu, 13 Jul 2023 13:43:30 -0700 Subject: [PATCH 1/2] Remove unnecessary model attribute assignment on 'freqs_cis' --- torchbenchmark/models/llama/model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchbenchmark/models/llama/model.py b/torchbenchmark/models/llama/model.py index 1f0d1eb317..4019cf9e32 100644 --- a/torchbenchmark/models/llama/model.py +++ b/torchbenchmark/models/llama/model.py @@ -224,8 +224,8 @@ def forward(self, tokens: torch.Tensor, start_pos: int): h = self.tok_embeddings(tokens) - self.freqs_cis = self.freqs_cis.to(h.device) - freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen] + freqs_cis = self.freqs_cis.to(h.device) + freqs_cis = freqs_cis[start_pos : start_pos + seqlen] mask = None From 4e129911432e9cb2b4fc554ee50980d1a0709b4b Mon Sep 17 00:00:00 2001 From: Bowen Bao Date: Fri, 14 Jul 2023 16:24:56 -0700 Subject: [PATCH 2/2] Update torchbenchmark/models/llama/model.py --- torchbenchmark/models/llama/model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchbenchmark/models/llama/model.py b/torchbenchmark/models/llama/model.py index 4019cf9e32..a01f4cae6a 100644 --- a/torchbenchmark/models/llama/model.py +++ b/torchbenchmark/models/llama/model.py @@ -224,6 +224,7 @@ def forward(self, tokens: torch.Tensor, start_pos: int): h = self.tok_embeddings(tokens) + # Reference: https://github.com/facebookresearch/llama/pull/349 freqs_cis = self.freqs_cis.to(h.device) freqs_cis = freqs_cis[start_pos : start_pos + seqlen]