11
22from .base import BaseCheck
3+ from ..constants import *
4+ from ..loader import SubmissionLogs
5+ from ..configuration .configuration import Config
6+ from .performance_check import PerformanceCheck
7+ from .accuracy_check import AccuracyCheck
8+ from ..utils import *
9+ import re
10+ import os
311
412class ComplianceCheck (BaseCheck ):
5- def __init__ (self , log , path , parsed_log ):
13+ def __init__ (self , log , path , config : Config , submission_logs : SubmissionLogs ):
614 super ().__init__ (log , path )
7- self .parsed_log = parsed_log
8- self .checks .append (self .sample_check )
15+ self .submission_logs = submission_logs
16+ self .config = config
17+ self .model = self .submission_logs .loader_data .get ("benchmark" , "" )
18+ self .model_mapping = self .submission_logs .loader_data .get ("model_mapping" , {})
19+ self .compliance_dir = self .submission_logs .loader_data .get ("compliance_path" , {})
20+ self .division = self .submission_logs .loader_data .get ("division" , "" )
21+ self .model = self .config .get_mlperf_model (self .model , self .model_mapping )
22+ self .test_list = self .get_test_list (self .model )
23+ self .setup_checks ()
924
10- def sample_check (self ):
11- return True
25+ def setup_checks (self ):
26+ self .checks .append (self .dir_exists_check )
27+ self .checks .append (self .performance_check )
28+ self .checks .append (self .accuracy_check )
29+
30+ def get_test_list (self , model ):
31+ test_list = []
32+ if model in self .config .base ["models_TEST01" ]:
33+ test_list .append ("TEST01" )
34+ if model in self .config .base ["models_TEST04" ]:
35+ test_list .append ("TEST04" )
36+ if model in self .config .base ["models_TEST06" ]:
37+ test_list .append ("TEST06" )
38+ return test_list
39+
40+ def dir_exists_check (self ):
41+ if self .division .lower () == "open" :
42+ self .log .info ("Compliance tests not needed for open division. Skipping tests on %s" , self .path )
43+ return True
44+ is_valid = True
45+ for test in self .test_list :
46+ test_dir = os .path .join (self .compliance_dir , test )
47+ acc_path = os .path .join (self .compliance_dir , test , "verify_accuracy.txt" )
48+ perf_comp_path = os .path .join (self .compliance_dir , test , "verify_performance.txt" )
49+ perf_path = os .path .join (self .compliance_dir , test , "performance" , "run_1" , "mlperf_log_detail.txt" )
50+ if not os .path .exists (test_dir ):
51+ self .log .error ("Missing %s in compliance dir %s" , test , self .compliance_dir )
52+ is_valid = False
53+ if test in ["TEST01" , "TEST06" ]:
54+ if not os .path .exists (acc_path ):
55+ self .log .error ("Missing accuracy file in compliance dir. Needs file %s" , acc_path )
56+ is_valid = False
57+ if test in ["TEST01" , "TEST04" ]:
58+ if not os .path .exists (perf_comp_path ):
59+ self .log .error ("Missing performance file in compliance dir. Needs file %s" , perf_comp_path )
60+ is_valid = False
61+ if not os .path .exists (perf_path ):
62+ self .log .error ("Missing perfomance file in compliance dir. Needs file %s" , perf_path )
63+ is_valid = False
64+ return is_valid
65+
66+ def performance_check (self ):
67+ if self .division .lower () == "open" :
68+ self .log .info ("Compliance tests not needed for open division. Skipping tests on %s" , self .path )
69+ return True
70+ is_valid = True
71+ for test in self .test_list :
72+ if test in ["TEST01" , "TEST04" ]:
73+ test_data = {
74+ "division" : self .submission_logs .loader_data .get ("division" , "" ),
75+ "benchmark" : self .submission_logs .loader_data .get ("benchmark" , "" ),
76+ "scenario" : self .submission_logs .loader_data .get ("scenario" , "" ),
77+ "model_mapping" : self .submission_logs .loader_data .get ("model_mapping" , {})
78+ }
79+ test_logs = SubmissionLogs (self .submission_logs .loader_data [f"{ test } _perf_log" ], None , None , None , self .submission_logs .system_json , None , test_data )
80+ perf_check = PerformanceCheck (self .log , os .path .join (self .compliance_dir , test ), self .config , test_logs )
81+ is_valid &= perf_check ()
82+ return is_valid
83+
84+ def accuracy_check (self ):
85+ if self .division .lower () == "open" :
86+ self .log .info ("Compliance tests not needed for open division. Skipping tests on %s" , self .path )
87+ return True
88+ is_valid = True
89+ for test in self .test_list :
90+ test_dir = os .path .join (self .compliance_dir , test )
91+ if test == "TEST01" :
92+ lines = self .submission_logs .loader_data [f"{ test } _acc_result" ]
93+ lines = [line .strip () for line in lines ]
94+ if "TEST PASS" in lines :
95+ self .log .info (
96+ "Compliance test accuracy check (deterministic mode) in %s passed" ,
97+ test_dir ,
98+ )
99+ else :
100+ self .log .info (
101+ "Compliance test accuracy check (deterministic mode) in %s failed" ,
102+ test_dir ,
103+ )
104+ test_acc_path = os .path .join (test_dir , "accuracy" )
105+ if not os .path .exists (test_acc_path ):
106+ self .log .error ("%s has no accuracy directory" , test_dir )
107+ is_valid = False
108+ else :
109+ diff = files_diff (
110+ list_files (test_acc_path ), REQUIRED_TEST01_ACC_FILES ,
111+ )
112+ if diff :
113+ self .log .error (
114+ "%s has file list mismatch (%s)" ,
115+ test_acc_path ,
116+ diff )
117+ is_valid = False
118+ else :
119+ target = self .config .get_accuracy_target (self .model )
120+ patterns , acc_targets , acc_types , acc_limits , up_patterns , acc_upper_limit = self .config .get_accuracy_values (self .model )
121+ acc_limit_check = True
122+
123+ acc_seen = [False for _ in acc_targets ]
124+ acc_baseline = {acc_type : 0 for acc_type in acc_types }
125+ acc_compliance = {acc_type : 0 for acc_type in acc_types }
126+ with open (
127+ os .path .join (test_acc_path , "baseline_accuracy.txt" ),
128+ "r" ,
129+ encoding = "utf-8" ,
130+ ) as f :
131+ for line in f :
132+ for acc_type , pattern in zip (acc_types , patterns ):
133+ m = re .match (pattern , line )
134+ if m :
135+ acc_baseline [acc_type ] = float (m .group (1 ))
136+ with open (
137+ os .path .join (test_acc_path , "compliance_accuracy.txt" ),
138+ "r" ,
139+ encoding = "utf-8" ,
140+ ) as f :
141+ for line in f :
142+ for acc_type , pattern in zip (acc_types , patterns ):
143+ m = re .match (pattern , line )
144+ if m :
145+ acc_compliance [acc_type ] = float (
146+ m .group (1 ))
147+ for acc_type in acc_types :
148+ if acc_baseline [acc_type ] == 0 or acc_compliance [acc_type ] == 0 :
149+ is_valid = False
150+ break
151+ else :
152+ required_delta_perc = self .config .get_delta_perc (
153+ self .model , acc_type
154+ )
155+ delta_perc = (
156+ abs (
157+ 1
158+ - acc_baseline [acc_type ] /
159+ acc_compliance [acc_type ]
160+ )
161+ * 100
162+ )
163+ if delta_perc <= required_delta_perc :
164+ is_valid = True
165+ else :
166+ self .log .error (
167+ "Compliance test accuracy check (non-deterministic mode) in %s failed" ,
168+ test_dir ,
169+ )
170+ is_valid = False
171+ break
172+ elif test == "TEST06" :
173+ lines = self .submission_logs .loader_data [f"{ test } _acc_result" ]
174+ lines = [line .strip () for line in lines ]
175+ first_token_pass = (
176+ "First token check pass: True" in lines
177+ or "First token check pass: Skipped" in lines
178+ )
179+ eos_pass = "EOS check pass: True" in lines
180+ length_check_pass = "Sample length check pass: True" in lines
181+ is_valid &= (first_token_pass and eos_pass and length_check_pass )
182+ if not is_valid :
183+ self .log .error (
184+ f"TEST06 accuracy check failed. first_token_check: { first_token_pass } eos_check: { eos_pass } length_check: { length_check_pass } ."
185+ )
186+ else :
187+ self .log .info (f"{ test_dir } does not require accuracy check" )
188+ return is_valid
189+
190+ def compliance_performance_check (self ):
191+ is_valid = True
192+ for test in self .test_list :
193+ test_dir = os .path .join (self .compliance_dir , test )
194+ if test in ["TEST01" , "TEST04" ]:
195+ fname = os .path .join (test_dir , "verify_performance.txt" )
196+ if not os .path .exists (fname ):
197+ self .log .error ("%s is missing in %s" , fname , test_dir )
198+ is_valid = False
199+ else :
200+ with open (fname , "r" ) as f :
201+ for line in f :
202+ # look for: TEST PASS
203+ if "TEST PASS" in line :
204+ is_valid = True
205+ break
206+ if is_valid == False :
207+ self .log .error (
208+ "Compliance test performance check in %s failed" ,
209+ test_dir )
210+
211+ # Check performance dir
212+ test_perf_path = os .path .join (test_dir , "performance" , "run_1" )
213+ if not os .path .exists (test_perf_path ):
214+ self .log .error ("%s has no performance/run_1 directory" , test_dir )
215+ is_valid = False
216+ else :
217+ diff = files_diff (
218+ list_files (test_perf_path ),
219+ REQUIRED_COMP_PER_FILES ,
220+ ["mlperf_log_accuracy.json" ],
221+ )
222+ if diff :
223+ self .log .error (
224+ "%s has file list mismatch (%s)" ,
225+ test_perf_path ,
226+ diff )
227+ is_valid = False
228+ return is_valid
0 commit comments