Skip to content

Commit 38ed79c

Browse files
mdouzefacebook-github-bot
authored andcommitted
Convert PQ 4 bit code to dynamic dispatch, yy
Summary: Migration of the 4-bit codecs to dynamic dispatch. The migration consists in: - templatizing the SIMD ResultHandlers to the SIMDLevel - instantiating the AVX2 and AVX512 code in their own files (compile units) - removing any SIMD dependency from IndexFastScan and IndexIVFFastScan - adding dispatching code for the SIMD code Differential Revision: D73581633
1 parent 51f6741 commit 38ed79c

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

58 files changed

+4178
-1380
lines changed

demos/demo_simd_levels.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import time
7+
import faiss
8+
import numpy as np
9+
import os
10+
from collections import defaultdict
11+
from faiss.contrib.datasets import SyntheticDataset
12+
13+
14+
print("compile options", faiss.get_compile_options())
15+
print("SIMD level: ", faiss.SIMDConfig.get_level_name())
16+
17+
18+
ds = SyntheticDataset(32, 8000, 10000, 8000)
19+
20+
21+
index = faiss.index_factory(ds.d, "PQ16x4fs")
22+
# index = faiss.index_factory(ds.d, "IVF64,PQ16x4fs")
23+
# index = faiss.index_factory(ds.d, "SQ8")
24+
25+
index.train(ds.get_train())
26+
index.add(ds.get_database())
27+
28+
29+
if False:
30+
faiss.omp_set_num_threads(1)
31+
print("PID=", os.getpid())
32+
input("press enter to continue")
33+
# for simd_level in faiss.NONE, faiss.AVX2, faiss.AVX512F:
34+
for simd_level in faiss.AVX2, faiss.AVX512F:
35+
36+
faiss.SIMDConfig.set_level(simd_level)
37+
print("simd_level=", faiss.SIMDConfig.get_level_name())
38+
for run in range(1000):
39+
D, I = index.search(ds.get_queries(), 10)
40+
41+
times = defaultdict(list)
42+
43+
for run in range(10):
44+
for simd_level in faiss.SIMDLevel_NONE, faiss.SIMDLevel_AVX2, faiss.SIMDLevel_AVX512F:
45+
faiss.SIMDConfig.set_level(simd_level)
46+
47+
t0 = time.time()
48+
D, I = index.search(ds.get_queries(), 10)
49+
t1 = time.time()
50+
51+
times[faiss.SIMDConfig.get_level_name()].append(t1 - t0)
52+
53+
for simd_level in times:
54+
print(f"simd_level={simd_level} search time: {np.mean(times[simd_level])*1000:.3f} ms")

faiss/CMakeLists.txt

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -186,11 +186,11 @@ set(FAISS_HEADERS
186186
impl/residual_quantizer_encode_steps.h
187187
impl/simd_result_handlers.h
188188
impl/zerocopy_io.h
189-
impl/code_distance/code_distance.h
190-
impl/code_distance/code_distance-generic.h
191-
impl/code_distance/code_distance-avx2.h
192-
impl/code_distance/code_distance-avx512.h
193-
impl/code_distance/code_distance-sve.h
189+
impl/pq_code_distance/code_distance.h
190+
impl/pq_code_distance/code_distance-generic.h
191+
impl/pq_code_distance/code_distance-avx2.h
192+
impl/pq_code_distance/code_distance-avx512.h
193+
impl/pq_code_distance/code_distance-sve.h
194194
invlists/BlockInvertedLists.h
195195
invlists/DirectMap.h
196196
invlists/InvertedLists.h

faiss/IndexAdditiveQuantizerFastScan.cpp

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,10 @@
77

88
#include <faiss/IndexAdditiveQuantizerFastScan.h>
99

10-
#include <cassert>
11-
#include <memory>
12-
1310
#include <faiss/impl/FaissAssert.h>
1411
#include <faiss/impl/LocalSearchQuantizer.h>
15-
#include <faiss/impl/LookupTableScaler.h>
1612
#include <faiss/impl/ResidualQuantizer.h>
17-
#include <faiss/impl/pq4_fast_scan.h>
13+
#include <faiss/impl/pq_4bit/pq4_fast_scan.h>
1814
#include <faiss/utils/quantize_lut.h>
1915
#include <faiss/utils/utils.h>
2016

@@ -199,12 +195,7 @@ void IndexAdditiveQuantizerFastScan::search(
199195
return;
200196
}
201197

202-
NormTableScaler scaler(norm_scale);
203-
if (metric_type == METRIC_L2) {
204-
search_dispatch_implem<true>(n, x, k, distances, labels, &scaler);
205-
} else {
206-
search_dispatch_implem<false>(n, x, k, distances, labels, &scaler);
207-
}
198+
search_dispatch_implem(n, x, k, distances, labels, norm_scale);
208199
}
209200

210201
void IndexAdditiveQuantizerFastScan::sa_decode(

0 commit comments

Comments
 (0)