@@ -391,7 +391,7 @@ StatefulOVInferRequest::StatefulOVInferRequest(ov::InferRequest infer_request, s
391391 }
392392}
393393
394- void StatefulOVInferRequest::CacheTensor (const std::string& tensor_name, const ov::element::Type& type,
394+ void StatefulOVInferRequest::FillTensor (const std::string& tensor_name, const ov::element::Type& type,
395395 const std::vector<size_t >& shape, int32_t fill_value) {
396396 ov::Tensor tensor = ov::Tensor (type, shape);
397397 std::fill_n (tensor.data <int32_t >(), tensor.get_size (), fill_value);
@@ -419,16 +419,43 @@ void StatefulOVInferRequest::SetTensorFromCache(const std::string& tensor_name,
419419 ovInfReq.set_tensor (tensor_name, new_tensor);
420420}
421421
422+ std::optional<ov::Tensor> StatefulOVInferRequest::FindTensor (const std::string& tensor_name) {
423+ // Check if tensor exists by examining input names in the compiled model
424+ const auto & model = ovInfReq.get_compiled_model ();
425+ bool tensor_exists = false ;
426+
427+ for (const auto & input : model.inputs ()) {
428+ const auto & names = input.get_names ();
429+ if (names.find (tensor_name) != names.end ()) {
430+ tensor_exists = true ;
431+ break ;
432+ }
433+ }
434+
435+ if (tensor_exists) {
436+ return ovInfReq.get_tensor (tensor_name);
437+ }
438+
439+ return std::nullopt ;
440+ }
441+
422442void StatefulOVInferRequest::PreProcessInferRequest () {
423443 // Workaround: Setting the value here as it cannot be set at the ORT GenAI layer currently.
424444 // TODO(ankit): Address this issue and implement the fix at the appropriate layer.
425- CacheTensor (" beam_idx" , ov::element::i32 , {1 }, 0 );
445+ FillTensor (" beam_idx" , ov::element::i32 , {1 }, 0 );
426446
427- // If 'prefill full chat history' mode is enabled, we need to cache input_ids and position_ids.
447+ // If 'prefill use full chat history' mode is enabled, we need to cache input_ids and position_ids.
428448 if (prefill_use_full_chat_history) {
429449 auto input_ids_tensor = ovInfReq.get_tensor (" input_ids" );
430450 CacheTensor (" input_ids" , cached_input_ids);
431- CacheTensor (" position_ids" , cached_position_ids);
451+
452+ // "position_ids" (GQA with Rotary Embeddings doesnt have position_ids) - check if exists
453+ auto position_ids_opt = FindTensor (" position_ids" );
454+ bool has_position_ids = position_ids_opt.has_value ();
455+
456+ if (has_position_ids) {
457+ CacheTensor (" position_ids" , cached_position_ids);
458+ }
432459
433460 // If we're about to run the prefill model
434461 if (input_ids_tensor.get_size () > 1 ) {
@@ -440,7 +467,11 @@ void StatefulOVInferRequest::PreProcessInferRequest() {
440467
441468 // Set tensors using cached values
442469 SetTensorFromCache (" input_ids" , cached_input_ids);
443- SetTensorFromCache (" position_ids" , cached_position_ids);
470+
471+ // Only set position_ids if it exists and we have cached values
472+ if (has_position_ids && !cached_position_ids.empty ()) {
473+ SetTensorFromCache (" position_ids" , cached_position_ids);
474+ }
444475 }
445476 }
446477 }
0 commit comments