1515#include " core/providers/openvino/backends/basic_backend.h"
1616#include " core/providers/openvino/onnx_ctx_model_helper.h"
1717#include " core/providers/openvino/backend_manager.h"
18+ #include " core/providers/openvino/ov_stateful_patch_utils.h"
1819
1920namespace onnxruntime {
2021
@@ -29,6 +30,7 @@ BasicBackend::BasicBackend(std::unique_ptr<ONNX_NAMESPACE::ModelProto>& model_pr
2930 ptr_stream_t & model_stream)
3031 : session_context_{session_context}, subgraph_context_{subgraph_context}, shared_context_{shared_context} {
3132 std::string& hw_target = session_context_.device_type ;
33+ bool enable_causallm = session_context_.enable_causallm ;
3234
3335 if (ValidateSubgraph (const_outputs_map_))
3436 return ;
@@ -43,7 +45,7 @@ BasicBackend::BasicBackend(std::unique_ptr<ONNX_NAMESPACE::ModelProto>& model_pr
4345 // Setting OpenCL queue throttling for GPU
4446 EnableGPUThrottling (device_config);
4547
46- // Enable streams; default=1 unless ovverriden by user config
48+ // Enable streams; default=1 unless overridden by user configuration
4749 EnableStreams ();
4850
4951 // Set the inference_num_threads property of the CPU
@@ -95,7 +97,7 @@ BasicBackend::BasicBackend(std::unique_ptr<ONNX_NAMESPACE::ModelProto>& model_pr
9597 } else if (!session_context_.has_external_weights &&
9698 !subgraph_context_.has_dynamic_input_shape &&
9799 !session_context_.so_context_enable &&
98- auto_unified_compile) {
100+ !enable_causallm && auto_unified_compile) {
99101 // Unified OV compile_model is efficient when ov model caching is enabled
100102 // Unified OV compile_model API is supported with AUTO from version 2024.3 and above
101103 // Inputs with static dimensions
@@ -115,7 +117,7 @@ BasicBackend::BasicBackend(std::unique_ptr<ONNX_NAMESPACE::ModelProto>& model_pr
115117 }
116118 auto ov_model = CreateOVModel (std::move (model), session_context_, const_outputs_map_);
117119 exe_network_ = OVCore::Get ()->CompileModel (
118- ov_model, hw_target, device_config, subgraph_context_.subgraph_name );
120+ ov_model, hw_target, device_config, enable_causallm, subgraph_context_.subgraph_name );
119121 }
120122#endif
121123 LOGS_DEFAULT (INFO) << log_tag << " Loaded model to the plugin" ;
@@ -200,6 +202,15 @@ void BasicBackend::PopulateConfigValue(ov::AnyMap& device_config) {
200202 if (!session_context_.load_config .empty ()) {
201203 const std::map<std::string, ov::AnyMap>& target_config = session_context_.load_config ;
202204
205+ if ((session_context_.device_type .find (" NPU" ) != std::string::npos) && session_context_.enable_causallm ) {
206+ if (target_config.find (" NPU" ) != target_config.end ()) {
207+ auto npu_genai_config = target_config.at (" NPU" );
208+ CausalLMConfig ().ApplyConfig (npu_genai_config, device_config);
209+ } else {
210+ LOGS_DEFAULT (WARNING) << " ORT GenAI CausalLMConfig Configuration not found." ;
211+ }
212+ }
213+
203214 if (session_context_.device_type .find (" NPU" ) != std::string::npos) {
204215 auto npuw_config = target_config.at (" NPU" );
205216
@@ -265,7 +276,8 @@ void BasicBackend::PopulateConfigValue(ov::AnyMap& device_config) {
265276 auto set_target_properties = [&](const std::string& device, const ov::AnyMap& config_options,
266277 const std::vector<ov::PropertyName>& supported_properties) {
267278 for (const auto & [key, value] : config_options) {
268- if (key.find (" NPUW" ) != std::string::npos) {
279+ if ((key.find (" NPUW" ) != std::string::npos) ||
280+ ((device_config.find (key) != device_config.end ()) && session_context_.enable_causallm )) {
269281 continue ;
270282 }
271283 if (is_supported_and_mutable (key, supported_properties)) {
@@ -358,6 +370,13 @@ void BasicBackend::SetNumThreads(ov::AnyMap& device_config) {
358370 device_config.emplace (ov::inference_num_threads (session_context_.num_of_threads ));
359371}
360372
373+ void BasicBackend::RewindKVCache (size_t index) {
374+ OVInferRequestPtr infer_request;
375+ infer_request = inferRequestsQueue_->getIdleRequest ();
376+ infer_request->RewindKVCache (index);
377+ inferRequestsQueue_->putIdleRequest (std::move (infer_request));
378+ }
379+
361380// Starts an asynchronous inference request for data in slice indexed by batch_slice_idx on
362381// an Infer Request indexed by infer_req_idx
363382void BasicBackend::StartAsyncInference (Ort::KernelContext& context, OVInferRequestPtr infer_request) {
@@ -376,14 +395,22 @@ void BasicBackend::StartAsyncInference(Ort::KernelContext& context, OVInferReque
376395 }
377396 index++;
378397 }
398+
399+ // For Stateful Model Compilation, the ONNX model includes KV cache (past/present) tensors.
400+ // However, these tensors are internally converted to a stateful representation, which removes them.
401+ // To prevent runtime exceptions, we simply continue processing here.
402+ if (input_name.empty () || input_name == " beam_idx" ) continue ;
403+
379404 ORT_ENFORCE (!input_name.empty (), log_tag,
380405 " Input names mismatch between OpenVINO and ONNX. " , onnx_input_name,
381406 " doesn't exist in the list of OpenVINO input tensor names" );
382407 size_t batch_slice_idx = 0 ;
383408 if (subgraph_context_.has_dynamic_input_shape &&
384409 !session_context_.disable_dynamic_shapes &&
385410 (session_context_.device_type .find (" CPU" ) != std::string::npos ||
386- session_context_.device_type .find (" GPU" ) != std::string::npos)) {
411+ session_context_.device_type .find (" GPU" ) != std::string::npos ||
412+ (session_context_.device_type .find (" NPU" ) != std::string::npos &&
413+ session_context_.enable_causallm ))) {
387414 auto tensor = context.GetInput (subgraph_context_.input_names .at (input_name));
388415 auto tensor_info = tensor.GetTensorTypeAndShapeInfo ();
389416 auto tensor_shape = tensor_info.GetShape ();
@@ -445,7 +472,8 @@ void BasicBackend::StartAsyncInference(Ort::KernelContext& context, OVInferReque
445472 }
446473 } // Loop subgraph original input names
447474
448- if (session_context_.device_type .find (" NPU" ) != std::string::npos) {
475+ // For Stateful Compilation i.e. enable_causallm as True, we use the dynamic shapes path for NPU plugin as well.
476+ if (session_context_.device_type .find (" NPU" ) != std::string::npos && !session_context_.enable_causallm ) {
449477 // Set the output blob as remote blob
450478 auto graph_output_info = exe_network_.Get ().outputs ();
451479 auto output_idx = 0 ;
@@ -640,7 +668,9 @@ void BasicBackend::CompleteAsyncInference(Ort::KernelContext& context, OVInferRe
640668 " list of OpenVINO output tensor names" );
641669 }
642670 if ((session_context_.device_type .find (" CPU" ) != std::string::npos ||
643- session_context_.device_type .find (" GPU" ) != std::string::npos)) {
671+ session_context_.device_type .find (" GPU" ) != std::string::npos ||
672+ (session_context_.device_type .find (" NPU" ) != std::string::npos &&
673+ session_context_.enable_causallm ))) {
644674 try {
645675 graph_output_blob = infer_request->GetTensor (output_name);
646676 } catch (const char * msg) {
@@ -719,25 +749,41 @@ void BasicBackend::Infer(OrtKernelContext* ctx) {
719749 try {
720750 StartRemoteAsyncInference (context, infer_request);
721751 } catch (std::string const & msg) {
752+ // If the inference fails (exception from ov::InferRequest::infer()),
753+ // we need to put the infer_request back into the pool to avoid deadlocks
754+ // and to allow the next inference request to proceed.
755+ inferRequestsQueue_->putIdleRequest (std::move (infer_request));
722756 ORT_THROW (msg);
723757 }
724758 } else {
725759 try {
726760 StartAsyncInference (context, infer_request);
727761 } catch (std::string const & msg) {
762+ // If the inference fails (exception from ov::InferRequest::infer()),
763+ // we need to put the infer_request back into the pool to avoid deadlocks
764+ // and to allow the next inference request to proceed.
765+ inferRequestsQueue_->putIdleRequest (std::move (infer_request));
728766 ORT_THROW (msg);
729767 }
730768 }
731769#else
732770 try {
733771 StartAsyncInference (context, infer_request);
734772 } catch (const std::runtime_error& e) {
773+ // If the inference fails (exception from ov::InferRequest::infer()),
774+ // we need to put the infer_request back into the pool to avoid deadlocks
775+ // and to allow the next inference request to proceed.
776+ inferRequestsQueue_->putIdleRequest (std::move (infer_request));
735777 ORT_THROW (log_tag + " Exception at StartAsyncInference: " + e.what ());
736778 }
737779#endif
738780 try {
739781 CompleteAsyncInference (context, infer_request);
740782 } catch (const std::runtime_error& e) {
783+ // If the inference fails (exception from ov::InferRequest::infer()),
784+ // we need to put the infer_request back into the pool to avoid deadlocks
785+ // and to allow the next inference request to proceed.
786+ inferRequestsQueue_->putIdleRequest (std::move (infer_request));
741787 ORT_THROW (log_tag + " Exception at CompleteAsyncInference: " + e.what ());
742788 }
743789
0 commit comments