Skip to content

Commit 5de0c43

Browse files
committed
Turn prefetching back on
1 parent 688f06a commit 5de0c43

File tree

3 files changed

+58
-73
lines changed

3 files changed

+58
-73
lines changed

CMakeLists.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -200,8 +200,6 @@ if(HNSWLIB_EXAMPLES)
200200
if(ENABLE_ASAN OR ENABLE_UBSAN)
201201
add_cxx_flags(-DHNSWLIB_USE_PREFETCH=0)
202202
endif()
203-
add_cxx_flags(-DHNSWLIB_USE_PREFETCH=0) # TODO(mbautin): remove
204-
205203
add_cxx_flags(-Wall -Wextra -Wpedantic -Werror)
206204

207205
# Unused functions in header files might still be used by other code

hnswlib/hnswalg.h

Lines changed: 31 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -129,8 +129,10 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
129129
label_offset_ = size_links_level0_ + data_size_;
130130
offsetLevel0_ = 0;
131131

132+
// Allocate 64 more bytes for each chunk so we can safely prefetch a
133+
// cache line beyond the chunk.
132134
data_level0_memory_ = ChunkedArray(
133-
size_data_per_element_, k_elements_per_chunk, max_elements);
135+
size_data_per_element_, k_elements_per_chunk, max_elements, 64);
134136

135137
cur_element_count = 0;
136138

@@ -141,7 +143,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
141143
maxlevel_ = -1;
142144

143145
linkLists_ = ChunkedArray(
144-
sizeof(void *), k_elements_per_chunk, max_elements);
146+
sizeof(void *), k_elements_per_chunk, max_elements, 0);
145147

146148
size_links_per_element_ = maxM_ * sizeof(tableint) + sizeof(linklistsizeint);
147149
mult_ = 1 / log(1.0 * M_);
@@ -226,7 +228,6 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
226228
return (data_level0_memory_[internal_id] + offsetData_);
227229
}
228230

229-
230231
int getRandomLevel(double reverse_size) {
231232
std::uniform_real_distribution<double> distribution(0.0, 1.0);
232233
double r = -log(distribution(level_generator_)) * reverse_size;
@@ -286,36 +287,26 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
286287
}
287288
size_t size = getListCount((linklistsizeint*)data);
288289
tableint *datal = (tableint *) (data + 1);
289-
#ifdef USE_SSE
290-
#if HNSWLIB_USE_PREFETCH
291-
_mm_prefetch((char *) (visited_array + *(data + 1)), _MM_HINT_T0);
292-
_mm_prefetch((char *) (visited_array + *(data + 1) + 64), _MM_HINT_T0);
293-
_mm_prefetch(getDataByInternalId(*datal), _MM_HINT_T0);
294-
_mm_prefetch(getDataByInternalId(*(datal + 1)), _MM_HINT_T0);
295-
#endif
296-
#endif
290+
HNSWLIB_MM_PREFETCH((char *) (visited_array + *(data + 1)), _MM_HINT_T0);
291+
HNSWLIB_MM_PREFETCH((char *) (visited_array + *(data + 1) + 64), _MM_HINT_T0);
292+
HNSWLIB_MM_PREFETCH(getDataByInternalId(*datal), _MM_HINT_T0);
293+
HNSWLIB_MM_PREFETCH(getDataByInternalId(*(datal + 1)), _MM_HINT_T0);
297294

298295
for (size_t j = 0; j < size; j++) {
299296
tableint candidate_id = *(datal + j);
300297
// if (candidate_id == 0) continue;
301-
#ifdef USE_SSE
302-
#if HNSWLIB_USE_PREFETCH
303-
_mm_prefetch((char *) (visited_array + *(datal + j + 1)), _MM_HINT_T0);
304-
_mm_prefetch(getDataByInternalId(*(datal + j + 1)), _MM_HINT_T0);
305-
#endif
306-
#endif
298+
if (j + 1 < size) {
299+
HNSWLIB_MM_PREFETCH((char *) (visited_array + *(datal + j + 1)), _MM_HINT_T0);
300+
HNSWLIB_MM_PREFETCH(getDataByInternalId(*(datal + j + 1)), _MM_HINT_T0);
301+
}
307302
if (visited_array[candidate_id] == visited_array_tag) continue;
308303
visited_array[candidate_id] = visited_array_tag;
309304
char *currObj1 = (getDataByInternalId(candidate_id));
310305

311306
dist_t dist1 = fstdistfunc_(data_point, currObj1, dist_func_param_);
312307
if (top_candidates.size() < ef_construction_ || lowerBound > dist1) {
313308
candidateSet.emplace(-dist1, candidate_id);
314-
#ifdef USE_SSE
315-
#if HNSWLIB_USE_PREFETCH
316-
_mm_prefetch(getDataByInternalId(candidateSet.top().second), _MM_HINT_T0);
317-
#endif
318-
#endif
309+
HNSWLIB_MM_PREFETCH(getDataByInternalId(candidateSet.top().second), _MM_HINT_T0);
319310

320311
if (!isMarkedDeleted(candidate_id))
321312
top_candidates.emplace(dist1, candidate_id);
@@ -396,25 +387,18 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
396387
metric_distance_computations+=size;
397388
}
398389

399-
#ifdef USE_SSE
400-
#if HNSWLIB_USE_PREFETCH
401-
_mm_prefetch((char *) (visited_array + *(data + 1)), _MM_HINT_T0);
402-
_mm_prefetch((char *) (visited_array + *(data + 1) + 64), _MM_HINT_T0);
403-
_mm_prefetch(data_level0_memory_[*(data + 1)] + offsetData_, _MM_HINT_T0);
404-
_mm_prefetch((char *) (data + 2), _MM_HINT_T0);
405-
#endif
406-
#endif
390+
HNSWLIB_MM_PREFETCH((char *) (visited_array + *(data + 1)), _MM_HINT_T0);
391+
HNSWLIB_MM_PREFETCH((char *) (visited_array + *(data + 1) + 64), _MM_HINT_T0);
392+
HNSWLIB_MM_PREFETCH(data_level0_memory_[*(data + 1)] + offsetData_, _MM_HINT_T0);
393+
HNSWLIB_MM_PREFETCH((char *) (data + 2), _MM_HINT_T0);
407394

408395
for (size_t j = 1; j <= size; j++) {
409396
int candidate_id = *(data + j);
410-
// if (candidate_id == 0) continue;
411-
#ifdef USE_SSE
412-
#if HNSWLIB_USE_PREFETCH
413-
_mm_prefetch((char *) (visited_array + *(data + j + 1)), _MM_HINT_T0);
414-
_mm_prefetch(data_level0_memory_[*(data + j + 1)] + offsetData_,
415-
_MM_HINT_T0); ////////////
416-
#endif
417-
#endif
397+
if (j < size) {
398+
HNSWLIB_MM_PREFETCH((char *) (visited_array + *(data + j + 1)), _MM_HINT_T0);
399+
HNSWLIB_MM_PREFETCH(data_level0_memory_[*(data + j + 1)] + offsetData_,
400+
_MM_HINT_T0);
401+
}
418402
if (!(visited_array[candidate_id] == visited_array_tag)) {
419403
visited_array[candidate_id] = visited_array_tag;
420404

@@ -430,13 +414,9 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
430414

431415
if (flag_consider_candidate) {
432416
candidate_set.emplace(-dist, candidate_id);
433-
#ifdef USE_SSE
434-
#if HNSWLIB_USE_PREFETCH
435-
_mm_prefetch(data_level0_memory_[candidate_set.top().second] +
417+
HNSWLIB_MM_PREFETCH(data_level0_memory_[candidate_set.top().second] +
436418
offsetLevel0_, ///////////
437419
_MM_HINT_T0); ////////////////////////
438-
#endif
439-
#endif
440420

441421
if (bare_bone_search ||
442422
(!isMarkedDeleted(candidate_id) && ((!isIdAllowed) || (*isIdAllowed)(getExternalLabel(candidate_id))))) {
@@ -822,7 +802,8 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
822802
data_level0_memory_ = ChunkedArray(
823803
size_data_per_element_,
824804
k_elements_per_chunk,
825-
max_elements);
805+
max_elements,
806+
64);
826807
data_level0_memory_.readFromStream(input, cur_element_count);
827808

828809
size_links_per_element_ = maxM_ * sizeof(tableint) + sizeof(linklistsizeint);
@@ -833,7 +814,9 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
833814

834815
visited_list_pool_.reset(new VisitedListPool(1, max_elements));
835816

836-
linkLists_.resize(max_elements);
817+
linkLists_ = ChunkedArray(
818+
sizeof(void *), k_elements_per_chunk, max_elements, 0);
819+
837820
element_levels_ = std::vector<int>(max_elements);
838821
revSize_ = 1.0 / mult_;
839822
ef_ = 10;
@@ -1126,17 +1109,9 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
11261109
data = get_linklist_at_level(currObj, level);
11271110
int size = getListCount(data);
11281111
tableint *datal = (tableint *) (data + 1);
1129-
#ifdef USE_SSE
1130-
#if HNSWLIB_USE_PREFETCH
1131-
_mm_prefetch(getDataByInternalId(*datal), _MM_HINT_T0);
1132-
#endif
1133-
#endif
1112+
HNSWLIB_MM_PREFETCH(getDataByInternalId(*datal), _MM_HINT_T0);
11341113
for (int i = 0; i < size; i++) {
1135-
#ifdef USE_SSE
1136-
#if HNSWLIB_USE_PREFETCH
1137-
_mm_prefetch(getDataByInternalId(*(datal + i + 1)), _MM_HINT_T0);
1138-
#endif
1139-
#endif
1114+
HNSWLIB_MM_PREFETCH(getDataByInternalId(*(datal + i + 1)), _MM_HINT_T0);
11401115
tableint cand = datal[i];
11411116
dist_t d = fstdistfunc_(dataPoint, getDataByInternalId(cand), dist_func_param_);
11421117
if (d < curdist) {
@@ -1523,7 +1498,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
15231498
if (isMarkedDeleted(internalId)) {
15241499
unsigned char *ll_cur = ((unsigned char *)get_linklist0(internalId)) + 2;
15251500
*ll_cur &= ~DELETE_MARK;
1526-
num_deleted_ -= 1;
1501+
num_deleted_ -= 1;
15271502
if (allow_replace_deleted_) {
15281503
std::unique_lock <std::mutex> lock_deleted_elements(deleted_elements_lock);
15291504
deleted_elements.erase(internalId);

hnswlib/hnswlib.h

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -384,15 +384,18 @@ class ChunkedArray {
384384
ChunkedArray()
385385
: element_byte_size_(0),
386386
elements_per_chunk_(0),
387-
element_count_(0) {
387+
element_count_(0),
388+
chunk_padding_bytes_(0) {
388389
}
389390

390391
ChunkedArray(size_t element_byte_size,
391392
size_t elements_per_chunk,
392-
size_t element_count) :
393+
size_t element_count,
394+
size_t chunk_padding_bytes) :
393395
element_byte_size_(element_byte_size),
394396
elements_per_chunk_(elements_per_chunk),
395-
element_count_(0) {
397+
element_count_(0),
398+
chunk_padding_bytes_(chunk_padding_bytes) {
396399
resize(element_count);
397400
}
398401

@@ -415,6 +418,7 @@ class ChunkedArray {
415418
std::swap(elements_per_chunk_, other.elements_per_chunk_);
416419
std::swap(element_count_, other.element_count_);
417420
std::swap(chunks_, other.chunks_);
421+
std::swap(chunk_padding_bytes_, other.chunk_padding_bytes_);
418422
}
419423

420424
~ChunkedArray() {
@@ -435,9 +439,7 @@ class ChunkedArray {
435439
char* operator[](size_t i) const {
436440
assert(i < getCapacity());
437441
if (i >= getCapacity()) return nullptr;
438-
size_t chunk_index = i / elements_per_chunk_;
439-
size_t index_in_chunk = i % elements_per_chunk_;
440-
return chunks_[chunk_index].get() + element_byte_size_ * index_in_chunk;
442+
return getElementNoRangeChecking(i);
441443
}
442444

443445
void clear() {
@@ -451,7 +453,8 @@ class ChunkedArray {
451453

452454
chunks_.resize(new_chunk_count);
453455
for (size_t i = chunk_count; i < new_chunk_count; i++) {
454-
chunks_[i] = internal::makeUniqueCharArray(getSizePerChunk());
456+
chunks_[i] = internal::makeUniqueCharArray(
457+
getSizePerChunk() + chunk_padding_bytes_);
455458
}
456459

457460
element_count_ = new_element_count;
@@ -479,14 +482,6 @@ class ChunkedArray {
479482
i + 1 == num_chunks_to_read ? last_chunk_bytes : getSizePerChunk());
480483
}
481484
}
482-
483-
std::deque<internal::MallocUniqueCharArrayPtr>::const_iterator begin_chunk() const {
484-
return chunks_.begin();
485-
}
486-
487-
std::deque<internal::MallocUniqueCharArrayPtr>::const_iterator end_chunk() const {
488-
return chunks_.end();
489-
}
490485

491486
private:
492487
size_t getChunkCount(size_t element_count) const {
@@ -497,10 +492,27 @@ class ChunkedArray {
497492
size_t elements_per_chunk_;
498493
size_t element_count_;
499494
std::deque<internal::MallocUniqueCharArrayPtr> chunks_;
495+
size_t chunk_padding_bytes_;
500496
};
501497

502498
} // namespace hnswlib
503499

500+
#if defined(USE_SSE) && HNSWLIB_USE_PREFETCH
501+
#if HNSWLIB_DEBUG_PREFETCH
502+
// This mode is used to find prefetch statements causing range check errors in
503+
// tests. We only print line numbers, which makes the output compact enough to
504+
// catch range check errors in some tests.
505+
#define HNSWLIB_MM_PREFETCH(address, hint) do { \
506+
std::cout << __LINE__ << " "; \
507+
_mm_prefetch(address, hint); \
508+
} while (0)
509+
#else
510+
#define HNSWLIB_MM_PREFETCH(address, hint) _mm_prefetch(address, hint)
511+
#endif
512+
#else
513+
#define HNSWLIB_MM_PREFETCH(address, hint)
514+
#endif
515+
504516
#include "space_l2.h"
505517
#include "space_ip.h"
506518
#include "stop_condition.h"

0 commit comments

Comments
 (0)