diff --git a/src/forge/actors/reference_model.py b/src/forge/actors/reference_model.py index 42efd9054..99e25f937 100644 --- a/src/forge/actors/reference_model.py +++ b/src/forge/actors/reference_model.py @@ -20,6 +20,7 @@ from forge.util.ops import compute_logprobs from monarch.actor import current_rank, current_size, endpoint from torch.distributed.tensor import DTensor +from torch.distributed.tensor.parallel import loss_parallel from torchtitan.config.job_config import ( Checkpoint, @@ -98,6 +99,10 @@ def __post_init__(self): self.rank = current_rank().rank self.size = math.prod(current_size().values()) + self.compute_log_probs = compute_logprobs + if self.compile.enable: + self.compute_log_probs = torch.compile(self.compute_log_probs) + env = { "RANK": str(self.rank), "LOCAL_RANK": str(self.rank), @@ -174,13 +179,23 @@ async def forward( with torch.inference_mode(): logits = self.model(input_ids) self.step += 1 - if isinstance(logits, DTensor): - logits = logits.full_tensor() if not return_logprobs: + if isinstance(logits, DTensor): + logits = logits.full_tensor() t.stop() return logits else: - logprobs = compute_logprobs(logits, input_ids[:, max_req_tokens:]) + response_tokens = input_ids[:, max_req_tokens:] + if parallel_dims.tp_enabled and isinstance(logits, DTensor): + with loss_parallel(): + logprobs = self.compute_log_probs(logits, response_tokens) + + # loss_parallel produces Replicated output - to_local() returns the full tensor + logprobs = logprobs.to_local() + else: + if isinstance(logits, DTensor): + logits = logits.full_tensor() + logprobs = self.compute_log_probs(logits, response_tokens) t.stop() return logprobs diff --git a/src/forge/util/ops.py b/src/forge/util/ops.py index f7152f065..59bff8570 100644 --- a/src/forge/util/ops.py +++ b/src/forge/util/ops.py @@ -6,10 +6,11 @@ import torch import torch.nn.functional as F +from torch.distributed.tensor import DTensor def compute_logprobs( - logits: torch.Tensor, + logits: torch.Tensor | DTensor, input_ids: torch.Tensor, temperature: float = 1.0, align: bool = True, @@ -52,9 +53,21 @@ def compute_logprobs( probabilities for the response portion, you don't need to re-run the model. This is a key optimization in RL training where the prompt remains constant. + **Tensor Parallelism Support:** + When logits is a DTensor sharded on the vocab dimension (e.g., from tensor parallel + training), wrap calls to this function with `loss_parallel()` context: + + >>> from torch.distributed.tensor.parallel import loss_parallel + >>> with loss_parallel(): + ... logprobs = compute_logprobs(logits, input_ids) + + The `loss_parallel` context ensures F.cross_entropy works correctly with + vocab-sharded DTensors without needing to gather the full tensor. + Args: logits (`torch.Tensor`): The model output logits of shape `(batch_size, sequence_length, vocab_size)`. + Can be a regular Tensor or a DTensor (when using with loss_parallel context). input_ids (`torch.Tensor`): The target token ids of shape `(batch_size, target_sequence_length)`. These are the tokens for which you want to compute log probabilities. diff --git a/tests/unit_tests/util/test_ops.py b/tests/unit_tests/util/test_ops.py index 2f224743a..b9e929120 100644 --- a/tests/unit_tests/util/test_ops.py +++ b/tests/unit_tests/util/test_ops.py @@ -6,9 +6,17 @@ import pytest import torch +import torch.distributed as dist import torch.nn.functional as F + from forge.util.ops import compute_logprobs +from tests.test_utils import gpu_test +from torch.distributed.device_mesh import init_device_mesh +from torch.distributed.tensor import DTensor, Shard +from torch.distributed.tensor.parallel import loss_parallel +from torch.testing._internal.common_fsdp import FSDPTest + def _textbook_log_softmax(logits: torch.Tensor, input_ids: torch.Tensor): # Helper: Textbook Log Softmax @@ -162,3 +170,60 @@ def test_align_comparison(self): # Both should give the same result assert torch.allclose(result_aligned, result_manual, atol=1e-5) + + +class TestComputeLogprobsWithLossParallel(FSDPTest): + """Test compute_logprobs with loss_parallel context for vocab-sharded DTensors.""" + + @property + def world_size(self) -> int: + return 2 + + @gpu_test(gpu_count=2) + def test_loss_parallel_matches_sequential(self): + """Verify compute_logprobs under loss_parallel matches non-sharded version.""" + torch.manual_seed(42) + + batch_size, seq_len, vocab_size, target_len = 4, 16, 1000, 8 + rank = dist.get_rank() + device = torch.device(f"cuda:{rank}") + + # Create and broadcast test data + if rank == 0: + full_logits = torch.randn(batch_size, seq_len, vocab_size, device=device) + target_ids = torch.randint( + 0, vocab_size, (batch_size, target_len), device=device + ) + else: + full_logits = torch.empty(batch_size, seq_len, vocab_size, device=device) + target_ids = torch.empty( + batch_size, target_len, dtype=torch.int64, device=device + ) + + dist.broadcast(full_logits, src=0) + dist.broadcast(target_ids, src=0) + + # Reference: non-sharded computation + expected = compute_logprobs(full_logits, target_ids, align=True) + + # Create vocab-sharded DTensor + mesh = init_device_mesh("cuda", (self.world_size,), mesh_dim_names=("tp",)) + local_vocab = vocab_size // self.world_size + dtensor_logits = DTensor.from_local( + full_logits[:, :, rank * local_vocab : (rank + 1) * local_vocab], + mesh, + placements=[Shard(2)], + ) + + # Compute with loss_parallel context + with loss_parallel(): + result = compute_logprobs(dtensor_logits, target_ids, align=True) + + # Verify output is Replicated as expected from loss_parallel + assert isinstance(result, DTensor) + assert result.placements[ + 0 + ].is_replicate(), f"Expected Replicated placement, got {result.placements}" + result = result.to_local() + + torch.testing.assert_close(result, expected, atol=1e-5, rtol=1e-5)