diff --git a/onnxruntime/core/providers/openvino/ov_interface.cc b/onnxruntime/core/providers/openvino/ov_interface.cc index 7723ce0a6c7f7..e97bbaceee4e2 100644 --- a/onnxruntime/core/providers/openvino/ov_interface.cc +++ b/onnxruntime/core/providers/openvino/ov_interface.cc @@ -361,7 +361,11 @@ void OVInferRequest::Infer() { StatefulOVInferRequest::StatefulOVInferRequest(ov::InferRequest infer_request, std::string device) : OVInferRequest(std::move(infer_request)), target_device(device) { bool gpu_or_npu = ((device.find("NPU") != std::string::npos) || (device.find("GPU") != std::string::npos)); - if (gpu_or_npu) { + + // check if there is input_ids tensors and if the tensor type is int64, + // because logic prefill_use_full_chat_history is only for specific inputs and data type + auto input_ids_opt = FindTensor("input_ids"); + if (gpu_or_npu && input_ids_opt.has_value() && input_ids_opt->get_element_type() == ov::element::i64) { prefill_use_full_chat_history = true; } } diff --git a/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc b/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc index b48b0efde7ab6..ca4867b7d8ae4 100644 --- a/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc +++ b/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc @@ -59,6 +59,17 @@ bool ModelHasInputOutputNames(std::shared_ptr model, const std::strin return false; } +std::string GetInputOutputName(std::shared_ptr ov_model, + const std::vector& candidate_names) { + for (const auto& name : candidate_names) { + if (ModelHasInputOutputNames(ov_model, name)) { + return name; + } + } + // Return the first candidate as default if none are found + return candidate_names.empty() ? "" : candidate_names[0]; +} + void FuseCacheReorder(std::shared_ptr ov_model, std::vector& not_kv_inputs, const std::vector& key_value_input_names, @@ -67,10 +78,15 @@ void FuseCacheReorder(std::shared_ptr ov_model, throw std::runtime_error("Model already has fused cache"); } - std::string main_input_name = "inputs_embeds"; - if (ModelHasInputOutputNames(ov_model, "input_ids")) { - main_input_name = "input_ids"; - } + // Define input name candidates in priority order + const std::vector input_name_candidates = { + "inputs_embeds", // Default fallback + "input_ids", // Most common + "input_hidden_states", // Alternative + "/model/embed_tokens/Gather_output_0" // Specific model type + }; + + std::string main_input_name = GetInputOutputName(ov_model, input_name_candidates); auto input_batch = ov_model->input(main_input_name).get_partial_shape()[0]; @@ -130,6 +146,14 @@ void PatchStatefulDecoder(std::shared_ptr model) { key_value_input_names.push_back(name); found = true; break; + } else if (name.find("keys") != std::string::npos) { + key_value_input_names.push_back(name); + found = true; + break; + } else if (name.find("values") != std::string::npos) { + key_value_input_names.push_back(name); + found = true; + break; } }