@@ -124,68 +124,36 @@ def compute_logprobs_parallel(
124124 Returns:
125125 Tensor of shape [batch_size, target_len] with log probabilities.
126126 """
127- # Get sharding info using helper
128- tp_group , tp_rank , tp_size , vocab_start , vocab_end = get_vocab_shard_info (logits )
127+ tp_group , _ , _ , vocab_start , vocab_end = get_vocab_shard_info (logits )
129128
130129 if tp_group is None :
131130 # DTensor but not sharded on vocab (Replicate or other dim sharding)
132131 return compute_logprobs (logits .full_tensor (), target_ids , temperature , align )
133132
134- # Get the local shard
135133 local_logits = logits ._local_tensor # [batch, seq_len, vocab_size / tp_size]
134+ target_len = target_ids .size (1 )
136135
137- # Align logits with target if needed
138136 if align :
139- # Slice to match target length: logits[:, -target_len-1:-1, :]
140- target_len = target_ids .size (1 )
141137 local_logits = local_logits [:, - target_len - 1 : - 1 , :]
142138
143- # Scale by temperature
144- local_logits = local_logits / temperature
145-
146- batch_size , seq_len , local_vocab_size = local_logits .shape
147-
148- # Move target_ids to the same device as local_logits
149139 target_ids = target_ids .to (local_logits .device )
140+ local_logits_fp32 = local_logits .float () / temperature
150141
151- # Cast to float32 for numerical stability
152- local_logits_fp32 = local_logits .float ()
153-
154- # Compute global max across all shards for numerical stability
155- local_max = local_logits_fp32 .max (dim = - 1 , keepdim = True ).values
156- global_max = local_max .clone ()
157- dist .all_reduce (global_max , op = dist .ReduceOp .MAX , group = tp_group )
158-
159- # Compute global sum(exp(x - max)) for the log-sum-exp trick
160- local_exp = torch .exp (local_logits_fp32 - global_max )
161- local_sum_exp = local_exp .sum (dim = - 1 , keepdim = True )
162- global_sum_exp = local_sum_exp .clone ()
163- dist .all_reduce (global_sum_exp , op = dist .ReduceOp .SUM , group = tp_group )
142+ log_normalizer = _distributed_log_normalizer (local_logits_fp32 , tp_group )
164143
165- # log_normalizer = global_max + log(global_sum_exp)
166- log_normalizer = global_max + torch .log (global_sum_exp ) # [batch, seq, 1]
167- log_normalizer = log_normalizer .squeeze (- 1 ) # [batch, seq]
168-
169- # Extract logits at target positions - each rank only has part of the vocab
144+ local_vocab_size = local_logits_fp32 .shape [- 1 ]
145+ local_indices = (target_ids - vocab_start ).clamp (0 , local_vocab_size - 1 )
170146 is_local = (target_ids >= vocab_start ) & (target_ids < vocab_end )
171147
172- # Convert global indices to local indices (only valid where is_local=True)
173- local_indices = target_ids - vocab_start
174- local_indices = local_indices .clamp (0 , local_vocab_size - 1 ) # Clamp for safety
175-
176148 target_logits = torch .gather (
177149 local_logits_fp32 ,
178150 dim = - 1 ,
179151 index = local_indices .unsqueeze (- 1 ).long (),
180152 ).squeeze (- 1 )
181-
182- # Zero out where this rank doesn't own the token, then reduce
183- target_logits = target_logits * is_local .float ()
153+ target_logits = target_logits .masked_fill (~ is_local , 0.0 )
184154 dist .all_reduce (target_logits , op = dist .ReduceOp .SUM , group = tp_group )
185155
186- logprobs = target_logits - log_normalizer
187-
188- return logprobs
156+ return target_logits - log_normalizer
189157
190158
191159def get_vocab_shard_info (
@@ -219,3 +187,19 @@ def get_vocab_shard_info(
219187
220188 # Not sharded
221189 return None , 0 , 1 , 0 , local_logits .shape [- 1 ]
190+
191+
192+ def _distributed_log_normalizer (
193+ local_logits_fp32 : torch .Tensor ,
194+ tp_group : dist .ProcessGroup ,
195+ ) -> torch .Tensor :
196+ """
197+ Compute logsumexp across vocab shards without materializing the full vocab.
198+ """
199+ global_max = local_logits_fp32 .max (dim = - 1 , keepdim = True ).values
200+ dist .all_reduce (global_max , op = dist .ReduceOp .MAX , group = tp_group )
201+
202+ sum_exp = torch .exp (local_logits_fp32 - global_max ).sum (dim = - 1 , keepdim = True )
203+ dist .all_reduce (sum_exp , op = dist .ReduceOp .SUM , group = tp_group )
204+
205+ return (global_max + torch .log (sum_exp )).squeeze (- 1 )
0 commit comments