Skip to content

Commit

Permalink
Moved logits .float() to loss and compiled it if compiling
Browse files Browse the repository at this point in the history
ghstack-source-id: 99a696d59af53f173d0af0b5c589056b4d76c7de
Pull Request resolved: #551
  • Loading branch information
awgu committed Aug 23, 2024
1 parent 9515a14 commit 07e51e6
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 3 deletions.
2 changes: 1 addition & 1 deletion torchtitan/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ def forward(self, tokens: torch.Tensor):
h = layer(h, self.freqs_cis)

h = self.norm(h) if self.norm else h
output = self.output(h).float() if self.output else h
output = self.output(h) if self.output else h
return output

@classmethod
Expand Down
7 changes: 5 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def main(job_config: JobConfig):
# loss function to be shared by Pipeline Parallel and SPMD training
def loss_fn(pred, labels):
return torch.nn.functional.cross_entropy(
pred.flatten(0, 1), labels.flatten(0, 1)
pred.flatten(0, 1).float(), labels.flatten(0, 1)
)

# apply parallelisms and initialization
Expand Down Expand Up @@ -289,7 +289,10 @@ def loss_fn(pred, labels):
# Non-PP forward / backward
with train_context():
pred = model(input_ids)
loss = loss_fn(pred, labels)
if job_config.training.compile:
loss = torch.compile(loss_fn)(pred, labels)
else:
loss = loss_fn(pred, labels)
# pred.shape=(bs, seq_len, vocab_size)
# need to free to before bwd to avoid peaking memory
del pred
Expand Down

0 comments on commit 07e51e6

Please sign in to comment.