Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions pfgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
{"question": "神在月とは何ですか?", "answer": "神在月とは、旧暦10月のことを指し、全国の八百万の神々が出雲大社に集まり、縁結びの神議り(かむはかり)が行われるとされる月です。出雲地方では「神在月」と呼びますが、他の地域では「神無月」と呼ばれます。"}
""" # noqa: E501
QUESTIONS: list[dict[str, str]] = []
FAILED_TO_GEN_MSG = "(FAILED TO GENERATE AN ANSWER)"


def get_questions() -> list[dict[str, str]]:
Expand Down Expand Up @@ -130,6 +131,8 @@ def run_tasks(
num_examples: int = 20,
num_trials: int = 100,
seed: str = "",
num_retries: int = 10,
ignore_failure: bool = False,
**parameters: typing.Any,
) -> None:
questions = get_questions()
Expand Down Expand Up @@ -158,7 +161,7 @@ def run_tasks(
print(f"Starting a trial: {trial}", file=sys.stderr)
if buf == "":
outputs: dict[str, str] = {}
for _ in range(10):
for _ in range(num_retries):
tasks: list[dict[str, str]] = []
task_questions: list[str] = []
for q_info in questions:
Expand All @@ -184,13 +187,17 @@ def run_tasks(
for q, a in zip(task_questions, callback(tasks, parameters)):
if a is None or a == "":
print(f"Failed to get an answer for: {q}", file=sys.stderr)
time.sleep(3)
continue
if ignore_failure:
a = FAILED_TO_GEN_MSG
else:
time.sleep(3)
continue
if mode in ("chat", "qa") and "A:" in a:
a = a.split("A:", 1)[1].strip()
result = {
"question": q,
"answer": a.strip(),
"generated": a != FAILED_TO_GEN_MSG,
"timestamp": datetime.datetime.now().isoformat(),
}
output = json.dumps(result, ensure_ascii=False)
Expand Down
6 changes: 4 additions & 2 deletions pfgen_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def run_result(
with open(score_path, "w") as f:
json.dump(data, f, indent=2, ensure_ascii=False)

result["num_trials"] = min([len(x) for x in data.values()])
result["num_trials"] = max([len(x) for x in data.values()])
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Putting num_trials in config.json seems appropriate.

task_scores = [[a["scores"]["average"] for a in x] for x in data.values()]
trial_scores = [sum(x) / len(x) for x in zip(*task_scores)]
result["score"], result["score_std"] = mean_std(trial_scores)
Expand Down Expand Up @@ -246,6 +246,7 @@ def run_result(
r["score"], r["score_std"] = mean_std([a["scores"]["average"] for a in answers])
r["length"], r["length_std"] = mean_std([len(a["answer"]) for a in answers], 1)
r["scores"] = scores
r["num_valid_trials"] = len(answers)
r["samples"] = samples
result_questions[question_id] = r
result["scores"] = scores_all
Expand Down Expand Up @@ -324,7 +325,8 @@ def run(self, force: bool = False) -> None:
"answers": [],
},
)
answers[d["question"]][output_path]["answers"].append(d)
if d["generated"] is True:
answers[d["question"]][output_path]["answers"].append(d)
# Check if the output needs to be updated.
if not force:
for question, data in answers.items():
Expand Down
9 changes: 9 additions & 0 deletions run-gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,13 @@ def callback(
)
parser.add_argument("--temperature", type=float, default=1.0, help="Temperature for sampling.")
parser.add_argument("--num-trials", type=int, default=10, help="Number of trials to run.")
parser.add_argument("--num-retries", type=int, default=10, help="Number of retries.")
parser.add_argument(
"--ignore-failure",
action="store_true",
default=False,
help="Do not throw an exception if answer generation fails.",
)
args = parser.parse_args()
pfgen.run_tasks(
args.mode,
Expand All @@ -94,4 +101,6 @@ def callback(
temperature=args.temperature,
num_trials=args.num_trials,
max_tokens=3000,
num_retries=args.num_retries,
ignore_failure=args.ignore_failure,
)
9 changes: 9 additions & 0 deletions run-hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,13 @@ def __call__(
parser.add_argument("--batch-size", type=int, default=1, help="Batch size for sampling.")
parser.add_argument("--device", type=str, default="auto", help="Device for sampling.")
parser.add_argument("--dtype", type=str, default="", help="Data type.")
parser.add_argument("--num-retries", type=int, default=10, help="Number of retries.")
parser.add_argument(
"--ignore-failure",
action="store_true",
default=False,
help="Do not throw an exception if answer generation fails.",
)
args = parser.parse_args()
kwargs = {}
if args.dtype:
Expand All @@ -160,6 +167,8 @@ def __call__(
num_trials=args.num_trials,
temperature=args.temperature,
top_p=args.top_p,
num_retries=args.num_retries,
ignore_failure=args.ignore_failure,
_path=args.path,
_batch_size=args.batch_size,
_device=args.device if torch.cuda.is_available() else "cpu",
Expand Down
9 changes: 9 additions & 0 deletions run-manual.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,19 @@ def callback(
default="mistralai/Mistral-7B-v0.1",
help="Huggingface model name.",
)
parser.add_argument("--num-retries", type=int, default=10, help="Number of retries.")
parser.add_argument(
"--ignore-failure",
action="store_true",
default=False,
help="Do not throw an exception if answer generation fails.",
)
args = parser.parse_args()
pfgen.run_tasks(
args.mode,
callback,
engine="manual",
model=args.model,
num_retries=args.num_retries,
ignore_failure=args.ignore_failure,
)
9 changes: 9 additions & 0 deletions run-openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,13 @@ def callback(
action="store_true",
help="Disable reasoning when generation by Qwen3 models",
)
parser.add_argument("--num-retries", type=int, default=10, help="Number of retries.")
parser.add_argument(
"--ignore-failure",
action="store_true",
default=False,
help="Do not throw an exception if answer generation fails.",
)
args = parser.parse_args()

wrapped_callback = partial(
Expand All @@ -110,4 +117,6 @@ def callback(
top_p=args.top_p,
num_trials=args.num_trials,
enable_thinking=not args.disable_thinking,
num_retries=args.num_retries,
ignore_failure=args.ignore_failure,
)
9 changes: 9 additions & 0 deletions run-vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,13 @@ def __call__(
"--tensor-parallel-size", type=int, default=-1, help="Tensor Parallel Size."
)
parser.add_argument("--quantization", type=str, default=None, help="Quantization method.")
parser.add_argument("--num-retries", type=int, default=10, help="Number of retries.")
parser.add_argument(
"--ignore-failure",
action="store_true",
default=False,
help="Do not throw an exception if answer generation fails.",
)
args = parser.parse_args()
kwargs = {}
if args.max_tokens:
Expand Down Expand Up @@ -180,6 +187,8 @@ def __call__(
num_trials=args.num_trials,
temperature=args.temperature,
top_p=args.top_p,
num_retries=args.num_retries,
ignore_failure=args.ignore_failure,
_path=args.path,
**kwargs,
)