We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent ae4c880 commit 352175eCopy full SHA for 352175e
src/forge/util/ops.py
@@ -131,10 +131,9 @@ def compute_logprobs_parallel(
131
return compute_logprobs(logits.full_tensor(), target_ids, temperature, align)
132
133
local_logits = logits._local_tensor # [batch, seq_len, vocab_size / tp_size]
134
- target_len = target_ids.size(1)
135
136
if align:
137
- local_logits = local_logits[:, -target_len - 1 : -1, :]
+ local_logits = local_logits[:, -target_ids.size(1) - 1 : -1, :]
138
139
target_ids = target_ids.to(local_logits.device)
140
local_logits_fp32 = local_logits.float() / temperature
0 commit comments