Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

405b more #552

Closed
wants to merge 250 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
250 commits
Select commit Hold shift + click to select a range
054f088
check in tokenizer.model for ease of dev setup (#59)
wanchaol Feb 13, 2024
bfe2b58
Add truncated llama style model init via reset parameters() (#54)
lessw2020 Feb 14, 2024
60f021a
add model num params display, gpu memory metrics (#56)
lessw2020 Feb 15, 2024
ad69e62
add TensorBoard logging with loss and wps
tianyu-l Feb 15, 2024
a4663b1
add memory metrics to TensorBoard
tianyu-l Feb 17, 2024
2daf53f
modify data split to use HF api
tianyu-l Feb 21, 2024
50d69f6
add multinode support via slurm trainer, large scale race condition f…
lessw2020 Feb 22, 2024
8ad4dcb
add configurable unique layer init, clean up lr and loss display (#64)
lessw2020 Feb 22, 2024
8097c26
add bunch of cleanups and design principle section (#71)
wanchaol Feb 23, 2024
28f431f
delete the linter to see if re-adding it helps (#80)
wconstab Feb 23, 2024
ebbb1cb
Unified config manager for toml and command line (#76)
gnadathur Feb 24, 2024
bccad90
Whc/add linter (#81)
wconstab Feb 24, 2024
ab75dbd
Add 4GPU unit test (#82)
wconstab Feb 24, 2024
468ce8f
update readme (#74)
wanchaol Feb 24, 2024
3fce6bb
move config folder to root and adjust options (#83)
wanchaol Feb 24, 2024
3b48039
add iter time tracking via cuda events, add data loading times, add c…
lessw2020 Feb 26, 2024
df77f4e
Fill missing options in toml file wih argparse defaults (#91)
gnadathur Feb 26, 2024
325951f
support infinite loop over alpaca dataset
tianyu-l Feb 26, 2024
b12b6dd
Add color to console output if local logging, auto avoid color loggin…
lessw2020 Feb 27, 2024
254279f
update GPU metrics logging to GiB (gibibytes) (#95)
lessw2020 Feb 27, 2024
4c03475
improve TensorBoard instructions in README
tianyu-l Feb 27, 2024
7ea0679
Enable libUV for torchtrain (#98)
gnadathur Feb 28, 2024
e60c573
use warmup steps for lr scheduler, ban steps == -1 (#99)
wanchaol Feb 29, 2024
900b215
Add llama 7B config (#100)
wanchaol Feb 29, 2024
6e87471
add selective activation checkpointing
tianyu-l Feb 29, 2024
1b343f2
Add job description field in toml (#101)
gnadathur Mar 1, 2024
42f8907
fix 2D parallel crash caused by all-reduce on 2D world_mesh
tianyu-l Mar 2, 2024
4042b05
Load missing keys default from argparse (#111)
gnadathur Mar 5, 2024
6529af1
Add meta_init, enable it as default init process (#84)
lessw2020 Mar 5, 2024
5f0eaea
Fix feedback from PR 111 (#113)
gnadathur Mar 5, 2024
1ce8188
fix SP minor issues
tianyu-l Mar 5, 2024
bb5c4c6
enable loss parallel in SP
tianyu-l Mar 6, 2024
f31adb0
Float8_experimental option for training (#102)
drisspg Mar 6, 2024
6927e45
add miniPile dataset for pretraining, 1M entries (solves the 'out of …
lessw2020 Mar 7, 2024
d902a47
add data loading option to load from local file system
tianyu-l Mar 7, 2024
422910b
add llama 13B configs
wanchaol Mar 9, 2024
af221ce
add llama 70B toml
wanchaol Mar 9, 2024
5e36c74
set betas and weight decay for optimizers
wanchaol Mar 9, 2024
08b332c
Add c4 dataset (177M, streaming), update multi-node support for lates…
lessw2020 Mar 9, 2024
1d11cf5
Add openwebtext dataset for larger scale training without shuffling (…
lessw2020 Mar 12, 2024
2722865
[TorchTrain][Checkpoint] Fix TrainState state_dict to unblock loading…
wz337 Mar 12, 2024
2369861
improve logging
tianyu-l Mar 13, 2024
3262a8b
use SequenceParallel style in tp/sp (#133)
wanchaol Mar 13, 2024
d9253ee
support TP-only parallelism
tianyu-l Mar 13, 2024
b42ce91
disable verbose print from profiling
tianyu-l Mar 13, 2024
3ac610b
add Selective layer activation checkpointing, single control for tur…
lessw2020 Mar 14, 2024
af56ae0
remove per iter syncronize
tianyu-l Mar 14, 2024
073909b
Shorten nccl comm timeout and enable flight recorder dumping (#103)
wconstab Mar 15, 2024
e3204c6
fix up gpu memory monitoring and logging
tianyu-l Mar 15, 2024
a257bc3
Separate timeout during init and training (#149)
wconstab Mar 15, 2024
1d6100c
Update activation check with updates to config manager (#152)
drisspg Mar 20, 2024
ae9a966
Refactor to clean up parallelisms/__init__.py
wconstab Mar 20, 2024
47bb509
enable gc control scheduling to help avoid stragglers (#148)
lessw2020 Mar 20, 2024
fcca670
Add float8 specific parallel strategies (#153)
drisspg Mar 20, 2024
5d28009
add MFU to metrics
tianyu-l Mar 20, 2024
35d881e
disable buffer reuse for compile for now (#156)
wanchaol Mar 21, 2024
f080027
refactor config manager and support cmd overrides (#157)
wanchaol Mar 22, 2024
34732f5
Add support for generating debug traces on failure
chauhang Mar 24, 2024
e008027
rename sequence_parallel to tensor_parallel (#162)
wanchaol Mar 25, 2024
44808f9
add basic AC configs for 13B and 70B (#169)
wanchaol Mar 27, 2024
bb61af0
[TorchTrain][Checkpoint] Update train state to include global_avg_los…
wz337 Mar 27, 2024
6500bc6
Basic integration test infra (#170)
gnadathur Mar 27, 2024
479694f
Add 2D integration test (FSDP + TP) (#171)
gnadathur Mar 27, 2024
02923f0
Used per-parameter FSDP (#165)
awgu Mar 28, 2024
615f9c1
plot losses in loaded TrainState to TensorBoard
tianyu-l Mar 28, 2024
b1349da
Removed setting global flag for `swap_tensors` since not needed anymore
awgu Mar 29, 2024
65f0297
Add integration test with compile enabled (#183)
gnadathur Apr 2, 2024
e1e17c9
remove folding and unfolding of sequence dim in model.py
tianyu-l Apr 3, 2024
b9a4548
bump comm.train_timeout_seconds (#189)
wanchaol Apr 4, 2024
3686897
fix checkpoint parser
wz337 Apr 5, 2024
7872248
support sequence of tests and add checkpoint test
wz337 Apr 5, 2024
5ac3aa6
Make freqs_cis a persistent buffer for pp init
wconstab Apr 5, 2024
5379282
Delete grad scaler, which is unsupported/unused
wconstab Apr 5, 2024
d8e64cc
Factor out loss_fn to share code with pipeline par
wconstab Apr 5, 2024
0397fef
[TorchTrain] Minor fix for #197 (#204)
wz337 Apr 5, 2024
cd1e5e8
Add FusedRMSNorm (Triton kernel, +15% eager), Add NPLayerNorm, Enable…
lessw2020 Apr 5, 2024
f795361
remove .item() per iter
tianyu-l Apr 5, 2024
946780a
Removed cache_k and cache_v comments
awgu Apr 10, 2024
18adb2f
Some more cleanups
awgu Apr 10, 2024
ef4c5d2
avoid record streams and make color printing a config
tianyu-l Apr 10, 2024
6629659
fix SAC to use the correct reduce_scatter op (#215)
wanchaol Apr 10, 2024
ddf916e
Test runner raises exception on failures (#216)
gnadathur Apr 10, 2024
ecdbacc
Revert "Separate TransformerEmbedding layer (#33)"
wconstab Apr 10, 2024
656be68
Fix 2DParallel test (#219)
gnadathur Apr 10, 2024
97fe9a4
Added initial FSDP readme
awgu Apr 10, 2024
ce05f65
[TorchTrain][Checkpoint] Add model_weights_only option to train_confi…
wz337 Apr 11, 2024
00293cb
Rename to torchtitan (#221)
wanchaol Apr 11, 2024
c08f617
[TorchTitan] Add destory process group at the end of training (#223)
wz337 Apr 12, 2024
7712f72
Add 1 sec delay to rank 0 cleanup (#224)
gnadathur Apr 12, 2024
71621a2
[Torchtrain][Checkpoint] Add support to allow dtype conversion (#222)
wz337 Apr 12, 2024
5aa0aec
[TorchTitan] Remove checkpoint folder at the end in test_runner.py (#…
wz337 Apr 12, 2024
cb24eb5
codebase cleanup
tianyu-l Apr 15, 2024
3cfdbf2
Update README to reflect positioning (#229)
wanchaol Apr 16, 2024
db04c7e
First release readme (#227)
lessw2020 Apr 16, 2024
f504816
Update licenses and headers (#231)
wanchaol Apr 16, 2024
41fb267
use permalink for logo image (#232)
lessw2020 Apr 16, 2024
82c2518
[TorchTitan][Checkpoint] Move checkpoint folder under dump_folder and…
wz337 Apr 16, 2024
d42a7d1
use combo of html and local file src for logo (#234)
lessw2020 Apr 16, 2024
80103a9
add performance -- infra metrics and loss curves (#237) (#238)
lessw2020 Apr 16, 2024
09e7bec
add license section in readme (#239)
wanchaol Apr 16, 2024
22aa488
[TorchTitan][Checkpoint] Add a step-by-step instruction for checkpoin…
wz337 Apr 16, 2024
81138d6
more license headers (#240)
wanchaol Apr 16, 2024
04f5b82
Update README (#242)
wanchaol Apr 16, 2024
4f6ed9a
Add torchtune checkpoint link, modify product position statement loca…
lessw2020 Apr 16, 2024
cd55a38
Add pyproject and upgrade version (#236)
wanchaol Apr 16, 2024
78b843b
minor doc updates - remove asynch checkpt ref, grammar on prod positi…
lessw2020 Apr 16, 2024
ce0fff0
Fix multi-line string usage (#244)
gnadathur Apr 16, 2024
7b353c8
polish toml files
tianyu-l Apr 16, 2024
bc7fec5
[torchtitan][checkpoint][doc] Minor fix checkpoint doc (#246)
wz337 Apr 16, 2024
a682505
fix default max_seq_len for freq_cis init (#248)
wanchaol Apr 17, 2024
1ea4dee
set max_seq_len before training to make it align with input data (#249)
wanchaol Apr 17, 2024
55c8e48
fix pypi docs
tianyu-l Apr 17, 2024
fd9b498
update dataset to use c4
tianyu-l Apr 18, 2024
978c5c6
Add c4_mini, a local 45K dataset (subset of c4) (#253)
lessw2020 Apr 18, 2024
4020e92
remove logo, update pre-release date to 4/18 (#254)
lessw2020 Apr 18, 2024
51a6f6f
add intro video (#233)
lessw2020 Apr 18, 2024
6aafe3c
add performance file to show convergence with 64 a100s (#255)
lessw2020 Apr 18, 2024
35470ca
Support Llama3 8b/70b (#256)
wanchaol Apr 20, 2024
960e70f
polish llama 3 setup
tianyu-l Apr 22, 2024
e1c116a
reenable integration tests with a test tokenizer (#259)
wanchaol Apr 23, 2024
be432e1
warn supported dataset checks instead of throw (#260)
wanchaol Apr 24, 2024
192ed48
De-dup repeated `freqs_cis` computation code
awgu Apr 24, 2024
f38766e
update readme.md and performance.md
tianyu-l Apr 24, 2024
0eacbae
followup changes to allow unsupported datasets
tianyu-l Apr 24, 2024
217cc94
fix ac 'checkpointing' spelling, minor spacing tweaks (#265)
lessw2020 Apr 24, 2024
e3b47ea
Update legal terms (#269)
lessw2020 Apr 25, 2024
eed7495
apply less heavy profiling
tianyu-l Apr 25, 2024
3393c2a
Showcase where the product positioning lies more clearly (#272)
soumith Apr 25, 2024
568dad6
Doc Fixes (#273)
msaroufim Apr 25, 2024
3e13e24
fix lr scheduling by checkpointing scheduler
tianyu-l Apr 26, 2024
f03c128
insert barrier to profiler to resolve collectives timeout
tianyu-l Apr 25, 2024
42549a9
some misc changes (#278)
wanchaol Apr 26, 2024
0d09a32
inherit stateful protocol where appropriate
tianyu-l Apr 26, 2024
06da6c2
Fixed docs on HSDP sharding/replication dims
awgu Apr 29, 2024
a843abf
Add more Float8 description (#284)
drisspg Apr 29, 2024
d442743
Remove unneeded torchvision/audio deps
wconstab Apr 29, 2024
e7f2d28
fix 3d mesh order (#288)
wanchaol Apr 30, 2024
4e5ffaf
unify data loading from HF and from disk
tianyu-l Apr 30, 2024
58b1169
Add periodic integration test with signal (#289)
gnadathur May 1, 2024
4d8c245
exclude embedding in MFU computation
tianyu-l Apr 26, 2024
17cda29
Add support for seed checkpoint creation for meta-init flow
wconstab May 2, 2024
1a6caf2
remove unnecessary install of torchtitan
tianyu-l May 2, 2024
787a571
Remove unnecessary .to() inside model forward
wconstab May 2, 2024
695bd01
Fix the incorrect step log for profiler after resuming from a checkpo…
fegin May 3, 2024
143b586
turn off dynamic shape for torch.compile (#297)
wanchaol May 3, 2024
f72a2a0
Renamed `bsz` to `bs` for consistency; removed dead code
awgu May 3, 2024
3295448
Implement async_checkpoint
fegin May 7, 2024
f5a3ad7
simplify embedding + first transformer block TP (#314)
wanchaol May 8, 2024
e64c6ca
Only include checkpoints that have .metadata written (#315)
liangluofb May 10, 2024
3444c4c
Refactor freqs_cis slice to be safer for PP
wconstab May 11, 2024
7f92f45
Make Transformer tolerate missing layers for PP
wconstab May 11, 2024
0973bab
Use torch generic workflow for CI
wconstab May 15, 2024
5ba0a4b
[checkpointing] import async checkpoint with pinned memory only when …
tianyu-l May 15, 2024
847189d
Add a workflow to build torchtitan-ubuntu-20.04-clang12 Docker image …
huydhn May 16, 2024
a2ace60
Make pip install torch quiet
wconstab May 17, 2024
f2c3a11
Make test_runner.py warn on non-empty output dir
wconstab May 17, 2024
3bd14ec
Expose mixed_precision dtype arguments
wconstab May 21, 2024
99a73dd
Use stateful dataloader to checkpoint data iteration order and token …
gokulavasan May 21, 2024
e7c31be
Add Pipeline Parallel (and 2D PP+FSDP) support
wconstab May 21, 2024
c5a9718
fix i periodic integration test and add helper message on torchdata i…
tianyu-l May 22, 2024
60810a9
torch.compile each TransformerBlock instead of the whole model (#268)
wanchaol May 22, 2024
910662c
Make test_runner use separate logger with default INFO
wconstab May 22, 2024
6807909
Fix llama_13b.toml -> llama2_13b.toml in multinode_trainer.slurm (#350)
pbelevich May 22, 2024
638ec48
Fix bug in PP output layer shape
wconstab May 22, 2024
c87e8bc
Update pipelining import after change on pytorch
wconstab May 23, 2024
02ae169
update .gitignore to screen out slew of new temp files (#359)
lessw2020 May 24, 2024
1ceaa4e
Add test for PP tracer frontend
wconstab May 24, 2024
6a8455e
only produce tensorboard logs on rank 0 by default
tianyu-l May 29, 2024
5831e81
replace old torch dependency in requirements.txt
tianyu-l May 29, 2024
3343d1d
Add --test option to specify test to run (#368)
kwen2501 May 30, 2024
07fa503
use integration test as the badge shown on the homepage
tianyu-l May 29, 2024
54b5fa2
keep only latest k checkpoints (#366)
liangluofb May 31, 2024
a0f82d5
Make seed checkpoint creation work on CPU
wconstab Jun 3, 2024
8badf7e
Fix start/stop layer parsing
wconstab Jun 3, 2024
3050098
Use general way to access and update submodules
kwen2501 Jun 3, 2024
0594c04
Make metrics logging work for pipeline parallelism
wconstab Jun 4, 2024
c89aa40
[RFC] Allow ModelWrapper and OptimizerWrapper to accept multiple models
fegin Jun 5, 2024
3bc7678
Add 3D support
wconstab Jun 4, 2024
7cf41bb
[torchtitan][optim] Add fused as an option in train config (#355)
wz337 Jun 6, 2024
baa678c
[torchtitan] Fix test runner fused optim tests (#384)
wz337 Jun 6, 2024
104bd6c
Abstract out out optimizer params and update foreach calling conventi…
drisspg Jun 7, 2024
ccf1ed8
DeviceMesh BC fix (#387)
wanchaol Jun 9, 2024
6de5d31
BC fix for ManualPipelineStage import (#388)
wanchaol Jun 9, 2024
9700e0f
fix missing tb logs
tianyu-l Jun 10, 2024
3bb7bf3
add the 8-gpu test badge and use correct links for the integration te…
tianyu-l Jun 10, 2024
e858ab4
Fix 1D PP tracer test
kwen2501 Jun 10, 2024
c5d5c1f
del logits=(bs, seq_len, vocab_size) to save 3.9G memory (#391)
weifengpy Jun 12, 2024
e991ae4
Update contributing.md (#385)
H-Huang Jun 12, 2024
763b810
update all toml files to use experimental section (#392)
wanchaol Jun 12, 2024
e17e3b8
enable TP fp8 allgather with PrepareFloat8ModuleInput (#393)
wanchaol Jun 13, 2024
a4cd9ba
Update unit_test_cpu.yaml with cpu nightly (#396)
wanchaol Jun 13, 2024
33f301d
Fix SAC BC breaking and renaming to ac_freq (#397)
wanchaol Jun 13, 2024
093ba15
SAC API follow ups to restore old behavior (#401)
wanchaol Jun 13, 2024
d761994
enable TritonFusedRMSNorm with local_map annotation (#404)
XilunWu Jun 14, 2024
3e946a1
Cosmetic changes to train.py
kwen2501 Jun 14, 2024
d5b7525
Break down parallelize_llama for inference cases
kwen2501 Jun 14, 2024
4a1095a
Change debugmodel to have 8 layers
wconstab Jun 17, 2024
443ead2
Prepare train.py for model chunks for pipelining
wconstab Jun 17, 2024
8adbfa3
dump memory snapshot to analyze OOMs (#395)
weifengpy Jun 19, 2024
c88d0bc
whole_model for fp8 (#414)
weifengpy Jun 20, 2024
ac83f9c
Add train loop support for looped PP schedules
wconstab Jun 21, 2024
abb9e15
Set `record_shapes=True` for profiler
awgu Jun 24, 2024
4ba94bd
Improved `repeat_kv` eager perf
awgu Jun 24, 2024
236d2ff
Adding FSDP Memory Tracking and Estimation
sanketpurandare Jun 25, 2024
cb73810
Adding integration test for FSDP Memory Tracking and Estimation
sanketpurandare Jun 25, 2024
f3fecb2
by default disable heavy memory profiling
tianyu-l Jun 26, 2024
ea8c5c8
Add the option to turn on async-TP
yifuwang Jun 26, 2024
b0ed7f0
Modifying memory estimation options and minor changes
sanketpurandare Jul 1, 2024
21dd980
add comment pointing to Sequence Parallel optimization example
tianyu-l Jul 4, 2024
3fca883
switch float8 logic from Float8DynamicLinear to Float8Linear (#436)
vkuzo Jul 8, 2024
2f5285b
Removed `_experimental_support_context_fn_in_torch_utils_checkpoint`
awgu Jul 10, 2024
f0ca3e8
Reordered TP parallel plan to follow execution order
awgu Jul 10, 2024
c7a6a3e
Made some stylistic changes to `apply_dp`
awgu Jul 10, 2024
bc3ec02
Refactored activation checkpointing
awgu Jul 10, 2024
7afe902
compiled RMSNorm
tianyu-l Jul 10, 2024
0929280
Renamed parallel styles for transformer block weights
awgu Jul 10, 2024
040ea1d
Added type annotations and more stylistic changes
awgu Jul 10, 2024
e591f62
[Cleanup] Remove libuv from run_llama_train.sh
wconstab Jul 15, 2024
6f4d1d1
[Cleanup] Organize run_llama_train.sh options
wconstab Jul 15, 2024
db609d5
[Cleanup] Split run_llama_train.sh and run_memory_estimation.sh
wconstab Jul 15, 2024
07ab2f9
[Cleanup] Remove unused TRAINER_DIR
wconstab Jul 15, 2024
0bb6980
Add educational code pointers to top level README
wconstab Jul 15, 2024
f025335
enable FSDP2 + fp8 all-gather and fix TP fp8 all-gather (#413)
weifengpy Jul 16, 2024
298494b
import float8_experimental only when fp8 is enabled and install it in…
weifengpy Jul 17, 2024
183390e
skip fp8 CI on non-H100 GPUs (#465)
weifengpy Jul 17, 2024
4a2de42
clean up float8 configs in torchtitan (#466)
vkuzo Jul 17, 2024
2937167
Add support of DDP and experimental CompiledAutograd
fegin Jul 18, 2024
9b37408
add torch.compile + FSDP2 float8 all-gather in CI (#468)
weifengpy Jul 19, 2024
b502cdc
[float8] keep model.output as `nn.Linear` (high precision, not fp8) (…
weifengpy Jul 19, 2024
d76b77f
remove CI for FSDP2 + fp8 all-gather (#470)
weifengpy Jul 20, 2024
0f70507
dynamically update torch.compile cache config to ensure async tp supp…
lessw2020 Jul 21, 2024
00a3c21
Fix 8gpu PP failure due to 2D DCP disablement
wconstab Jul 15, 2024
5124c14
update float8 integration after UX changes (#484)
vkuzo Jul 26, 2024
43584e0
Re-enable FSDP2 Mem Tracker integration tests
Jul 26, 2024
668f6cd
Used `partial` instead of global vars for LR scheduling
awgu Jul 29, 2024
f13fe3f
[EZ] Add logs for some basic training params so that we can verify in…
fduwjj Jul 30, 2024
b012237
make float8 scaling type configurable (#489)
vkuzo Jul 30, 2024
9cf4b2f
[PP] add flexible interleaved 1f1b schedule #490 (#493)
H-Huang Jul 30, 2024
b069f70
move float8 callsites to torchao.float8 (#492)
vkuzo Jul 30, 2024
3ddce59
[BE][1/n] simplify train.py
tianyu-l Jul 31, 2024
7119d0c
[BE][2/n] use proper method signatures in parallelize_llama
tianyu-l Jul 31, 2024
04d219a
[BE][3/n] wrap fp8 logic using Float8Handler
tianyu-l Jul 31, 2024
e457deb
Bring LLaMa 3.1 405B to TorchTitan family (#481)
fduwjj Aug 1, 2024
72a1614
[TP] Infer local n_heads instead of ad-hoc model changes
kwen2501 Aug 2, 2024
de9fd2b
some compile-related updates
tianyu-l Aug 2, 2024
b3f2f58
[EZ][405B] Use scientific notation for 405B model lr
fduwjj Aug 5, 2024
904913f
[405B] Add performance data for 405B model
fduwjj Aug 21, 2024
7928650
Merge branch 'main' into 405b_more
fduwjj Aug 21, 2024
57d4725
Handle rebase
fduwjj Aug 21, 2024
cbcd3f2
Further handle merge
fduwjj Aug 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ Our guiding principles when building `torchtitan`:

You may want to see how the model is defined or how parallelism techniques are applied. For a guided tour, see these files first:
* [train.py](https://github.com/pytorch/torchtitan/blob/main/train.py) - the main training loop and high-level setup code
* [torchtitan/parallelisms/parallelize_llama.py](https://github.com/pytorch/torchtitan/blob/main/torchtitan/parallelisms/parallelize_llama.py) - helpers for applying Data Parallel, Tensor Parallel, activation checkpointing, and `torch.compile` to the model
* [torchtitan/parallelisms/parallelize_llama.py](https://github.com/pytorch/torchtitan/blob/main/torchtitan/parallelisms/parallelize_llama.py) - helpers for applying Data Parallel, Tensor Parallel, activation checkpointing, and `torch.compile` to the model * [torchtitan/parallelisms/parallelize_llama.py](https://github.com/pytorch/torchtitan/blob/main/torchtitan/parallelisms/parallelize_llama.py) - helpers for applying Data / Tensor / Pipeline Parallelisms to the model
* [torchtitan/parallelisms/pipeline_llama.py](https://github.com/pytorch/torchtitan/blob/main/torchtitan/parallelisms/pipeline_llama.py) - helpers for applying Pipeline Parallel to the model
* [torchtitan/checkpoint.py](https://github.com/pytorch/torchtitan/blob/main/torchtitan/checkpoint.py) - utils for saving/loading distributed checkpoints
* [torchtitan/float8.py](https://github.com/pytorch/torchtitan/blob/main/torchtitan/float8.py) - utils for applying Float8 techniques
Expand Down Expand Up @@ -64,6 +64,7 @@ git clone https://github.com/pytorch/torchtitan
cd torchtitan
pip install -r requirements.txt
pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 # or cu118
pip3 install --pre torchdata --index-url https://download.pytorch.org/whl/nightly
```

### Downloading a tokenizer
Expand Down
Binary file added assets/images/llama3_1_405B_loss_curves.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
15 changes: 15 additions & 0 deletions docs/performance.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,21 @@ Next we show the loss curves for Llama 3 8B and Llama 3 70B training with both 1
![image](../assets/images/llama3_loss_curves.png)


## Llama 3.1 performance numbers

We did a performance measurement on the 405B model released in [LLaMa 3.1](https://llama.meta.com/docs/model-cards-and-prompt-formats/llama3_1). Because the model now is larger, we run on 128 H100 GPUs to test both performance and loss curves. Below is the performance result of 405B model with optimizations we have developed. We do see OOM for 1D even with batch size = 1, so we only tested the 2D case.


| Model size | Batch size | Activation checkpointing | WPS | MFU | optimizations |
| ----- | ----- | ----- | ----- | ----- | ----- |
| 405B | 2 | full | 118 | 37.1% | None
| 405B | 2 | full | 177 | 27.77% | FP8
| 405B | 2 | full | 185 | 29.03% | FP8 + async TP

And the loss curves are shown below:

![image](../assets/images/llama3_1_405B_loss_curves.png)

## Llama 2 performance numbers

Below are the WPS and MFU results which torchtitan achieves on Llama 2 models with FSDP2 on 64 A100 (80GB) GPUs.
Expand Down
41 changes: 25 additions & 16 deletions estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from torchtitan.config_manager import JobConfig
from torchtitan.datasets import build_tokenizer
from torchtitan.float8 import Float8Handler
from torchtitan.float8_linear import Float8Handler
from torchtitan.logging import init_logger, logger
from torchtitan.models import model_name_to_cls, model_name_to_tokenizer, models_config
from torchtitan.optimizer import build_lr_schedulers, build_optimizers
Expand Down Expand Up @@ -122,25 +122,33 @@ def loss_fn(pred, labels):
f"Building {model_name} {job_config.model.flavor} with {model_config}"
)
with torch.device("meta"):
model = model_cls.from_model_args(model_config)
whole_model = model_cls.from_model_args(model_config)

# a no-op hander if float8 is not enabled
# a no-op hander if fp8 is not enabled
float8_handler = Float8Handler(job_config, parallel_dims)
# swap to Float8Linear based on float8 configs
float8_handler.convert_to_float8_training(model)
# swap to Float8Linear base on fp8 config
float8_handler.convert_to_float8_training(whole_model)

# apply PT-D DP/TP parallelisms and activation checkpointing
models_parallelize_fns[model_name](model, world_mesh, parallel_dims, job_config)
model_parts = [whole_model]
model_parts = [
models_parallelize_fns[model_name](m, world_mesh, parallel_dims, job_config)
for m in model_parts
]

init_device = "cuda"
for model in model_parts:
model.to_empty(device=init_device)

model.to_empty(device="cuda")
if not active_fake_mode():
model.init_weights()
model.train()
whole_model.init_weights()

# build optimizer after applying parallelisms to the model
optimizers = build_optimizers([model], job_config)
optimizers = build_optimizers(model_parts, job_config)
lr_schedulers = build_lr_schedulers(optimizers.optimizers, job_config)

for model in model_parts:
model.train()
logger.info(f"Vocab size: {model_config.vocab_size}")
# Create a dummy batch instead of loading from a dataset
batch = (
Expand All @@ -157,31 +165,32 @@ def loss_fn(pred, labels):
device="cuda",
),
)
fsdp_memtracker = FSDPMemTracker(mod=model, optm=optimizers.optimizers[0])
fsdp_memtracker = FSDPMemTracker(mod=whole_model, optm=optimizers.optimizers[0])
fsdp_memtracker.track_inputs(batch)

with fsdp_memtracker:
for iter_idx in range(2):
input_ids, labels = batch
# train step
with train_context():
pred = model(input_ids)
pred = whole_model(input_ids)
loss = loss_fn(pred, labels)
del pred
loss.backward()

# clip gradients
torch.nn.utils.clip_grad_norm_(
model.parameters(), job_config.training.max_norm, foreach=True
)
for model in model_parts:
torch.nn.utils.clip_grad_norm_(
model.parameters(), job_config.training.max_norm, foreach=True
)
# sync float8 amaxes and scales
float8_handler.sync_float8_amax_and_scale_history(model)
# optimizer step
optimizers.step()
lr_schedulers.step()
# calculate float8 dynamic amax/scale for all-parameter for FSDP2
# it issues a single all-reduce for all parameters at once for better performance
float8_handler.precompute_float8_dynamic_scale_for_fsdp(model)
float8_handler.precompute_fp8_dynamic_scale_for_fsdp(model)
optimizers.zero_grad()
print(f"Peak Memory at iter: {iter_idx}")
fsdp_memtracker.display_snapshot("peak", units="MiB", tabulate=True)
Expand Down
4 changes: 2 additions & 2 deletions test/datasets/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@

class TestCheckpoint:
def test_c4_resumption(self):
dataset_name = "c4_test"
dataset_path = "./test/assets/c4_test"
dataset_name = "c4_mini"
dataset_path = "./torchtitan/datasets/c4_mini"
batch_size = 1
seq_len = 1024
world_size = 4
Expand Down
Loading
Loading