Skip to content

Commit

Permalink
[TP] Infer local n_heads instead of ad-hoc model changes
Browse files Browse the repository at this point in the history
ghstack-source-id: 587e3d6e5270714ca734b8031ce41a962e6394ea
Pull Request resolved: #498
  • Loading branch information
kwen2501 committed Aug 2, 2024
1 parent e457deb commit 72a1614
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 8 deletions.
9 changes: 6 additions & 3 deletions torchtitan/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,9 +190,12 @@ def forward(
bs, seqlen, _ = x.shape
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)

xq = xq.view(bs, seqlen, self.n_heads, self.head_dim)
xk = xk.view(bs, seqlen, self.n_kv_heads, self.head_dim)
xv = xv.view(bs, seqlen, self.n_kv_heads, self.head_dim)
# Use -1 instead of `n_heads` (or `n_kv_heads`) to infer the actual
# local heads from sizes of xq, xk, and xv as TP may have sharded them
# after the above linear ops.
xq = xq.view(bs, seqlen, -1, self.head_dim)
xk = xk.view(bs, seqlen, -1, self.head_dim)
xv = xv.view(bs, seqlen, -1, self.head_dim)

xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)

Expand Down
5 changes: 0 additions & 5 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,11 +398,6 @@ def apply_tp(
"feed_forward.w3": colwise_parallel_weight(),
}

# Adjust attention module to use the local number of heads
attn_layer = transformer_block.attention
attn_layer.n_heads = attn_layer.n_heads // tp_mesh.size()
attn_layer.n_kv_heads = attn_layer.n_kv_heads // tp_mesh.size()

parallelize_module(
module=transformer_block,
device_mesh=tp_mesh,
Expand Down

0 comments on commit 72a1614

Please sign in to comment.