Skip to content

Commit 237bbe3

Browse files
[Automated Commit] Format Codebase
1 parent cc89d84 commit 237bbe3

File tree

10 files changed

+384
-163
lines changed

10 files changed

+384
-163
lines changed

tools/submission/submission_checker/checks/compliance_check.py

Lines changed: 81 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,21 @@
99
import re
1010
import os
1111

12+
1213
class ComplianceCheck(BaseCheck):
13-
def __init__(self, log, path, config: Config, submission_logs: SubmissionLogs):
14+
def __init__(self, log, path, config: Config,
15+
submission_logs: SubmissionLogs):
1416
super().__init__(log, path)
1517
self.submission_logs = submission_logs
1618
self.config = config
1719
self.model = self.submission_logs.loader_data.get("benchmark", "")
18-
self.model_mapping = self.submission_logs.loader_data.get("model_mapping", {})
19-
self.compliance_dir = self.submission_logs.loader_data.get("compliance_path", {})
20+
self.model_mapping = self.submission_logs.loader_data.get(
21+
"model_mapping", {})
22+
self.compliance_dir = self.submission_logs.loader_data.get(
23+
"compliance_path", {})
2024
self.division = self.submission_logs.loader_data.get("division", "")
21-
self.model = self.config.get_mlperf_model(self.model, self.model_mapping)
25+
self.model = self.config.get_mlperf_model(
26+
self.model, self.model_mapping)
2227
self.test_list = self.get_test_list(self.model)
2328
self.setup_checks()
2429

@@ -36,36 +41,54 @@ def get_test_list(self, model):
3641
if model in self.config.base["models_TEST06"]:
3742
test_list.append("TEST06")
3843
return test_list
39-
44+
4045
def dir_exists_check(self):
4146
if self.division.lower() == "open":
42-
self.log.info("Compliance tests not needed for open division. Skipping tests on %s", self.path)
47+
self.log.info(
48+
"Compliance tests not needed for open division. Skipping tests on %s",
49+
self.path)
4350
return True
4451
is_valid = True
4552
for test in self.test_list:
4653
test_dir = os.path.join(self.compliance_dir, test)
47-
acc_path = os.path.join(self.compliance_dir, test, "verify_accuracy.txt")
48-
perf_comp_path = os.path.join(self.compliance_dir, test, "verify_performance.txt")
49-
perf_path = os.path.join(self.compliance_dir, test, "performance", "run_1", "mlperf_log_detail.txt")
54+
acc_path = os.path.join(
55+
self.compliance_dir, test, "verify_accuracy.txt")
56+
perf_comp_path = os.path.join(
57+
self.compliance_dir, test, "verify_performance.txt")
58+
perf_path = os.path.join(
59+
self.compliance_dir,
60+
test,
61+
"performance",
62+
"run_1",
63+
"mlperf_log_detail.txt")
5064
if not os.path.exists(test_dir):
51-
self.log.error("Missing %s in compliance dir %s", test, self.compliance_dir)
65+
self.log.error(
66+
"Missing %s in compliance dir %s",
67+
test,
68+
self.compliance_dir)
5269
is_valid = False
5370
if test in ["TEST01", "TEST06"]:
5471
if not os.path.exists(acc_path):
55-
self.log.error("Missing accuracy file in compliance dir. Needs file %s", acc_path)
72+
self.log.error(
73+
"Missing accuracy file in compliance dir. Needs file %s", acc_path)
5674
is_valid = False
5775
if test in ["TEST01", "TEST04"]:
5876
if not os.path.exists(perf_comp_path):
59-
self.log.error("Missing performance file in compliance dir. Needs file %s", perf_comp_path)
77+
self.log.error(
78+
"Missing performance file in compliance dir. Needs file %s",
79+
perf_comp_path)
6080
is_valid = False
6181
if not os.path.exists(perf_path):
62-
self.log.error("Missing perfomance file in compliance dir. Needs file %s", perf_path)
82+
self.log.error(
83+
"Missing perfomance file in compliance dir. Needs file %s", perf_path)
6384
is_valid = False
6485
return is_valid
65-
86+
6687
def performance_check(self):
6788
if self.division.lower() == "open":
68-
self.log.info("Compliance tests not needed for open division. Skipping tests on %s", self.path)
89+
self.log.info(
90+
"Compliance tests not needed for open division. Skipping tests on %s",
91+
self.path)
6992
return True
7093
is_valid = True
7194
for test in self.test_list:
@@ -76,14 +99,24 @@ def performance_check(self):
7699
"scenario": self.submission_logs.loader_data.get("scenario", ""),
77100
"model_mapping": self.submission_logs.loader_data.get("model_mapping", {})
78101
}
79-
test_logs = SubmissionLogs(self.submission_logs.loader_data[f"{test}_perf_log"], None, None, None, self.submission_logs.system_json, None, test_data)
80-
perf_check = PerformanceCheck(self.log, os.path.join(self.compliance_dir, test), self.config, test_logs)
102+
test_logs = SubmissionLogs(
103+
self.submission_logs.loader_data[f"{test}_perf_log"],
104+
None,
105+
None,
106+
None,
107+
self.submission_logs.system_json,
108+
None,
109+
test_data)
110+
perf_check = PerformanceCheck(self.log, os.path.join(
111+
self.compliance_dir, test), self.config, test_logs)
81112
is_valid &= perf_check()
82113
return is_valid
83114

84115
def accuracy_check(self):
85116
if self.division.lower() == "open":
86-
self.log.info("Compliance tests not needed for open division. Skipping tests on %s", self.path)
117+
self.log.info(
118+
"Compliance tests not needed for open division. Skipping tests on %s",
119+
self.path)
87120
return True
88121
is_valid = True
89122
for test in self.test_list:
@@ -103,11 +136,13 @@ def accuracy_check(self):
103136
)
104137
test_acc_path = os.path.join(test_dir, "accuracy")
105138
if not os.path.exists(test_acc_path):
106-
self.log.error("%s has no accuracy directory", test_dir)
139+
self.log.error(
140+
"%s has no accuracy directory", test_dir)
107141
is_valid = False
108142
else:
109143
diff = files_diff(
110-
list_files(test_acc_path), REQUIRED_TEST01_ACC_FILES,
144+
list_files(
145+
test_acc_path), REQUIRED_TEST01_ACC_FILES,
111146
)
112147
if diff:
113148
self.log.error(
@@ -116,30 +151,39 @@ def accuracy_check(self):
116151
diff)
117152
is_valid = False
118153
else:
119-
target = self.config.get_accuracy_target(self.model)
120-
patterns, acc_targets, acc_types, acc_limits, up_patterns, acc_upper_limit = self.config.get_accuracy_values(self.model)
154+
target = self.config.get_accuracy_target(
155+
self.model)
156+
patterns, acc_targets, acc_types, acc_limits, up_patterns, acc_upper_limit = self.config.get_accuracy_values(
157+
self.model)
121158
acc_limit_check = True
122159

123160
acc_seen = [False for _ in acc_targets]
124-
acc_baseline = {acc_type: 0 for acc_type in acc_types}
125-
acc_compliance = {acc_type: 0 for acc_type in acc_types}
161+
acc_baseline = {
162+
acc_type: 0 for acc_type in acc_types}
163+
acc_compliance = {
164+
acc_type: 0 for acc_type in acc_types}
126165
with open(
127-
os.path.join(test_acc_path, "baseline_accuracy.txt"),
166+
os.path.join(
167+
test_acc_path, "baseline_accuracy.txt"),
128168
"r",
129169
encoding="utf-8",
130170
) as f:
131171
for line in f:
132-
for acc_type, pattern in zip(acc_types, patterns):
172+
for acc_type, pattern in zip(
173+
acc_types, patterns):
133174
m = re.match(pattern, line)
134175
if m:
135-
acc_baseline[acc_type] = float(m.group(1))
176+
acc_baseline[acc_type] = float(
177+
m.group(1))
136178
with open(
137-
os.path.join(test_acc_path, "compliance_accuracy.txt"),
179+
os.path.join(
180+
test_acc_path, "compliance_accuracy.txt"),
138181
"r",
139182
encoding="utf-8",
140183
) as f:
141184
for line in f:
142-
for acc_type, pattern in zip(acc_types, patterns):
185+
for acc_type, pattern in zip(
186+
acc_types, patterns):
143187
m = re.match(pattern, line)
144188
if m:
145189
acc_compliance[acc_type] = float(
@@ -178,15 +222,16 @@ def accuracy_check(self):
178222
)
179223
eos_pass = "EOS check pass: True" in lines
180224
length_check_pass = "Sample length check pass: True" in lines
181-
is_valid &= (first_token_pass and eos_pass and length_check_pass)
225+
is_valid &= (
226+
first_token_pass and eos_pass and length_check_pass)
182227
if not is_valid:
183228
self.log.error(
184229
f"TEST06 accuracy check failed. first_token_check: {first_token_pass} eos_check: {eos_pass} length_check: {length_check_pass}."
185230
)
186231
else:
187232
self.log.info(f"{test_dir} does not require accuracy check")
188233
return is_valid
189-
234+
190235
def compliance_performance_check(self):
191236
is_valid = True
192237
for test in self.test_list:
@@ -209,9 +254,11 @@ def compliance_performance_check(self):
209254
test_dir)
210255

211256
# Check performance dir
212-
test_perf_path = os.path.join(test_dir, "performance", "run_1")
257+
test_perf_path = os.path.join(
258+
test_dir, "performance", "run_1")
213259
if not os.path.exists(test_perf_path):
214-
self.log.error("%s has no performance/run_1 directory", test_dir)
260+
self.log.error(
261+
"%s has no performance/run_1 directory", test_dir)
215262
is_valid = False
216263
else:
217264
diff = files_diff(
@@ -225,4 +272,4 @@ def compliance_performance_check(self):
225272
test_perf_path,
226273
diff)
227274
is_valid = False
228-
return is_valid
275+
return is_valid

tools/submission/submission_checker/checks/measurements_checks.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,22 +7,23 @@
77

88

99
class MeasurementsCheck(BaseCheck):
10-
def __init__(self, log, path, config: Config, submission_logs: SubmissionLogs):
10+
def __init__(self, log, path, config: Config,
11+
submission_logs: SubmissionLogs):
1112
super().__init__(log, path)
1213
self.name = "measurement checks"
1314
self.submission_logs = submission_logs
1415
self.measurements_json = self.submission_logs.measurements_json
1516
self.submitter = self.submission_logs.loader_data.get("submitter", "")
1617
self.division = self.submission_logs.loader_data.get("division", "")
17-
self.measurements_dir = self.submission_logs.loader_data.get("measurements_dir", "")
18+
self.measurements_dir = self.submission_logs.loader_data.get(
19+
"measurements_dir", "")
1820
self.config = config
1921
self.setup_checks()
2022

2123
def setup_checks(self):
2224
self.checks.append(self.missing_check)
2325
self.checks.append(self.required_files_check)
2426
self.checks.append(self.required_fields_check)
25-
2627

2728
def missing_check(self):
2829
if self.measurements_json is None:
@@ -32,7 +33,7 @@ def missing_check(self):
3233
)
3334
return False
3435
return True
35-
36+
3637
def required_files_check(self):
3738
is_valid = True
3839
files = list_files(self.measurements_dir)
@@ -43,7 +44,10 @@ def required_files_check(self):
4344
elif not self.config.skip_empty_files_check and (
4445
os.stat(os.path.join(self.measurements_dir, i)).st_size == 0
4546
):
46-
self.log.error("%s is having empty %s", self.measurements_dir, i)
47+
self.log.error(
48+
"%s is having empty %s",
49+
self.measurements_dir,
50+
i)
4751
is_valid = False
4852
return is_valid
4953

@@ -58,4 +62,4 @@ def required_fields_check(self):
5862
is_valid = False
5963
self.log.error(
6064
"%s, field %s is missing meaningful value", self.path, k)
61-
return is_valid
65+
return is_valid

0 commit comments

Comments
 (0)