Skip to content

Commit 660adfc

Browse files
authored
feat: ORT GenAI Stateful Compilation changes (#676)
* feat: ORT GenAI Stateful Compilation changes * fix: Disabled UT as testdata/attention_past_state.onnx model is invalid * fix:lint fixes * fix: refactor tensor caching * update: Fix optional position ids caching logic * fix: remove unwanted comment
1 parent be8f8be commit 660adfc

File tree

14 files changed

+839
-52
lines changed

14 files changed

+839
-52
lines changed

onnxruntime/core/providers/openvino/backend_manager.cc

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,10 @@ BackendManager::BackendManager(SessionContext& session_context,
4444
shared_context_{shared_context} {
4545
subgraph_context_.is_ep_ctx_graph = ep_ctx_handle_.CheckForOVEPCtxNodeInGraph(subgraph);
4646

47+
bool cpu_or_gpu = session_context_.device_type.find("CPU") != std::string::npos ||
48+
session_context_.device_type.find("GPU") != std::string::npos;
49+
bool npu = session_context_.device_type.find("NPU") != std::string::npos;
50+
4751
subgraph_context_.model_precision = [&](const GraphViewer& graph_viewer) {
4852
// return empty if graph has no inputs or if types are not one of FP32/FP16
4953
// else assume the type of the first input
@@ -105,8 +109,7 @@ BackendManager::BackendManager(SessionContext& session_context,
105109
if (ModelHasSymbolicInputDims(subgraph)) {
106110
subgraph_context_.has_dynamic_input_shape = true;
107111
LOGS_DEFAULT(INFO) << "[OpenVINO-EP] Model has symbolic input dims";
108-
if ((session_context_.device_type.find("CPU") != std::string::npos ||
109-
session_context_.device_type.find("GPU") != std::string::npos) &&
112+
if (cpu_or_gpu || (npu && session_context_.enable_causallm) &&
110113
!session_context_.disable_dynamic_shapes) {
111114
LOGS_DEFAULT(INFO) << "[OpenVINO-EP] Starting backend initialization. "
112115
<< "Creating backend Dynamic Shapes";
@@ -480,6 +483,9 @@ BackendManager::ReWriteBatchDimWithOne(const ONNX_NAMESPACE::ModelProto& model_p
480483
void BackendManager::Compute(OrtKernelContext* context) {
481484
Ort::KernelContext ctx(context);
482485
std::chrono::high_resolution_clock::time_point start_compute, end_compute;
486+
bool cpu_or_gpu = session_context_.device_type.find("CPU") != std::string::npos ||
487+
session_context_.device_type.find("GPU") != std::string::npos;
488+
bool npu = session_context_.device_type.find("NPU") != std::string::npos;
483489
#ifdef OPENVINO_FIL_ENABLED
484490
static bool fil_enabled = true;
485491
if (fil_enabled) {
@@ -493,8 +499,7 @@ void BackendManager::Compute(OrtKernelContext* context) {
493499
// disable_dynamic_shapes is always set to true for OV NPU plugin.
494500
if (subgraph_context_.has_dynamic_input_shape &&
495501
!session_context_.disable_dynamic_shapes &&
496-
(session_context_.device_type.find("CPU") != std::string::npos ||
497-
session_context_.device_type.find("GPU") != std::string::npos)) {
502+
(cpu_or_gpu || (npu && session_context_.enable_causallm))) {
498503
concrete_backend_->Infer(context);
499504
} else if (subgraph_context_.has_dynamic_input_shape) {
500505
std::vector<std::vector<int64_t>> tensor_shapes = GetInputTensorShapes(ctx);
@@ -567,5 +572,11 @@ void BackendManager::ShutdownBackendManager() {
567572
concrete_backend_.reset();
568573
}
569574

575+
void BackendManager::RewindKVCache(size_t index) {
576+
if (concrete_backend_) {
577+
concrete_backend_->RewindKVCache(index);
578+
}
579+
}
580+
570581
} // namespace openvino_ep
571582
} // namespace onnxruntime

onnxruntime/core/providers/openvino/backend_manager.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ class BackendManager {
3030
SessionContext& GetSessionContext();
3131
Status ExportCompiledBlobAsEPCtxNode(const onnxruntime::GraphViewer& subgraph);
3232
ov::CompiledModel GetOVCompiledModel();
33+
void RewindKVCache(size_t index);
3334

3435
private:
3536
std::unique_ptr<ONNX_NAMESPACE::ModelProto> GetModelProtoFromFusedNode(

onnxruntime/core/providers/openvino/backends/basic_backend.cc

Lines changed: 48 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "core/providers/openvino/backends/basic_backend.h"
1616
#include "core/providers/openvino/onnx_ctx_model_helper.h"
1717
#include "core/providers/openvino/backend_manager.h"
18+
#include "core/providers/openvino/ov_stateful_patch_utils.h"
1819

1920
namespace onnxruntime {
2021

@@ -29,6 +30,7 @@ BasicBackend::BasicBackend(std::unique_ptr<ONNX_NAMESPACE::ModelProto>& model_pr
2930
ptr_stream_t& model_stream)
3031
: session_context_{session_context}, subgraph_context_{subgraph_context}, shared_context_{shared_context} {
3132
std::string& hw_target = session_context_.device_type;
33+
bool enable_causallm = session_context_.enable_causallm;
3234

3335
if (ValidateSubgraph(const_outputs_map_))
3436
return;
@@ -43,7 +45,7 @@ BasicBackend::BasicBackend(std::unique_ptr<ONNX_NAMESPACE::ModelProto>& model_pr
4345
// Setting OpenCL queue throttling for GPU
4446
EnableGPUThrottling(device_config);
4547

46-
// Enable streams; default=1 unless ovverriden by user config
48+
// Enable streams; default=1 unless overridden by user configuration
4749
EnableStreams();
4850

4951
// Set the inference_num_threads property of the CPU
@@ -76,7 +78,7 @@ BasicBackend::BasicBackend(std::unique_ptr<ONNX_NAMESPACE::ModelProto>& model_pr
7678
} else if (!session_context_.has_external_weights &&
7779
!subgraph_context_.has_dynamic_input_shape &&
7880
!session_context_.so_context_enable &&
79-
auto_unified_compile) {
81+
!enable_causallm && auto_unified_compile) {
8082
// Unified OV compile_model is efficient when ov model caching is enabled
8183
// Unified OV compile_model API is supported with AUTO from version 2024.3 and above
8284
// Inputs with static dimensions
@@ -96,7 +98,7 @@ BasicBackend::BasicBackend(std::unique_ptr<ONNX_NAMESPACE::ModelProto>& model_pr
9698
}
9799
auto ov_model = CreateOVModel(std::move(model), session_context_, const_outputs_map_);
98100
exe_network_ = OVCore::Get()->CompileModel(
99-
ov_model, hw_target, device_config, subgraph_context_.subgraph_name);
101+
ov_model, hw_target, device_config, enable_causallm, subgraph_context_.subgraph_name);
100102
}
101103
LOGS_DEFAULT(INFO) << log_tag << "Loaded model to the plugin";
102104
} catch (const char* msg) {
@@ -120,7 +122,7 @@ BasicBackend::BasicBackend(std::unique_ptr<ONNX_NAMESPACE::ModelProto>& model_pr
120122
};
121123
}
122124
inferRequestsQueue_ = std::unique_ptr<InferRequestsQueue>(new InferRequestsQueue(exe_network_, num_infer_req, std::move(initializer)));
123-
bindings_ = std::make_unique<OnnxToOvNetworkBindings>(exe_network_, subgraph_context_);
125+
bindings_ = std::make_unique<OnnxToOvNetworkBindings>(exe_network_, subgraph_context_, session_context_);
124126
}
125127

126128
bool BasicBackend::ValidateSubgraph(std::map<std::string, std::shared_ptr<ov::Node>>& const_outputs_map) {
@@ -181,6 +183,15 @@ void BasicBackend::PopulateConfigValue(ov::AnyMap& device_config) {
181183
if (!session_context_.load_config.empty()) {
182184
const std::map<std::string, ov::AnyMap>& target_config = session_context_.load_config;
183185

186+
if ((session_context_.device_type.find("NPU") != std::string::npos) && session_context_.enable_causallm) {
187+
if (target_config.find("NPU") != target_config.end()) {
188+
auto npu_genai_config = target_config.at("NPU");
189+
CausalLMConfig().ApplyConfig(npu_genai_config, device_config);
190+
} else {
191+
LOGS_DEFAULT(WARNING) << "ORT GenAI CausalLMConfig Configuration not found.";
192+
}
193+
}
194+
184195
if (session_context_.device_type.find("NPU") != std::string::npos) {
185196
auto npuw_config = target_config.at("NPU");
186197

@@ -246,7 +257,8 @@ void BasicBackend::PopulateConfigValue(ov::AnyMap& device_config) {
246257
auto set_target_properties = [&](const std::string& device, const ov::AnyMap& config_options,
247258
const std::vector<ov::PropertyName>& supported_properties) {
248259
for (const auto& [key, value] : config_options) {
249-
if (key.find("NPUW") != std::string::npos) {
260+
if ((key.find("NPUW") != std::string::npos) ||
261+
((device_config.find(key) != device_config.end()) && session_context_.enable_causallm)) {
250262
continue;
251263
}
252264
if (is_supported_and_mutable(key, supported_properties)) {
@@ -339,6 +351,13 @@ void BasicBackend::SetNumThreads(ov::AnyMap& device_config) {
339351
device_config.emplace(ov::inference_num_threads(session_context_.num_of_threads));
340352
}
341353

354+
void BasicBackend::RewindKVCache(size_t index) {
355+
OVInferRequestPtr infer_request;
356+
infer_request = inferRequestsQueue_->getIdleRequest();
357+
infer_request->RewindKVCache(index);
358+
inferRequestsQueue_->putIdleRequest(std::move(infer_request));
359+
}
360+
342361
// Starts an asynchronous inference request for data in slice indexed by batch_slice_idx on
343362
// an Infer Request indexed by infer_req_idx
344363
void BasicBackend::StartAsyncInference(Ort::KernelContext& context, OVInferRequestPtr infer_request) {
@@ -351,7 +370,7 @@ void BasicBackend::StartAsyncInference(Ort::KernelContext& context, OVInferReque
351370
size_t batch_slice_idx = 0;
352371
if (subgraph_context_.has_dynamic_input_shape &&
353372
!session_context_.disable_dynamic_shapes &&
354-
cpu_or_gpu) {
373+
cpu_or_gpu || (npu && session_context_.enable_causallm)) {
355374
auto tensor = context.GetInput(input_info.onnx_index);
356375
auto tensor_info = tensor.GetTensorTypeAndShapeInfo();
357376
auto tensor_shape = tensor_info.GetShape();
@@ -409,7 +428,8 @@ void BasicBackend::StartAsyncInference(Ort::KernelContext& context, OVInferReque
409428
}
410429
} // Loop subgraph original input
411430

412-
if (npu) {
431+
// For Stateful Compilation i.e. enable_causallm as True, we use the dynamic shapes path for NPU plugin as well.
432+
if (npu && !session_context_.enable_causallm) {
413433
// Set the output blob as remote blob
414434
for (const auto& output_info : bindings_->network_outputs_) {
415435
Ort::UnownedValue tensor = context.GetOutput(output_info.onnx_index, output_info.onnx_shape);
@@ -453,19 +473,20 @@ void BasicBackend::CompleteAsyncInference(Ort::KernelContext& context, OVInferRe
453473

454474
bool cpu_or_gpu = session_context_.device_type.find("CPU") != std::string::npos ||
455475
session_context_.device_type.find("GPU") != std::string::npos;
456-
if (cpu_or_gpu) {
476+
bool npu = session_context_.device_type.find("NPU") != std::string::npos;
477+
if (cpu_or_gpu || (npu && session_context_.enable_causallm)) {
457478
for (const auto& output_info : bindings_->network_outputs_) {
458-
OVTensorPtr graph_output_blob;
459-
try {
460-
graph_output_blob = infer_request->GetTensor(output_info.name);
461-
} catch (const char* msg) {
462-
ORT_THROW(msg);
463-
}
464-
size_t batch_size = 1;
465-
Ort::UnownedValue output_tensor =
466-
GetOutputTensor(context, batch_size, infer_request, output_info.name, subgraph_context_.output_names);
467-
auto mem_info = output_tensor.GetTensorMemoryInfo();
468-
if (mem_info.GetAllocatorName() == OpenVINO_GPU) {
479+
OVTensorPtr graph_output_blob;
480+
try {
481+
graph_output_blob = infer_request->GetTensor(output_info.name);
482+
} catch (const char* msg) {
483+
ORT_THROW(msg);
484+
}
485+
size_t batch_size = 1;
486+
Ort::UnownedValue output_tensor =
487+
GetOutputTensor(context, batch_size, infer_request, output_info.name, subgraph_context_.output_names);
488+
auto mem_info = output_tensor.GetTensorMemoryInfo();
489+
if (mem_info.GetAllocatorName() == OpenVINO_GPU) {
469490
return;
470491
} else {
471492
size_t batch_slice = 0;
@@ -538,11 +559,19 @@ void BasicBackend::Infer(OrtKernelContext* ctx) {
538559
try {
539560
StartAsyncInference(context, infer_request);
540561
} catch (const std::runtime_error& e) {
562+
// If the inference fails (exception from ov::InferRequest::infer()),
563+
// we need to put the infer_request back into the pool to avoid deadlocks
564+
// and to allow the next inference request to proceed.
565+
inferRequestsQueue_->putIdleRequest(std::move(infer_request));
541566
ORT_THROW(log_tag + " Exception at StartAsyncInference: " + e.what());
542567
}
543568
try {
544569
CompleteAsyncInference(context, infer_request);
545570
} catch (const std::runtime_error& e) {
571+
// If the inference fails (exception from ov::InferRequest::infer()),
572+
// we need to put the infer_request back into the pool to avoid deadlocks
573+
// and to allow the next inference request to proceed.
574+
inferRequestsQueue_->putIdleRequest(std::move(infer_request));
546575
ORT_THROW(log_tag + " Exception at CompleteAsyncInference: " + e.what());
547576
}
548577

onnxruntime/core/providers/openvino/backends/basic_backend.h

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,22 @@ struct OnnxToOvNetworkBindings {
4242
std::vector<ParameterInfo> network_outputs_;
4343
std::vector<ParameterInfo> network_inputs_;
4444

45-
OnnxToOvNetworkBindings(OVExeNetwork& exec_network, SubGraphContext& subgraph_context) {
45+
OnnxToOvNetworkBindings(OVExeNetwork& exec_network, SubGraphContext& subgraph_context, SessionContext& session_context) {
4646
auto populate = [&](auto& input_output_map, const SubGraphContext::string_index_map_t& onnx_input_map, const auto& ov_parameters) {
4747
for (const auto& [onnx_name, onnx_param_index] : onnx_input_map) {
4848
auto it = std::find_if(ov_parameters.begin(), ov_parameters.end(),
4949
[&onnx_name](const auto& ov_parameter_info) { return ov_parameter_info.get_names().contains(onnx_name); });
5050

51+
// For Stateful Model Compilation, the ONNX model includes KV cache (past/present) tensors.
52+
// However, these tensors are internally converted to a stateful representation, which removes them.
53+
// To prevent runtime exceptions, we simply continue processing here.
54+
if ((onnx_name.empty() || onnx_name == "beam_idx" ||
55+
onnx_name.find("past_key_values") != std::string::npos ||
56+
onnx_name.find("present") != std::string::npos) &&
57+
session_context.enable_causallm) {
58+
continue;
59+
}
60+
5161
ORT_ENFORCE(it != ov_parameters.end(), backend_utils::log_tag,
5262
"Input names mismatch between OpenVINO and ONNX. ", onnx_name,
5363
" doesn't exist in the list of OpenVINO input tensor names");
@@ -85,6 +95,7 @@ class BasicBackend : public IBackend {
8595
ov::CompiledModel GetOVCompiledModel() override {
8696
return exe_network_.Get();
8797
}
98+
void RewindKVCache(size_t index) override;
8899

89100
private:
90101
bool ValidateSubgraph(std::map<std::string, std::shared_ptr<ov::Node>>& const_outputs_map);
@@ -114,7 +125,7 @@ class InferRequestsQueue {
114125
OVInferRequestPtr infer_request;
115126
live_threads=nireq;
116127
for (size_t id = 0; id < nireq; id++) {
117-
infer_request = std::make_shared<OVInferRequest>(net.CreateInferRequest());
128+
infer_request = net.CreateInferRequest();
118129
initializer(infer_request);
119130
infer_requests_.push_back(infer_request);
120131
}
@@ -144,7 +155,6 @@ class InferRequestsQueue {
144155

145156
OVInferRequestPtr getIdleRequest() {
146157
std::unique_lock<std::mutex> lock(_mutex);
147-
std::cout << "get Idle Request" << live_threads << "\n";
148158
if(live_threads==0) {
149159
return nullptr;
150160
}

onnxruntime/core/providers/openvino/contexts.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ struct ProviderInfo {
9797
bool disable_dynamic_shapes{false}; // [disable_dynamic_shapes]: Rewrite dynamic shaped models to
9898
// static shape at runtime and execute.
9999
bool enable_qdq_optimizer{false}; // Enables QDQ pruning for efficient inference latency with NPU
100+
bool enable_causallm{false}; // Enables Causal LM Compilation for ORT GenAI OVEP Pass
100101
bool so_context_enable{false}; // ORT session option
101102
bool so_disable_cpu_ep_fallback{false}; // ORT session option
102103
bool so_context_embed_mode{false}; // ORT session option

onnxruntime/core/providers/openvino/ibackend.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ class IBackend {
1717
virtual void Infer(OrtKernelContext* context) = 0;
1818
virtual ov::CompiledModel GetOVCompiledModel() = 0;
1919
virtual ~IBackend() = default;
20+
virtual void RewindKVCache(size_t index) {}
2021
};
2122
using ptr_stream_t = std::unique_ptr<std::istream>;
2223
class BackendFactory {

onnxruntime/core/providers/openvino/openvino_execution_provider.cc

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,25 @@ common::Status OpenVINOExecutionProvider::SetEpDynamicOptions(gsl::span<const ch
254254
}
255255
}
256256
}
257+
} else if (key == "kvcache_rewind") {
258+
// Convert kvcache_rewind value to int64_t
259+
int64_t index;
260+
try {
261+
index = std::stoll(value);
262+
} catch (const std::exception& e) {
263+
LOGS_DEFAULT(WARNING) << "Conversion for kvcache_rewind string value to int64_t index failed."
264+
<< "Exception:" + std::string(e.what());
265+
return Status::OK();
266+
}
267+
268+
// Trigger KVCache Rewind for target Backend
269+
for (auto& backend : backend_managers_) {
270+
if (index >= 0) {
271+
backend.RewindKVCache(static_cast<size_t>(index));
272+
} else {
273+
LOGS_DEFAULT(WARNING) << "kvcache_rewind index is < 0:\t" << index;
274+
}
275+
}
257276
} else {
258277
// Handle unknown options
259278
LOGS_DEFAULT(WARNING) << "Unknown key/value pair - ignoring " << key << "/" << value;

onnxruntime/core/providers/openvino/openvino_provider_factory.cc

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -343,13 +343,20 @@ static void ParseProviderInfo(const ProviderOptions& provider_options,
343343

344344
pi.enable_qdq_optimizer = ParseBooleanOption(provider_options, "enable_qdq_optimizer");
345345

346+
pi.enable_causallm = ParseBooleanOption(provider_options, "enable_causallm");
347+
346348
pi.disable_dynamic_shapes = ParseBooleanOption(provider_options, "disable_dynamic_shapes");
347349
} catch (std::string msg) {
348350
ORT_THROW(msg);
349351
}
350352
// Always true for NPU plugin or when passed .
351353
if (pi.device_type.find("NPU") != std::string::npos) {
352-
pi.disable_dynamic_shapes = true;
354+
// For Stateful Compilation i.e. enable_causallm as True, we use the dynamic shapes path.
355+
if (pi.enable_causallm) {
356+
pi.disable_dynamic_shapes = false;
357+
} else {
358+
pi.disable_dynamic_shapes = true;
359+
}
353360
}
354361
}
355362

0 commit comments

Comments
 (0)