Skip to content

Commit

Permalink
some compile-related updates
Browse files Browse the repository at this point in the history
ghstack-source-id: 32fd853f1ed51db1ddba7bd7cb44b780b056591a
Pull Request resolved: #443
  • Loading branch information
tianyu-l committed Jul 10, 2024
1 parent 7afe902 commit 261c4be
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 5 deletions.
11 changes: 11 additions & 0 deletions test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,17 @@ def build_test_list():
"1D compile",
"1d_compile",
),
OverrideDefinitions(
[
[
"--training.compile",
"--activation_checkpoint.mode selective",
"--activation_checkpoint.selective_ac_option op",
],
],
"1D compile with selective op AC",
"1d_compile_sac_op",
),
OverrideDefinitions(
[
[
Expand Down
10 changes: 5 additions & 5 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,13 +415,13 @@ def apply_compile(model, job_config: JobConfig):
"fused_rmsnorm not yet compatible with torch.compile. Please use layernorm or rmsnorm."
)

# TODO(anijain): the following flag is on to accelarate compilation
# remove it after it's enabled in pytorch by default
torch._dynamo.config.inline_inbuilt_nn_modules = True

for layer_id, transformer_block in model.layers.named_children():
# turn on per-transformer block compile after AC wrapping and before FSDP
# TODO: dynamic shape have some issues so we turn it off for now.
# TODO: inline inbuilt nn modules does not work yet, enable it to accelarate
# compile time.
# torch._dynamo.config.inline_inbuilt_nn_modules = True
transformer_block = torch.compile(transformer_block, dynamic=False)
transformer_block = torch.compile(transformer_block, fullgraph=True)
model.layers.register_module(layer_id, transformer_block)

logger.info("Compiled each TransformerBlock with torch.compile")
Expand Down

0 comments on commit 261c4be

Please sign in to comment.