Skip to content

Commit 9a05f62

Browse files
gitlost-muraliFelipe Mello
andauthored
feat: Reduce reference model memory with with parallel logprob computation (#608)
Co-authored-by: Felipe Mello <[email protected]>
1 parent e451ad5 commit 9a05f62

File tree

3 files changed

+97
-4
lines changed

3 files changed

+97
-4
lines changed

src/forge/actors/reference_model.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from forge.util.ops import compute_logprobs
2121
from monarch.actor import current_rank, current_size, endpoint
2222
from torch.distributed.tensor import DTensor
23+
from torch.distributed.tensor.parallel import loss_parallel
2324

2425
from torchtitan.config.job_config import (
2526
Checkpoint,
@@ -98,6 +99,10 @@ def __post_init__(self):
9899
self.rank = current_rank().rank
99100
self.size = math.prod(current_size().values())
100101

102+
self.compute_log_probs = compute_logprobs
103+
if self.compile.enable:
104+
self.compute_log_probs = torch.compile(self.compute_log_probs)
105+
101106
env = {
102107
"RANK": str(self.rank),
103108
"LOCAL_RANK": str(self.rank),
@@ -174,13 +179,23 @@ async def forward(
174179
with torch.inference_mode():
175180
logits = self.model(input_ids)
176181
self.step += 1
177-
if isinstance(logits, DTensor):
178-
logits = logits.full_tensor()
179182

180183
if not return_logprobs:
184+
if isinstance(logits, DTensor):
185+
logits = logits.full_tensor()
181186
t.stop()
182187
return logits
183188
else:
184-
logprobs = compute_logprobs(logits, input_ids[:, max_req_tokens:])
189+
response_tokens = input_ids[:, max_req_tokens:]
190+
if parallel_dims.tp_enabled and isinstance(logits, DTensor):
191+
with loss_parallel():
192+
logprobs = self.compute_log_probs(logits, response_tokens)
193+
194+
# loss_parallel produces Replicated output - to_local() returns the full tensor
195+
logprobs = logprobs.to_local()
196+
else:
197+
if isinstance(logits, DTensor):
198+
logits = logits.full_tensor()
199+
logprobs = self.compute_log_probs(logits, response_tokens)
185200
t.stop()
186201
return logprobs

src/forge/util/ops.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@
66

77
import torch
88
import torch.nn.functional as F
9+
from torch.distributed.tensor import DTensor
910

1011

1112
def compute_logprobs(
12-
logits: torch.Tensor,
13+
logits: torch.Tensor | DTensor,
1314
input_ids: torch.Tensor,
1415
temperature: float = 1.0,
1516
align: bool = True,
@@ -52,9 +53,21 @@ def compute_logprobs(
5253
probabilities for the response portion, you don't need to re-run the model. This
5354
is a key optimization in RL training where the prompt remains constant.
5455
56+
**Tensor Parallelism Support:**
57+
When logits is a DTensor sharded on the vocab dimension (e.g., from tensor parallel
58+
training), wrap calls to this function with `loss_parallel()` context:
59+
60+
>>> from torch.distributed.tensor.parallel import loss_parallel
61+
>>> with loss_parallel():
62+
... logprobs = compute_logprobs(logits, input_ids)
63+
64+
The `loss_parallel` context ensures F.cross_entropy works correctly with
65+
vocab-sharded DTensors without needing to gather the full tensor.
66+
5567
Args:
5668
logits (`torch.Tensor`):
5769
The model output logits of shape `(batch_size, sequence_length, vocab_size)`.
70+
Can be a regular Tensor or a DTensor (when using with loss_parallel context).
5871
input_ids (`torch.Tensor`):
5972
The target token ids of shape `(batch_size, target_sequence_length)`.
6073
These are the tokens for which you want to compute log probabilities.

tests/unit_tests/util/test_ops.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,17 @@
66

77
import pytest
88
import torch
9+
import torch.distributed as dist
910
import torch.nn.functional as F
11+
1012
from forge.util.ops import compute_logprobs
1113

14+
from tests.test_utils import gpu_test
15+
from torch.distributed.device_mesh import init_device_mesh
16+
from torch.distributed.tensor import DTensor, Shard
17+
from torch.distributed.tensor.parallel import loss_parallel
18+
from torch.testing._internal.common_fsdp import FSDPTest
19+
1220

1321
def _textbook_log_softmax(logits: torch.Tensor, input_ids: torch.Tensor):
1422
# Helper: Textbook Log Softmax
@@ -162,3 +170,60 @@ def test_align_comparison(self):
162170

163171
# Both should give the same result
164172
assert torch.allclose(result_aligned, result_manual, atol=1e-5)
173+
174+
175+
class TestComputeLogprobsWithLossParallel(FSDPTest):
176+
"""Test compute_logprobs with loss_parallel context for vocab-sharded DTensors."""
177+
178+
@property
179+
def world_size(self) -> int:
180+
return 2
181+
182+
@gpu_test(gpu_count=2)
183+
def test_loss_parallel_matches_sequential(self):
184+
"""Verify compute_logprobs under loss_parallel matches non-sharded version."""
185+
torch.manual_seed(42)
186+
187+
batch_size, seq_len, vocab_size, target_len = 4, 16, 1000, 8
188+
rank = dist.get_rank()
189+
device = torch.device(f"cuda:{rank}")
190+
191+
# Create and broadcast test data
192+
if rank == 0:
193+
full_logits = torch.randn(batch_size, seq_len, vocab_size, device=device)
194+
target_ids = torch.randint(
195+
0, vocab_size, (batch_size, target_len), device=device
196+
)
197+
else:
198+
full_logits = torch.empty(batch_size, seq_len, vocab_size, device=device)
199+
target_ids = torch.empty(
200+
batch_size, target_len, dtype=torch.int64, device=device
201+
)
202+
203+
dist.broadcast(full_logits, src=0)
204+
dist.broadcast(target_ids, src=0)
205+
206+
# Reference: non-sharded computation
207+
expected = compute_logprobs(full_logits, target_ids, align=True)
208+
209+
# Create vocab-sharded DTensor
210+
mesh = init_device_mesh("cuda", (self.world_size,), mesh_dim_names=("tp",))
211+
local_vocab = vocab_size // self.world_size
212+
dtensor_logits = DTensor.from_local(
213+
full_logits[:, :, rank * local_vocab : (rank + 1) * local_vocab],
214+
mesh,
215+
placements=[Shard(2)],
216+
)
217+
218+
# Compute with loss_parallel context
219+
with loss_parallel():
220+
result = compute_logprobs(dtensor_logits, target_ids, align=True)
221+
222+
# Verify output is Replicated as expected from loss_parallel
223+
assert isinstance(result, DTensor)
224+
assert result.placements[
225+
0
226+
].is_replicate(), f"Expected Replicated placement, got {result.placements}"
227+
result = result.to_local()
228+
229+
torch.testing.assert_close(result, expected, atol=1e-5, rtol=1e-5)

0 commit comments

Comments
 (0)