We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent c0aedca commit 90120bbCopy full SHA for 90120bb
src/forge/actors/reference_model.py
@@ -188,6 +188,8 @@ async def forward(
188
if parallel_dims.tp_enabled and isinstance(logits, DTensor):
189
with loss_parallel():
190
logprobs = compute_logprobs(logits, response_tokens, align=True)
191
+
192
+ logprobs = logprobs.to_local()
193
else:
194
logprobs = compute_logprobs(logits, response_tokens)
195
t.step("compute_logprobs")
0 commit comments