|
5 | 5 | # LICENSE file in the root directory of this source tree. |
6 | 6 |
|
7 | 7 | import torch |
8 | | -import torch.distributed as dist |
9 | 8 | import torch.nn.functional as F |
10 | 9 | from torch.distributed.tensor import DTensor |
11 | | -from torch.distributed.tensor.placement_types import Shard |
12 | 10 |
|
13 | 11 |
|
14 | 12 | @torch.compile |
15 | 13 | def compute_logprobs( |
16 | | - logits: torch.Tensor, |
| 14 | + logits: torch.Tensor | DTensor, |
17 | 15 | input_ids: torch.Tensor, |
18 | 16 | temperature: float = 1.0, |
19 | 17 | align: bool = True, |
@@ -56,9 +54,21 @@ def compute_logprobs( |
56 | 54 | probabilities for the response portion, you don't need to re-run the model. This |
57 | 55 | is a key optimization in RL training where the prompt remains constant. |
58 | 56 |
|
| 57 | + **Tensor Parallelism Support:** |
| 58 | + When logits is a DTensor sharded on the vocab dimension (e.g., from tensor parallel |
| 59 | + training), wrap calls to this function with `loss_parallel()` context: |
| 60 | +
|
| 61 | + >>> from torch.distributed.tensor.parallel import loss_parallel |
| 62 | + >>> with loss_parallel(): |
| 63 | + ... logprobs = compute_logprobs(logits, input_ids) |
| 64 | +
|
| 65 | + The `loss_parallel` context ensures F.cross_entropy works correctly with |
| 66 | + vocab-sharded DTensors without needing to gather the full tensor. |
| 67 | +
|
59 | 68 | Args: |
60 | 69 | logits (`torch.Tensor`): |
61 | 70 | The model output logits of shape `(batch_size, sequence_length, vocab_size)`. |
| 71 | + Can be a regular Tensor or a DTensor (when using with loss_parallel context). |
62 | 72 | input_ids (`torch.Tensor`): |
63 | 73 | The target token ids of shape `(batch_size, target_sequence_length)`. |
64 | 74 | These are the tokens for which you want to compute log probabilities. |
@@ -99,127 +109,3 @@ def compute_logprobs( |
99 | 109 | ) |
100 | 110 |
|
101 | 111 | return logprobs.reshape(batch_size, seq_len) |
102 | | - |
103 | | - |
104 | | -@torch.compile |
105 | | -def compute_logprobs_parallel( |
106 | | - logits: DTensor, |
107 | | - target_ids: torch.Tensor, |
108 | | - temperature: float = 1.0, |
109 | | - align: bool = True, |
110 | | -) -> torch.Tensor: |
111 | | - """ |
112 | | - Compute log probabilities for target tokens from vocab-sharded DTensor logits. |
113 | | -
|
114 | | - This function computes log_softmax(logits)[target_ids] distributedly, |
115 | | - without ever gathering the full vocabulary dimension. |
116 | | -
|
117 | | - IMPORTANT: Only use this when logits is a DTensor sharded on vocab dimension. |
118 | | - For regular tensors or non-vocab-sharded DTensors, use compute_logprobs instead. |
119 | | -
|
120 | | - Args: |
121 | | - logits: DTensor of shape [batch_size, seq_len, vocab_size], sharded on dim=-1. |
122 | | - target_ids: Tensor of shape [batch_size, target_len] with target token IDs. |
123 | | - temperature: Temperature for scaling logits (default 1.0). |
124 | | - align: If True, slice logits to align with target_ids (default True). |
125 | | -
|
126 | | - Returns: |
127 | | - Tensor of shape [batch_size, target_len] with log probabilities. |
128 | | - """ |
129 | | - tp_group, _, _, vocab_start, vocab_end = get_vocab_shard_info(logits) |
130 | | - |
131 | | - if tp_group is None: |
132 | | - # DTensor but not sharded on vocab (Replicate or other dim sharding) |
133 | | - return compute_logprobs(logits.full_tensor(), target_ids, temperature, align) |
134 | | - |
135 | | - local_logits = logits._local_tensor # [batch, seq_len, vocab_size / tp_size] |
136 | | - |
137 | | - if align: |
138 | | - local_logits = local_logits[:, -target_ids.size(1) - 1 : -1, :] |
139 | | - |
140 | | - target_ids = target_ids.to(local_logits.device) |
141 | | - local_logits_fp32 = local_logits.float() / temperature |
142 | | - |
143 | | - log_normalizer = _distributed_log_normalizer(local_logits_fp32, tp_group) |
144 | | - |
145 | | - local_vocab_size = local_logits_fp32.shape[-1] |
146 | | - local_indices = (target_ids - vocab_start).clamp(0, local_vocab_size - 1) |
147 | | - is_local = (target_ids >= vocab_start) & (target_ids < vocab_end) |
148 | | - |
149 | | - target_logits = torch.gather( |
150 | | - local_logits_fp32, |
151 | | - dim=-1, |
152 | | - index=local_indices.unsqueeze(-1).long(), |
153 | | - ).squeeze(-1) |
154 | | - target_logits = target_logits.masked_fill(~is_local, 0.0) |
155 | | - dist.all_reduce(target_logits, op=dist.ReduceOp.SUM, group=tp_group) |
156 | | - |
157 | | - return target_logits - log_normalizer |
158 | | - |
159 | | - |
160 | | -def _get_vocab_shard_bounds( |
161 | | - vocab_size: int, tp_rank: int, tp_size: int |
162 | | -) -> tuple[int, int, int]: |
163 | | - """ |
164 | | - Return (start, end, width) for a shard when vocab dimension is unevenly split. |
165 | | - """ |
166 | | - base_shard = vocab_size // tp_size |
167 | | - remainder = vocab_size % tp_size |
168 | | - shard_width = base_shard + (1 if tp_rank < remainder else 0) |
169 | | - vocab_start = tp_rank * base_shard + min(tp_rank, remainder) |
170 | | - vocab_end = vocab_start + shard_width |
171 | | - return vocab_start, vocab_end, shard_width |
172 | | - |
173 | | - |
174 | | -def get_vocab_shard_info( |
175 | | - logits: DTensor, |
176 | | -) -> tuple[dist.ProcessGroup | None, int, int, int, int]: |
177 | | - """ |
178 | | - Get vocabulary sharding information from a DTensor. |
179 | | -
|
180 | | - Args: |
181 | | - logits: DTensor with shape [..., vocab_size], potentially sharded on vocab dim. |
182 | | -
|
183 | | - Returns: |
184 | | - Tuple of (tp_group, tp_rank, tp_size, vocab_start, vocab_end). |
185 | | - If not sharded, returns (None, 0, 1, 0, vocab_size). |
186 | | - """ |
187 | | - local_logits = logits._local_tensor |
188 | | - placements = logits.placements |
189 | | - device_mesh = logits.device_mesh |
190 | | - global_vocab_size = logits.shape[-1] |
191 | | - |
192 | | - for i, p in enumerate(placements): |
193 | | - if isinstance(p, Shard) and p.dim == 2: # vocab dimension |
194 | | - tp_group = device_mesh.get_group(mesh_dim=i) |
195 | | - tp_size = dist.get_world_size(tp_group) |
196 | | - tp_rank = dist.get_rank(tp_group) |
197 | | - vocab_start, vocab_end, shard_width = _get_vocab_shard_bounds( |
198 | | - global_vocab_size, tp_rank, tp_size |
199 | | - ) |
200 | | - local_vocab_size = local_logits.shape[-1] |
201 | | - if local_vocab_size != shard_width: |
202 | | - raise ValueError( |
203 | | - "DTensor local shard width does not match inferred shard size " |
204 | | - f"(rank={tp_rank}, local={local_vocab_size}, expected={shard_width})" |
205 | | - ) |
206 | | - return tp_group, tp_rank, tp_size, vocab_start, vocab_end |
207 | | - |
208 | | - # Not sharded |
209 | | - return None, 0, 1, 0, global_vocab_size |
210 | | - |
211 | | - |
212 | | -def _distributed_log_normalizer( |
213 | | - local_logits_fp32: torch.Tensor, |
214 | | - tp_group: dist.ProcessGroup, |
215 | | -) -> torch.Tensor: |
216 | | - """ |
217 | | - Compute logsumexp across vocab shards without materializing the full vocab. |
218 | | - """ |
219 | | - global_max = local_logits_fp32.max(dim=-1, keepdim=True).values |
220 | | - dist.all_reduce(global_max, op=dist.ReduceOp.MAX, group=tp_group) |
221 | | - |
222 | | - sum_exp = torch.exp(local_logits_fp32 - global_max).sum(dim=-1, keepdim=True) |
223 | | - dist.all_reduce(sum_exp, op=dist.ReduceOp.SUM, group=tp_group) |
224 | | - |
225 | | - return (global_max + torch.log(sum_exp)).squeeze(-1) |
0 commit comments