99import re
1010import os
1111
12+
1213class ComplianceCheck (BaseCheck ):
13- def __init__ (self , log , path , config : Config , submission_logs : SubmissionLogs ):
14+ def __init__ (self , log , path , config : Config ,
15+ submission_logs : SubmissionLogs ):
1416 super ().__init__ (log , path )
1517 self .submission_logs = submission_logs
1618 self .config = config
1719 self .model = self .submission_logs .loader_data .get ("benchmark" , "" )
18- self .model_mapping = self .submission_logs .loader_data .get ("model_mapping" , {})
19- self .compliance_dir = self .submission_logs .loader_data .get ("compliance_path" , {})
20+ self .model_mapping = self .submission_logs .loader_data .get (
21+ "model_mapping" , {})
22+ self .compliance_dir = self .submission_logs .loader_data .get (
23+ "compliance_path" , {})
2024 self .division = self .submission_logs .loader_data .get ("division" , "" )
21- self .model = self .config .get_mlperf_model (self .model , self .model_mapping )
25+ self .model = self .config .get_mlperf_model (
26+ self .model , self .model_mapping )
2227 self .test_list = self .get_test_list (self .model )
2328 self .setup_checks ()
2429
@@ -36,36 +41,54 @@ def get_test_list(self, model):
3641 if model in self .config .base ["models_TEST06" ]:
3742 test_list .append ("TEST06" )
3843 return test_list
39-
44+
4045 def dir_exists_check (self ):
4146 if self .division .lower () == "open" :
42- self .log .info ("Compliance tests not needed for open division. Skipping tests on %s" , self .path )
47+ self .log .info (
48+ "Compliance tests not needed for open division. Skipping tests on %s" ,
49+ self .path )
4350 return True
4451 is_valid = True
4552 for test in self .test_list :
4653 test_dir = os .path .join (self .compliance_dir , test )
47- acc_path = os .path .join (self .compliance_dir , test , "verify_accuracy.txt" )
48- perf_comp_path = os .path .join (self .compliance_dir , test , "verify_performance.txt" )
49- perf_path = os .path .join (self .compliance_dir , test , "performance" , "run_1" , "mlperf_log_detail.txt" )
54+ acc_path = os .path .join (
55+ self .compliance_dir , test , "verify_accuracy.txt" )
56+ perf_comp_path = os .path .join (
57+ self .compliance_dir , test , "verify_performance.txt" )
58+ perf_path = os .path .join (
59+ self .compliance_dir ,
60+ test ,
61+ "performance" ,
62+ "run_1" ,
63+ "mlperf_log_detail.txt" )
5064 if not os .path .exists (test_dir ):
51- self .log .error ("Missing %s in compliance dir %s" , test , self .compliance_dir )
65+ self .log .error (
66+ "Missing %s in compliance dir %s" ,
67+ test ,
68+ self .compliance_dir )
5269 is_valid = False
5370 if test in ["TEST01" , "TEST06" ]:
5471 if not os .path .exists (acc_path ):
55- self .log .error ("Missing accuracy file in compliance dir. Needs file %s" , acc_path )
72+ self .log .error (
73+ "Missing accuracy file in compliance dir. Needs file %s" , acc_path )
5674 is_valid = False
5775 if test in ["TEST01" , "TEST04" ]:
5876 if not os .path .exists (perf_comp_path ):
59- self .log .error ("Missing performance file in compliance dir. Needs file %s" , perf_comp_path )
77+ self .log .error (
78+ "Missing performance file in compliance dir. Needs file %s" ,
79+ perf_comp_path )
6080 is_valid = False
6181 if not os .path .exists (perf_path ):
62- self .log .error ("Missing perfomance file in compliance dir. Needs file %s" , perf_path )
82+ self .log .error (
83+ "Missing perfomance file in compliance dir. Needs file %s" , perf_path )
6384 is_valid = False
6485 return is_valid
65-
86+
6687 def performance_check (self ):
6788 if self .division .lower () == "open" :
68- self .log .info ("Compliance tests not needed for open division. Skipping tests on %s" , self .path )
89+ self .log .info (
90+ "Compliance tests not needed for open division. Skipping tests on %s" ,
91+ self .path )
6992 return True
7093 is_valid = True
7194 for test in self .test_list :
@@ -76,14 +99,24 @@ def performance_check(self):
7699 "scenario" : self .submission_logs .loader_data .get ("scenario" , "" ),
77100 "model_mapping" : self .submission_logs .loader_data .get ("model_mapping" , {})
78101 }
79- test_logs = SubmissionLogs (self .submission_logs .loader_data [f"{ test } _perf_log" ], None , None , None , self .submission_logs .system_json , None , test_data )
80- perf_check = PerformanceCheck (self .log , os .path .join (self .compliance_dir , test ), self .config , test_logs )
102+ test_logs = SubmissionLogs (
103+ self .submission_logs .loader_data [f"{ test } _perf_log" ],
104+ None ,
105+ None ,
106+ None ,
107+ self .submission_logs .system_json ,
108+ None ,
109+ test_data )
110+ perf_check = PerformanceCheck (self .log , os .path .join (
111+ self .compliance_dir , test ), self .config , test_logs )
81112 is_valid &= perf_check ()
82113 return is_valid
83114
84115 def accuracy_check (self ):
85116 if self .division .lower () == "open" :
86- self .log .info ("Compliance tests not needed for open division. Skipping tests on %s" , self .path )
117+ self .log .info (
118+ "Compliance tests not needed for open division. Skipping tests on %s" ,
119+ self .path )
87120 return True
88121 is_valid = True
89122 for test in self .test_list :
@@ -103,11 +136,13 @@ def accuracy_check(self):
103136 )
104137 test_acc_path = os .path .join (test_dir , "accuracy" )
105138 if not os .path .exists (test_acc_path ):
106- self .log .error ("%s has no accuracy directory" , test_dir )
139+ self .log .error (
140+ "%s has no accuracy directory" , test_dir )
107141 is_valid = False
108142 else :
109143 diff = files_diff (
110- list_files (test_acc_path ), REQUIRED_TEST01_ACC_FILES ,
144+ list_files (
145+ test_acc_path ), REQUIRED_TEST01_ACC_FILES ,
111146 )
112147 if diff :
113148 self .log .error (
@@ -116,30 +151,39 @@ def accuracy_check(self):
116151 diff )
117152 is_valid = False
118153 else :
119- target = self .config .get_accuracy_target (self .model )
120- patterns , acc_targets , acc_types , acc_limits , up_patterns , acc_upper_limit = self .config .get_accuracy_values (self .model )
154+ target = self .config .get_accuracy_target (
155+ self .model )
156+ patterns , acc_targets , acc_types , acc_limits , up_patterns , acc_upper_limit = self .config .get_accuracy_values (
157+ self .model )
121158 acc_limit_check = True
122159
123160 acc_seen = [False for _ in acc_targets ]
124- acc_baseline = {acc_type : 0 for acc_type in acc_types }
125- acc_compliance = {acc_type : 0 for acc_type in acc_types }
161+ acc_baseline = {
162+ acc_type : 0 for acc_type in acc_types }
163+ acc_compliance = {
164+ acc_type : 0 for acc_type in acc_types }
126165 with open (
127- os .path .join (test_acc_path , "baseline_accuracy.txt" ),
166+ os .path .join (
167+ test_acc_path , "baseline_accuracy.txt" ),
128168 "r" ,
129169 encoding = "utf-8" ,
130170 ) as f :
131171 for line in f :
132- for acc_type , pattern in zip (acc_types , patterns ):
172+ for acc_type , pattern in zip (
173+ acc_types , patterns ):
133174 m = re .match (pattern , line )
134175 if m :
135- acc_baseline [acc_type ] = float (m .group (1 ))
176+ acc_baseline [acc_type ] = float (
177+ m .group (1 ))
136178 with open (
137- os .path .join (test_acc_path , "compliance_accuracy.txt" ),
179+ os .path .join (
180+ test_acc_path , "compliance_accuracy.txt" ),
138181 "r" ,
139182 encoding = "utf-8" ,
140183 ) as f :
141184 for line in f :
142- for acc_type , pattern in zip (acc_types , patterns ):
185+ for acc_type , pattern in zip (
186+ acc_types , patterns ):
143187 m = re .match (pattern , line )
144188 if m :
145189 acc_compliance [acc_type ] = float (
@@ -178,15 +222,16 @@ def accuracy_check(self):
178222 )
179223 eos_pass = "EOS check pass: True" in lines
180224 length_check_pass = "Sample length check pass: True" in lines
181- is_valid &= (first_token_pass and eos_pass and length_check_pass )
225+ is_valid &= (
226+ first_token_pass and eos_pass and length_check_pass )
182227 if not is_valid :
183228 self .log .error (
184229 f"TEST06 accuracy check failed. first_token_check: { first_token_pass } eos_check: { eos_pass } length_check: { length_check_pass } ."
185230 )
186231 else :
187232 self .log .info (f"{ test_dir } does not require accuracy check" )
188233 return is_valid
189-
234+
190235 def compliance_performance_check (self ):
191236 is_valid = True
192237 for test in self .test_list :
@@ -209,9 +254,11 @@ def compliance_performance_check(self):
209254 test_dir )
210255
211256 # Check performance dir
212- test_perf_path = os .path .join (test_dir , "performance" , "run_1" )
257+ test_perf_path = os .path .join (
258+ test_dir , "performance" , "run_1" )
213259 if not os .path .exists (test_perf_path ):
214- self .log .error ("%s has no performance/run_1 directory" , test_dir )
260+ self .log .error (
261+ "%s has no performance/run_1 directory" , test_dir )
215262 is_valid = False
216263 else :
217264 diff = files_diff (
@@ -225,4 +272,4 @@ def compliance_performance_check(self):
225272 test_perf_path ,
226273 diff )
227274 is_valid = False
228- return is_valid
275+ return is_valid
0 commit comments