Skip to content

Commit 143f4c1

Browse files
committed
fix: refactor EpCtx OVIR parsing logic to use ep.context_file_path
1 parent 89b6bd1 commit 143f4c1

File tree

7 files changed

+102
-19
lines changed

7 files changed

+102
-19
lines changed

onnxruntime/core/providers/openvino/backend_manager.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ BackendManager::BackendManager(SessionContext& session_context,
4343
session_context_(session_context),
4444
shared_context_{shared_context} {
4545
subgraph_context_.is_ep_ctx_graph = ep_ctx_handle_.CheckForOVEPCtxNodeInGraph(subgraph);
46+
// If the graph contains a OVIR wrapped node, we check if it has xml file attribute
47+
subgraph_context_.is_ep_ctx_ovir_encapsulated = ep_ctx_handle_.CheckEPCacheContextAttribute(subgraph, "xml");
4648

4749
bool cpu_or_gpu = session_context_.device_type.find("CPU") != std::string::npos ||
4850
session_context_.device_type.find("GPU") != std::string::npos;

onnxruntime/core/providers/openvino/backends/basic_backend.cc

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -71,13 +71,23 @@ BasicBackend::BasicBackend(std::unique_ptr<ONNX_NAMESPACE::ModelProto>& model_pr
7171
!session_context_.so_disable_cpu_ep_fallback &&
7272
!subgraph_context_.is_ep_ctx_graph);
7373
if (subgraph_context_.is_ep_ctx_graph) {
74-
// If the blob is held in an EPContext node, then skip FE+Compile
75-
// and directly move on to creating a backend with the executable blob
76-
exe_network_ = OVCore::Get()->ImportModel(*model_stream,
77-
hw_target,
78-
device_config,
79-
enable_causallm,
80-
session_context_.onnx_model_path_name.string());
74+
if (subgraph_context_.is_ep_ctx_ovir_encapsulated) {
75+
// If the EPContext node with OVIR Encapsulation, then create
76+
// an executable network from EP_CACHE_CONTEXT using read_model() & compile_model()
77+
exe_network_ = OVCore::Get()->ImportEPCtxOVIREncapsulation(*model_stream,
78+
hw_target,
79+
device_config,
80+
enable_causallm,
81+
session_context_.so_context_file_path,
82+
subgraph_context_.subgraph_name);
83+
} else {
84+
// If the blob is held in an EPContext node, then skip FE+Compile
85+
// and directly move on to creating a backend with the executable blob
86+
exe_network_ = OVCore::Get()->ImportModel(*model_stream,
87+
hw_target,
88+
device_config,
89+
subgraph_context_.subgraph_name);
90+
}
8191
model_stream.reset(); // Delete stream after it is no longer needed
8292
} else if (!session_context_.has_external_weights &&
8393
!subgraph_context_.has_dynamic_input_shape &&

onnxruntime/core/providers/openvino/contexts.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ struct SubGraphContext {
137137
string_index_map_t output_names;
138138
std::string model_precision;
139139
bool is_ep_ctx_graph = false;
140+
bool is_ep_ctx_ovir_encapsulated = false;
140141
};
141142

142143
} // namespace openvino_ep

onnxruntime/core/providers/openvino/onnx_ctx_model_helper.cc

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ std::unique_ptr<std::istream> EPCtxHandler::GetModelBlobStream(const std::filesy
130130
// If the model stream is not an XML (i.e. precompiled blob), the OpenVINO SDK version that it was
131131
// exported with must match the version that is currently running.
132132
ORT_ENFORCE((attrs.count(EP_SDK_VER) == 1) && (attrs.at(EP_SDK_VER).s() == openvino_sdk_version_),
133-
"EPCtx blob was exported / is compatible with with OpenVINO SDK version " + attrs.at(EP_SDK_VER).s() +
133+
"EPCtx blob was exported / is compatible with OpenVINO SDK version " + attrs.at(EP_SDK_VER).s() +
134134
", but OpenVINO SDK version currently in use is " + openvino_sdk_version_);
135135
}
136136

@@ -165,5 +165,32 @@ InlinedVector<const Node*> EPCtxHandler::GetEPCtxNodes() const {
165165
return InlinedVector<const Node*>(epctx_nodes.begin(), epctx_nodes.end());
166166
}
167167

168+
// Check if graph's only node is EPContext & EP_CACHE_CONTEXT attribute has target extension.
169+
// @param graph_viewer: The graph to inspect.
170+
// @param target_attr_extn: The string to search for in the EP_CACHE_CONTEXT attribute.
171+
// @return true if the node exists, is of the correct type, and the attribute contains the extension; false otherwise.
172+
bool EPCtxHandler::CheckEPCacheContextAttribute(const GraphViewer& graph_viewer, const std::string& target_attr_extn) const {
173+
// Only check if the graph has exactly one node
174+
if (graph_viewer.NumberOfNodes() != 1) {
175+
return false;
176+
}
177+
// Get the first node in topological order
178+
auto first_index = *graph_viewer.GetNodesInTopologicalOrder().begin();
179+
const Node* node = graph_viewer.GetNode(first_index);
180+
if (!node) {
181+
return false;
182+
}
183+
// Check OpType and required attributes
184+
if (node->OpType() != EPCONTEXT_OP) {
185+
return false;
186+
}
187+
const auto& attrs = node->GetAttributes();
188+
auto it = attrs.find(EP_CACHE_CONTEXT);
189+
if (it != attrs.end()) {
190+
return it->second().s().find(target_attr_extn) != std::string::npos;
191+
}
192+
return false;
193+
}
194+
168195
} // namespace openvino_ep
169196
} // namespace onnxruntime

onnxruntime/core/providers/openvino/onnx_ctx_model_helper.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ class EPCtxHandler {
3333
std::string&& model_blob_str) const;
3434
std::unique_ptr<std::istream> GetModelBlobStream(const std::filesystem::path& so_context_file_path, const GraphViewer& graph_viewer) const;
3535
InlinedVector<const Node*> GetEPCtxNodes() const;
36+
bool CheckEPCacheContextAttribute(const GraphViewer& graph_viewer, const std::string& target_attr_extn) const;
3637

3738
private:
3839
const std::string openvino_sdk_version_;

onnxruntime/core/providers/openvino/ov_interface.cc

Lines changed: 47 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -191,30 +191,66 @@ OVExeNetwork OVCore::CompileModel(const std::string& onnx_model,
191191
}
192192

193193
OVExeNetwork OVCore::ImportModel(std::istream& model_stream,
194+
std::string hw_target,
195+
const ov::AnyMap& device_config,
196+
std::string name) {
197+
try {
198+
ov::CompiledModel obj;
199+
obj = core.import_model(model_stream, hw_target, device_config);
200+
#ifndef NDEBUG
201+
printDebugInfo(exe.Get());
202+
#endif
203+
OVExeNetwork exe(obj, hw_target);
204+
return exe;
205+
} catch (const Exception& e) {
206+
ORT_THROW(log_tag + " Exception while Loading Network for graph: " + name + e.what());
207+
} catch (...) {
208+
ORT_THROW(log_tag + " Exception while Loading Network for graph " + name);
209+
}
210+
}
211+
212+
OVExeNetwork OVCore::ImportEPCtxOVIREncapsulation(std::istream& model_stream,
194213
std::string hw_target,
195214
const ov::AnyMap& device_config,
196215
bool enable_causallm,
216+
std::filesystem::path context_file_path,
197217
std::string name) {
198218
try {
199219
OVExeNetwork exe;
200220

201221
bool isXML = backend_utils::IsModelStreamXML(model_stream);
202222

203-
if (!isXML) {
204-
auto obj = core.import_model(model_stream, hw_target, device_config);
205-
exe = OVExeNetwork(obj, hw_target);
206-
} else {
223+
ORT_ENFORCE(!context_file_path.string().empty(),
224+
"The session option ep.context_file_path is not set for EPContext node with OVIR Encapsulation. "
225+
"Current value: '" + context_file_path.string() + "'");
226+
227+
// Helper function to check if file exists and is readable
228+
const auto check_file_access = [&context_file_path](const std::filesystem::path& path) {
229+
try {
230+
const auto status = std::filesystem::status(path);
231+
if (!std::filesystem::exists(status)) {
232+
ORT_THROW(log_tag + "Required file missing: " + path.string());
233+
}
234+
std::ifstream file(path);
235+
if (!file.is_open()) {
236+
ORT_THROW(log_tag + "Required file not readable: " + path.string());
237+
}
238+
} catch (const std::exception& e) {
239+
ORT_THROW(log_tag + "Exception while checking file access for: " + path.string() + " - " + e.what());
240+
}
241+
};
242+
243+
if (isXML) {
207244
// If the model is XML, we need to load it with the XML content in read_model()
208245
// where weights from bin file is directly consumed
209-
std::string xml_file_name = name;
210-
if (name.size() >= 5 && name.substr(name.size() - 5) == ".onnx") {
211-
xml_file_name.replace(name.size() - 5, 5, ".xml");
212-
} else {
213-
throw std::runtime_error("Invalid model name. Make sure *.onnx, *.xml, and *.bin carry the same name.");
214-
}
246+
auto xml_file_path = context_file_path.parent_path() / (context_file_path.stem().string() + ".xml");
247+
248+
check_file_access(xml_file_path);
249+
250+
LOGS_DEFAULT(INFO) << log_tag << "Reading OVIR from XML file path: " << xml_file_path.string();
215251

216252
// Load the model explicitly with XML contents
217-
std::shared_ptr<ov::Model> model = core.read_model(xml_file_name);
253+
std::shared_ptr<ov::Model> model = core.read_model(xml_file_path.string());
218254

219255
if (enable_causallm) {
220256
exe = OVCore::Get()->StatefulCompileModel(model, hw_target, device_config);

onnxruntime/core/providers/openvino/ov_interface.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,10 +80,16 @@ struct OVCore : WeakSingleton<OVCore> {
8080
const std::string& name);
8181
// OV Interface for Import model Stream
8282
OVExeNetwork ImportModel(std::istream& model_stream,
83+
std::string hw_target,
84+
const ov::AnyMap& device_config,
85+
std::string name);
86+
OVExeNetwork ImportEPCtxOVIREncapsulation(std::istream& model_stream,
8387
std::string hw_target,
8488
const ov::AnyMap& device_config,
8589
bool enable_causallm,
90+
std::filesystem::path context_file_path,
8691
std::string name);
92+
8793
std::vector<std::string> GetAvailableDevices() const;
8894
std::vector<std::string> GetAvailableDevices(const std::string& device_type) const;
8995
void SetCache(const std::string& cache_dir_path);

0 commit comments

Comments
 (0)