Skip to content
Open
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 87 additions & 24 deletions onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -134,50 +134,113 @@

// Converted to C++ from below reference URL:
// https://github.com/huggingface/optimum-intel/blob/main/optimum/exporters/openvino/stateful.py#L281
void PatchStatefulDecoder(std::shared_ptr<ov::Model> model) {
// Helper function to extract KV patterns from output names dynamically
std::pair<std::vector<std::string>, std::vector<std::string>> ExtractKVPatternsFromOutputs(const std::shared_ptr<ov::Model>& model) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function returns two std::vectors only to check that the first one is non-empty and the second one is used as a sort of lookup table. Therefore, it can return std::optional<T> instead.

std::set<std::string> unique_patterns;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider switching to std::unordered_set<T> if you don't need the values to be sorted.

std::vector<std::string> key_value_output_names;

const std::string prefix = "present_";
const size_t prefix_len = prefix.length();
for (const ov::Output<ov::Node>& output : model->outputs()) {
const auto& names = output.get_names();
for (const auto& name : names) {
if (name.find(prefix) == 0 && name.length() > prefix_len) {
key_value_output_names.push_back(name);
size_t last_underscore_pos = name.rfind('_');

// Extract pattern between "present_" and the last underscore
if (last_underscore_pos != std::string::npos && last_underscore_pos > prefix_len) {
std::string pattern = name.substr(prefix_len, last_underscore_pos - prefix_len);

if (!pattern.empty()) {
unique_patterns.insert(pattern);
}
}
break;
}
}
}
std::vector<std::string> extracted_patterns(unique_patterns.begin(), unique_patterns.end());

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it necessary to construct a std::vector here? Would it be possible to return the set directly?


return std::make_pair(key_value_output_names, extracted_patterns);
}

// Main function to extract KV tensors using dynamic pattern matching
std::pair<std::vector<std::string>, std::vector<std::string>> ExtractInputKVTensors(

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The same here, consider switching to std::optional<T>

const std::shared_ptr<ov::Model>& model, const std::vector<std::string>& patterns) {

std::vector<std::string> key_value_input_names;
std::vector<std::string> not_kv_inputs;

if (patterns.empty()) {
// Fallback: use original substring matching
for (const ov::Output<ov::Node>& input : model->inputs()) {
const auto& names = input.get_names();
const std::string input_name = input.get_any_name();

bool is_kv_input = false;
for (const auto& name : names) {
if (name.find("key_values") != std::string::npos ||
name.find("keys") != std::string::npos ||
name.find("values") != std::string::npos) {
key_value_input_names.push_back(name);
is_kv_input = true;
break;
}
}

if (!is_kv_input) {
not_kv_inputs.push_back(input_name);
}
}

return std::make_pair(key_value_input_names, not_kv_inputs);
}

std::set<std::string> found_kv_inputs;

Check warning on line 200 in onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <set> for set<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc:200: Add #include <set> for set<> [build/include_what_you_use] [4]
Copy link

Copilot AI Nov 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The variable found_kv_inputs is declared but never used. Consider removing it or implementing the intended logic.

Suggested change
std::set<std::string> found_kv_inputs;

Copilot uses AI. Check for mistakes.

for (const ov::Output<ov::Node>& input : model->inputs()) {
auto& names = input.get_names();

bool found = false;
for (auto& name : names) {
if (name.find("key_values") != std::string::npos) {
key_value_input_names.push_back(name);
found = true;
break;
} else if (name.find("keys") != std::string::npos) {
key_value_input_names.push_back(name);
found = true;
break;
} else if (name.find("values") != std::string::npos) {
key_value_input_names.push_back(name);
found = true;
break;

// Check if any input name contains the extracted patterns
for (const auto& name : names) {
for (const auto& pattern : patterns) {
if (name.find(pattern) != std::string::npos){
key_value_input_names.push_back(name);
found = true;
break;
}
}
if (found) break;
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic contradicts the comment above. If we find a pattern in the name, we won't check other names. Is that an expected behavior?

Copy link
Author

@Kotomi-Du Kotomi-Du Nov 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, it is expected behavior. I rephrased the comment.

}

if (!found) {
not_kv_inputs.push_back(input.get_any_name());
}
}

std::vector<std::string> key_value_output_names;
for (const ov::Output<ov::Node>& output : model->outputs()) {
auto& names = output.get_names();
for (auto& name : names) {
if (name.find("present") != std::string::npos) {
key_value_output_names.push_back(name);
break;
}
}
}
return std::make_pair(key_value_input_names, not_kv_inputs);
}

// Updated PatchStatefulDecoder function
void PatchStatefulDecoder(std::shared_ptr<ov::Model> model) {
// Use the dynamic pattern-based extraction logic
auto [key_value_output_names, extracted_patterns] = ExtractKVPatternsFromOutputs(model);
auto [key_value_input_names, not_kv_inputs] = ExtractInputKVTensors(model, extracted_patterns);

std::cout << key_value_input_names.size() << ";" << key_value_output_names.size() << std::endl;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like a debug statement here.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed

if (key_value_input_names.empty() || key_value_output_names.empty()) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know this line not something you added -- but can we add some more strict checking here? If all goes well, then key_value_output_names and key_value_output_names should be non-empty and have the same number of elements correct?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

std::cout << "no key_value_input_names or key_value_output_names found" << std::endl;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same for here as below -- I think there should be a runtime exception thrown here. I don't think we'd ever intend for the stateful flow to get enabled, and not identify pairs of tensors to perform a make_stateful transformation on.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

return;
}

if (key_value_input_names.size() != key_value_output_names.size()) {
std::cout << "found different sizes btween key_value_input_names and key_value_output_names, they couldn't be paired" << std::endl;

Check warning on line 240 in onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <iostream> for cout [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc:240: Add #include <iostream> for cout [build/include_what_you_use] [4]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that this one should be a runtime exception of some sort. I don't think we'd ever want to hit this state, return, and have the rest of the stateful flow continue on.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

return;
}

// By default, batch is the 0 - th but chatglm uses 1 - st dimension as batch
// TODO(ryan): Deduce from a model via ordinal reshape(? ) and topology
// batch_dim = 1 if config.model_type == "chatglm" and not hasattr(config, "rope_ratio") else 0
Expand Down
Loading