Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat/refactor partition strategy #13

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
191 commits
Select commit Hold shift + click to select a range
10aa63f
support optimized sp
yingtongxiong Oct 7, 2023
e5a2909
Merge remote-tracking branch 'upstream/develop' into feat/deepspeed_sp
yingtongxiong Oct 7, 2023
bf475b6
debug
yingtongxiong Oct 8, 2023
bd4af3a
modify the all2all
yingtongxiong Oct 8, 2023
189a313
support fstp and refactor code
yingtongxiong Oct 9, 2023
21c1a7f
support evaluation with fstp
yingtongxiong Oct 9, 2023
949431f
modify the config
yingtongxiong Oct 9, 2023
0fa1083
Merge remote-tracking branch 'upstream/develop' into feat/fstp
yingtongxiong Oct 9, 2023
54e5616
remove useless code for no-pp
yingtongxiong Oct 9, 2023
144731c
fix evaluation bug in pp
yingtongxiong Oct 9, 2023
ef9e7cc
modify the config
yingtongxiong Oct 9, 2023
5d39c33
restore train.py
yingtongxiong Oct 9, 2023
29df765
refactor code
yingtongxiong Oct 9, 2023
f191853
fix lint
yingtongxiong Oct 9, 2023
007e58a
merge upstream develop
yingtongxiong Oct 9, 2023
a8dea63
fix the ci incompatible in config
yingtongxiong Oct 9, 2023
1b7935d
merge upstream develop
yingtongxiong Oct 9, 2023
dd67ab9
merge develop
yingtongxiong Oct 9, 2023
db63754
fix lint
yingtongxiong Oct 9, 2023
5fb6d99
feat(configs/7B_sft.py): update parallel config comment
huangting4201 Oct 10, 2023
0fac845
overlap grad_input computation and grad_weight reduce_scatter
yingtongxiong Oct 10, 2023
c94be64
merge origin
yingtongxiong Oct 10, 2023
792b066
communication overlap
yingtongxiong Oct 11, 2023
5fd5a8a
support fine-grained overlap
yingtongxiong Oct 11, 2023
d0b1346
feat(model/linear.py): support block allgather overlap
huangting4201 Oct 12, 2023
d0f0c22
feat(model/linear.py): change pre backward from wqkv to block
huangting4201 Oct 13, 2023
82204ee
support hybrid overlap
yingtongxiong Oct 16, 2023
0d1fa03
feat(model/linear.py): set block 0 full weight
huangting4201 Oct 16, 2023
d1af0d6
feat(model/linear.py): block-grained backward
huangting4201 Oct 17, 2023
229cc5c
impl reduce scatter async
Oct 17, 2023
4e99a7f
feat(train/training_internlm.py): remove abnormal tgs when calculatin…
huangting4201 Oct 17, 2023
6682f5d
fix reduce scatter async bug
Oct 17, 2023
b51cf4e
Merge branch 'feat/fstp' of github.com:yingtongxiong/InternLM into fe…
Oct 17, 2023
6408b94
support fine grained
yingtongxiong Oct 17, 2023
a5c6e45
Merge branch 'feat/fstp' of https://github.com/yingtongxiong/InternLM…
yingtongxiong Oct 17, 2023
5c38cb6
add head overlap
yingtongxiong Oct 17, 2023
5abe519
remove full weight for block 0
yingtongxiong Oct 17, 2023
16ef7b7
add test
yingtongxiong Oct 17, 2023
a5aeab2
memory profiling test
yingtongxiong Oct 17, 2023
4742271
add memory pool
yingtongxiong Oct 19, 2023
ed72327
support reduce scatter memory pool
yingtongxiong Oct 20, 2023
815a584
feat(model/linear.py): remove useless code
huangting4201 Oct 20, 2023
95488d8
update optimizer accumulate grad impl when fstp
Oct 20, 2023
d91a5d9
feat(initialize/launch.py): refactor config for fstp
huangting4201 Oct 20, 2023
3c69254
feat(optimizer/hybrid_zero_optim.py): resolve conflicts
huangting4201 Oct 20, 2023
eac382a
feat(optimizer/hybrid_zero_optim.py): fix lint error
huangting4201 Oct 20, 2023
2acf9b8
feat(utils/gputest.py): fix lint error
huangting4201 Oct 20, 2023
f22e5b3
Merge pull request #4 from yingtongxiong/fstp/refactor-config
yingtongxiong Oct 20, 2023
dcd89ed
refactor linear
yingtongxiong Oct 20, 2023
1804d01
merge reduce-scatter
yingtongxiong Oct 20, 2023
85ad917
feat(model/overlap_handler.py): refactor overlap hook handle
huangting4201 Oct 20, 2023
b20f47a
feat(model/overlap_handler.py): move handler to gpc
huangting4201 Oct 23, 2023
e7f9f1d
feat(model/overlap_handler.py): optimize reduce scatter mem pool
huangting4201 Oct 23, 2023
f6a5086
support bias
yingtongxiong Oct 23, 2023
0d693cf
feat(model/overlap_handler.py): fix lint error
huangting4201 Oct 23, 2023
03cc7f9
feat(model/overlap_handler.py): fix lint error
huangting4201 Oct 23, 2023
9cf1ff0
feat(solver/optimizer/hybrid_zero_optim.py): minor update
huangting4201 Oct 23, 2023
b2c1a70
feat(train/training_internlm.py): fix lint error
huangting4201 Oct 23, 2023
b48687a
Merge pull request #5 from yingtongxiong/fstp/refactor-hook-handle
huangting4201 Oct 23, 2023
0996c47
fix accumulate grads bug
Oct 23, 2023
97dcefc
support model activation checkpoint
yingtongxiong Oct 24, 2023
5d83136
feat(model/overlap_handler.py): fix head post backward hook when acti…
huangting4201 Oct 24, 2023
262de4b
support tflops computation and generate test py files
yingtongxiong Oct 24, 2023
0d3592a
Merge branch 'feat/fstp_refactor' of https://github.com/yingtongxiong…
yingtongxiong Oct 24, 2023
41cfa1a
feat(model/overlap_handler.py): fix overlap handler None bug
huangting4201 Oct 24, 2023
0bac166
add test
yingtongxiong Oct 25, 2023
918dff7
reset moe
yingtongxiong Oct 25, 2023
363275b
add memory print
yingtongxiong Oct 25, 2023
985465c
merge upstream
yingtongxiong Oct 25, 2023
cc20fa2
reset print memory
yingtongxiong Oct 25, 2023
d831ddc
modify the config
yingtongxiong Oct 26, 2023
1aae39b
Merge remote-tracking branch 'upstream/develop' into feat/fstp_refactor
yingtongxiong Oct 26, 2023
cbd4f04
add synchronize
yingtongxiong Oct 26, 2023
3253cbf
add a new get_tflops_func
mwiacx Oct 26, 2023
4d83e10
Merge branch 'feat/fstp_refactor' of https://github.com/yingtongxiong…
yingtongxiong Oct 26, 2023
8aefb74
add flash tflops
yingtongxiong Oct 26, 2023
aa3840f
fix some bugs
yingtongxiong Oct 26, 2023
3778c66
feat(model/overlap_handler.py): fix overlap hander to support pp(non-…
huangting4201 Oct 27, 2023
bc5a85c
Merge pull request #6 from yingtongxiong/fstp/overlap-support-pp
yingtongxiong Oct 27, 2023
4c1cd5d
fix async reduce scatter
mwiacx Oct 31, 2023
6b84325
fix(optimizer/hybrid_zero_optim.py): remove redundant _accum_grad_buc…
huangting4201 Oct 31, 2023
b3def4c
fix(optimizer/hybrid_zero_optim.py): add reduce_scatter_overlap switch
huangting4201 Oct 31, 2023
10b5056
fix all-gather overlap the model_checkpoint is 0
yingtongxiong Nov 1, 2023
4851291
fix(optimizer/hybrid_zero_optim.py): fix bucket size full judge condi…
huangting4201 Nov 2, 2023
5a18b3b
fix(model/overlap_handler.py): fix last block hook when pp with activ…
huangting4201 Nov 2, 2023
9b1265c
modify the sp allreduce and support tf32 for fstp linear
yingtongxiong Nov 6, 2023
c517ec5
feat(model/overlap_handler.py): delete reduce_scatter_overlap switch
huangting4201 Nov 6, 2023
7c6d293
reset the sp allreduce in optimizer
yingtongxiong Nov 6, 2023
b80e6cd
merge origin
yingtongxiong Nov 6, 2023
b5e4d04
fix conflicts
yingtongxiong Nov 6, 2023
7475439
feat(model/overlap_handler.py): add memory_pool switch and refactor o…
huangting4201 Nov 13, 2023
3c07423
feat(model/overlap_handler.py): release weight
huangting4201 Nov 14, 2023
a1fd877
fix(train.py): clear memory pool before optim step
huangting4201 Nov 15, 2023
a80fcf8
feat(model): refactor weight and os and data patition strategy
huangting4201 Nov 28, 2023
cab9abd
fix(training_internlm.py): fix loss accuracy(optim init and seed set)
huangting4201 Nov 29, 2023
d3ee3ef
fix(model): reset embedding and head
huangting4201 Nov 30, 2023
6cd271c
fix(model): fix process group error
huangting4201 Dec 1, 2023
0817b8c
fix(model): fix FSTP linear Torch process group
huangting4201 Dec 1, 2023
1b7d2dc
fix(overlap_handler.py): release module post backward when model ckpt is
huangting4201 Dec 7, 2023
fd5a144
feat(model): embedding and head use sp group and refactor parameter g…
huangting4201 Dec 11, 2023
ac72710
feat(model): modify grad norm compute func
huangting4201 Dec 12, 2023
76be8c2
fix(model/utils.py): fix fstp linear reduce scatter sum->avg
huangting4201 Dec 14, 2023
d30aecd
feat(core/context): support pp for initializing isp/msp/fsp process g…
huangting4201 Dec 19, 2023
e9cd521
feat(model): refactor model and optimizer for msp/fsp/isp
huangting4201 Dec 20, 2023
e0cafb0
fix(overlap_handler.py): fix hook error and param group split
huangting4201 Dec 21, 2023
7974a32
fix(overlap_handler.py): fix clear weight error when activation ckpt …
huangting4201 Dec 22, 2023
3361350
fix(parallel_context.py): fix seed mode when TENSOR parallel
huangting4201 Dec 25, 2023
9b22258
feat(*) refactor fstp handler
mwiacx Dec 26, 2023
8e3196b
Merge branch 'feat/refactor-partition-strategy' into feat/refactor-fs…
mwiacx Dec 26, 2023
fe6fed7
feat(*): fix bug
mwiacx Dec 28, 2023
a80fbe3
fix(train/utils.py): fix zp size cheak and embed_param group
huangting4201 Jan 12, 2024
c01f015
Merge branch 'feat/refactor-partition-strategy' into feat/refactor-fs…
mwiacx Jan 12, 2024
1aebcd9
fix(model/util): force to pass communictor
mwiacx Jan 12, 2024
917ab0d
fix(model/utils.py): fix param set
huangting4201 Jan 12, 2024
b77787f
fix(hybrid_zero_optim.py): fix reduce scatter error when wp_size=1
huangting4201 Jan 12, 2024
594d61d
feat(model_checkpoint.py): model and optimizer save/load ckpt adapt t…
huangting4201 Jan 15, 2024
d87d9f9
Merge pull request #9 from yingtongxiong/feat/refactor-fstp-handler
huangting4201 Jan 15, 2024
e4d1ff8
fix(model_checkpoint.py): fix dp/zo size check
huangting4201 Jan 16, 2024
f2f88a7
support sequence parallel for moe
blankde Dec 27, 2023
6e012b1
modify expert groups
blankde Jan 17, 2024
18e6e78
feat(isp): support interleaved pipeline parallel scheduler
mwiacx Jan 17, 2024
55ebba0
add moe group
blankde Jan 17, 2024
ab039d5
fix(isp.py): fix comment
huangting4201 Jan 17, 2024
c113443
Merge pull request #1 from huangting4201/feat/support-interleaved-pp-…
huangting4201 Jan 17, 2024
8347ab4
feat(model): remove useless debug print
huangting4201 Jan 17, 2024
7ed1109
feat(model): fix lint error
huangting4201 Jan 17, 2024
ba254e3
merge huangting/feat/refactor-partition-strategy
blankde Jan 18, 2024
ccc2108
refactor code
blankde Jan 18, 2024
71543b3
fix(conflicts): resolve conflicts from merging develop
huangting4201 Jan 18, 2024
fac2b20
refactor code
blankde Jan 18, 2024
a83a94f
merge huangting/feat/refactor-partition-strategy
blankde Jan 18, 2024
05fa04a
feat(multi_head_attention.py): set bias=True
huangting4201 Jan 19, 2024
91bd3f9
fix bugs
blankde Jan 19, 2024
20f6b36
support moe checkpoint
blankde Jan 19, 2024
7cdeea8
fix(tests): fix ci test error
huangting4201 Jan 19, 2024
f959781
fix(tests): fix ci test error
huangting4201 Jan 19, 2024
e873668
fix(tests): fix ci test error
huangting4201 Jan 19, 2024
b99a642
fix(tests): fix ci test error
huangting4201 Jan 19, 2024
7ac53bf
fix(tests): fix ci test error
huangting4201 Jan 19, 2024
bb5835e
fix(tests): fix ci test error
huangting4201 Jan 19, 2024
d5872e7
fix(tests): fix ci test error
huangting4201 Jan 22, 2024
18da3fc
Merge branch 'feat/refactor-partition-strategy' of https://github.com…
blankde Jan 22, 2024
0aebd2c
update moe config file
blankde Jan 22, 2024
15610f6
adapt grad profiling
JiaoPL Jan 22, 2024
b007c43
Merge branch 'feat/refactor-partition-strategy' into feat/adapt_grad_…
JiaoPL Jan 22, 2024
c8b100e
fix(communication/isp.py): fix bias switch for mem pool
huangting4201 Jan 22, 2024
e1676f0
Merge branch 'feat/refactor-partition-strategy' into feat/adapt_grad_…
JiaoPL Jan 22, 2024
c606bb5
fix(model/utils.py): fix boolean value ambiguous error
huangting4201 Jan 22, 2024
d646f91
Merge branch 'feat/refactor-partition-strategy' into feat/adapt_grad_…
JiaoPL Jan 22, 2024
70a17d6
test grad profiling with mtp,msp,fsp,isp
JiaoPL Jan 22, 2024
4e9b276
feat(training_internlm.py): update initialize_model func to adapt to …
huangting4201 Jan 22, 2024
32df5ad
feat(training_internlm.py): move get_scheduler_hooks from train.py to…
huangting4201 Jan 22, 2024
d388ddc
feat(model): fix dict has no attri mode error
huangting4201 Jan 23, 2024
8e1b619
feat(training_internlm.py): move use_fp32_norm config to gpc.config
huangting4201 Jan 23, 2024
fd349f1
Merge branch 'feat/refactor-partition-strategy' of https://github.com…
blankde Jan 23, 2024
978cea8
feat(version): update internevo version and torch verion
huangting4201 Jan 24, 2024
d5fe8fe
feat(context/parallel_context.py): set default parallel size in paral…
huangting4201 Jan 24, 2024
0c8e0cf
Merge pull request #2 from blankde/feat/support_moe_for_isp
huangting4201 Jan 24, 2024
1d64a22
feat(format): fix ci lint check error
huangting4201 Jan 24, 2024
b0c6a20
feat(format): fix ci lint check error
huangting4201 Jan 24, 2024
571d83c
feat(format): fix ci lint check error
huangting4201 Jan 24, 2024
83517ca
feat(evaluation.py): fix evaluation error when msp/fsp with pp
huangting4201 Jan 24, 2024
48aca7f
Merge branch 'feat/refactor-partition-strategy' into feat/adapt_grad_…
JiaoPL Jan 25, 2024
0ec9b67
fix moe param groups
JiaoPL Jan 25, 2024
aa388b5
modify the distributedAttention for different data pack mode
yingtongxiong Jan 25, 2024
34b9479
feat(model/multi_head_attention.py): fix return output
huangting4201 Jan 25, 2024
10309b8
feat(utils/evaluation.py): rename gpc.evaluation to gpc.is_evaluating
huangting4201 Jan 25, 2024
4c8324a
feat(multi_head_attention.py): rename gpc.evaluation to gpc.is_evalua…
huangting4201 Jan 25, 2024
1b64785
Merge pull request #3 from JiaoPL/feat/adapt_grad_norm_profiling_for_…
huangting4201 Jan 25, 2024
f186a75
feat(communication/isp.py): refactor isp communicator to adapt to dif…
huangting4201 Jan 25, 2024
f064880
fix(conflicts): resolve conflicts from merging develop
huangting4201 Jan 26, 2024
3d7402d
fix(tests): fix ci test error
huangting4201 Jan 26, 2024
8170641
fix(tests): fix ci pipeline test error
huangting4201 Jan 26, 2024
85dd51f
feat(utils/common.py): remove func get_megatron_flops_2
huangting4201 Jan 26, 2024
971c8eb
feat(communication/isp.py): isp communicator support 0.x activation ckpt
huangting4201 Jan 29, 2024
6853bab
feat(train/training_internlm.py): move isp init to func initialize_is…
huangting4201 Jan 29, 2024
8c45118
feat(communication/isp.py): fix prefetch last ckpt block wait handle
huangting4201 Jan 29, 2024
e74f2dd
Merge pull request #4 from huangting4201/feat/isp-communicator-suppor…
huangting4201 Jan 29, 2024
011edcf
feat(utils/parallel.py): add func is_using_isp
huangting4201 Jan 29, 2024
f02523e
fix(tests): fix ci tests error
huangting4201 Jan 29, 2024
23ab67f
feat(model/modeling_llama.py): update model llama
huangting4201 Jan 30, 2024
f11422e
feat(model/utils.py): simplify code
huangting4201 Jan 30, 2024
4a27957
fix(conflicts): resolve conflicts from merging develop
huangting4201 Jan 30, 2024
8e1ee6f
feat(model/linear.py): update FeedForward class to internlm2
huangting4201 Jan 30, 2024
b5f9ada
Merge pull request #5 from huangting4201/feat/support-feedforwardv2-ckpt
huangting4201 Jan 30, 2024
d7928a6
fix(parallel_context.py): fix private repo ci tests error
huangting4201 Jan 30, 2024
1960dc0
feat(parallel_context.py): set zero1 parallel size >= 1
huangting4201 Jan 30, 2024
52ace84
fix(conflicts): resolve conflicts from merging develop
huangting4201 Jan 30, 2024
62a665d
feat(tests): add e2e test case for isp and enable pytorch expandable_…
huangting4201 Jan 31, 2024
e91acb4
feat(doc): update doc torch and flashattn version
huangting4201 Jan 31, 2024
2e4f749
Merge branch 'develop' into feat/refactor-partition-strategy
sunpengsdu Feb 1, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion .github/workflows/e2e_test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,21 @@ jobs:
- name: training_8GPU
run: |
source $evo_env
srun -p ${SLURM_PARTITION} --kill-on-bad-exit=1 --job-name=${GITHUB_RUN_ID}-${GITHUB_JOB} -n8 --ntasks-per-node=8 --cpus-per-task=4 --gpus-per-task=1 pytest -s -v --color=yes -m "training_8GPU" ./tests/test_training
srun -p ${SLURM_PARTITION} --kill-on-bad-exit=1 --job-name=${GITHUB_RUN_ID}-${GITHUB_JOB} -n8 --ntasks-per-node=8 --cpus-per-task=4 --gpus-per-task=1 pytest -s -v --color=yes -m "training_8GPU" ./tests/test_training/test_loss.py

training_8GPU_ISP:
runs-on: [t_cluster]
timeout-minutes: 10
steps:
- name: mask env
run: |
echo "::add-mask::${{env.WORKSPACE_PREFIX}}"
echo "::add-mask::$path_prefix"
- uses: actions/checkout@v3

- name: training_8GPU_ISP
run: |
source $evo_env
conda activate /mnt/petrelfs/share_data/huangting.p/envs/llm-torch2.1-flash2
conda activate /mnt/petrelfs/share_data/huangting.p/envs/llm-torch2.1-flash2
srun -p ${SLURM_PARTITION} --kill-on-bad-exit=1 --job-name=${GITHUB_RUN_ID}-${GITHUB_JOB} -n8 --ntasks-per-node=8 --cpus-per-task=4 --gpus-per-task=1 pytest -s -v --color=yes -m "training_8GPU_ISP" ./tests/test_training/test_loss.py
63 changes: 43 additions & 20 deletions configs/7B_MoE4_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
# 'load_ckpt_info' setting guide:
# 1. the 'path' indicate ckpt path,
# 2. the 'content‘ means what states will be loaded, support: "model", "sampler", "optimizer", "scheduler", "all"
# 3. the ’ckpt_type‘ means the type of checkpoint to be loaded, now only 'normal' type is supported.
# 3. the ’ckpt_type‘ means the type of checkpoint to be loaded, support: "internlm", "llama", "hf_llama".
load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("model",), ckpt_type="internlm"),
# 'auto_resume' is designed to automatically load the latest checkpoint from 'save_ckpt_folder' when encountering
# training interruptions/hangs caused by hardware failures, using a scheduling system (such as k8s/slurm)
Expand All @@ -44,8 +44,8 @@
oss_snapshot_freq=int(CHECKPOINT_EVERY / 2), # snapshot ckpt save frequency.
)

TRAIN_FOLDER = "/path/to/dataset"
VALID_FOLDER = "/path/to/dataset"
TRAIN_FOLDER = None # "/path/to/dataset"
VALID_FOLDER = None # "/path/to/dataset"
data = dict(
seq_len=SEQ_LEN,
# micro_num means the number of micro_batch contained in one gradient update
Expand All @@ -59,12 +59,17 @@
pack_sample_into_one=False,
total_steps=50000,
skip_batches="",
# rampup_batch_size (str): A string with three space-separated integers representing the
# starting batch size, the increment, and the number of steps between
# each increment. For example, "192 24 8" means that the batch size (micro_num)
# starts at 192 and increases by 24 every 8 steps. Defaults to None.
# (IMPORTANT): The interval step size is 'micro_bsz'.
rampup_batch_size="",
# Datasets with less than 50 rows will be discarded
min_length=50,
# train_folder=TRAIN_FOLDER,
# valid_folder=VALID_FOLDER,
empty_cache_and_diag_interval=10,
train_folder=TRAIN_FOLDER,
valid_folder=VALID_FOLDER,
empty_cache_and_diag_interval=200,
diag_outlier_ratio=1.1,
)

Expand Down Expand Up @@ -125,6 +130,7 @@
cur_iter=-1,
)

use_fp32_norm = False
model = dict(
checkpoint=False, # The proportion of layers for activation aheckpointing, the optional value are True/False/[0-1]
num_attention_heads=NUM_ATTENTION_HEAD,
Expand All @@ -145,23 +151,36 @@
moe_use_residual=False,
moe_type="GShard",
)

# zero1 parallel:
# 1. if zero1 <= 0, The size of the zero process group is equal to the size of the dp process group,
# so parameters will be divided within the range of dp.
# 2. if zero1 == 1, zero is not used, and all dp groups retain the full amount of model parameters.
# 3. zero1 > 1 and zero1 <= dp world size, the world size of zero is a subset of dp world size.
# For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8.
# pipeline parallel (dict):
# 1. size: int, the size of pipeline parallel.
# 2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler.
# tensor parallel: tensor parallel size, usually the number of GPUs per node.

"""
zero1 parallel (dict):
1. size: int
* if size <= 0, the size of the zero process group is equal to the size of the dp process group,
so parameters will be divided within the range of dp.
* if size == 1, zero is not used, and all dp groups retain the full amount of model parameters.
* if size > 1 and size <= dp world size, the world size of zero is a subset of dp world size.
For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8.
2. fsdp: bool, enable/disable torch's fully sharded data parallel, defaults to False.
tensor parallel (dict):
1. size: int, the size of tensor parallel.
2. mode: str, the tensor parallel mode, should be in ['mtp', 'msp', 'fsp', 'isp'],
defaults to 'mtp', means the pure megatron tensor parallel without sequence parallel.
msp: megatron tensor parallel with sequence parallel, sequence parallel size = tensor parallel size.
fsp: tensor parallel by flash-attn with sequence parallel, sequence parallel size = tensor parallel size.
isp: customed intern sequence parallel without tensor parallel, can be used with weight parallel.
pipeline parallel (dict):
1. size: int, the size of pipeline parallel.
2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler,
defaults to False.
weight parallel (dict):
1. size: int, the size of weight parallel.
2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False.
3. memory_pool: bool, enable/disable memory pool, defaults to False.
"""
parallel = dict(
zero1=dict(size=-1, fsdp=False),
tensor=1,
tensor=dict(size=1, mode="mtp"),
pipeline=dict(size=1, interleaved_overlap=True),
sequence_parallel=False,
weight=dict(size=1, overlap=True, memory_pool=True),
)

cudnn_deterministic = False
Expand All @@ -173,6 +192,10 @@
enable_feishu_alert=DO_ALERT,
feishu_alert_address=None, # feishu webhook to send alert message
light_monitor_address=None, # light_monitor address to send heartbeat
alert_file_path=f"llm_alter/{JOB_NAME}_alert.log",
),
tensorboard=dict(
queue_max_length=10,
),
)

Expand Down
200 changes: 200 additions & 0 deletions configs/7B_isp_sft.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
JOB_NAME = "7b_train"
DO_ALERT = False

SEQ_LEN = 2048
HIDDEN_SIZE = 4096
NUM_ATTENTION_HEAD = 32
MLP_RATIO = 8 / 3
NUM_LAYER = 32
VOCAB_SIZE = 103168

MODEL_ONLY_FOLDER = "local:llm_ckpts/xxxx"
# Ckpt folder format:
# fs: 'local:/mnt/nfs/XXX'
SAVE_CKPT_FOLDER = "local:llm_ckpts"
LOAD_CKPT_FOLDER = "local:llm_ckpts/49"

# boto3 Ckpt folder format:
# import os
# BOTO3_IP = os.environ["BOTO3_IP"] # boto3 bucket endpoint
# SAVE_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm"
# LOAD_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm/snapshot/1/"
CHECKPOINT_EVERY = 50
ckpt = dict(
enable_save_ckpt=False, # enable ckpt save.
save_ckpt_folder=SAVE_CKPT_FOLDER, # Path to save training ckpt.
# load_ckpt_folder= dict(path=MODEL_ONLY_FOLDER, content=["model"], ckpt_type="normal"),
load_ckpt_folder="local:llm_ckpts/",
# 'load_ckpt_info' setting guide:
# 1. the 'path' indicate ckpt path,
# 2. the 'content‘ means what states will be loaded, support: "model", "sampler", "optimizer", "scheduler", "all"
# 3. the ’ckpt_type‘ means the type of checkpoint to be loaded, support: "internlm", "llama", "hf_llama".
load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("model",), ckpt_type="internlm"),
# 'auto_resume' is designed to automatically load the latest checkpoint from 'save_ckpt_folder' when encountering
# training interruptions/hangs caused by hardware failures, using a scheduling system (such as k8s/slurm)
# with an automatic restart mechanism upon training reboot.
# Please be aware that if `auto_resume` is not set (its default value is True), it will not load the checkpoint
# path specified in `load_ckpt_info` by default.
# If you want to initialize your model weights from another model, you must set `auto_resume` to False.
# If you want to train from scratch, please set `auto_resume` to False and 'load_ckpt_info' to None.
auto_resume=True,
checkpoint_every=CHECKPOINT_EVERY,
async_upload=True, # async ckpt upload. (only work for boto3 ckpt)
async_upload_tmp_folder="/dev/shm/internlm_tmp_ckpt/", # path for temporarily files during asynchronous upload.
oss_snapshot_freq=int(CHECKPOINT_EVERY / 2), # snapshot ckpt save frequency.
)

TRAIN_FOLDER = None # "/path/to/dataset"
VALID_FOLDER = None # "/path/to/dataset"
data = dict(
seq_len=SEQ_LEN,
# micro_num means the number of micro_batch contained in one gradient update
micro_num=4,
# packed_length = micro_bsz * SEQ_LEN
micro_bsz=2,
# defaults to the value of micro_num
valid_micro_num=4,
# defaults to 0, means disable evaluate
valid_every=50,
pack_sample_into_one=False,
total_steps=50000,
skip_batches="",
# rampup_batch_size (str): A string with three space-separated integers representing the
# starting batch size, the increment, and the number of steps between
# each increment. For example, "192 24 8" means that the batch size (micro_num)
# starts at 192 and increases by 24 every 8 steps. Defaults to None.
# (IMPORTANT): The interval step size is 'micro_bsz'.
rampup_batch_size="",
# Datasets with less than 50 rows will be discarded
min_length=50,
train_folder=TRAIN_FOLDER,
valid_folder=VALID_FOLDER,
empty_cache_and_diag_interval=200,
diag_outlier_ratio=1.1,
)

grad_scaler = dict(
fp16=dict(
# the initial loss scale, defaults to 2**16
initial_scale=2**16,
# the minimum loss scale, defaults to None
min_scale=1,
# the number of steps to increase loss scale when no overflow occurs
growth_interval=1000,
),
# the multiplication factor for increasing loss scale, defaults to 2
growth_factor=2,
# the multiplication factor for decreasing loss scale, defaults to 0.5
backoff_factor=0.5,
# the maximum loss scale, defaults to None
max_scale=2**24,
# the number of overflows before decreasing loss scale, defaults to 2
hysteresis=2,
)

hybrid_zero_optimizer = dict(
# Enable low_level_optimzer overlap_communication
overlap_sync_grad=True,
overlap_sync_param=False,
# bucket size for nccl communication params
reduce_bucket_size=512 * 1024 * 1024,
# grad clipping
clip_grad_norm=1.0,
)

loss = dict(
label_smoothing=0,
)

adam = dict(
lr=1e-4,
adam_beta1=0.9,
adam_beta2=0.95,
adam_beta2_c=0,
adam_eps=1e-8,
weight_decay=0.01,
)

lr_scheduler = dict(
total_steps=data["total_steps"],
init_steps=0, # optimizer_warmup_step
warmup_ratio=0.01,
eta_min=1e-5,
last_epoch=-1,
)

beta2_scheduler = dict(
init_beta2=adam["adam_beta2"],
c=adam["adam_beta2_c"],
cur_iter=-1,
)

use_fp32_norm = False
model = dict(
checkpoint=False, # The proportion of layers for activation aheckpointing, the optional value are True/False/[0-1]
num_attention_heads=NUM_ATTENTION_HEAD,
embed_split_hidden=True,
vocab_size=VOCAB_SIZE,
embed_grad_scale=1,
parallel_output=True,
hidden_size=HIDDEN_SIZE,
num_layers=NUM_LAYER,
mlp_ratio=MLP_RATIO,
apply_post_layer_norm=False,
dtype="torch.bfloat16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32"
norm_type="rmsnorm",
layer_norm_epsilon=1e-5,
use_flash_attn=True,
num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used.
)
"""
zero1 parallel (dict):
1. size: int
* if size <= 0, the size of the zero process group is equal to the size of the dp process group,
so parameters will be divided within the range of dp.
* if size == 1, zero is not used, and all dp groups retain the full amount of model parameters.
* if size > 1 and size <= dp world size, the world size of zero is a subset of dp world size.
For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8.
2. fsdp: bool, enable/disable torch's fully sharded data parallel, defaults to False.
tensor parallel (dict):
1. size: int, the size of tensor parallel.
2. mode: str, the tensor parallel mode, should be in ['mtp', 'msp', 'fsp', 'isp'],
defaults to 'mtp', means the pure megatron tensor parallel without sequence parallel.
msp: megatron tensor parallel with sequence parallel, sequence parallel size = tensor parallel size.
fsp: tensor parallel by flash-attn with sequence parallel, sequence parallel size = tensor parallel size.
isp: customed intern sequence parallel without tensor parallel, can be used with weight parallel.
pipeline parallel (dict):
1. size: int, the size of pipeline parallel.
2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler,
defaults to False.
weight parallel (dict):
1. size: int, the size of weight parallel.
2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False.
3. memory_pool: bool, enable/disable memory pool, defaults to False.
"""
parallel = dict(
zero1=dict(size=-1),
tensor=dict(size=2, mode="isp"),
pipeline=dict(size=1, interleaved_overlap=True),
weight=dict(size=4, overlap=True, memory_pool=True),
)

cudnn_deterministic = False
cudnn_benchmark = False

monitor = dict(
# feishu alert configs
alert=dict(
enable_feishu_alert=DO_ALERT,
feishu_alert_address=None, # feishu webhook to send alert message
light_monitor_address=None, # light_monitor address to send heartbeat
alert_file_path=f"llm_alter/{JOB_NAME}_alert.log",
),
tensorboard=dict(
queue_max_length=10,
),
)

# metric_dtype can be "fp32" or other string
# only when set to "fp32" will use fp32 to calc in metrics
# metric_dtype = "fp32"
Loading
Loading