Skip to content

Commit a6b42a8

Browse files
refactor: simplify and refactor logprobs parallel method
1 parent 064bb3f commit a6b42a8

File tree

1 file changed

+24
-40
lines changed

1 file changed

+24
-40
lines changed

src/forge/util/ops.py

Lines changed: 24 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -124,68 +124,36 @@ def compute_logprobs_parallel(
124124
Returns:
125125
Tensor of shape [batch_size, target_len] with log probabilities.
126126
"""
127-
# Get sharding info using helper
128-
tp_group, tp_rank, tp_size, vocab_start, vocab_end = get_vocab_shard_info(logits)
127+
tp_group, _, _, vocab_start, vocab_end = get_vocab_shard_info(logits)
129128

130129
if tp_group is None:
131130
# DTensor but not sharded on vocab (Replicate or other dim sharding)
132131
return compute_logprobs(logits.full_tensor(), target_ids, temperature, align)
133132

134-
# Get the local shard
135133
local_logits = logits._local_tensor # [batch, seq_len, vocab_size / tp_size]
134+
target_len = target_ids.size(1)
136135

137-
# Align logits with target if needed
138136
if align:
139-
# Slice to match target length: logits[:, -target_len-1:-1, :]
140-
target_len = target_ids.size(1)
141137
local_logits = local_logits[:, -target_len - 1 : -1, :]
142138

143-
# Scale by temperature
144-
local_logits = local_logits / temperature
145-
146-
batch_size, seq_len, local_vocab_size = local_logits.shape
147-
148-
# Move target_ids to the same device as local_logits
149139
target_ids = target_ids.to(local_logits.device)
140+
local_logits_fp32 = local_logits.float() / temperature
150141

151-
# Cast to float32 for numerical stability
152-
local_logits_fp32 = local_logits.float()
153-
154-
# Compute global max across all shards for numerical stability
155-
local_max = local_logits_fp32.max(dim=-1, keepdim=True).values
156-
global_max = local_max.clone()
157-
dist.all_reduce(global_max, op=dist.ReduceOp.MAX, group=tp_group)
158-
159-
# Compute global sum(exp(x - max)) for the log-sum-exp trick
160-
local_exp = torch.exp(local_logits_fp32 - global_max)
161-
local_sum_exp = local_exp.sum(dim=-1, keepdim=True)
162-
global_sum_exp = local_sum_exp.clone()
163-
dist.all_reduce(global_sum_exp, op=dist.ReduceOp.SUM, group=tp_group)
142+
log_normalizer = _distributed_log_normalizer(local_logits_fp32, tp_group)
164143

165-
# log_normalizer = global_max + log(global_sum_exp)
166-
log_normalizer = global_max + torch.log(global_sum_exp) # [batch, seq, 1]
167-
log_normalizer = log_normalizer.squeeze(-1) # [batch, seq]
168-
169-
# Extract logits at target positions - each rank only has part of the vocab
144+
local_vocab_size = local_logits_fp32.shape[-1]
145+
local_indices = (target_ids - vocab_start).clamp(0, local_vocab_size - 1)
170146
is_local = (target_ids >= vocab_start) & (target_ids < vocab_end)
171147

172-
# Convert global indices to local indices (only valid where is_local=True)
173-
local_indices = target_ids - vocab_start
174-
local_indices = local_indices.clamp(0, local_vocab_size - 1) # Clamp for safety
175-
176148
target_logits = torch.gather(
177149
local_logits_fp32,
178150
dim=-1,
179151
index=local_indices.unsqueeze(-1).long(),
180152
).squeeze(-1)
181-
182-
# Zero out where this rank doesn't own the token, then reduce
183-
target_logits = target_logits * is_local.float()
153+
target_logits = target_logits.masked_fill(~is_local, 0.0)
184154
dist.all_reduce(target_logits, op=dist.ReduceOp.SUM, group=tp_group)
185155

186-
logprobs = target_logits - log_normalizer
187-
188-
return logprobs
156+
return target_logits - log_normalizer
189157

190158

191159
def get_vocab_shard_info(
@@ -219,3 +187,19 @@ def get_vocab_shard_info(
219187

220188
# Not sharded
221189
return None, 0, 1, 0, local_logits.shape[-1]
190+
191+
192+
def _distributed_log_normalizer(
193+
local_logits_fp32: torch.Tensor,
194+
tp_group: dist.ProcessGroup,
195+
) -> torch.Tensor:
196+
"""
197+
Compute logsumexp across vocab shards without materializing the full vocab.
198+
"""
199+
global_max = local_logits_fp32.max(dim=-1, keepdim=True).values
200+
dist.all_reduce(global_max, op=dist.ReduceOp.MAX, group=tp_group)
201+
202+
sum_exp = torch.exp(local_logits_fp32 - global_max).sum(dim=-1, keepdim=True)
203+
dist.all_reduce(sum_exp, op=dist.ReduceOp.SUM, group=tp_group)
204+
205+
return (global_max + torch.log(sum_exp)).squeeze(-1)

0 commit comments

Comments
 (0)