Skip to content

Commit 3d4342b

Browse files
fix: safely handle edgecase of uneven vocab shards in tensor-parallel logprobs
1 parent 352175e commit 3d4342b

File tree

2 files changed

+83
-3
lines changed

2 files changed

+83
-3
lines changed

src/forge/util/ops.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,20 @@ def compute_logprobs_parallel(
155155
return target_logits - log_normalizer
156156

157157

158+
def _get_vocab_shard_bounds(
159+
vocab_size: int, tp_rank: int, tp_size: int
160+
) -> tuple[int, int, int]:
161+
"""
162+
Return (start, end, width) for a shard when vocab dimension is unevenly split.
163+
"""
164+
base_shard = vocab_size // tp_size
165+
remainder = vocab_size % tp_size
166+
shard_width = base_shard + (1 if tp_rank < remainder else 0)
167+
vocab_start = tp_rank * base_shard + min(tp_rank, remainder)
168+
vocab_end = vocab_start + shard_width
169+
return vocab_start, vocab_end, shard_width
170+
171+
158172
def get_vocab_shard_info(
159173
logits: DTensor,
160174
) -> tuple[dist.ProcessGroup | None, int, int, int, int]:
@@ -171,19 +185,26 @@ def get_vocab_shard_info(
171185
local_logits = logits._local_tensor
172186
placements = logits.placements
173187
device_mesh = logits.device_mesh
188+
global_vocab_size = logits.shape[-1]
174189

175190
for i, p in enumerate(placements):
176191
if isinstance(p, Shard) and p.dim == 2: # vocab dimension
177192
tp_group = device_mesh.get_group(mesh_dim=i)
178193
tp_size = dist.get_world_size(tp_group)
179194
tp_rank = dist.get_rank(tp_group)
195+
vocab_start, vocab_end, shard_width = _get_vocab_shard_bounds(
196+
global_vocab_size, tp_rank, tp_size
197+
)
180198
local_vocab_size = local_logits.shape[-1]
181-
vocab_start = tp_rank * local_vocab_size
182-
vocab_end = vocab_start + local_vocab_size
199+
if local_vocab_size != shard_width:
200+
raise ValueError(
201+
"DTensor local shard width does not match inferred shard size "
202+
f"(rank={tp_rank}, local={local_vocab_size}, expected={shard_width})"
203+
)
183204
return tp_group, tp_rank, tp_size, vocab_start, vocab_end
184205

185206
# Not sharded
186-
return None, 0, 1, 0, local_logits.shape[-1]
207+
return None, 0, 1, 0, global_vocab_size
187208

188209

189210
def _distributed_log_normalizer(

tests/unit_tests/util/test_ops.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,65 @@ def test_parallel_logprobs_align_false(self):
350350
msg="Parallel logprobs with align=False should match",
351351
)
352352

353+
@gpu_test(gpu_count=2)
354+
def test_parallel_logprobs_uneven_vocab_shards(self):
355+
"""Ensure uneven vocab shards still produce correct logprobs."""
356+
torch.manual_seed(321)
357+
358+
batch_size = 2
359+
seq_len = 12
360+
vocab_size = 1001 # Not divisible by world_size
361+
target_len = 6
362+
363+
rank = dist.get_rank()
364+
device = torch.device(f"cuda:{rank}")
365+
366+
if rank == 0:
367+
full_logits = torch.randn(
368+
batch_size, seq_len, vocab_size, dtype=torch.float32, device=device
369+
)
370+
target_ids = torch.randint(
371+
0, vocab_size, (batch_size, target_len), device=device
372+
)
373+
else:
374+
full_logits = torch.empty(
375+
batch_size, seq_len, vocab_size, dtype=torch.float32, device=device
376+
)
377+
target_ids = torch.empty(
378+
batch_size, target_len, dtype=torch.int64, device=device
379+
)
380+
381+
dist.broadcast(full_logits, src=0)
382+
dist.broadcast(target_ids, src=0)
383+
384+
expected = compute_logprobs(full_logits, target_ids, align=True)
385+
386+
mesh = init_device_mesh("cuda", (self.world_size,), mesh_dim_names=("tp",))
387+
base_shard = vocab_size // self.world_size
388+
remainder = vocab_size % self.world_size
389+
extra = 1 if rank < remainder else 0
390+
vocab_start = rank * base_shard + min(rank, remainder)
391+
vocab_end = vocab_start + base_shard + extra
392+
local_slice = full_logits[:, :, vocab_start:vocab_end].contiguous()
393+
394+
dtensor_logits = DTensor.from_local(
395+
local_slice,
396+
mesh,
397+
placements=[Shard(2)],
398+
shape=torch.Size((batch_size, seq_len, vocab_size)),
399+
stride=full_logits.stride(),
400+
)
401+
402+
result = compute_logprobs_parallel(dtensor_logits, target_ids, align=True)
403+
404+
torch.testing.assert_close(
405+
result,
406+
expected,
407+
atol=1e-5,
408+
rtol=1e-5,
409+
msg="Parallel logprobs should support uneven vocab shards",
410+
)
411+
353412
@gpu_test(gpu_count=2)
354413
def test_parallel_logprobs_numerical_stability(self):
355414
"""Test parallel logprobs handles extreme values correctly."""

0 commit comments

Comments
 (0)