Skip to content

Commit

Permalink
force megatron dataloader re-use StaticBatchSampler
Browse files Browse the repository at this point in the history
  • Loading branch information
zigzagcai committed Sep 27, 2024
1 parent a572ca1 commit 3a226b1
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 71 deletions.
21 changes: 14 additions & 7 deletions internlm/data/build_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
import subprocess
from functools import partial

import torch
import torch.distributed as dist
from torch.utils.data import ConcatDataset, DataLoader

from internlm.accelerator.abstract_accelerator import get_accelerator
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.data.megatron.batch_sampler import MegatronBatchSampler
from internlm.data.megatron.collaters import megatron_collate_fn
from internlm.data.megatron.dataset import build_megatron_dataset
from internlm.data.mocked.batch_sampler import MockedSequentialBatchSampler
Expand Down Expand Up @@ -41,8 +42,8 @@
from internlm.utils.logger import get_logger
from internlm.utils.utils import DataType

# global llm logger
logger = get_logger(__file__)
internlm_accelerator = get_accelerator()


def get_tokenized_train_loader_items(data_cfg):
Expand Down Expand Up @@ -162,7 +163,8 @@ def get_megatron_train_loader_items(data_cfg):
try:
from internlm.data.megatron import helpers # noqa # pylint: disable=W0611
except ImportError:
if gpc.is_rank_for_log():
# Compile dynamic library on-demand
if gpc.get_global_rank() % internlm_accelerator.device_count() == 0:
subprocess.run( # noqa # pylint: disable=W1510
[
"g++",
Expand All @@ -176,8 +178,9 @@ def get_megatron_train_loader_items(data_cfg):
"internlm/data/megatron/helpers.cpp",
"-o",
"internlm/data/megatron/helpers.so",
]
],
)
torch.distributed.barrier()

# NOTICE: Currently we only support single megatron dataset, a.k.a., single .bin and .idx
# Megatron dataset (.bin and.idx) should be generated by Megatron-LM tools/preprocess_data.py
Expand All @@ -188,11 +191,15 @@ def get_megatron_train_loader_items(data_cfg):
seed=data_cfg.get("seed", 1024),
)

train_sampler = MegatronBatchSampler(
total_samples=len(train_ds),
consumed_samples=0,
train_sampler = StaticBatchSampler(
train_ds.datasets if isinstance(train_ds, ConcatDataset) else [train_ds],
batch_size=data_cfg.micro_num * data_cfg.micro_bsz,
rampup_batch_size=data_cfg.rampup_batch_size,
micro_bsz=data_cfg.micro_bsz,
seed=data_cfg.get("seed", 1024),
drop_last=True,
data_rank=gpc.get_local_rank(ParallelMode.DATA),
data_world_size=gpc.get_world_size(ParallelMode.DATA),
)

train_collate_fn = partial(
Expand Down
2 changes: 0 additions & 2 deletions internlm/data/megatron/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
from .batch_sampler import MegatronBatchSampler
from .collaters import megatron_collate_fn
from .dataset import build_megatron_dataset

__all__ = [
"MegatronBatchSampler",
"build_megatron_dataset",
"megatron_collate_fn",
]
62 changes: 0 additions & 62 deletions internlm/data/megatron/batch_sampler.py

This file was deleted.

0 comments on commit 3a226b1

Please sign in to comment.