Skip to content
105 changes: 81 additions & 24 deletions onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT License

#include "core/providers/openvino/ov_stateful_patch_utils.h"
#include "regex"

namespace onnxruntime {
namespace openvino_ep {
Expand Down Expand Up @@ -134,44 +135,100 @@

// Converted to C++ from below reference URL:
// https://github.com/huggingface/optimum-intel/blob/main/optimum/exporters/openvino/stateful.py#L281
void PatchStatefulDecoder(std::shared_ptr<ov::Model> model) {
// Helper function to extract KV patterns from output names dynamically
std::pair<std::vector<std::string>, std::vector<std::string>> ExtractKVPatternsFromOutputs(const std::shared_ptr<ov::Model>& model) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function returns two std::vectors only to check that the first one is non-empty and the second one is used as a sort of lookup table. Therefore, it can return std::optional<T> instead.

Copy link
Author

@Kotomi-Du Kotomi-Du Nov 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Each element in key_value_output_names will be passed as information for making stateful, it cannot be switched to std::optional.

It is updated to std::optional for patterns now.

std::set<std::string> unique_patterns;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider switching to std::unordered_set<T> if you don't need the values to be sorted.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

std::vector<std::string> key_value_output_names;

// Regex to match "present_" prefix and numeric suffix
std::regex present_pattern(R"(present_(.+)_(\d+))");

// Scan all outputs with "present" in the name
for (const ov::Output<ov::Node>& output : model->outputs()) {
const auto& names = output.get_names();
for (const auto& name : names) {
if (name.starts_with("present")) {
key_value_output_names.push_back(name);
std::smatch match;
if (std::regex_match(name, match, present_pattern)) {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need a regular expression here? It seems like it can be simplified to a regular string manipulation: we already know that the string starts with "present" and we can find where the last underscore is, gathering the substring.
This is also a micro-optimization in terms of performance, since C++ regexps are compiled at runtime and are known for their bad performance.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

// Extract the middle part (between "present_" and "_number")
std::string pattern = match[1].str();
unique_patterns.insert(pattern);
}
break;
}
}
}

std::vector<std::string> extracted_patterns(unique_patterns.begin(), unique_patterns.end());

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it necessary to construct a std::vector here? Would it be possible to return the set directly?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated it to std::optional<std::pair<std::string, std::string>> now.


return std::make_pair(key_value_output_names, extracted_patterns);
}

// Main function to extract KV tensors using dynamic pattern matching
std::pair<std::vector<std::string>, std::vector<std::string>> ExtractInputKVTensors(
const std::shared_ptr<ov::Model>& model, const std::vector<std::string>& patterns) {

std::vector<std::string> key_value_input_names;
std::vector<std::string> not_kv_inputs;

if (patterns.empty()) {
// Fallback: use original substring matching
for (const ov::Output<ov::Node>& input : model->inputs()) {
const auto& names = input.get_names();
const std::string input_name = input.get_any_name();

bool is_kv_input = false;
for (const auto& name : names) {
if (name.find("key_values") != std::string::npos ||
name.find("keys") != std::string::npos ||
name.find("values") != std::string::npos) {
key_value_input_names.push_back(name);
is_kv_input = true;
break;
}
}

if (!is_kv_input) {
not_kv_inputs.push_back(input_name);
}
}

return std::make_pair(key_value_input_names, not_kv_inputs);
}

std::set<std::string> found_kv_inputs;

Check warning on line 200 in onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <set> for set<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc:200: Add #include <set> for set<> [build/include_what_you_use] [4]

for (const ov::Output<ov::Node>& input : model->inputs()) {
auto& names = input.get_names();

bool found = false;
for (auto& name : names) {
if (name.find("key_values") != std::string::npos) {
key_value_input_names.push_back(name);
found = true;
break;
} else if (name.find("keys") != std::string::npos) {
key_value_input_names.push_back(name);
found = true;
break;
} else if (name.find("values") != std::string::npos) {
key_value_input_names.push_back(name);
found = true;
break;

// Check each input name against potential patterns
for (const auto& name : names) {
for (const auto& pattern : patterns) {
if (name.find(pattern) != std::string::npos){
key_value_input_names.push_back(name);
found = true;
break;
}
}
if (found) break;
}

if (!found) {
not_kv_inputs.push_back(input.get_any_name());
}
}

std::vector<std::string> key_value_output_names;
for (const ov::Output<ov::Node>& output : model->outputs()) {
auto& names = output.get_names();
for (auto& name : names) {
if (name.find("present") != std::string::npos) {
key_value_output_names.push_back(name);
break;
}
}
}
return std::make_pair(key_value_input_names, not_kv_inputs);
}

// Updated PatchStatefulDecoder function
void PatchStatefulDecoder(std::shared_ptr<ov::Model> model) {
// Use the dynamic pattern-based extraction logic
auto [key_value_output_names, extracted_patterns] = ExtractKVPatternsFromOutputs(model);
auto [key_value_input_names, not_kv_inputs] = ExtractInputKVTensors(model, extracted_patterns);

if (key_value_input_names.empty() || key_value_output_names.empty()) {
std::cout << "no key_value_input_names or key_value_output_names found" << std::endl;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same for here as below -- I think there should be a runtime exception thrown here. I don't think we'd ever intend for the stateful flow to get enabled, and not identify pairs of tensors to perform a make_stateful transformation on.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

Expand Down
Loading