5959{"question": "神在月とは何ですか?", "answer": "神在月とは、旧暦10月のことを指し、全国の八百万の神々が出雲大社に集まり、縁結びの神議り(かむはかり)が行われるとされる月です。出雲地方では「神在月」と呼びますが、他の地域では「神無月」と呼ばれます。"}
6060""" # noqa: E501
6161QUESTIONS : list [dict [str , str ]] = []
62+ FAILED_TO_GEN_MSG = "(FAILED TO GENERATE AN ANSWER)"
6263
6364
6465def get_questions () -> list [dict [str , str ]]:
@@ -130,6 +131,8 @@ def run_tasks(
130131 num_examples : int = 20 ,
131132 num_trials : int = 100 ,
132133 seed : str = "" ,
134+ num_retries : int = 10 ,
135+ ignore_failure : bool = False ,
133136 ** parameters : typing .Any ,
134137) -> None :
135138 questions = get_questions ()
@@ -158,7 +161,7 @@ def run_tasks(
158161 print (f"Starting a trial: { trial } " , file = sys .stderr )
159162 if buf == "" :
160163 outputs : dict [str , str ] = {}
161- for _ in range (10 ):
164+ for _ in range (num_retries ):
162165 tasks : list [dict [str , str ]] = []
163166 task_questions : list [str ] = []
164167 for q_info in questions :
@@ -184,13 +187,17 @@ def run_tasks(
184187 for q , a in zip (task_questions , callback (tasks , parameters )):
185188 if a is None or a == "" :
186189 print (f"Failed to get an answer for: { q } " , file = sys .stderr )
187- time .sleep (3 )
188- continue
190+ if ignore_failure :
191+ a = FAILED_TO_GEN_MSG
192+ else :
193+ time .sleep (3 )
194+ continue
189195 if mode in ("chat" , "qa" ) and "A:" in a :
190196 a = a .split ("A:" , 1 )[1 ].strip ()
191197 result = {
192198 "question" : q ,
193199 "answer" : a .strip (),
200+ "generated" : a != FAILED_TO_GEN_MSG ,
194201 "timestamp" : datetime .datetime .now ().isoformat (),
195202 }
196203 output = json .dumps (result , ensure_ascii = False )
0 commit comments