From 3a226b167168c07a0b4e62a478ea73cecda190b1 Mon Sep 17 00:00:00 2001 From: zigzagcai Date: Fri, 27 Sep 2024 16:41:00 +0800 Subject: [PATCH] force megatron dataloader re-use StaticBatchSampler --- internlm/data/build_dataloader.py | 21 ++++++--- internlm/data/megatron/__init__.py | 2 - internlm/data/megatron/batch_sampler.py | 62 ------------------------- 3 files changed, 14 insertions(+), 71 deletions(-) delete mode 100644 internlm/data/megatron/batch_sampler.py diff --git a/internlm/data/build_dataloader.py b/internlm/data/build_dataloader.py index 039282c2..e99bbfc7 100644 --- a/internlm/data/build_dataloader.py +++ b/internlm/data/build_dataloader.py @@ -2,12 +2,13 @@ import subprocess from functools import partial +import torch import torch.distributed as dist from torch.utils.data import ConcatDataset, DataLoader +from internlm.accelerator.abstract_accelerator import get_accelerator from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc -from internlm.data.megatron.batch_sampler import MegatronBatchSampler from internlm.data.megatron.collaters import megatron_collate_fn from internlm.data.megatron.dataset import build_megatron_dataset from internlm.data.mocked.batch_sampler import MockedSequentialBatchSampler @@ -41,8 +42,8 @@ from internlm.utils.logger import get_logger from internlm.utils.utils import DataType -# global llm logger logger = get_logger(__file__) +internlm_accelerator = get_accelerator() def get_tokenized_train_loader_items(data_cfg): @@ -162,7 +163,8 @@ def get_megatron_train_loader_items(data_cfg): try: from internlm.data.megatron import helpers # noqa # pylint: disable=W0611 except ImportError: - if gpc.is_rank_for_log(): + # Compile dynamic library on-demand + if gpc.get_global_rank() % internlm_accelerator.device_count() == 0: subprocess.run( # noqa # pylint: disable=W1510 [ "g++", @@ -176,8 +178,9 @@ def get_megatron_train_loader_items(data_cfg): "internlm/data/megatron/helpers.cpp", "-o", "internlm/data/megatron/helpers.so", - ] + ], ) + torch.distributed.barrier() # NOTICE: Currently we only support single megatron dataset, a.k.a., single .bin and .idx # Megatron dataset (.bin and.idx) should be generated by Megatron-LM tools/preprocess_data.py @@ -188,11 +191,15 @@ def get_megatron_train_loader_items(data_cfg): seed=data_cfg.get("seed", 1024), ) - train_sampler = MegatronBatchSampler( - total_samples=len(train_ds), - consumed_samples=0, + train_sampler = StaticBatchSampler( + train_ds.datasets if isinstance(train_ds, ConcatDataset) else [train_ds], batch_size=data_cfg.micro_num * data_cfg.micro_bsz, + rampup_batch_size=data_cfg.rampup_batch_size, + micro_bsz=data_cfg.micro_bsz, + seed=data_cfg.get("seed", 1024), drop_last=True, + data_rank=gpc.get_local_rank(ParallelMode.DATA), + data_world_size=gpc.get_world_size(ParallelMode.DATA), ) train_collate_fn = partial( diff --git a/internlm/data/megatron/__init__.py b/internlm/data/megatron/__init__.py index 5e447596..5405f6f8 100644 --- a/internlm/data/megatron/__init__.py +++ b/internlm/data/megatron/__init__.py @@ -1,9 +1,7 @@ -from .batch_sampler import MegatronBatchSampler from .collaters import megatron_collate_fn from .dataset import build_megatron_dataset __all__ = [ - "MegatronBatchSampler", "build_megatron_dataset", "megatron_collate_fn", ] diff --git a/internlm/data/megatron/batch_sampler.py b/internlm/data/megatron/batch_sampler.py deleted file mode 100644 index 88b44c62..00000000 --- a/internlm/data/megatron/batch_sampler.py +++ /dev/null @@ -1,62 +0,0 @@ -import copy -import math - -from internlm.core.context import ParallelMode -from internlm.core.context import global_context as gpc - - -class MegatronBatchSampler: - """ - MegatronBatchSampler - """ - - def __init__(self, total_samples, consumed_samples, batch_size, drop_last=True): - # Keep a copy of input params for later use. - self.total_samples = total_samples - self.consumed_samples = consumed_samples - self.batch_size = batch_size - self.drop_last = drop_last - - self.dp_rank = gpc.get_local_rank(ParallelMode.DATA) - self.dp_size = gpc.get_world_size(ParallelMode.DATA) - - # Sanity checks. - assert self.total_samples > 0, "no sample to consume: {}".format(self.total_samples) - assert self.consumed_samples < self.total_samples, "no samples left to consume: {}, {}".format( - self.consumed_samples, self.total_samples - ) - assert self.batch_size > 0 - assert self.dp_size > 0 - assert self.dp_rank < self.dp_size, "dp_rank should be smaller than dp_size: {}, " "{}".format( - self.dp_rank, self.dp_size - ) - - def __len__(self): - if self.drop_last and self.total_samples % self.dp_size != 0: - return math.ceil(self.total_samples - self.dp_size) / self.dp_size - else: - return math.ceil(self.total_samples / self.dp_size) - - def get_start_end_idx(self): - start_idx = self.dp_rank * self.batch_size - end_idx = start_idx + self.batch_size - return start_idx, end_idx - - def __iter__(self): - batch = [] - # Last batch will be dropped if drop_last is not set False - for idx in range(self.consumed_samples, self.total_samples): - batch.append(idx) - if len(batch) == self.batch_size * self.dp_size: - start_idx, end_idx = self.get_start_end_idx() - yield batch[start_idx:end_idx] - batch = [] - - # Check the last partial batch and see drop_last is set - if len(batch) > 0 and not self.drop_last: - start_idx, end_idx = self.get_start_end_idx() - yield batch[start_idx:end_idx] - - # TODO: implement copy method that compatible with InternEvo trainstate ckpt save and load. - def copy(self): - return copy.deepcopy(self)