Skip to content

Commit

Permalink
dynamically update torch.compile cache config to ensure async tp supp…
Browse files Browse the repository at this point in the history
…ort, enhance async tp UX (#471)

This PR adds some enhancements for supporting async tp:

1 - if async tp is active, auto updates the torch.dynamo cache limit to
10K. If this is not updated, async tp will not be activated on larger
models as it will quietly stop compilation due to 'cache limit reached'
with no info for the user.
This config update is logged. 

2 - if async tp is enabled, verifies that torch.compile is set to true
for this job config. If not, it warns and then activates torch.compile
to ensure user gets working async tp. (see WARNING in below screenshot)

<img width="1345" alt="Screenshot 2024-07-20 at 4 33 04 PM"
src="https://github.com/user-attachments/assets/26e5a48e-4bb8-4f33-b1b5-8939c1517c1d">

3 - Updates the 'Applied Tensor Parallel' to the model to be 'Applied
Async Tensor Parallel' when async tp is active to make it clear in the
logs which TP is active. (see above screenshot)
  • Loading branch information
lessw2020 authored Jul 21, 2024
1 parent d76b77f commit 0f70507
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 1 deletion.
16 changes: 15 additions & 1 deletion torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,13 +413,27 @@ def apply_tp(
parallelize_plan=layer_plan,
)

# updates expressly for async tensor parallel
if job_config.experimental.enable_async_tensor_parallel:
from torch.distributed._symmetric_memory import enable_symm_mem_for_group

torch._dynamo.config.cache_size_limit = 10000
logger.info(
"Updating torch._dynamo.config.cache_size_limit to 10000 to support Async TP"
)

torch._inductor.config._micro_pipeline_tp = True
enable_symm_mem_for_group(tp_mesh.get_group().group_name)

logger.info("Applied Tensor Parallelism to the model")
if not job_config.training.compile:
logger.warning(
"Async TP requires compilation...auto enabling compile = True for this job to resolve."
)
job_config.training.compile = True

logger.info(
f"Applied{' Async ' if job_config.experimental.enable_async_tensor_parallel else ' '}Tensor Parallelism to the model"
)
return model


Expand Down
1 change: 1 addition & 0 deletions train_configs/debug_model.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ dataset = "c4_mini" # supported datasets: c4_mini (45K), c4 (177M)

[experimental]
pipeline_parallel_degree = 1
enable_async_tensor_parallel = false

[checkpoint]
enable_checkpoint = false
Expand Down

0 comments on commit 0f70507

Please sign in to comment.