Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
ea74704
feat: Add parallel log-prob computation for vocab-sharded tensors in …
gitlost-murali Nov 30, 2025
a5b0348
refactor: use util function
gitlost-murali Nov 30, 2025
8b5d6bb
refactor: remove redundant code and reuse compute_logprobs function
gitlost-murali Nov 30, 2025
72fa4c3
test: add test to verify correctness with compute_logprobs
gitlost-murali Nov 30, 2025
540253d
chore: make routing explicit
gitlost-murali Nov 30, 2025
f656334
style: clean up inline comments in parallel_logprobs
gitlost-murali Nov 30, 2025
309ba5c
refactor: move compute_logprobs_parallel into ops alongside compute_l…
gitlost-murali Dec 1, 2025
c0f21ca
refactor: simplify and refactor logprobs parallel method
gitlost-murali Dec 1, 2025
72f58ce
test: move parallel logprobs test to test_ops.py reflecting folder st…
gitlost-murali Dec 1, 2025
59b8799
chore: merge single use declarations
gitlost-murali Dec 1, 2025
ec6d0ae
fix: safely handle edgecase of uneven vocab shards in tensor-parallel…
gitlost-murali Dec 1, 2025
c79740a
fix: fix compute_logprobs_parallel import
gitlost-murali Dec 1, 2025
805bbf7
feat: add torch compile to further reduce the reference model GPU usa…
gitlost-murali Dec 5, 2025
c0aedca
refactor: Simplify sharded logprobs computation using loss_parallel c…
gitlost-murali Dec 5, 2025
90120bb
fix: convert DTensor output to regular tensor after loss_parallel
gitlost-murali Dec 5, 2025
f92b503
refactor: make compile configurable for logprobs computation
gitlost-murali Dec 6, 2025
bc7be03
refactor: remove redundant logging steps as on main
gitlost-murali Dec 8, 2025
459b285
fix: ReferenceModel to properly handle DTensor logits when TP is disa…
gitlost-murali Dec 9, 2025
e21b373
add comment and change unit test to to_local
Dec 9, 2025
e46801e
Merge branch 'main' of https://github.com/meta-pytorch/forge into opt…
Dec 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions src/forge/actors/reference_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from forge.util.ops import compute_logprobs
from monarch.actor import current_rank, current_size, endpoint
from torch.distributed.tensor import DTensor
from torch.distributed.tensor.parallel import loss_parallel

from torchtitan.config.job_config import (
Checkpoint,
Expand Down Expand Up @@ -98,6 +99,10 @@ def __post_init__(self):
self.rank = current_rank().rank
self.size = math.prod(current_size().values())

self.compute_log_probs = compute_logprobs
if self.compile.enable:
self.compute_log_probs = torch.compile(self.compute_log_probs)

env = {
"RANK": str(self.rank),
"LOCAL_RANK": str(self.rank),
Expand Down Expand Up @@ -174,13 +179,23 @@ async def forward(
with torch.inference_mode():
logits = self.model(input_ids)
self.step += 1
if isinstance(logits, DTensor):
logits = logits.full_tensor()

if not return_logprobs:
if isinstance(logits, DTensor):
logits = logits.full_tensor()
t.stop()
return logits
else:
logprobs = compute_logprobs(logits, input_ids[:, max_req_tokens:])
response_tokens = input_ids[:, max_req_tokens:]
if parallel_dims.tp_enabled and isinstance(logits, DTensor):
with loss_parallel():
logprobs = self.compute_log_probs(logits, response_tokens)

# loss_parallel produces Replicated output - to_local() returns the full tensor
logprobs = logprobs.to_local()
else:
if isinstance(logits, DTensor):
logits = logits.full_tensor()
logprobs = self.compute_log_probs(logits, response_tokens)
t.stop()
return logprobs
15 changes: 14 additions & 1 deletion src/forge/util/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@

import torch
import torch.nn.functional as F
from torch.distributed.tensor import DTensor


def compute_logprobs(
logits: torch.Tensor,
logits: torch.Tensor | DTensor,
input_ids: torch.Tensor,
temperature: float = 1.0,
align: bool = True,
Expand Down Expand Up @@ -52,9 +53,21 @@ def compute_logprobs(
probabilities for the response portion, you don't need to re-run the model. This
is a key optimization in RL training where the prompt remains constant.

**Tensor Parallelism Support:**
When logits is a DTensor sharded on the vocab dimension (e.g., from tensor parallel
training), wrap calls to this function with `loss_parallel()` context:

>>> from torch.distributed.tensor.parallel import loss_parallel
>>> with loss_parallel():
... logprobs = compute_logprobs(logits, input_ids)

The `loss_parallel` context ensures F.cross_entropy works correctly with
vocab-sharded DTensors without needing to gather the full tensor.

Args:
logits (`torch.Tensor`):
The model output logits of shape `(batch_size, sequence_length, vocab_size)`.
Can be a regular Tensor or a DTensor (when using with loss_parallel context).
input_ids (`torch.Tensor`):
The target token ids of shape `(batch_size, target_sequence_length)`.
These are the tokens for which you want to compute log probabilities.
Expand Down
65 changes: 65 additions & 0 deletions tests/unit_tests/util/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,17 @@

import pytest
import torch
import torch.distributed as dist
import torch.nn.functional as F

from forge.util.ops import compute_logprobs

from tests.test_utils import gpu_test
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor import DTensor, Shard
from torch.distributed.tensor.parallel import loss_parallel
from torch.testing._internal.common_fsdp import FSDPTest


def _textbook_log_softmax(logits: torch.Tensor, input_ids: torch.Tensor):
# Helper: Textbook Log Softmax
Expand Down Expand Up @@ -162,3 +170,60 @@ def test_align_comparison(self):

# Both should give the same result
assert torch.allclose(result_aligned, result_manual, atol=1e-5)


class TestComputeLogprobsWithLossParallel(FSDPTest):
"""Test compute_logprobs with loss_parallel context for vocab-sharded DTensors."""

@property
def world_size(self) -> int:
return 2

@gpu_test(gpu_count=2)
def test_loss_parallel_matches_sequential(self):
"""Verify compute_logprobs under loss_parallel matches non-sharded version."""
torch.manual_seed(42)

batch_size, seq_len, vocab_size, target_len = 4, 16, 1000, 8
rank = dist.get_rank()
device = torch.device(f"cuda:{rank}")

# Create and broadcast test data
if rank == 0:
full_logits = torch.randn(batch_size, seq_len, vocab_size, device=device)
target_ids = torch.randint(
0, vocab_size, (batch_size, target_len), device=device
)
else:
full_logits = torch.empty(batch_size, seq_len, vocab_size, device=device)
target_ids = torch.empty(
batch_size, target_len, dtype=torch.int64, device=device
)

dist.broadcast(full_logits, src=0)
dist.broadcast(target_ids, src=0)

# Reference: non-sharded computation
expected = compute_logprobs(full_logits, target_ids, align=True)

# Create vocab-sharded DTensor
mesh = init_device_mesh("cuda", (self.world_size,), mesh_dim_names=("tp",))
local_vocab = vocab_size // self.world_size
dtensor_logits = DTensor.from_local(
full_logits[:, :, rank * local_vocab : (rank + 1) * local_vocab],
mesh,
placements=[Shard(2)],
)

# Compute with loss_parallel context
with loss_parallel():
result = compute_logprobs(dtensor_logits, target_ids, align=True)

# Verify output is Replicated as expected from loss_parallel
assert isinstance(result, DTensor)
assert result.placements[
0
].is_replicate(), f"Expected Replicated placement, got {result.placements}"
result = result.to_local()

torch.testing.assert_close(result, expected, atol=1e-5, rtol=1e-5)
Loading