Skip to content

Commit 8a9369f

Browse files
Victor Bourginfacebook-github-bot
authored andcommitted
Migrate sampler_test.py from pytest to unittest (#1077)
Summary: Pull Request resolved: #1077 Reviewed By: moto-meta Differential Revision: D85999483
1 parent 2a1c9a3 commit 8a9369f

File tree

1 file changed

+189
-170
lines changed

1 file changed

+189
-170
lines changed

tests/spdl_unittest/dataloader/sampler_test.py

Lines changed: 189 additions & 170 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,13 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import unittest
78
from collections import Counter
89
from functools import partial
10+
from typing import Optional
911

1012
import numpy as np
11-
import pytest
13+
from parameterized import parameterized
1214
from spdl.pipeline import iterate_in_subprocess
1315
from spdl.source import (
1416
DistributedDeterministicSampler,
@@ -21,195 +23,212 @@
2123
# pyre-unsafe
2224

2325

24-
def testDistributedsampler_interface():
25-
"""samplers conform to Iterable/IterableWithShuffle protocol"""
26-
assert isinstance(
27-
DistributedRandomSampler(9, rank=0, world_size=1), SizedIterableWithShuffle
28-
)
29-
assert isinstance(
30-
DistributedDeterministicSampler(9, rank=0, world_size=1), SizedIterable
31-
)
32-
33-
34-
def testDistributedsamplerdeterministic_iter():
35-
"""without distributed, deterministic iteration behaves same as `range(N)`"""
36-
N = 30
37-
sampler = DistributedDeterministicSampler(N, rank=0, world_size=1)
38-
assert len(sampler) == N
39-
assert list(sampler) == list(range(N))
26+
class TestDistributedSamplerInterface(unittest.TestCase):
27+
def test_distributed_sampler_interface(self) -> None:
28+
"""samplers conform to Iterable/IterableWithShuffle protocol"""
29+
self.assertIsInstance(
30+
DistributedRandomSampler(9, rank=0, world_size=1), SizedIterableWithShuffle
31+
)
32+
self.assertIsInstance(
33+
DistributedDeterministicSampler(9, rank=0, world_size=1), SizedIterable
34+
)
35+
36+
37+
class TestDistributedSamplerDeterministic(unittest.TestCase):
38+
def test_deterministic_iter(self) -> None:
39+
"""without distributed, deterministic iteration behaves same as `range(N)`"""
40+
N = 30
41+
sampler = DistributedDeterministicSampler(N, rank=0, world_size=1)
42+
self.assertEqual(len(sampler), N)
43+
self.assertEqual(list(sampler), list(range(N)))
44+
45+
def test_deterministic_iter_distributed(self) -> None:
46+
"""deterministic iteration behaves same as `range(rank, M, world_size)`"""
47+
N = 26
48+
for world_size in range(1, N + 1):
49+
len_ = N // world_size
50+
max_ = len_ * world_size
51+
c = Counter()
52+
for rank in range(world_size):
53+
print(f"{N=}, {world_size=}, {rank=}, {len_=}, {max_=}")
54+
sampler = DistributedDeterministicSampler(
55+
N, rank=rank, world_size=world_size
56+
)
57+
self.assertEqual(len(sampler), len_)
58+
59+
indices = list(sampler)
60+
self.assertEqual(indices, list(range(rank, max_, world_size)))
61+
c.update(indices)
62+
63+
# Check that together, the samplers covered the whole dataset
64+
num_iters = N // world_size * world_size
65+
self.assertEqual(c.total(), num_iters)
66+
self.assertEqual(len(c.keys()), num_iters)
67+
self.assertEqual(set(c.keys()), set(range(num_iters)))
68+
self.assertTrue(all(v == 1 for v in c.values()))
69+
70+
71+
class TestDistributedSamplerRandom(unittest.TestCase):
72+
def test_shuffle(self) -> None:
73+
"""shuffling makes sampler generates different indices."""
74+
N = 640
75+
rank = 3
76+
world_size = 8
4077

41-
42-
def testDistributedsamplerdeterministic_iterDistributed():
43-
"""deterministic iteration behaves same as `range(rank, M, world_size)`"""
44-
N = 26
45-
for world_size in range(1, N + 1):
46-
len_ = N // world_size
47-
max_ = len_ * world_size
48-
c = Counter()
49-
for rank in range(world_size):
50-
print(f"{N=}, {world_size=}, {rank=}, {len_=}, {max_=}")
51-
sampler = DistributedDeterministicSampler(
52-
N, rank=rank, world_size=world_size
53-
)
54-
assert len(sampler) == len_
55-
56-
indices = list(sampler)
57-
assert indices == list(range(rank, max_, world_size))
58-
c.update(indices)
59-
60-
# Check that together, the samplers covered the whole dataset
61-
num_iters = N // world_size * world_size
62-
assert c.total() == num_iters
63-
assert len(c.keys()) == num_iters
64-
assert set(c.keys()) == set(range(num_iters))
65-
assert all(v == 1 for v in c.values())
66-
67-
68-
def testDistributedsampler_shuffle():
69-
"""shuffling makes sampler generates different indices."""
70-
N = 640
71-
rank = 3
72-
world_size = 8
73-
74-
previous = []
75-
for epoch in range(100):
76-
sampler = DistributedRandomSampler(N, rank=rank, world_size=world_size)
77-
sampler.shuffle(seed=epoch)
78-
79-
indices = list(sampler)
80-
print(f"{indices=}")
81-
assert indices != previous
82-
previous = indices
83-
84-
85-
@pytest.mark.parametrize("w", [None, 1])
86-
def testDistributedsampler_repeat(w):
87-
"""Without calling shuffle, sampler generates the same sequence."""
88-
N = 40
89-
world_size = 8
90-
91-
weights = None if w is None else [1] * N
92-
for rank in range(world_size):
9378
previous = []
94-
for i in range(100):
95-
sampler = DistributedRandomSampler(
96-
N, rank=rank, world_size=world_size, weights=weights
97-
)
79+
for epoch in range(100):
80+
sampler = DistributedRandomSampler(N, rank=rank, world_size=world_size)
81+
sampler.shuffle(seed=epoch)
9882

9983
indices = list(sampler)
10084
print(f"{indices=}")
101-
if i > 0:
102-
assert indices == previous
85+
self.assertNotEqual(indices, previous)
10386
previous = indices
10487

88+
@parameterized.expand(
89+
[
90+
(None,),
91+
(1,),
92+
]
93+
)
94+
def test_repeat(self, w: Optional[int]) -> None:
95+
"""Without calling shuffle, sampler generates the same sequence."""
96+
N = 40
97+
world_size = 8
10598

106-
@pytest.mark.parametrize("shuffle", [True, False])
107-
def testDistributedsampler_mutual_exclusive(shuffle):
108-
"""Without weights, samplers generate mutually exclusive sets"""
109-
N = 640
110-
world_size = 8
111-
112-
for epoch in range(100):
113-
c = Counter()
114-
for rank in range(world_size):
115-
sampler = DistributedRandomSampler(N, rank=rank, world_size=world_size)
116-
if shuffle:
117-
sampler.shuffle(seed=epoch)
118-
c.update(sampler)
119-
120-
assert c.total() == N
121-
assert len(c.keys()) == N
122-
assert set(c.keys()) == set(range(N))
123-
assert all(v == 1 for v in c.values())
124-
125-
126-
@pytest.mark.parametrize("shuffle", [True, False])
127-
def testDistributedsampler_mutual_exclusive_num_draws(shuffle):
128-
"""Without weights, samplers generate mutually exclusive sets"""
129-
N = 640
130-
num_draws = 321
131-
world_size = 8
132-
133-
for epoch in range(100):
134-
c = Counter()
99+
weights = None if w is None else [1.0] * N
135100
for rank in range(world_size):
136-
sampler = DistributedRandomSampler(
137-
N, rank=rank, world_size=world_size, num_draws=num_draws
138-
)
139-
if shuffle:
140-
sampler.shuffle(seed=epoch)
141-
c.update(sampler)
142-
143-
m = num_draws // world_size * world_size
144-
assert c.total() == m
145-
assert len(c.keys()) == m
146-
assert all(v == 1 for v in c.values())
147-
148-
149-
def testDistributedsampler_weighted_sampling():
150-
"""Indices are drawn according to the weights"""
151-
weights = [0, 1, 3, 5, 10]
152-
N = len(weights)
153-
154-
sampler = DistributedRandomSampler(
155-
N, rank=0, world_size=1, num_draws=1_000_000, weights=weights
101+
previous = []
102+
for i in range(100):
103+
sampler = DistributedRandomSampler(
104+
N, rank=rank, world_size=world_size, weights=weights
105+
)
106+
107+
indices = list(sampler)
108+
print(f"{indices=}")
109+
if i > 0:
110+
self.assertEqual(indices, previous)
111+
previous = indices
112+
113+
@parameterized.expand(
114+
[
115+
(True,),
116+
(False,),
117+
]
156118
)
119+
def test_mutual_exclusive(self, shuffle: bool) -> None:
120+
"""Without weights, samplers generate mutually exclusive sets"""
121+
N = 640
122+
world_size = 8
123+
124+
for epoch in range(100):
125+
c = Counter()
126+
for rank in range(world_size):
127+
sampler = DistributedRandomSampler(N, rank=rank, world_size=world_size)
128+
if shuffle:
129+
sampler.shuffle(seed=epoch)
130+
c.update(sampler)
131+
132+
self.assertEqual(c.total(), N)
133+
self.assertEqual(len(c.keys()), N)
134+
self.assertEqual(set(c.keys()), set(range(N)))
135+
self.assertTrue(all(v == 1 for v in c.values()))
136+
137+
@parameterized.expand(
138+
[
139+
(True,),
140+
(False,),
141+
]
142+
)
143+
def test_mutual_exclusive_num_draws(self, shuffle: bool) -> None:
144+
"""Without weights, samplers generate mutually exclusive sets"""
145+
N = 640
146+
num_draws = 321
147+
world_size = 8
148+
149+
for epoch in range(100):
150+
c = Counter()
151+
for rank in range(world_size):
152+
sampler = DistributedRandomSampler(
153+
N, rank=rank, world_size=world_size, num_draws=num_draws
154+
)
155+
if shuffle:
156+
sampler.shuffle(seed=epoch)
157+
c.update(sampler)
158+
159+
m = num_draws // world_size * world_size
160+
self.assertEqual(c.total(), m)
161+
self.assertEqual(len(c.keys()), m)
162+
self.assertTrue(all(v == 1 for v in c.values()))
163+
164+
165+
class TestDistributedSamplerWeighted(unittest.TestCase):
166+
def test_weighted_sampling(self) -> None:
167+
"""Indices are drawn according to the weights"""
168+
weights = [0.0, 1.0, 3.0, 5.0, 10.0]
169+
N = len(weights)
170+
171+
sampler = DistributedRandomSampler(
172+
N, rank=0, world_size=1, num_draws=1_000_000, weights=weights
173+
)
174+
175+
c = Counter(sampler)
176+
distribution = [c[i] for i in range(N)]
177+
178+
print(f"{weights=}")
179+
print(f"{distribution=}")
180+
181+
ref = np.asarray(weights) / np.sum(weights)
182+
hyp = np.asarray(distribution) / np.sum(distribution)
157183

158-
c = Counter(sampler)
159-
distribution = [c[i] for i in range(N)]
160-
161-
print(f"{weights=}")
162-
print(f"{distribution=}")
163-
164-
ref = np.asarray(weights) / np.sum(weights)
165-
hyp = np.asarray(distribution) / np.sum(distribution)
166-
167-
print(f"{ref=}")
168-
print(f"{hyp=}")
169-
170-
assert np.allclose(hyp, ref, atol=1e-3)
184+
print(f"{ref=}")
185+
print(f"{hyp=}")
171186

187+
self.assertTrue(np.allclose(hyp, ref, atol=1e-3))
172188

173-
def testDistributedsampler_embed_shuffle():
174-
"""DistributedSampler is compatibile with embed_shuffle"""
175-
N = 10
176-
weights = [1 for _ in range(N)]
177189

178-
s0 = DistributedRandomSampler(N, rank=0, world_size=1, weights=weights)
179-
s1 = DistributedRandomSampler(N, rank=0, world_size=1, weights=weights)
190+
class TestDistributedSamplerEmbedShuffle(unittest.TestCase):
191+
def test_embed_shuffle(self) -> None:
192+
"""DistributedSampler is compatibile with embed_shuffle"""
193+
N = 10
194+
weights = [1.0 for _ in range(N)]
180195

181-
s1 = embed_shuffle(s1)
196+
s0 = DistributedRandomSampler(N, rank=0, world_size=1, weights=weights)
197+
s1 = DistributedRandomSampler(N, rank=0, world_size=1, weights=weights)
182198

183-
previous = []
184-
for i in range(100):
185-
hyp = list(s1)
186-
print(f"{hyp=}")
199+
s1 = embed_shuffle(s1)
187200

188-
s0.shuffle(i)
189-
ref = list(s0)
190-
print(f"{ref=}")
201+
previous = []
202+
for i in range(100):
203+
hyp = list(s1)
204+
print(f"{hyp=}")
191205

192-
assert hyp == ref
193-
assert hyp != previous
194-
previous = hyp
206+
s0.shuffle(i)
207+
ref = list(s0)
208+
print(f"{ref=}")
195209

210+
self.assertEqual(hyp, ref)
211+
self.assertNotEqual(hyp, previous)
212+
previous = hyp
196213

197-
def testDistributedsampler_iterate_in_subprocess():
198-
"""Iterating in a subprocess generates identical result"""
199-
N = 10
200-
weights = [1 for _ in range(N)]
201214

202-
sampler = DistributedRandomSampler(N, rank=0, world_size=1, weights=weights)
203-
sampler_sub = iterate_in_subprocess(partial(embed_shuffle, sampler))
204-
sampler = embed_shuffle(sampler)
215+
class TestDistributedSamplerIterateInSubprocess(unittest.TestCase):
216+
def test_iterate_in_subprocess(self) -> None:
217+
"""Iterating in a subprocess generates identical result"""
218+
N = 10
219+
weights = [1.0 for _ in range(N)]
205220

206-
previous = []
207-
for _ in range(100):
208-
hyp = list(sampler_sub)
209-
print(f"{hyp=}")
210-
ref = list(sampler)
211-
print(f"{ref=}")
221+
sampler = DistributedRandomSampler(N, rank=0, world_size=1, weights=weights)
222+
sampler_sub = iterate_in_subprocess(partial(embed_shuffle, sampler))
223+
sampler = embed_shuffle(sampler)
212224

213-
assert hyp == ref
214-
assert hyp != previous
215-
previous = hyp
225+
previous = []
226+
for _ in range(100):
227+
hyp = list(sampler_sub)
228+
print(f"{hyp=}")
229+
ref = list(sampler)
230+
print(f"{ref=}")
231+
232+
self.assertEqual(hyp, ref)
233+
self.assertNotEqual(hyp, previous)
234+
previous = hyp

0 commit comments

Comments
 (0)