Skip to content

Commit 382fdcc

Browse files
committed
Add support for parsing AUTO, HETERO and MULTI from json config
1 parent 2c61a3a commit 382fdcc

File tree

2 files changed

+19
-4
lines changed

2 files changed

+19
-4
lines changed

onnxruntime/core/providers/openvino/backends/basic_backend.cc

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,15 @@ void BasicBackend::PopulateConfigValue(ov::AnyMap& device_config) {
222222
}
223223
}
224224
}
225+
auto find_device_type_mode = [&](const std::string& device_type) -> std::string {
226+
std::string device_mode="";
227+
auto delimiter_pos = device_type.find(':');
228+
if (delimiter_pos != std::string::npos) {
229+
std::stringstream str_stream(device_type.substr(0, delimiter_pos));
230+
std::getline(str_stream, device_mode, ',');
231+
}
232+
return device_mode;
233+
};
225234

226235
// Parse device types like "AUTO:CPU,GPU" and extract individual devices
227236
auto parse_individual_devices = [&](const std::string& device_type) -> std::vector<std::string> {
@@ -270,8 +279,12 @@ void BasicBackend::PopulateConfigValue(ov::AnyMap& device_config) {
270279
if (session_context_.device_type.find("AUTO") == 0 ||
271280
session_context_.device_type.find("HETERO") == 0 ||
272281
session_context_.device_type.find("MULTI") == 0) {
282+
//// Parse to get the device mode (e.g., "AUTO:CPU,GPU" -> "AUTO")
283+
auto device_mode = find_device_type_mode(session_context_.device_type);
273284
// Parse individual devices (e.g., "AUTO:CPU,GPU" -> ["CPU", "GPU"])
274285
auto individual_devices = parse_individual_devices(session_context_.device_type);
286+
if(!device_mode.empty()) individual_devices.emplace_back(device_mode);
287+
275288
// Set properties only for individual devices (e.g., "CPU", "GPU")
276289
for (const std::string& device : individual_devices) {
277290
if (target_config.count(device)) {
@@ -282,6 +295,8 @@ void BasicBackend::PopulateConfigValue(ov::AnyMap& device_config) {
282295
}
283296
}
284297
} else {
298+
std::unordered_set<std::string> valid_ov_devices = {"CPU", "GPU", "NPU", "AUTO", "HETERO", "MULTI"};
299+
285300
if (target_config.count(session_context_.device_type)) {
286301
auto supported_properties = OVCore::Get()->core.get_property(session_context_.device_type,
287302
ov::supported_properties);

onnxruntime/core/providers/openvino/openvino_provider_factory.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,8 @@ std::string ParseDeviceType(std::shared_ptr<OVCore> ov_core, const ProviderOptio
101101
default_device = DEVICE_NAME;
102102

103103
// Validate that devices passed are valid
104-
int delimit = device_type.find(":");
105-
const auto& devices = device_type.substr(delimit + 1);
104+
int delimit = default_device.find(":");
105+
const auto& devices = default_device.substr(delimit + 1);
106106
auto device_list = split(devices, ',');
107107
for (const auto& device : devices) {
108108
if (!ov_supported_device_types.contains(device)) {
@@ -199,9 +199,9 @@ struct OpenVINO_Provider : Provider {
199199

200200
for (auto& [key, value] : json_config.items()) {
201201
ov::AnyMap inner_map;
202-
202+
std::unordered_set<std::string> valid_ov_devices = {"CPU", "GPU", "NPU", "AUTO", "HETERO", "MULTI"};
203203
// Ensure the key is one of "CPU", "GPU", or "NPU"
204-
if (key != "CPU" && key != "GPU" && key != "NPU") {
204+
if (valid_ov_devices.find(key) == valid_ov_devices.end()) {
205205
LOGS_DEFAULT(WARNING) << "Unsupported device key: " << key << ". Skipping entry.\n";
206206
continue;
207207
}

0 commit comments

Comments
 (0)