Skip to content

Commit a5ac79d

Browse files
committed
update: Fix optional position ids caching logic
1 parent 833cff9 commit a5ac79d

File tree

4 files changed

+43
-10
lines changed

4 files changed

+43
-10
lines changed

onnxruntime/core/providers/openvino/backends/basic_backend.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ BasicBackend::BasicBackend(std::unique_ptr<ONNX_NAMESPACE::ModelProto>& model_pr
142142
};
143143
}
144144
inferRequestsQueue_ = std::unique_ptr<InferRequestsQueue>(new InferRequestsQueue(exe_network_, num_infer_req, std::move(initializer)));
145-
bindings_ = std::make_unique<OnnxToOvNetworkBindings>(exe_network_, subgraph_context_);
145+
bindings_ = std::make_unique<OnnxToOvNetworkBindings>(exe_network_, subgraph_context_, session_context_);
146146
}
147147

148148
bool BasicBackend::ValidateSubgraph(std::map<std::string, std::shared_ptr<ov::Node>>& const_outputs_map) {

onnxruntime/core/providers/openvino/backends/basic_backend.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ struct OnnxToOvNetworkBindings {
4242
std::vector<ParameterInfo> network_outputs_;
4343
std::vector<ParameterInfo> network_inputs_;
4444

45-
OnnxToOvNetworkBindings(OVExeNetwork& exec_network, SubGraphContext& subgraph_context) {
45+
OnnxToOvNetworkBindings(OVExeNetwork& exec_network, SubGraphContext& subgraph_context, SessionContext& session_context) {
4646
auto populate = [&](auto& input_output_map, const SubGraphContext::string_index_map_t& onnx_input_map, const auto& ov_parameters) {
4747
for (const auto& [onnx_name, onnx_param_index] : onnx_input_map) {
4848
auto it = std::find_if(ov_parameters.begin(), ov_parameters.end(),
@@ -51,9 +51,10 @@ struct OnnxToOvNetworkBindings {
5151
// For Stateful Model Compilation, the ONNX model includes KV cache (past/present) tensors.
5252
// However, these tensors are internally converted to a stateful representation, which removes them.
5353
// To prevent runtime exceptions, we simply continue processing here.
54-
if (onnx_name.empty() || onnx_name == "beam_idx" ||
54+
if ((onnx_name.empty() || onnx_name == "beam_idx" ||
5555
onnx_name.find("past_key_values") != std::string::npos ||
56-
onnx_name.find("present") != std::string::npos) {
56+
onnx_name.find("present") != std::string::npos) &&
57+
session_context.enable_causallm) {
5758
continue;
5859
}
5960

onnxruntime/core/providers/openvino/ov_interface.cc

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,7 @@ StatefulOVInferRequest::StatefulOVInferRequest(ov::InferRequest infer_request, s
391391
}
392392
}
393393

394-
void StatefulOVInferRequest::CacheTensor(const std::string& tensor_name, const ov::element::Type& type,
394+
void StatefulOVInferRequest::FillTensor(const std::string& tensor_name, const ov::element::Type& type,
395395
const std::vector<size_t>& shape, int32_t fill_value) {
396396
ov::Tensor tensor = ov::Tensor(type, shape);
397397
std::fill_n(tensor.data<int32_t>(), tensor.get_size(), fill_value);
@@ -419,16 +419,43 @@ void StatefulOVInferRequest::SetTensorFromCache(const std::string& tensor_name,
419419
ovInfReq.set_tensor(tensor_name, new_tensor);
420420
}
421421

422+
std::optional<ov::Tensor> StatefulOVInferRequest::FindTensor(const std::string& tensor_name) {
423+
// Check if tensor exists by examining input names in the compiled model
424+
const auto& model = ovInfReq.get_compiled_model();
425+
bool tensor_exists = false;
426+
427+
for (const auto& input : model.inputs()) {
428+
const auto& names = input.get_names();
429+
if (names.find(tensor_name) != names.end()) {
430+
tensor_exists = true;
431+
break;
432+
}
433+
}
434+
435+
if (tensor_exists) {
436+
return ovInfReq.get_tensor(tensor_name);
437+
}
438+
439+
return std::nullopt;
440+
}
441+
422442
void StatefulOVInferRequest::PreProcessInferRequest() {
423443
// Workaround: Setting the value here as it cannot be set at the ORT GenAI layer currently.
424444
// TODO(ankit): Address this issue and implement the fix at the appropriate layer.
425-
CacheTensor("beam_idx", ov::element::i32, {1}, 0);
445+
FillTensor("beam_idx", ov::element::i32, {1}, 0);
426446

427-
// If 'prefill full chat history' mode is enabled, we need to cache input_ids and position_ids.
447+
// If 'prefill use full chat history' mode is enabled, we need to cache input_ids and position_ids.
428448
if (prefill_use_full_chat_history) {
429449
auto input_ids_tensor = ovInfReq.get_tensor("input_ids");
430450
CacheTensor("input_ids", cached_input_ids);
431-
CacheTensor("position_ids", cached_position_ids);
451+
452+
// "position_ids" (GQA with Rotary Embeddings doesnt have position_ids) - check if exists
453+
auto position_ids_opt = FindTensor("position_ids");
454+
bool has_position_ids = position_ids_opt.has_value();
455+
456+
if (has_position_ids) {
457+
CacheTensor("position_ids", cached_position_ids);
458+
}
432459

433460
// If we're about to run the prefill model
434461
if (input_ids_tensor.get_size() > 1) {
@@ -440,7 +467,11 @@ void StatefulOVInferRequest::PreProcessInferRequest() {
440467

441468
// Set tensors using cached values
442469
SetTensorFromCache("input_ids", cached_input_ids);
443-
SetTensorFromCache("position_ids", cached_position_ids);
470+
471+
// Only set position_ids if it exists and we have cached values
472+
if (has_position_ids && !cached_position_ids.empty()) {
473+
SetTensorFromCache("position_ids", cached_position_ids);
474+
}
444475
}
445476
}
446477
}

onnxruntime/core/providers/openvino/ov_interface.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,10 +146,11 @@ class StatefulOVInferRequest : public OVInferRequest {
146146
void StartAsync() override;
147147
void Infer() override;
148148
void RewindKVCache(size_t index) override;
149-
void CacheTensor(const std::string& tensor_name, const ov::element::Type& type,
149+
void FillTensor(const std::string& tensor_name, const ov::element::Type& type,
150150
const std::vector<size_t>& shape, int32_t fill_value);
151151
void CacheTensor(const std::string& tensor_name, std::vector<int64_t>& cache);
152152
void SetTensorFromCache(const std::string& tensor_name, const std::vector<int64_t>& cache_data);
153+
std::optional<ov::Tensor> FindTensor(const std::string& tensor_name);
153154

154155
private:
155156
void PreProcessInferRequest();

0 commit comments

Comments
 (0)