Skip to content

Commit 8e1224e

Browse files
authored
fix: make stdlib.reward more robust (#1300)
Signed-off-by: Louis Mandel <[email protected]>
1 parent a734ddc commit 8e1224e

File tree

1 file changed

+18
-3
lines changed

1 file changed

+18
-3
lines changed

src/pdl/pdl_stdlib.pdl

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,23 +38,38 @@ defs:
3838
else:
3939
lp_n = lp_evaluation
4040
other = 'true'
41-
if evaluation in contents_keys:
41+
42+
if other in contents_keys:
4243
exact_match = True
4344
else:
4445
exact_match = False
4546

47+
lp_other = -math.inf
4648
for tp in top_logprobs:
4749
if match(other, tp['token'], exact_match):
4850
lp_other = tp['logprob']
4951
break
5052

53+
# if math.isinf(lp_other):
54+
# print("XXXXXXXXXXXXXXXXXXXXXXXXXXXX")
55+
# print("XXXX lp_other is not defined")
56+
# print(f"{evaluation=}")
57+
# print(f"{other=}")
58+
# print(f"{exact_match=}")
59+
# print(f"top_logprobs tokens:")
60+
# for tp in top_logprobs:
61+
# print(f" {tp['token']}")
62+
# print("XXXXXXXXXXXXXXXXXXXXXXXXXXXX")
63+
5164
if other == 'true':
5265
lp_y = lp_other
5366
else:
5467
lp_n = lp_other
5568

56-
result = math.log(math.exp(lp_y) / (math.exp(lp_y) + math.exp(lp_n)))
57-
69+
if math.exp(lp_y) == 0.0:
70+
result = -math.inf
71+
else:
72+
result = math.log(math.exp(lp_y) / (math.exp(lp_y) + math.exp(lp_n)))
5873

5974
# print(f"evaluation: { evaluation }")
6075
# print(f"exact_match: { exact_match }")

0 commit comments

Comments
 (0)