Skip to content

Commit 90120bb

Browse files
fix: convert DTensor output to regular tensor after loss_parallel
1 parent c0aedca commit 90120bb

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

src/forge/actors/reference_model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,8 @@ async def forward(
188188
if parallel_dims.tp_enabled and isinstance(logits, DTensor):
189189
with loss_parallel():
190190
logprobs = compute_logprobs(logits, response_tokens, align=True)
191+
192+
logprobs = logprobs.to_local()
191193
else:
192194
logprobs = compute_logprobs(logits, response_tokens)
193195
t.step("compute_logprobs")

0 commit comments

Comments
 (0)