@@ -10,12 +10,6 @@ namespace onnxruntime {
1010
1111using namespace openvino_ep ;
1212
13- constexpr size_t default_alignment = 4096 ;
14-
15- static inline size_t align_up (size_t size, size_t pow2_alignment) {
16- return (size + pow2_alignment - 1 ) & ~(pow2_alignment - 1 );
17- }
18-
1913OVRTAllocator::OVRTAllocator (ov::Core& core, OrtDevice::DeviceType device_type, OrtDevice::DeviceId device_id, const char * name) : IAllocator(OrtMemoryInfo(name, OrtAllocatorType::OrtDeviceAllocator, OrtDevice(device_type, OrtDevice::MemType::DEFAULT, device_id), device_id, OrtMemTypeCPUInput)), core_(core) {
2014 if (device_type == OrtDevice::NPU) {
2115 remote_ctx_ = core_.get_default_context (" NPU" ).as <ov::intel_npu::level_zero::ZeroContext>();
@@ -26,25 +20,26 @@ OVRTAllocator::OVRTAllocator(ov::Core& core, OrtDevice::DeviceType device_type,
2620
2721void * OVRTAllocator::Alloc (size_t size) {
2822 try {
29- size_t alloc_size = align_up (size + sizeof (ov::Tensor*) + default_alignment, default_alignment);
3023 ov::Tensor* tensor = new ov::Tensor (remote_ctx_.create_host_tensor (ov::element::Type_t::u8 ,
31- {alloc_size}));
32- uintptr_t data_ptr = reinterpret_cast <uintptr_t >(tensor->data ());
33-
34- ov::Tensor** ptr = reinterpret_cast <ov::Tensor**>(align_up (data_ptr + sizeof (ov::Tensor*), default_alignment));
35- ptr[-1 ] = tensor;
36-
37- return reinterpret_cast <void *>(ptr);
38-
24+ {size}));
25+ std::unique_lock lock (mutex_);
26+ allocated_.insert ({tensor->data (), tensor});
27+ return reinterpret_cast <void *>(tensor->data ());
3928 } catch (const ov::Exception& e) {
4029 ORT_THROW (std::string (" Alloc failed: " ) + e.what ());
4130 }
4231}
4332
4433void OVRTAllocator::Free (void * p) {
4534 try {
46- ov::Tensor** ptr = reinterpret_cast <ov::Tensor**>(p);
47- delete ptr[-1 ];
35+ std::unique_lock lock (mutex_);
36+ auto it = allocated_.find (p);
37+ if (it != allocated_.end ()) {
38+ ov::Tensor* tensor = it->second ;
39+ allocated_.erase (it);
40+ lock.unlock ();
41+ delete tensor;
42+ }
4843 } catch (const ov::Exception& e) {
4944 ORT_THROW (std::string (" Free failed: " ) + e.what ());
5045 }
0 commit comments