Skip to content

Commit

Permalink
BC fix for ManualPipelineStage import (pytorch#388)
Browse files Browse the repository at this point in the history
  • Loading branch information
wanchaol authored Jun 9, 2024
1 parent 894dd42 commit 1ff2a8a
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 10 deletions.
1 change: 0 additions & 1 deletion test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,6 @@ def run_test(test_flavor: OverrideDefinitions, full_path: str, output_dir: str):
logger.info(result.stdout)

for override_arg in test_flavor.override_args:

cmd = f"CONFIG_FILE={full_path} NGPU={test_flavor.ngpu} LOG_RANK={all_ranks} ./run_llama_train.sh"
cmd += " " + dump_folder_arg
if override_arg:
Expand Down
10 changes: 2 additions & 8 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,7 @@
checkpoint_wrapper as ptd_checkpoint_wrapper,
CheckpointImpl,
)
from torch.distributed.pipelining import (
ManualPipelineStage,
pipeline,
PipelineStage,
SplitPoint,
)
from torch.distributed.pipelining import pipeline, PipelineStage, SplitPoint
from torch.distributed.tensor.parallel import (
ColwiseParallel,
parallelize_module,
Expand Down Expand Up @@ -236,12 +231,11 @@ def pipeline_llama_manual(
output = torch.rand(layers_io_shape, dtype=mp_dtype, device=device)

model.to_empty(device=device)
stage = ManualPipelineStage(
stage = PipelineStage(
model,
pp_rank,
pp_size,
device,
microbatches,
input_args=input.chunk(microbatches)[0],
output_args=output.chunk(microbatches)[0],
group=pp_mesh.get_group("pp"),
Expand Down
6 changes: 5 additions & 1 deletion torchtitan/parallelisms/pipelining_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,12 @@ def build_pipeline_schedule(job_config, parallel_dims, stage, loss_fn):
logger.info(
f"Using pipeline schedule {job_config.experimental.pipeline_parallel_schedule}"
)
n_microbatches = job_config.experimental.pipeline_parallel_microbatches
if n_microbatches is None:
n_microbatches = job_config.experimental.pipeline_parallel_degree

return schedule_class(
stage,
n_microbatches=stage.chunks,
n_microbatches=n_microbatches,
loss_fn=loss_fn,
)

0 comments on commit 1ff2a8a

Please sign in to comment.