Skip to content

Commit 8f8c6cb

Browse files
jatinwadhwa921ankitm3k
authored andcommitted
[OVEP] Fix for precision accuracy
1 parent 0c4bb03 commit 8f8c6cb

File tree

4 files changed

+111
-56
lines changed

4 files changed

+111
-56
lines changed

onnxruntime/core/providers/openvino/backends/basic_backend.cc

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -158,10 +158,8 @@ void BasicBackend::PopulateConfigValue(ov::AnyMap& device_config) {
158158
if (session_context_.precision.find("FP32") != std::string::npos) {
159159
device_config.emplace(ov::hint::inference_precision("f32"));
160160
}
161-
if (session_context_.precision.find("ACCURACY") != std::string::npos &&
162-
session_context_.device_type.find("GPU") != std::string::npos) {
161+
if (session_context_.precision.find("ACCURACY") != std::string::npos) {
163162
if (session_context_.OpenVINO_Version.at(0) >= 2024) {
164-
device_config.emplace(ov::hint::inference_precision(ov::element::dynamic));
165163
device_config.emplace(ov::hint::execution_mode(ov::hint::ExecutionMode::ACCURACY));
166164
} else {
167165
if (!subgraph_context_.model_precision.empty())
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
#include <algorithm>
2+
#include "core/providers/openvino/openvino_parser_utils.h"
3+
#include "core/providers/shared_library/provider_api.h"
4+
5+
namespace onnxruntime {
6+
namespace openvino_ep {
7+
8+
std::string OpenVINOParserUtils::ParsePrecision(const ProviderOptions& provider_options,
9+
std::string& device_type,
10+
const std::string& option_name) {
11+
using DeviceName = std::string;
12+
using DefaultValue = std::string;
13+
using ValidValues = std::list<std::string>;
14+
using foo = std::pair<DefaultValue, ValidValues>;
15+
using ParserHelper = std::map<DeviceName, foo>;
16+
17+
ParserHelper helper = {
18+
{"GPU", {"FP16", {"FP16", "FP32", "ACCURACY"}}},
19+
{"NPU", {"FP16", {"FP16", "ACCURACY"}}},
20+
{"CPU", {"FP32", {"FP32", "ACCURACY"}}},
21+
};
22+
23+
std::set<std::string> deprecated_device_types = {
24+
"CPU_FP32", "GPU_FP32", "GPU.0_FP32", "GPU.1_FP32", "GPU_FP16",
25+
"GPU.0_FP16", "GPU.1_FP16"};
26+
27+
bool is_composite = device_type.find(':') != std::string::npos; // FOR devices AUTO:,HETERO:,MULTI:
28+
29+
if (provider_options.contains(option_name)) {
30+
const auto& precision = provider_options.at(option_name);
31+
32+
if (is_composite) {
33+
std::set<std::string> allowed_precisions = {"FP16", "FP32", "ACCURACY"};
34+
if (allowed_precisions.contains(precision)) {
35+
return precision;
36+
} else {
37+
ORT_THROW("[ERROR] [OpenVINO] Unsupported inference precision is selected. ",
38+
precision, ".\n");
39+
}
40+
} else {
41+
if (helper.contains(device_type)) {
42+
auto const& valid_values = helper[device_type].second;
43+
44+
if (precision == "ACCURACY") {
45+
return valid_values.back(); // Return highest supported precision
46+
} else {
47+
if (std::find(valid_values.begin(), valid_values.end(), precision) != valid_values.end()) {
48+
return precision; // Return precision selected if valid
49+
} else {
50+
auto value_iter = valid_values.begin();
51+
std::string valid_values_joined = *value_iter;
52+
// Append 2nd and up, if only one then ++value_iter is same as end()
53+
for (++value_iter; value_iter != valid_values.end(); ++value_iter) {
54+
valid_values_joined += ", " + *value_iter;
55+
}
56+
57+
ORT_THROW("[ERROR] [OpenVINO] Unsupported inference precision is selected. ",
58+
device_type, " only supports", valid_values_joined, ".\n");
59+
}
60+
}
61+
} else if (deprecated_device_types.contains(device_type)) {
62+
LOGS_DEFAULT(WARNING)
63+
<< "[OpenVINO] Selected 'device_type' " + device_type + " is deprecated. \n"
64+
<< "Update the 'device_type' to specified types 'CPU', 'GPU', 'GPU.0', "
65+
<< "'GPU.1', 'NPU' or from HETERO/MULTI/AUTO options and set 'precision' separately. \n";
66+
auto delimit = device_type.find("_");
67+
device_type = device_type.substr(0, delimit);
68+
return device_type.substr(delimit + 1);
69+
} else {
70+
ORT_THROW("[ERROR] [OpenVINO] Unsupported device type provided: ",
71+
device_type, "\n");
72+
}
73+
}
74+
} else {
75+
if (device_type.find("NPU") != std::string::npos || device_type.find("GPU") != std::string::npos) {
76+
return "FP16";
77+
} else if (device_type.find("CPU") != std::string::npos) {
78+
return "FP32";
79+
} else {
80+
ORT_THROW("[ERROR] [OpenVINO] Unsupported device is selected", device_type, "\n");
81+
}
82+
}
83+
}
84+
85+
} // namespace openvino_ep
86+
} // namespace onnxruntime
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
#pragma once
2+
3+
#include <list>
4+
#include <map>
5+
#include <set>
6+
#include <string>
7+
#include <utility>
8+
9+
#include "core/framework/provider_options.h"
10+
11+
namespace onnxruntime {
12+
namespace openvino_ep {
13+
14+
class OpenVINOParserUtils {
15+
public:
16+
static std::string ParsePrecision(const ProviderOptions& provider_options,
17+
std::string& device_type,
18+
const std::string& option_name);
19+
};
20+
21+
} // namespace openvino_ep
22+
} // namespace onnxruntime

onnxruntime/core/providers/openvino/openvino_provider_factory.cc

Lines changed: 2 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "core/providers/openvino/backend_utils.h"
1212
#include "core/session/onnxruntime_session_options_config_keys.h"
1313
#include "nlohmann/json.hpp"
14+
#include "core/providers/openvino/openvino_parser_utils.h"
1415

1516
namespace onnxruntime {
1617
namespace openvino_ep {
@@ -114,58 +115,6 @@ std::string ParseDeviceType(std::shared_ptr<OVCore> ov_core, const ProviderOptio
114115
}
115116
}
116117

117-
// Depends on ProviderOptions.
118-
std::string ParsePrecision(const ProviderOptions& provider_options, std::string& device_type, const std::string& option_name) {
119-
using DeviceName = std::string;
120-
using DefaultValue = std::string;
121-
using ValidValues = std::list<std::string>;
122-
using foo = std::pair<DefaultValue, ValidValues>;
123-
using ParserHelper = std::map<DeviceName, foo>;
124-
ParserHelper helper = {
125-
{"GPU", {"FP16", {"FP16", "FP32"}}},
126-
{"NPU", {"FP16", {"FP16"}}},
127-
{"CPU", {"FP32", {"FP32"}}},
128-
};
129-
130-
std::set<std::string> deprecated_device_types = {"CPU_FP32", "GPU_FP32",
131-
"GPU.0_FP32", "GPU.1_FP32", "GPU_FP16",
132-
"GPU.0_FP16", "GPU.1_FP16"};
133-
134-
if (provider_options.contains(option_name)) {
135-
// Start by checking if the device_type is a normal valid one
136-
if (helper.contains(device_type)) {
137-
auto const& valid_values = helper[device_type].second;
138-
const auto& precision = provider_options.at(option_name);
139-
if (precision == "ACCURACY") {
140-
return valid_values.back(); // Return highest supported precision
141-
} else {
142-
if (std::find(valid_values.begin(), valid_values.end(), precision) != valid_values.end()) {
143-
return precision; // Return precision selected if valid
144-
} else {
145-
auto value_iter = valid_values.begin();
146-
std::string valid_values_joined = *value_iter;
147-
// Append 2nd and up, if only one then ++value_iter is same as end()
148-
for (++value_iter; value_iter != valid_values.end(); ++value_iter) {
149-
valid_values_joined += ", " + *value_iter;
150-
}
151-
152-
ORT_THROW("[ERROR] [OpenVINO] Unsupported inference precision is selected. ", device_type, " only supports", valid_values_joined, ".\n");
153-
}
154-
}
155-
} else if (deprecated_device_types.contains(device_type)) {
156-
LOGS_DEFAULT(WARNING) << "[OpenVINO] Selected 'device_type' " + device_type + " is deprecated. \n"
157-
<< "Update the 'device_type' to specified types 'CPU', 'GPU', 'GPU.0', "
158-
<< "'GPU.1', 'NPU' or from"
159-
<< " HETERO/MULTI/AUTO options and set 'precision' separately. \n";
160-
auto delimit = device_type.find("_");
161-
device_type = device_type.substr(0, delimit);
162-
return device_type.substr(delimit + 1);
163-
}
164-
}
165-
// Return default
166-
return helper[device_type].first;
167-
}
168-
169118
void ParseProviderOptions([[maybe_unused]] ProviderInfo& result, [[maybe_unused]] const ProviderOptions& config_options) {}
170119

171120
struct OpenVINOProviderFactory : IExecutionProviderFactory {
@@ -227,7 +176,7 @@ struct OpenVINO_Provider : Provider {
227176
pi.cache_dir = provider_options.at("cache_dir");
228177
}
229178

230-
pi.precision = ParsePrecision(provider_options, pi.device_type, "precision");
179+
pi.precision = OpenVINOParserUtils::ParsePrecision(provider_options, pi.device_type, "precision");
231180

232181
if (provider_options.contains("load_config")) {
233182
auto parse_config = [&](const std::string& config_str) -> std::map<std::string, ov::AnyMap> {

0 commit comments

Comments
 (0)