@@ -385,38 +385,50 @@ void OVInferRequest::QueryStatus() {
385385
386386StatefulOVInferRequest::StatefulOVInferRequest (ov::InferRequest infer_request, std::string device)
387387 : OVInferRequest(std::move(infer_request)), target_device(device) {
388- if ((device.find (" NPU" ) != std::string::npos) || (device.find (" GPU" ) != std::string::npos)) {
388+ bool gpu_or_npu = ((device.find (" NPU" ) != std::string::npos) || (device.find (" GPU" ) != std::string::npos));
389+ if (gpu_or_npu) {
389390 prefill_use_full_chat_history = true ;
390391 }
391392}
392393
394+ void StatefulOVInferRequest::CacheTensor (const std::string& tensor_name, const ov::element::Type& type,
395+ const std::vector<size_t >& shape, int32_t fill_value) {
396+ ov::Tensor tensor = ov::Tensor (type, shape);
397+ std::fill_n (tensor.data <int32_t >(), tensor.get_size (), fill_value);
398+ ovInfReq.set_tensor (tensor_name, tensor);
399+ }
400+
401+ void StatefulOVInferRequest::CacheTensor (const std::string& tensor_name, std::vector<int64_t >& cache) {
402+ auto tensor = ovInfReq.get_tensor (tensor_name);
403+ auto * pData = tensor.data <int64_t >();
404+ for (size_t i = 0 ; i < tensor.get_size (); i++) {
405+ cache.emplace_back (pData[i]);
406+ }
407+ }
408+
409+ void StatefulOVInferRequest::SetTensorFromCache (const std::string& tensor_name,
410+ const std::vector<int64_t >& cache_data) {
411+ auto tensor = ovInfReq.get_tensor (tensor_name);
412+ auto new_shape = tensor.get_shape ();
413+ new_shape[1 ] = cache_data.size ();
414+
415+ auto new_tensor = ov::Tensor (tensor.get_element_type (), new_shape);
416+ auto * pNewData = new_tensor.data <int64_t >();
417+ std::memcpy (pNewData, cache_data.data (), cache_data.size () * sizeof (int64_t ));
418+
419+ ovInfReq.set_tensor (tensor_name, new_tensor);
420+ }
421+
393422void StatefulOVInferRequest::PreProcessInferRequest () {
394423 // Workaround: Setting the value here as it cannot be set at the ORT GenAI layer currently.
395424 // TODO(ankit): Address this issue and implement the fix at the appropriate layer.
396- ov::Tensor beam_idx = ov::Tensor (ov::element::i32 , {1 });
397- std::fill_n (beam_idx.data <int32_t >(), 1 , 0 );
398- ovInfReq.set_tensor (" beam_idx" , beam_idx);
425+ CacheTensor (" beam_idx" , ov::element::i32 , {1 }, 0 );
399426
400427 // If 'prefill full chat history' mode is enabled, we need to cache input_ids and position_ids.
401428 if (prefill_use_full_chat_history) {
402429 auto input_ids_tensor = ovInfReq.get_tensor (" input_ids" );
403-
404- // Cache the "input_ids" tensor
405- {
406- auto * pData = input_ids_tensor.data <int64_t >();
407- for (size_t i = 0 ; i < input_ids_tensor.get_size (); i++) {
408- cached_input_ids.push_back (pData[i]);
409- }
410- }
411-
412- // Cache the "position_ids" tensor
413- {
414- auto position_ids = ovInfReq.get_tensor (" position_ids" );
415- auto * pData = position_ids.data <int64_t >();
416- for (size_t i = 0 ; i < position_ids.get_size (); i++) {
417- cached_position_ids.push_back (pData[i]);
418- }
419- }
430+ CacheTensor (" input_ids" , cached_input_ids);
431+ CacheTensor (" position_ids" , cached_position_ids);
420432
421433 // If we're about to run the prefill model
422434 if (input_ids_tensor.get_size () > 1 ) {
@@ -426,26 +438,9 @@ void StatefulOVInferRequest::PreProcessInferRequest() {
426438 // Clear the internal KVCache state. For NPU device, this operation is a no-op.
427439 ovInfReq.reset_state ();
428440
429- // Create and set a new "input_ids" tensor using the cached "input_ids" values.
430- {
431- auto new_shape = input_ids_tensor.get_shape ();
432- new_shape[1 ] = cached_input_ids.size ();
433- auto new_input_ids = ov::Tensor (input_ids_tensor.get_element_type (), new_shape);
434- auto * pNewInputIds = new_input_ids.data <int64_t >();
435- std::memcpy (pNewInputIds, cached_input_ids.data (), cached_input_ids.size () * sizeof (int64_t ));
436- ovInfReq.set_tensor (" input_ids" , new_input_ids);
437- }
438-
439- // Create and set a new "position_ids" tensor using the cached "position_ids" values.
440- {
441- auto position_ids_tensor = ovInfReq.get_tensor (" position_ids" );
442- auto new_shape = position_ids_tensor.get_shape ();
443- new_shape[1 ] = cached_position_ids.size ();
444- auto new_position_ids = ov::Tensor (position_ids_tensor.get_element_type (), new_shape);
445- auto * pNewPositionIds = new_position_ids.data <int64_t >();
446- std::memcpy (pNewPositionIds, cached_position_ids.data (), cached_position_ids.size () * sizeof (int64_t ));
447- ovInfReq.set_tensor (" position_ids" , new_position_ids);
448- }
441+ // Set tensors using cached values
442+ SetTensorFromCache (" input_ids" , cached_input_ids);
443+ SetTensorFromCache (" position_ids" , cached_position_ids);
449444 }
450445 }
451446 }
0 commit comments