Skip to content

Commit

Permalink
add more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ramanishsingh committed Dec 4, 2024
1 parent ce98c57 commit fa556eb
Showing 1 changed file with 103 additions and 3 deletions.
106 changes: 103 additions & 3 deletions test/nodes/test_multi_node_round_robin_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@
# LICENSE file in the root directory of this source tree.

import collections
import itertools
import math

import torch
from parameterized import parameterized
from torch.testing._internal.common_utils import TestCase

from torchdata.nodes.adapters import IterableWrapper
Expand Down Expand Up @@ -40,6 +43,82 @@ def get_unequal_dataset(self, num_samples, num_datasets):
}
return datasets

def test_empty_datasets(self) -> None:
datasets = self.get_equal_dataset(0, self._num_datasets)
sampler = MultiNodeRoundRobinSampler(datasets, StopCriteria.FIRST_DATASET_EXHAUSTED)
batch_size = 3
batcher = Batcher(sampler, batch_size=batch_size)
for batch in batcher:
self.fail("Expected no batches as each dataset is empty.")

@parameterized.expand([4, 8, 16])
def test_single_dataset(self, num_samples: int) -> None:
datasets = self.get_equal_dataset(num_samples, 1)
sampler = MultiNodeRoundRobinSampler(datasets, StopCriteria.FIRST_DATASET_EXHAUSTED)
batch_size = 4
batcher = Batcher(sampler, batch_size=batch_size)
for batch_number, batch in enumerate(batcher):
self.assertGreater(len(batch), 0)
self.assertEqual(len(batch), batch_size)
self.assertEqual(batch_number + 1, num_samples // batch_size)

@parameterized.expand(
itertools.product(
[8, 16, 32],
[True, False],
)
)
def test_single_dataset_drop_last(self, num_samples: int, drop_last: bool) -> None:
datasets = self.get_equal_dataset(num_samples, 1)
sampler = MultiNodeRoundRobinSampler(datasets, StopCriteria.FIRST_DATASET_EXHAUSTED)
batch_size = 5
batcher = Batcher(sampler, batch_size=batch_size, drop_last=drop_last)
num_batches = 0
for batch_number, batch in enumerate(batcher):
num_batches += 1
self.assertGreater(len(batch), 0)
if drop_last:
self.assertEqual(len(batch), batch_size)
if drop_last:
self.assertEqual(num_batches, math.ceil(num_samples / batch_size) - 1)
else:
self.assertEqual(num_batches, math.ceil(num_samples / batch_size))

def test_stop_criteria_all_datasets_exhausted(self) -> None:
datasets = self.get_unequal_dataset(self._num_samples, self._num_datasets)
total_items = sum(range(self._num_samples, self._num_samples + self._num_datasets))
sampler = MultiNodeRoundRobinSampler(datasets, StopCriteria.ALL_DATASETS_EXHAUSTED)
batch_size = 3
batcher = Batcher(sampler, batch_size=batch_size, drop_last=True)
num_batches = 0
for batch in batcher:
num_batches += 1
self.assertEqual(num_batches, total_items // batch_size)

def test_stop_criteria_first_dataset_exhausted(self) -> None:
num_samples = 4
datasets = self.get_unequal_dataset(
num_samples, self._num_datasets
) # first dataset has 4 samples, second has 5, and third has 6
sampler = MultiNodeRoundRobinSampler(datasets, StopCriteria.FIRST_DATASET_EXHAUSTED)
batch_size = 2
batcher = Batcher(sampler, batch_size=batch_size)
num_batches = 0
for batch in batcher:
num_batches += 1
self.assertEqual(num_batches, num_samples * self._num_datasets // batch_size)

def test_stop_criteria_cycle_until_all_datasets_exhausted(self) -> None:
num_samples = 4
datasets = self.get_unequal_dataset(num_samples, self._num_datasets)
sampler = MultiNodeRoundRobinSampler(datasets, StopCriteria.CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED)
batch_size = 3
batcher = Batcher(sampler, batch_size=batch_size)
num_batches = 0
for batch in batcher:
num_batches += 1
self.assertEqual(num_batches, (num_samples + self._num_datasets - 1))

def test_multi_node_round_robin_sampler_equal_dataset(self) -> None:
datasets = self.get_equal_dataset(self._num_samples, self._num_datasets)
sampler = MultiNodeRoundRobinSampler(datasets, StopCriteria.FIRST_DATASET_EXHAUSTED)
Expand Down Expand Up @@ -161,13 +240,23 @@ def test_get_state(self) -> None:
self.assertIn("datasets_exhausted", state)
self.assertIn("dataset_node_states", state)

def test_save_load_state(self) -> None:
@parameterized.expand(
itertools.product(
[100, 500, 1200],
[
StopCriteria.ALL_DATASETS_EXHAUSTED,
StopCriteria.FIRST_DATASET_EXHAUSTED,
StopCriteria.CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED,
],
)
)
def test_save_load_state(self, midpoint: int, stop_criteria: str) -> None:
num_samples = 1500
num_datasets = 3
datasets = self.get_equal_dataset(num_samples, num_datasets)
sampler = MultiNodeRoundRobinSampler(datasets, StopCriteria.ALL_DATASETS_EXHAUSTED)
sampler = MultiNodeRoundRobinSampler(datasets, stop_criteria)
prefetcher = Prefetcher(sampler, 3)
run_test_save_load_state(self, prefetcher, 400)
run_test_save_load_state(self, prefetcher, midpoint)

datasets = self.get_unequal_dataset(num_samples, num_datasets)
sampler = MultiNodeRoundRobinSampler(datasets, StopCriteria.ALL_DATASETS_EXHAUSTED)
Expand All @@ -189,3 +278,14 @@ def test_multiple_epochs(self) -> None:
for batch in loader:
results[epoch].append(batch)
self.assertEqual(results[0], results[1])

def test_get_state_after_reset(self) -> None:
datasets = self.get_equal_dataset(self._num_samples, self._num_datasets)
sampler = MultiNodeRoundRobinSampler(datasets, StopCriteria.FIRST_DATASET_EXHAUSTED)
batch_size = 3
batcher = Batcher(sampler, batch_size=batch_size)
next(batcher)
state_before_reset = sampler.get_state()
sampler.reset()
state_after_reset = sampler.get_state()
self.assertNotEqual(state_before_reset, state_after_reset)

0 comments on commit fa556eb

Please sign in to comment.