Skip to content

Commit c3bd1cb

Browse files
authored
fix: raise RuntimeError when failed to generate an answer (#94)
1 parent 03db7c4 commit c3bd1cb

File tree

7 files changed

+59
-5
lines changed

7 files changed

+59
-5
lines changed

pfgen.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
{"question": "神在月とは何ですか?", "answer": "神在月とは、旧暦10月のことを指し、全国の八百万の神々が出雲大社に集まり、縁結びの神議り(かむはかり)が行われるとされる月です。出雲地方では「神在月」と呼びますが、他の地域では「神無月」と呼ばれます。"}
6060
""" # noqa: E501
6161
QUESTIONS: list[dict[str, str]] = []
62+
FAILED_TO_GEN_MSG = "(FAILED TO GENERATE AN ANSWER)"
6263

6364

6465
def get_questions() -> list[dict[str, str]]:
@@ -130,6 +131,8 @@ def run_tasks(
130131
num_examples: int = 20,
131132
num_trials: int = 100,
132133
seed: str = "",
134+
num_retries: int = 10,
135+
ignore_failure: bool = False,
133136
**parameters: typing.Any,
134137
) -> None:
135138
questions = get_questions()
@@ -158,7 +161,7 @@ def run_tasks(
158161
print(f"Starting a trial: {trial}", file=sys.stderr)
159162
if buf == "":
160163
outputs: dict[str, str] = {}
161-
for _ in range(10):
164+
for _ in range(num_retries):
162165
tasks: list[dict[str, str]] = []
163166
task_questions: list[str] = []
164167
for q_info in questions:
@@ -184,13 +187,17 @@ def run_tasks(
184187
for q, a in zip(task_questions, callback(tasks, parameters)):
185188
if a is None or a == "":
186189
print(f"Failed to get an answer for: {q}", file=sys.stderr)
187-
time.sleep(3)
188-
continue
190+
if ignore_failure:
191+
a = FAILED_TO_GEN_MSG
192+
else:
193+
time.sleep(3)
194+
continue
189195
if mode in ("chat", "qa") and "A:" in a:
190196
a = a.split("A:", 1)[1].strip()
191197
result = {
192198
"question": q,
193199
"answer": a.strip(),
200+
"generated": a != FAILED_TO_GEN_MSG,
194201
"timestamp": datetime.datetime.now().isoformat(),
195202
}
196203
output = json.dumps(result, ensure_ascii=False)

pfgen_eval.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ def run_result(
206206
with open(score_path, "w") as f:
207207
json.dump(data, f, indent=2, ensure_ascii=False)
208208

209-
result["num_trials"] = min([len(x) for x in data.values()])
209+
result["num_trials"] = max([len(x) for x in data.values()])
210210
task_scores = [[a["scores"]["average"] for a in x] for x in data.values()]
211211
trial_scores = [sum(x) / len(x) for x in zip(*task_scores)]
212212
result["score"], result["score_std"] = mean_std(trial_scores)
@@ -246,6 +246,7 @@ def run_result(
246246
r["score"], r["score_std"] = mean_std([a["scores"]["average"] for a in answers])
247247
r["length"], r["length_std"] = mean_std([len(a["answer"]) for a in answers], 1)
248248
r["scores"] = scores
249+
r["num_valid_trials"] = len(answers)
249250
r["samples"] = samples
250251
result_questions[question_id] = r
251252
result["scores"] = scores_all
@@ -324,7 +325,8 @@ def run(self, force: bool = False) -> None:
324325
"answers": [],
325326
},
326327
)
327-
answers[d["question"]][output_path]["answers"].append(d)
328+
if d["generated"] is True:
329+
answers[d["question"]][output_path]["answers"].append(d)
328330
# Check if the output needs to be updated.
329331
if not force:
330332
for question, data in answers.items():

run-gemini.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,13 @@ def callback(
8484
)
8585
parser.add_argument("--temperature", type=float, default=1.0, help="Temperature for sampling.")
8686
parser.add_argument("--num-trials", type=int, default=10, help="Number of trials to run.")
87+
parser.add_argument("--num-retries", type=int, default=10, help="Number of retries.")
88+
parser.add_argument(
89+
"--ignore-failure",
90+
action="store_true",
91+
default=False,
92+
help="Do not throw an exception if answer generation fails.",
93+
)
8794
args = parser.parse_args()
8895
pfgen.run_tasks(
8996
args.mode,
@@ -94,4 +101,6 @@ def callback(
94101
temperature=args.temperature,
95102
num_trials=args.num_trials,
96103
max_tokens=3000,
104+
num_retries=args.num_retries,
105+
ignore_failure=args.ignore_failure,
97106
)

run-hf.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,13 @@ def __call__(
142142
parser.add_argument("--batch-size", type=int, default=1, help="Batch size for sampling.")
143143
parser.add_argument("--device", type=str, default="auto", help="Device for sampling.")
144144
parser.add_argument("--dtype", type=str, default="", help="Data type.")
145+
parser.add_argument("--num-retries", type=int, default=10, help="Number of retries.")
146+
parser.add_argument(
147+
"--ignore-failure",
148+
action="store_true",
149+
default=False,
150+
help="Do not throw an exception if answer generation fails.",
151+
)
145152
args = parser.parse_args()
146153
kwargs = {}
147154
if args.dtype:
@@ -160,6 +167,8 @@ def __call__(
160167
num_trials=args.num_trials,
161168
temperature=args.temperature,
162169
top_p=args.top_p,
170+
num_retries=args.num_retries,
171+
ignore_failure=args.ignore_failure,
163172
_path=args.path,
164173
_batch_size=args.batch_size,
165174
_device=args.device if torch.cuda.is_available() else "cpu",

run-manual.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,19 @@ def callback(
3636
default="mistralai/Mistral-7B-v0.1",
3737
help="Huggingface model name.",
3838
)
39+
parser.add_argument("--num-retries", type=int, default=10, help="Number of retries.")
40+
parser.add_argument(
41+
"--ignore-failure",
42+
action="store_true",
43+
default=False,
44+
help="Do not throw an exception if answer generation fails.",
45+
)
3946
args = parser.parse_args()
4047
pfgen.run_tasks(
4148
args.mode,
4249
callback,
4350
engine="manual",
4451
model=args.model,
52+
num_retries=args.num_retries,
53+
ignore_failure=args.ignore_failure,
4554
)

run-openai.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,13 @@ def callback(
9393
action="store_true",
9494
help="Disable reasoning when generation by Qwen3 models",
9595
)
96+
parser.add_argument("--num-retries", type=int, default=10, help="Number of retries.")
97+
parser.add_argument(
98+
"--ignore-failure",
99+
action="store_true",
100+
default=False,
101+
help="Do not throw an exception if answer generation fails.",
102+
)
96103
args = parser.parse_args()
97104

98105
wrapped_callback = partial(
@@ -110,4 +117,6 @@ def callback(
110117
top_p=args.top_p,
111118
num_trials=args.num_trials,
112119
enable_thinking=not args.disable_thinking,
120+
num_retries=args.num_retries,
121+
ignore_failure=args.ignore_failure,
113122
)

run-vllm.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,13 @@ def __call__(
153153
"--tensor-parallel-size", type=int, default=-1, help="Tensor Parallel Size."
154154
)
155155
parser.add_argument("--quantization", type=str, default=None, help="Quantization method.")
156+
parser.add_argument("--num-retries", type=int, default=10, help="Number of retries.")
157+
parser.add_argument(
158+
"--ignore-failure",
159+
action="store_true",
160+
default=False,
161+
help="Do not throw an exception if answer generation fails.",
162+
)
156163
args = parser.parse_args()
157164
kwargs = {}
158165
if args.max_tokens:
@@ -180,6 +187,8 @@ def __call__(
180187
num_trials=args.num_trials,
181188
temperature=args.temperature,
182189
top_p=args.top_p,
190+
num_retries=args.num_retries,
191+
ignore_failure=args.ignore_failure,
183192
_path=args.path,
184193
**kwargs,
185194
)

0 commit comments

Comments
 (0)