Skip to content

Commit

Permalink
remove isp memory pool (#330)
Browse files Browse the repository at this point in the history
  • Loading branch information
mwiacx authored Sep 13, 2024
1 parent 77e6cb7 commit 1403550
Show file tree
Hide file tree
Showing 37 changed files with 53 additions and 256 deletions.
6 changes: 2 additions & 4 deletions configs/1.8B_MoE16_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,6 @@
weight parallel (dict):
1. size: int, the size of weight parallel.
2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False.
3. memory_pool: bool, enable/disable memory pool, defaults to False.
expert parallel (dict):
1. size: int
* if size <= 0, ep size equals to dp size, but if the number of experts is smaller than dp size, set ep size
Expand All @@ -196,15 +195,14 @@
expert weight parallel (dict):
1. size: int, the size of weight parallel for expert module, distinct with global weight parallel size.
2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False.
3. memory_pool: bool, enable/disable memory pool, defaults to False.
"""
parallel = dict(
zero1=dict(size=-1, fsdp=False),
tensor=dict(size=1, mode="mtp"),
pipeline=dict(size=1, interleaved_overlap=True),
weight=dict(size=1, overlap=True, memory_pool=True),
weight=dict(size=1, overlap=True),
expert=dict(size=-1, no_tp=False),
expert_weight=dict(size=1, overlap=True, memory_pool=True),
expert_weight=dict(size=1, overlap=True),
)

cudnn_deterministic = False
Expand Down
6 changes: 2 additions & 4 deletions configs/7B_MoE4_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,6 @@
weight parallel (dict):
1. size: int, the size of weight parallel.
2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False.
3. memory_pool: bool, enable/disable memory pool, defaults to False.
expert parallel (dict):
1. size: int
* if size <= 0, ep size equals to dp size, but if the number of experts is smaller than dp size, set ep size
Expand All @@ -194,15 +193,14 @@
expert weight parallel (dict):
1. size: int, the size of weight parallel for expert module, distinct with global weight parallel size.
2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False.
3. memory_pool: bool, enable/disable memory pool, defaults to False.
"""
parallel = dict(
zero1=dict(size=-1, fsdp=False),
tensor=dict(size=1, mode="mtp"),
pipeline=dict(size=1, interleaved_overlap=True),
weight=dict(size=1, overlap=True, memory_pool=True),
weight=dict(size=1, overlap=True),
expert=dict(size=-1, no_tp=False),
expert_weight=dict(size=1, overlap=True, memory_pool=True),
expert_weight=dict(size=1, overlap=True),
)

cudnn_deterministic = False
Expand Down
3 changes: 1 addition & 2 deletions configs/7B_baichuan2.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,13 +180,12 @@
weight parallel (dict):
1. size: int, the size of weight parallel.
2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False.
3. memory_pool: bool, enable/disable memory pool, defaults to False.
"""
parallel = dict(
zero1=dict(size=-1),
tensor=dict(size=1, mode="mtp"),
pipeline=dict(size=1, interleaved_overlap=True),
weight=dict(size=1, overlap=True, memory_pool=True),
weight=dict(size=1, overlap=True),
)

cudnn_deterministic = False
Expand Down
3 changes: 1 addition & 2 deletions configs/7B_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,13 +187,12 @@
weight parallel (dict):
1. size: int, the size of weight parallel.
2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False.
3. memory_pool: bool, enable/disable memory pool, defaults to False.
"""
parallel = dict(
zero1=dict(size=-1),
tensor=dict(size=1, mode="mtp"),
pipeline=dict(size=1, interleaved_overlap=True),
weight=dict(size=1, overlap=True, memory_pool=True),
weight=dict(size=1, overlap=True),
)

cudnn_deterministic = False
Expand Down
9 changes: 4 additions & 5 deletions configs/7B_internlm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
# defaults to 0, means disable evaluate
valid_every=0,
pack_sample_into_one=False,
total_steps=20,
total_steps=20000,
skip_batches="",
# rampup_batch_size (str): A string with three space-separated integers representing the
# starting batch size, the increment, and the number of steps between
Expand Down Expand Up @@ -177,13 +177,12 @@
weight parallel (dict):
1. size: int, the size of weight parallel.
2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False.
3. memory_pool: bool, enable/disable memory pool, defaults to False.
"""
parallel = dict(
zero1=dict(size=8),
tensor=dict(size=1, mode="mtp"),
zero1=dict(size=-1),
tensor=dict(size=2, mode="isp"),
pipeline=dict(size=1, interleaved_overlap=True),
weight=dict(size=1, overlap=True, memory_pool=True),
weight=dict(size=2, overlap=True),
)

cudnn_deterministic = False
Expand Down
3 changes: 1 addition & 2 deletions configs/7B_isp_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,6 @@
weight parallel (dict):
1. size: int, the size of weight parallel.
2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False.
3. memory_pool: bool, enable/disable memory pool, defaults to False.
sequence_2D (dict):
1. enable: bool, whether enable the 2D sequence parallel or not.
2. head_size: int, the parallel degree of head parallelism (DeepSpeed Ulysses).
Expand All @@ -206,7 +205,7 @@
zero1=dict(size=-1),
tensor=dict(size=2, mode="isp"),
pipeline=dict(size=1, interleaved_overlap=True),
weight=dict(size=4, overlap=True, memory_pool=False),
weight=dict(size=4, overlap=True),
sequence_2D=dict(
enable=False,
head_size=2,
Expand Down
3 changes: 1 addition & 2 deletions configs/7B_llama2.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,13 +177,12 @@
weight parallel (dict):
1. size: int, the size of weight parallel.
2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False.
3. memory_pool: bool, enable/disable memory pool, defaults to False.
"""
parallel = dict(
zero1=dict(size=-1),
tensor=dict(size=1, mode="mtp"),
pipeline=dict(size=1, interleaved_overlap=True),
weight=dict(size=1, overlap=True, memory_pool=True),
weight=dict(size=1, overlap=True),
)

cudnn_deterministic = False
Expand Down
3 changes: 1 addition & 2 deletions configs/7B_qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,13 +187,12 @@
weight parallel (dict):
1. size: int, the size of weight parallel.
2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False.
3. memory_pool: bool, enable/disable memory pool, defaults to False.
"""
parallel = dict(
zero1=dict(size=-1),
tensor=dict(size=1, mode="mtp"),
pipeline=dict(size=1, interleaved_overlap=True),
weight=dict(size=1, overlap=True, memory_pool=True),
weight=dict(size=1, overlap=True),
)

cudnn_deterministic = False
Expand Down
3 changes: 1 addition & 2 deletions configs/7B_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,13 +190,12 @@
weight parallel (dict):
1. size: int, the size of weight parallel.
2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False.
3. memory_pool: bool, enable/disable memory pool, defaults to False.
"""
parallel = dict(
zero1=dict(size=-1),
tensor=dict(size=1, mode="mtp"),
pipeline=dict(size=1, interleaved_overlap=True, zero_bubble=False),
weight=dict(size=1, overlap=True, memory_pool=True),
weight=dict(size=1, overlap=True),
)

cudnn_deterministic = False
Expand Down
3 changes: 1 addition & 2 deletions configs/_base_/models/internlm2_1B.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,10 @@
weight parallel (dict):
1. size: int, the size of weight parallel.
2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False.
3. memory_pool: bool, enable/disable memory pool, defaults to False.
"""
parallel = dict(
zero1=dict(size=8),
tensor=dict(size=1, mode="mtp"),
pipeline=dict(size=1, interleaved_overlap=True),
weight=dict(size=1, overlap=True, memory_pool=True),
weight=dict(size=1, overlap=True),
)
3 changes: 1 addition & 2 deletions configs/_base_/models/internlm2_20B.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,10 @@
weight parallel (dict):
1. size: int, the size of weight parallel.
2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False.
3. memory_pool: bool, enable/disable memory pool, defaults to False.
"""
parallel = dict(
zero1=dict(size=16),
tensor=dict(size=2, mode="fsp"),
pipeline=dict(size=1, interleaved_overlap=True),
weight=dict(size=1, overlap=True, memory_pool=True),
weight=dict(size=1, overlap=True),
)
3 changes: 1 addition & 2 deletions configs/_base_/models/internlm2_7B.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,10 @@
weight parallel (dict):
1. size: int, the size of weight parallel.
2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False.
3. memory_pool: bool, enable/disable memory pool, defaults to False.
"""
parallel = dict(
zero1=dict(size=8),
tensor=dict(size=1, mode="mtp"),
pipeline=dict(size=1, interleaved_overlap=True),
weight=dict(size=1, overlap=True, memory_pool=True),
weight=dict(size=1, overlap=True),
)
3 changes: 1 addition & 2 deletions configs/_base_/models/internlm_20B.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,10 @@
weight parallel (dict):
1. size: int, the size of weight parallel.
2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False.
3. memory_pool: bool, enable/disable memory pool, defaults to False.
"""
parallel = dict(
zero1=dict(size=8),
tensor=dict(size=4, mode="mtp"),
pipeline=dict(size=1, interleaved_overlap=True),
weight=dict(size=1, overlap=True, memory_pool=True),
weight=dict(size=1, overlap=True),
)
3 changes: 1 addition & 2 deletions configs/_base_/models/internlm_7B.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,10 @@
weight parallel (dict):
1. size: int, the size of weight parallel.
2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False.
3. memory_pool: bool, enable/disable memory pool, defaults to False.
"""
parallel = dict(
zero1=dict(size=8),
tensor=dict(size=1, mode="mtp"),
pipeline=dict(size=1, interleaved_overlap=True),
weight=dict(size=1, overlap=True, memory_pool=True),
weight=dict(size=1, overlap=True),
)
2 changes: 1 addition & 1 deletion configs/demo_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@
zero1=dict(size=-1),
tensor=dict(size=1, mode="mtp"),
pipeline=dict(size=1, interleaved_overlap=True),
weight=dict(size=1, overlap=True, memory_pool=True),
weight=dict(size=1, overlap=True),
)

cudnn_deterministic = False
Expand Down
6 changes: 2 additions & 4 deletions doc/usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -283,13 +283,12 @@ pipeline parallel (dict):
weight parallel (dict):
1. size: int, the size of weight parallel.
2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False.
3. memory_pool: bool, enable/disable memory pool, defaults to False.
"""
parallel = dict(
zero1=dict(size=-1),
tensor=dict(size=1, mode="mtp"),
pipeline=dict(size=1, interleaved_overlap=True),
weight=dict(size=1, overlap=True, memory_pool=True),
weight=dict(size=1, overlap=True),
)

cudnn_deterministic = False
Expand Down Expand Up @@ -425,7 +424,7 @@ parallel = dict(
zero1=dict(size=-1),
tensor=dict(size=1, mode="mtp"),
pipeline=dict(size=1, interleaved_overlap=True),
weight=dict(size=1, overlap=True, memory_pool=True),
weight=dict(size=1, overlap=True),
)
```
- zero1(字典):
Expand All @@ -447,7 +446,6 @@ parallel = dict(
- weight(字典):
1. size: 整数,权重并行的大小。
2. overlap: 布尔值,启用/禁用all_gather/reduce_scatter通信重叠,默认为False。
3. memory_pool: 布尔值,启用/禁用内存池,默认为False。

注意:`数据并行大小 = 总的 GPU 数目 / 流水线并行大小 / 张量并行大小`

Expand Down
4 changes: 2 additions & 2 deletions internlm/core/context/parallel_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,11 +506,11 @@ def init_parallel_groups(self):
if "tensor" not in parallel_config:
parallel_config._add_item("tensor", dict(size=1, mode=TensorParallelMode.mtp.name))
if "weight" not in parallel_config:
parallel_config._add_item("weight", dict(size=1, overlap=False, memory_pool=False))
parallel_config._add_item("weight", dict(size=1, overlap=False))
if "expert" not in parallel_config:
parallel_config._add_item("expert", dict(size=-1, no_tp=False))
if "expert_weight" not in parallel_config:
parallel_config._add_item("expert_weight", dict(size=1, overlap=False, memory_pool=False))
parallel_config._add_item("expert_weight", dict(size=1, overlap=False))
# set default value for sequence_2D
if "sequence_2D" not in parallel_config:
parallel_config._add_item(
Expand Down
Loading

0 comments on commit 1403550

Please sign in to comment.