1515#include " core/providers/openvino/backends/basic_backend.h"
1616#include " core/providers/openvino/onnx_ctx_model_helper.h"
1717#include " core/providers/openvino/backend_manager.h"
18+ #include " core/providers/openvino/ov_stateful_patch_utils.h"
1819
1920namespace onnxruntime {
2021
@@ -29,6 +30,7 @@ BasicBackend::BasicBackend(std::unique_ptr<ONNX_NAMESPACE::ModelProto>& model_pr
2930 ptr_stream_t & model_stream)
3031 : session_context_{session_context}, subgraph_context_{subgraph_context}, shared_context_{shared_context} {
3132 std::string& hw_target = session_context_.device_type ;
33+ bool enable_causallm = session_context_.enable_causallm ;
3234
3335 if (ValidateSubgraph (const_outputs_map_))
3436 return ;
@@ -43,7 +45,7 @@ BasicBackend::BasicBackend(std::unique_ptr<ONNX_NAMESPACE::ModelProto>& model_pr
4345 // Setting OpenCL queue throttling for GPU
4446 EnableGPUThrottling (device_config);
4547
46- // Enable streams; default=1 unless ovverriden by user config
48+ // Enable streams; default=1 unless overridden by user configuration
4749 EnableStreams ();
4850
4951 // Set the inference_num_threads property of the CPU
@@ -76,7 +78,7 @@ BasicBackend::BasicBackend(std::unique_ptr<ONNX_NAMESPACE::ModelProto>& model_pr
7678 } else if (!session_context_.has_external_weights &&
7779 !subgraph_context_.has_dynamic_input_shape &&
7880 !session_context_.so_context_enable &&
79- auto_unified_compile) {
81+ !enable_causallm && auto_unified_compile) {
8082 // Unified OV compile_model is efficient when ov model caching is enabled
8183 // Unified OV compile_model API is supported with AUTO from version 2024.3 and above
8284 // Inputs with static dimensions
@@ -96,7 +98,7 @@ BasicBackend::BasicBackend(std::unique_ptr<ONNX_NAMESPACE::ModelProto>& model_pr
9698 }
9799 auto ov_model = CreateOVModel (std::move (model), session_context_, const_outputs_map_);
98100 exe_network_ = OVCore::Get ()->CompileModel (
99- ov_model, hw_target, device_config, subgraph_context_.subgraph_name );
101+ ov_model, hw_target, device_config, enable_causallm, subgraph_context_.subgraph_name );
100102 }
101103 LOGS_DEFAULT (INFO) << log_tag << " Loaded model to the plugin" ;
102104 } catch (const char * msg) {
@@ -120,7 +122,7 @@ BasicBackend::BasicBackend(std::unique_ptr<ONNX_NAMESPACE::ModelProto>& model_pr
120122 };
121123 }
122124 inferRequestsQueue_ = std::unique_ptr<InferRequestsQueue>(new InferRequestsQueue (exe_network_, num_infer_req, std::move (initializer)));
123- bindings_ = std::make_unique<OnnxToOvNetworkBindings>(exe_network_, subgraph_context_);
125+ bindings_ = std::make_unique<OnnxToOvNetworkBindings>(exe_network_, subgraph_context_, session_context_ );
124126}
125127
126128bool BasicBackend::ValidateSubgraph (std::map<std::string, std::shared_ptr<ov::Node>>& const_outputs_map) {
@@ -181,6 +183,15 @@ void BasicBackend::PopulateConfigValue(ov::AnyMap& device_config) {
181183 if (!session_context_.load_config .empty ()) {
182184 const std::map<std::string, ov::AnyMap>& target_config = session_context_.load_config ;
183185
186+ if ((session_context_.device_type .find (" NPU" ) != std::string::npos) && session_context_.enable_causallm ) {
187+ if (target_config.find (" NPU" ) != target_config.end ()) {
188+ auto npu_genai_config = target_config.at (" NPU" );
189+ CausalLMConfig ().ApplyConfig (npu_genai_config, device_config);
190+ } else {
191+ LOGS_DEFAULT (WARNING) << " ORT GenAI CausalLMConfig Configuration not found." ;
192+ }
193+ }
194+
184195 if (session_context_.device_type .find (" NPU" ) != std::string::npos) {
185196 auto npuw_config = target_config.at (" NPU" );
186197
@@ -246,7 +257,8 @@ void BasicBackend::PopulateConfigValue(ov::AnyMap& device_config) {
246257 auto set_target_properties = [&](const std::string& device, const ov::AnyMap& config_options,
247258 const std::vector<ov::PropertyName>& supported_properties) {
248259 for (const auto & [key, value] : config_options) {
249- if (key.find (" NPUW" ) != std::string::npos) {
260+ if ((key.find (" NPUW" ) != std::string::npos) ||
261+ ((device_config.find (key) != device_config.end ()) && session_context_.enable_causallm )) {
250262 continue ;
251263 }
252264 if (is_supported_and_mutable (key, supported_properties)) {
@@ -339,6 +351,13 @@ void BasicBackend::SetNumThreads(ov::AnyMap& device_config) {
339351 device_config.emplace (ov::inference_num_threads (session_context_.num_of_threads ));
340352}
341353
354+ void BasicBackend::RewindKVCache (size_t index) {
355+ OVInferRequestPtr infer_request;
356+ infer_request = inferRequestsQueue_->getIdleRequest ();
357+ infer_request->RewindKVCache (index);
358+ inferRequestsQueue_->putIdleRequest (std::move (infer_request));
359+ }
360+
342361// Starts an asynchronous inference request for data in slice indexed by batch_slice_idx on
343362// an Infer Request indexed by infer_req_idx
344363void BasicBackend::StartAsyncInference (Ort::KernelContext& context, OVInferRequestPtr infer_request) {
@@ -351,7 +370,7 @@ void BasicBackend::StartAsyncInference(Ort::KernelContext& context, OVInferReque
351370 size_t batch_slice_idx = 0 ;
352371 if (subgraph_context_.has_dynamic_input_shape &&
353372 !session_context_.disable_dynamic_shapes &&
354- cpu_or_gpu) {
373+ cpu_or_gpu || (npu && session_context_. enable_causallm ) ) {
355374 auto tensor = context.GetInput (input_info.onnx_index );
356375 auto tensor_info = tensor.GetTensorTypeAndShapeInfo ();
357376 auto tensor_shape = tensor_info.GetShape ();
@@ -409,7 +428,8 @@ void BasicBackend::StartAsyncInference(Ort::KernelContext& context, OVInferReque
409428 }
410429 } // Loop subgraph original input
411430
412- if (npu) {
431+ // For Stateful Compilation i.e. enable_causallm as True, we use the dynamic shapes path for NPU plugin as well.
432+ if (npu && !session_context_.enable_causallm ) {
413433 // Set the output blob as remote blob
414434 for (const auto & output_info : bindings_->network_outputs_ ) {
415435 Ort::UnownedValue tensor = context.GetOutput (output_info.onnx_index , output_info.onnx_shape );
@@ -453,19 +473,20 @@ void BasicBackend::CompleteAsyncInference(Ort::KernelContext& context, OVInferRe
453473
454474 bool cpu_or_gpu = session_context_.device_type .find (" CPU" ) != std::string::npos ||
455475 session_context_.device_type .find (" GPU" ) != std::string::npos;
456- if (cpu_or_gpu) {
476+ bool npu = session_context_.device_type .find (" NPU" ) != std::string::npos;
477+ if (cpu_or_gpu || (npu && session_context_.enable_causallm )) {
457478 for (const auto & output_info : bindings_->network_outputs_ ) {
458- OVTensorPtr graph_output_blob;
459- try {
460- graph_output_blob = infer_request->GetTensor (output_info.name );
461- } catch (const char * msg) {
462- ORT_THROW (msg);
463- }
464- size_t batch_size = 1 ;
465- Ort::UnownedValue output_tensor =
466- GetOutputTensor (context, batch_size, infer_request, output_info.name , subgraph_context_.output_names );
467- auto mem_info = output_tensor.GetTensorMemoryInfo ();
468- if (mem_info.GetAllocatorName () == OpenVINO_GPU) {
479+ OVTensorPtr graph_output_blob;
480+ try {
481+ graph_output_blob = infer_request->GetTensor (output_info.name );
482+ } catch (const char * msg) {
483+ ORT_THROW (msg);
484+ }
485+ size_t batch_size = 1 ;
486+ Ort::UnownedValue output_tensor =
487+ GetOutputTensor (context, batch_size, infer_request, output_info.name , subgraph_context_.output_names );
488+ auto mem_info = output_tensor.GetTensorMemoryInfo ();
489+ if (mem_info.GetAllocatorName () == OpenVINO_GPU) {
469490 return ;
470491 } else {
471492 size_t batch_slice = 0 ;
@@ -538,11 +559,19 @@ void BasicBackend::Infer(OrtKernelContext* ctx) {
538559 try {
539560 StartAsyncInference (context, infer_request);
540561 } catch (const std::runtime_error& e) {
562+ // If the inference fails (exception from ov::InferRequest::infer()),
563+ // we need to put the infer_request back into the pool to avoid deadlocks
564+ // and to allow the next inference request to proceed.
565+ inferRequestsQueue_->putIdleRequest (std::move (infer_request));
541566 ORT_THROW (log_tag + " Exception at StartAsyncInference: " + e.what ());
542567 }
543568 try {
544569 CompleteAsyncInference (context, infer_request);
545570 } catch (const std::runtime_error& e) {
571+ // If the inference fails (exception from ov::InferRequest::infer()),
572+ // we need to put the infer_request back into the pool to avoid deadlocks
573+ // and to allow the next inference request to proceed.
574+ inferRequestsQueue_->putIdleRequest (std::move (infer_request));
546575 ORT_THROW (log_tag + " Exception at CompleteAsyncInference: " + e.what ());
547576 }
548577
0 commit comments