Skip to content

Commit 833cff9

Browse files
committed
fix: refactor tensor caching
1 parent 0475528 commit 833cff9

File tree

2 files changed

+40
-41
lines changed

2 files changed

+40
-41
lines changed

onnxruntime/core/providers/openvino/ov_interface.cc

Lines changed: 36 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -385,38 +385,50 @@ void OVInferRequest::QueryStatus() {
385385

386386
StatefulOVInferRequest::StatefulOVInferRequest(ov::InferRequest infer_request, std::string device)
387387
: OVInferRequest(std::move(infer_request)), target_device(device) {
388-
if ((device.find("NPU") != std::string::npos) || (device.find("GPU") != std::string::npos)) {
388+
bool gpu_or_npu = ((device.find("NPU") != std::string::npos) || (device.find("GPU") != std::string::npos));
389+
if (gpu_or_npu) {
389390
prefill_use_full_chat_history = true;
390391
}
391392
}
392393

394+
void StatefulOVInferRequest::CacheTensor(const std::string& tensor_name, const ov::element::Type& type,
395+
const std::vector<size_t>& shape, int32_t fill_value) {
396+
ov::Tensor tensor = ov::Tensor(type, shape);
397+
std::fill_n(tensor.data<int32_t>(), tensor.get_size(), fill_value);
398+
ovInfReq.set_tensor(tensor_name, tensor);
399+
}
400+
401+
void StatefulOVInferRequest::CacheTensor(const std::string& tensor_name, std::vector<int64_t>& cache) {
402+
auto tensor = ovInfReq.get_tensor(tensor_name);
403+
auto* pData = tensor.data<int64_t>();
404+
for (size_t i = 0; i < tensor.get_size(); i++) {
405+
cache.emplace_back(pData[i]);
406+
}
407+
}
408+
409+
void StatefulOVInferRequest::SetTensorFromCache(const std::string& tensor_name,
410+
const std::vector<int64_t>& cache_data) {
411+
auto tensor = ovInfReq.get_tensor(tensor_name);
412+
auto new_shape = tensor.get_shape();
413+
new_shape[1] = cache_data.size();
414+
415+
auto new_tensor = ov::Tensor(tensor.get_element_type(), new_shape);
416+
auto* pNewData = new_tensor.data<int64_t>();
417+
std::memcpy(pNewData, cache_data.data(), cache_data.size() * sizeof(int64_t));
418+
419+
ovInfReq.set_tensor(tensor_name, new_tensor);
420+
}
421+
393422
void StatefulOVInferRequest::PreProcessInferRequest() {
394423
// Workaround: Setting the value here as it cannot be set at the ORT GenAI layer currently.
395424
// TODO(ankit): Address this issue and implement the fix at the appropriate layer.
396-
ov::Tensor beam_idx = ov::Tensor(ov::element::i32, {1});
397-
std::fill_n(beam_idx.data<int32_t>(), 1, 0);
398-
ovInfReq.set_tensor("beam_idx", beam_idx);
425+
CacheTensor("beam_idx", ov::element::i32, {1}, 0);
399426

400427
// If 'prefill full chat history' mode is enabled, we need to cache input_ids and position_ids.
401428
if (prefill_use_full_chat_history) {
402429
auto input_ids_tensor = ovInfReq.get_tensor("input_ids");
403-
404-
// Cache the "input_ids" tensor
405-
{
406-
auto* pData = input_ids_tensor.data<int64_t>();
407-
for (size_t i = 0; i < input_ids_tensor.get_size(); i++) {
408-
cached_input_ids.push_back(pData[i]);
409-
}
410-
}
411-
412-
// Cache the "position_ids" tensor
413-
{
414-
auto position_ids = ovInfReq.get_tensor("position_ids");
415-
auto* pData = position_ids.data<int64_t>();
416-
for (size_t i = 0; i < position_ids.get_size(); i++) {
417-
cached_position_ids.push_back(pData[i]);
418-
}
419-
}
430+
CacheTensor("input_ids", cached_input_ids);
431+
CacheTensor("position_ids", cached_position_ids);
420432

421433
// If we're about to run the prefill model
422434
if (input_ids_tensor.get_size() > 1) {
@@ -426,26 +438,9 @@ void StatefulOVInferRequest::PreProcessInferRequest() {
426438
// Clear the internal KVCache state. For NPU device, this operation is a no-op.
427439
ovInfReq.reset_state();
428440

429-
// Create and set a new "input_ids" tensor using the cached "input_ids" values.
430-
{
431-
auto new_shape = input_ids_tensor.get_shape();
432-
new_shape[1] = cached_input_ids.size();
433-
auto new_input_ids = ov::Tensor(input_ids_tensor.get_element_type(), new_shape);
434-
auto* pNewInputIds = new_input_ids.data<int64_t>();
435-
std::memcpy(pNewInputIds, cached_input_ids.data(), cached_input_ids.size() * sizeof(int64_t));
436-
ovInfReq.set_tensor("input_ids", new_input_ids);
437-
}
438-
439-
// Create and set a new "position_ids" tensor using the cached "position_ids" values.
440-
{
441-
auto position_ids_tensor = ovInfReq.get_tensor("position_ids");
442-
auto new_shape = position_ids_tensor.get_shape();
443-
new_shape[1] = cached_position_ids.size();
444-
auto new_position_ids = ov::Tensor(position_ids_tensor.get_element_type(), new_shape);
445-
auto* pNewPositionIds = new_position_ids.data<int64_t>();
446-
std::memcpy(pNewPositionIds, cached_position_ids.data(), cached_position_ids.size() * sizeof(int64_t));
447-
ovInfReq.set_tensor("position_ids", new_position_ids);
448-
}
441+
// Set tensors using cached values
442+
SetTensorFromCache("input_ids", cached_input_ids);
443+
SetTensorFromCache("position_ids", cached_position_ids);
449444
}
450445
}
451446
}

onnxruntime/core/providers/openvino/ov_interface.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,10 @@ class StatefulOVInferRequest : public OVInferRequest {
146146
void StartAsync() override;
147147
void Infer() override;
148148
void RewindKVCache(size_t index) override;
149+
void CacheTensor(const std::string& tensor_name, const ov::element::Type& type,
150+
const std::vector<size_t>& shape, int32_t fill_value);
151+
void CacheTensor(const std::string& tensor_name, std::vector<int64_t>& cache);
152+
void SetTensorFromCache(const std::string& tensor_name, const std::vector<int64_t>& cache_data);
149153

150154
private:
151155
void PreProcessInferRequest();

0 commit comments

Comments
 (0)