From 83c206c60d349834ef98bd17a5fb1209b62bf842 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Thu, 1 Aug 2024 19:48:24 -0700 Subject: [PATCH] [TP] Infer local n_heads instead of ad-hoc model changes ghstack-source-id: 77baf1a00d48e781d89c8bcd952661b3b51dccc5 Pull Request resolved: https://github.com/pytorch/torchtitan/pull/498 --- torchtitan/models/llama/model.py | 9 ++++++--- torchtitan/parallelisms/parallelize_llama.py | 5 ----- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/torchtitan/models/llama/model.py b/torchtitan/models/llama/model.py index 49cda624..4f5529a6 100644 --- a/torchtitan/models/llama/model.py +++ b/torchtitan/models/llama/model.py @@ -190,9 +190,12 @@ def forward( bs, seqlen, _ = x.shape xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) - xq = xq.view(bs, seqlen, self.n_heads, self.head_dim) - xk = xk.view(bs, seqlen, self.n_kv_heads, self.head_dim) - xv = xv.view(bs, seqlen, self.n_kv_heads, self.head_dim) + # Use -1 instead of `n_heads` (or `n_kv_heads`) to infer the actual + # local heads from sizes of xq, xk, and xv as TP may shard them after + # the above linear ops. + xq = xq.view(bs, seqlen, -1, self.head_dim) + xk = xk.view(bs, seqlen, -1, self.head_dim) + xv = xv.view(bs, seqlen, -1, self.head_dim) xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index be627432..5c3f1614 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -383,11 +383,6 @@ def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig): "ffn_norm": SequenceParallel(), } - # Adjust attention module to use the local number of heads - attn_layer = transformer_block.attention - attn_layer.n_heads = attn_layer.n_heads // tp_mesh.size() - attn_layer.n_kv_heads = attn_layer.n_kv_heads // tp_mesh.size() - parallelize_module( module=transformer_block, device_mesh=tp_mesh,