diff --git a/include/thread_context.h b/include/thread_context.h new file mode 100644 index 000000000..828b93799 --- /dev/null +++ b/include/thread_context.h @@ -0,0 +1,51 @@ +#pragma once + +#ifndef _THREAD_CONTEXT_H_ +#define _THREAD_CONTEXT_H_ + +void SetContextFuncs(void *getter, void *setter); + +class OmpParallelContext +{ +public: + OmpParallelContext(); + OmpParallelContext(const OmpParallelContext& other); + ~OmpParallelContext(); + +private: + void* m_savedContext; +}; + +class DefaultThreadContext +{ +public: + DefaultThreadContext(); + ~DefaultThreadContext(); + + void* SavedContext(); + +private: + void* m_savedContext; +}; + +class ThreadContext +{ +public: + ThreadContext(void* context); + ~ThreadContext(); + +private: + void* m_savedContext; +}; + +class SavePartitionContext +{ +public: + SavePartitionContext(); + ~SavePartitionContext(); + void* SavedContext(); + +private: + void* m_savedContext; +}; +#endif \ No newline at end of file diff --git a/src/dll/CMakeLists.txt b/src/dll/CMakeLists.txt index 090d3b096..ebbd1a1ba 100644 --- a/src/dll/CMakeLists.txt +++ b/src/dll/CMakeLists.txt @@ -5,16 +5,16 @@ if (DISKANN_USE_STATIC_LIB) add_library(${PROJECT_NAME} STATIC ../abstract_data_store.cpp ../partition.cpp ../pq.cpp ../pq_flash_index.cpp ../logger.cpp ../utils.cpp ../windows_aligned_file_reader.cpp ../distance.cpp ../memory_mapper.cpp ../index.cpp ../in_mem_data_store.cpp ../in_mem_graph_store.cpp ../math_utils.cpp ../disk_utils.cpp ../filter_utils.cpp - ../ann_exception.cpp ../natural_number_set.cpp ../natural_number_map.cpp ../scratch.cpp ../index_factory.cpp ../abstract_index.cpp) + ../ann_exception.cpp ../natural_number_set.cpp ../natural_number_map.cpp ../scratch.cpp ../index_factory.cpp ../abstract_index.cpp ../thread_context.cpp) else() add_library(${PROJECT_NAME} SHARED dllmain.cpp ../abstract_data_store.cpp ../partition.cpp ../pq.cpp ../pq_flash_index.cpp ../logger.cpp ../utils.cpp ../windows_aligned_file_reader.cpp ../distance.cpp ../memory_mapper.cpp ../index.cpp ../in_mem_data_store.cpp ../in_mem_graph_store.cpp ../math_utils.cpp ../disk_utils.cpp ../filter_utils.cpp - ../ann_exception.cpp ../natural_number_set.cpp ../natural_number_map.cpp ../scratch.cpp ../index_factory.cpp ../abstract_index.cpp) + ../ann_exception.cpp ../natural_number_set.cpp ../natural_number_map.cpp ../scratch.cpp ../index_factory.cpp ../abstract_index.cpp ../thread_context.cpp) add_library(${PROJECT_NAME}_build SHARED dllmain.cpp ../abstract_data_store.cpp ../partition.cpp ../pq.cpp ../pq_flash_index.cpp ../logger.cpp ../utils.cpp ../windows_aligned_file_reader.cpp ../distance.cpp ../memory_mapper.cpp ../index.cpp ../in_mem_data_store.cpp ../in_mem_graph_store.cpp ../math_utils.cpp ../disk_utils.cpp ../filter_utils.cpp - ../ann_exception.cpp ../natural_number_set.cpp ../natural_number_map.cpp ../scratch.cpp ../index_factory.cpp ../abstract_index.cpp) + ../ann_exception.cpp ../natural_number_set.cpp ../natural_number_map.cpp ../scratch.cpp ../index_factory.cpp ../abstract_index.cpp ../thread_context.cpp) endif() set(TARGET_DIR "$<$:${CMAKE_LIBRARY_OUTPUT_DIRECTORY_DEBUG}>$<$:${CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE}>") diff --git a/src/index.cpp b/src/index.cpp index 92007fbb0..c24c64ad8 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -15,6 +15,9 @@ #if defined(RELEASE_UNUSED_TCMALLOC_MEMORY_AT_CHECKPOINTS) && defined(DISKANN_BUILD) #include "gperftools/malloc_extension.h" #endif +#ifdef _ANN_ALLOCATOR +#include "thread_context.h" +#endif #ifdef _WINDOWS #include @@ -1566,7 +1569,12 @@ template void Index void Index::prune_all_neighbors(const uint32_t max_degree, cons _filtered_index = true; diskann::Timer timer; +#ifdef _ANN_ALLOCATOR + OmpParallelContext context; +#pragma omp parallel for firstprivate(context) +#else #pragma omp parallel for +#endif for (int64_t node = 0; node < (int64_t)(_max_points + _num_frozen_pts); node++) { if ((size_t)node < _nd || (size_t)node >= _max_points) @@ -2718,7 +2735,12 @@ consolidation_report Index::consolidate_deletes(const IndexWrit uint32_t num_calls_to_process_delete = 0; diskann::Timer timer; +#ifdef _ANN_ALLOCATOR + OmpParallelContext context; +#pragma omp parallel for num_threads(num_threads) schedule(dynamic, 8192) reduction(+ : num_calls_to_process_delete) firstprivate(context) +#else #pragma omp parallel for num_threads(num_threads) schedule(dynamic, 8192) reduction(+ : num_calls_to_process_delete) +#endif for (int64_t loc = 0; loc < (int64_t)_max_points; loc++) { if (old_delete_set->find((uint32_t)loc) == old_delete_set->end() && !_empty_slots.is_in_set((uint32_t)loc)) diff --git a/src/thread_context.cpp b/src/thread_context.cpp new file mode 100644 index 000000000..b709605f0 --- /dev/null +++ b/src/thread_context.cpp @@ -0,0 +1,73 @@ +#include "inc/Helper/ThreadContext.h" + +void* DummyGet() { return nullptr; } +void DummySet(void* context) {} + +typedef void* (*GetOmpContextFuncType)(); +typedef void (*SetOmpContextFuncType)(void *); + +GetOmpContextFuncType s_getOmpContextFunc = &DummyGet; +SetOmpContextFuncType s_setOmpContextFunc = &DummySet; + +void SetContextFuncs(void *getter, void *setter) +{ + s_getOmpContextFunc = (GetOmpContextFuncType)getter; + s_setOmpContextFunc = (SetOmpContextFuncType)setter; +} + +OmpParallelContext::OmpParallelContext() +{ + m_savedContext = (*s_getOmpContextFunc)(); +} +OmpParallelContext::OmpParallelContext(const OmpParallelContext& other) +{ + m_savedContext = other.m_savedContext; + (*s_setOmpContextFunc)(m_savedContext); +} +OmpParallelContext::~OmpParallelContext() +{ + (*s_setOmpContextFunc)(m_savedContext); +} + +DefaultThreadContext::DefaultThreadContext() +{ + m_savedContext = (*s_getOmpContextFunc)(); + (*s_setOmpContextFunc)(nullptr); +} + +DefaultThreadContext::~DefaultThreadContext() +{ + (*s_setOmpContextFunc)(m_savedContext); +} + +void* DefaultThreadContext::SavedContext() +{ + return m_savedContext; +} + +ThreadContext::ThreadContext(void* context) +{ + m_savedContext = (*s_getOmpContextFunc)(); + (*s_setOmpContextFunc)(context); +} + +ThreadContext::~ThreadContext() +{ + (*s_setOmpContextFunc)(m_savedContext); +} + +SavePartitionContext::SavePartitionContext() +{ + // Save the current allocator address, but not do allocator switch + m_savedContext = (*s_getOmpContextFunc)(); +} + +SavePartitionContext::~SavePartitionContext() +{ + // No switch, so nothing to do here +} + +void *SavePartitionContext::SavedContext() +{ + return m_savedContext; +}