Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion onnxruntime/core/providers/openvino/ov_interface.cc
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,11 @@ void OVInferRequest::Infer() {
StatefulOVInferRequest::StatefulOVInferRequest(ov::InferRequest infer_request, std::string device)
: OVInferRequest(std::move(infer_request)), target_device(device) {
bool gpu_or_npu = ((device.find("NPU") != std::string::npos) || (device.find("GPU") != std::string::npos));
if (gpu_or_npu) {

// check if there is input_ids tensors and if the tensor type is int64,
// because logic prefill_use_full_chat_history is only for specific inputs and data type
auto input_ids_opt = FindTensor("input_ids");
if (gpu_or_npu && input_ids_opt.has_value() && input_ids_opt->get_element_type() == ov::element::i64) {
prefill_use_full_chat_history = true;
}
}
Expand Down
48 changes: 33 additions & 15 deletions onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,17 @@ bool ModelHasInputOutputNames(std::shared_ptr<ov::Model> model, const std::strin
return false;
}

std::string GetInputOutputName(std::shared_ptr<ov::Model> ov_model,
const std::vector<std::string>& candidate_names) {
for (const auto& name : candidate_names) {
if (ModelHasInputOutputNames(ov_model, name)) {
return name;
}
}
// Return the first candidate as default if none are found
return candidate_names.empty() ? "" : candidate_names[0];
}

void FuseCacheReorder(std::shared_ptr<ov::Model> ov_model,
std::vector<std::string>& not_kv_inputs,
const std::vector<std::string>& key_value_input_names,
Expand All @@ -67,10 +78,15 @@ void FuseCacheReorder(std::shared_ptr<ov::Model> ov_model,
throw std::runtime_error("Model already has fused cache");
}

std::string main_input_name = "inputs_embeds";
if (ModelHasInputOutputNames(ov_model, "input_ids")) {
main_input_name = "input_ids";
}
// Define input name candidates in priority order
const std::vector<std::string> input_name_candidates = {
"inputs_embeds", // Default fallback
"input_ids", // Most common
"input_hidden_states", // Alternative
"/model/embed_tokens/Gather_output_0" // Specific model type
};

std::string main_input_name = GetInputOutputName(ov_model, input_name_candidates);

auto input_batch = ov_model->input(main_input_name).get_partial_shape()[0];

Expand Down Expand Up @@ -121,20 +137,22 @@ void MakeStateful(std::shared_ptr<ov::Model>& ov_model,
void PatchStatefulDecoder(std::shared_ptr<ov::Model> model) {
std::vector<std::string> key_value_input_names;
std::vector<std::string> not_kv_inputs;
for (const ov::Output<ov::Node>& input : model->inputs()) {
auto& names = input.get_names();

bool found = false;
for (auto& name : names) {
if (name.find("key_values") != std::string::npos) {
key_value_input_names.push_back(name);
const auto& params = model->get_parameters();
bool found = false;
for (size_t i = 0; i < params.size(); i++) {
auto param_name = params.at(i)->output(0).get_any_name();
Copy link

Copilot AI Oct 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using params.at(i) is less efficient than direct indexing with params[i]. Consider using range-based for loop or direct indexing for better performance.

Suggested change
auto param_name = params.at(i)->output(0).get_any_name();
auto param_name = params[i]->output(0).get_any_name();

Copilot uses AI. Check for mistakes.
if (param_name.find("key_values") != std::string::npos) {
key_value_input_names.push_back(param_name);
found = true;
} else if (param_name.find("key") != std::string::npos) {
key_value_input_names.push_back(param_name);
found = true;
} else if (param_name.find("value") != std::string::npos) {
key_value_input_names.push_back(param_name);
found = true;
break;
}
}

if (!found) {
not_kv_inputs.push_back(input.get_any_name());
not_kv_inputs.push_back(param_name);
}
Comment on lines 160 to 162
Copy link

Copilot AI Oct 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The 'found' variable is never reset to false between iterations, causing incorrect classification of subsequent parameters. Reset 'found = false' at the beginning of each loop iteration.

Copilot uses AI. Check for mistakes.
}

Expand Down
Loading