Skip to content

Commit c0aedca

Browse files
refactor: Simplify sharded logprobs computation using loss_parallel context
1 parent 805bbf7 commit c0aedca

File tree

3 files changed

+37
-489
lines changed

3 files changed

+37
-489
lines changed

src/forge/actors/reference_model.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,10 @@
1717
from forge.controller import ForgeActor
1818
from forge.observability.metrics import record_metric, Reduce
1919
from forge.observability.perf_tracker import Tracer
20-
from forge.util.ops import compute_logprobs, compute_logprobs_parallel
20+
from forge.util.ops import compute_logprobs
2121
from monarch.actor import current_rank, current_size, endpoint
2222
from torch.distributed.tensor import DTensor
23+
from torch.distributed.tensor.parallel import loss_parallel
2324

2425
from torchtitan.config.job_config import (
2526
Checkpoint,
@@ -177,19 +178,16 @@ async def forward(
177178
t.step("forward")
178179

179180
if not return_logprobs:
180-
# Only gather full tensor when returning raw logits
181181
if isinstance(logits, DTensor):
182182
logits = logits.full_tensor()
183183
t.stop()
184184
return logits
185185
else:
186-
# Compute logprobs in parallel without gathering full vocab tensor
187-
# Use parallel version when TP is enabled (vocab sharded across GPUs)
186+
# When TP is enabled, use loss_parallel context for vocab-sharded DTensors
188187
response_tokens = input_ids[:, max_req_tokens:]
189188
if parallel_dims.tp_enabled and isinstance(logits, DTensor):
190-
logprobs = compute_logprobs_parallel(
191-
logits, response_tokens, align=True
192-
)
189+
with loss_parallel():
190+
logprobs = compute_logprobs(logits, response_tokens, align=True)
193191
else:
194192
logprobs = compute_logprobs(logits, response_tokens)
195193
t.step("compute_logprobs")

src/forge/util/ops.py

Lines changed: 13 additions & 127 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,13 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import torch
8-
import torch.distributed as dist
98
import torch.nn.functional as F
109
from torch.distributed.tensor import DTensor
11-
from torch.distributed.tensor.placement_types import Shard
1210

1311

1412
@torch.compile
1513
def compute_logprobs(
16-
logits: torch.Tensor,
14+
logits: torch.Tensor | DTensor,
1715
input_ids: torch.Tensor,
1816
temperature: float = 1.0,
1917
align: bool = True,
@@ -56,9 +54,21 @@ def compute_logprobs(
5654
probabilities for the response portion, you don't need to re-run the model. This
5755
is a key optimization in RL training where the prompt remains constant.
5856
57+
**Tensor Parallelism Support:**
58+
When logits is a DTensor sharded on the vocab dimension (e.g., from tensor parallel
59+
training), wrap calls to this function with `loss_parallel()` context:
60+
61+
>>> from torch.distributed.tensor.parallel import loss_parallel
62+
>>> with loss_parallel():
63+
... logprobs = compute_logprobs(logits, input_ids)
64+
65+
The `loss_parallel` context ensures F.cross_entropy works correctly with
66+
vocab-sharded DTensors without needing to gather the full tensor.
67+
5968
Args:
6069
logits (`torch.Tensor`):
6170
The model output logits of shape `(batch_size, sequence_length, vocab_size)`.
71+
Can be a regular Tensor or a DTensor (when using with loss_parallel context).
6272
input_ids (`torch.Tensor`):
6373
The target token ids of shape `(batch_size, target_sequence_length)`.
6474
These are the tokens for which you want to compute log probabilities.
@@ -99,127 +109,3 @@ def compute_logprobs(
99109
)
100110

101111
return logprobs.reshape(batch_size, seq_len)
102-
103-
104-
@torch.compile
105-
def compute_logprobs_parallel(
106-
logits: DTensor,
107-
target_ids: torch.Tensor,
108-
temperature: float = 1.0,
109-
align: bool = True,
110-
) -> torch.Tensor:
111-
"""
112-
Compute log probabilities for target tokens from vocab-sharded DTensor logits.
113-
114-
This function computes log_softmax(logits)[target_ids] distributedly,
115-
without ever gathering the full vocabulary dimension.
116-
117-
IMPORTANT: Only use this when logits is a DTensor sharded on vocab dimension.
118-
For regular tensors or non-vocab-sharded DTensors, use compute_logprobs instead.
119-
120-
Args:
121-
logits: DTensor of shape [batch_size, seq_len, vocab_size], sharded on dim=-1.
122-
target_ids: Tensor of shape [batch_size, target_len] with target token IDs.
123-
temperature: Temperature for scaling logits (default 1.0).
124-
align: If True, slice logits to align with target_ids (default True).
125-
126-
Returns:
127-
Tensor of shape [batch_size, target_len] with log probabilities.
128-
"""
129-
tp_group, _, _, vocab_start, vocab_end = get_vocab_shard_info(logits)
130-
131-
if tp_group is None:
132-
# DTensor but not sharded on vocab (Replicate or other dim sharding)
133-
return compute_logprobs(logits.full_tensor(), target_ids, temperature, align)
134-
135-
local_logits = logits._local_tensor # [batch, seq_len, vocab_size / tp_size]
136-
137-
if align:
138-
local_logits = local_logits[:, -target_ids.size(1) - 1 : -1, :]
139-
140-
target_ids = target_ids.to(local_logits.device)
141-
local_logits_fp32 = local_logits.float() / temperature
142-
143-
log_normalizer = _distributed_log_normalizer(local_logits_fp32, tp_group)
144-
145-
local_vocab_size = local_logits_fp32.shape[-1]
146-
local_indices = (target_ids - vocab_start).clamp(0, local_vocab_size - 1)
147-
is_local = (target_ids >= vocab_start) & (target_ids < vocab_end)
148-
149-
target_logits = torch.gather(
150-
local_logits_fp32,
151-
dim=-1,
152-
index=local_indices.unsqueeze(-1).long(),
153-
).squeeze(-1)
154-
target_logits = target_logits.masked_fill(~is_local, 0.0)
155-
dist.all_reduce(target_logits, op=dist.ReduceOp.SUM, group=tp_group)
156-
157-
return target_logits - log_normalizer
158-
159-
160-
def _get_vocab_shard_bounds(
161-
vocab_size: int, tp_rank: int, tp_size: int
162-
) -> tuple[int, int, int]:
163-
"""
164-
Return (start, end, width) for a shard when vocab dimension is unevenly split.
165-
"""
166-
base_shard = vocab_size // tp_size
167-
remainder = vocab_size % tp_size
168-
shard_width = base_shard + (1 if tp_rank < remainder else 0)
169-
vocab_start = tp_rank * base_shard + min(tp_rank, remainder)
170-
vocab_end = vocab_start + shard_width
171-
return vocab_start, vocab_end, shard_width
172-
173-
174-
def get_vocab_shard_info(
175-
logits: DTensor,
176-
) -> tuple[dist.ProcessGroup | None, int, int, int, int]:
177-
"""
178-
Get vocabulary sharding information from a DTensor.
179-
180-
Args:
181-
logits: DTensor with shape [..., vocab_size], potentially sharded on vocab dim.
182-
183-
Returns:
184-
Tuple of (tp_group, tp_rank, tp_size, vocab_start, vocab_end).
185-
If not sharded, returns (None, 0, 1, 0, vocab_size).
186-
"""
187-
local_logits = logits._local_tensor
188-
placements = logits.placements
189-
device_mesh = logits.device_mesh
190-
global_vocab_size = logits.shape[-1]
191-
192-
for i, p in enumerate(placements):
193-
if isinstance(p, Shard) and p.dim == 2: # vocab dimension
194-
tp_group = device_mesh.get_group(mesh_dim=i)
195-
tp_size = dist.get_world_size(tp_group)
196-
tp_rank = dist.get_rank(tp_group)
197-
vocab_start, vocab_end, shard_width = _get_vocab_shard_bounds(
198-
global_vocab_size, tp_rank, tp_size
199-
)
200-
local_vocab_size = local_logits.shape[-1]
201-
if local_vocab_size != shard_width:
202-
raise ValueError(
203-
"DTensor local shard width does not match inferred shard size "
204-
f"(rank={tp_rank}, local={local_vocab_size}, expected={shard_width})"
205-
)
206-
return tp_group, tp_rank, tp_size, vocab_start, vocab_end
207-
208-
# Not sharded
209-
return None, 0, 1, 0, global_vocab_size
210-
211-
212-
def _distributed_log_normalizer(
213-
local_logits_fp32: torch.Tensor,
214-
tp_group: dist.ProcessGroup,
215-
) -> torch.Tensor:
216-
"""
217-
Compute logsumexp across vocab shards without materializing the full vocab.
218-
"""
219-
global_max = local_logits_fp32.max(dim=-1, keepdim=True).values
220-
dist.all_reduce(global_max, op=dist.ReduceOp.MAX, group=tp_group)
221-
222-
sum_exp = torch.exp(local_logits_fp32 - global_max).sum(dim=-1, keepdim=True)
223-
dist.all_reduce(sum_exp, op=dist.ReduceOp.SUM, group=tp_group)
224-
225-
return (global_max + torch.log(sum_exp)).squeeze(-1)

0 commit comments

Comments
 (0)