Skip to content

Commit d2058fc

Browse files
mdouzefacebook-github-bot
authored andcommitted
moved IndexIVFPQ and IndexPQ to dynamic dispatch (facebookresearch#4291)
Summary: Pull Request resolved: facebookresearch#4291 moved IndexIVFPQ and IndexPQ to dynamic dispatch. Since the code was already quite modular (thanks Alex!), this boils down to make independent cpp files for the different SIMD versions. Differential Revision: D72937709
1 parent cdb5faf commit d2058fc

17 files changed

+896
-1094
lines changed

faiss/IndexIVFPQ.cpp

Lines changed: 41 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -817,8 +817,9 @@ struct RangeSearchResults {
817817
* The scanning functions call their favorite precompute_*
818818
* function to precompute the tables they need.
819819
*****************************************************/
820-
template <typename IDType, MetricType METRIC_TYPE, class PQDecoder>
820+
template <typename IDType, MetricType METRIC_TYPE, class PQCodeDistance>
821821
struct IVFPQScannerT : QueryTables {
822+
using PQDecoder = typename PQCodeDistance::PQDecoder;
822823
const uint8_t* list_codes;
823824
const IDType* list_ids;
824825
size_t list_size;
@@ -894,7 +895,7 @@ struct IVFPQScannerT : QueryTables {
894895
float distance_1 = 0;
895896
float distance_2 = 0;
896897
float distance_3 = 0;
897-
distance_four_codes<PQDecoder>(
898+
PQCodeDistance::distance_four_codes(
898899
pq.M,
899900
pq.nbits,
900901
sim_table,
@@ -917,7 +918,7 @@ struct IVFPQScannerT : QueryTables {
917918

918919
if (counter >= 1) {
919920
float dis = dis0 +
920-
distance_single_code<PQDecoder>(
921+
PQCodeDistance::distance_single_code(
921922
pq.M,
922923
pq.nbits,
923924
sim_table,
@@ -926,7 +927,7 @@ struct IVFPQScannerT : QueryTables {
926927
}
927928
if (counter >= 2) {
928929
float dis = dis0 +
929-
distance_single_code<PQDecoder>(
930+
PQCodeDistance::distance_single_code(
930931
pq.M,
931932
pq.nbits,
932933
sim_table,
@@ -935,7 +936,7 @@ struct IVFPQScannerT : QueryTables {
935936
}
936937
if (counter >= 3) {
937938
float dis = dis0 +
938-
distance_single_code<PQDecoder>(
939+
PQCodeDistance::distance_single_code(
939940
pq.M,
940941
pq.nbits,
941942
sim_table,
@@ -1101,7 +1102,7 @@ struct IVFPQScannerT : QueryTables {
11011102
float distance_1 = dis0;
11021103
float distance_2 = dis0;
11031104
float distance_3 = dis0;
1104-
distance_four_codes<PQDecoder>(
1105+
PQCodeDistance::distance_four_codes(
11051106
pq.M,
11061107
pq.nbits,
11071108
sim_table,
@@ -1132,7 +1133,7 @@ struct IVFPQScannerT : QueryTables {
11321133
n_hamming_pass++;
11331134

11341135
float dis = dis0 +
1135-
distance_single_code<PQDecoder>(
1136+
PQCodeDistance::distance_single_code(
11361137
pq.M,
11371138
pq.nbits,
11381139
sim_table,
@@ -1152,7 +1153,7 @@ struct IVFPQScannerT : QueryTables {
11521153
n_hamming_pass++;
11531154

11541155
float dis = dis0 +
1155-
distance_single_code<PQDecoder>(
1156+
PQCodeDistance::distance_single_code(
11561157
pq.M,
11571158
pq.nbits,
11581159
sim_table,
@@ -1197,8 +1198,8 @@ struct IVFPQScannerT : QueryTables {
11971198
*
11981199
* use_sel: store or ignore the IDSelector
11991200
*/
1200-
template <MetricType METRIC_TYPE, class C, class PQDecoder, bool use_sel>
1201-
struct IVFPQScanner : IVFPQScannerT<idx_t, METRIC_TYPE, PQDecoder>,
1201+
template <MetricType METRIC_TYPE, class C, class PQCodeDistance, bool use_sel>
1202+
struct IVFPQScanner : IVFPQScannerT<idx_t, METRIC_TYPE, PQCodeDistance>,
12021203
InvertedListScanner {
12031204
int precompute_mode;
12041205
const IDSelector* sel;
@@ -1208,7 +1209,7 @@ struct IVFPQScanner : IVFPQScannerT<idx_t, METRIC_TYPE, PQDecoder>,
12081209
bool store_pairs,
12091210
int precompute_mode,
12101211
const IDSelector* sel)
1211-
: IVFPQScannerT<idx_t, METRIC_TYPE, PQDecoder>(ivfpq, nullptr),
1212+
: IVFPQScannerT<idx_t, METRIC_TYPE, PQCodeDistance>(ivfpq, nullptr),
12121213
precompute_mode(precompute_mode),
12131214
sel(sel) {
12141215
this->store_pairs = store_pairs;
@@ -1228,7 +1229,7 @@ struct IVFPQScanner : IVFPQScannerT<idx_t, METRIC_TYPE, PQDecoder>,
12281229
float distance_to_code(const uint8_t* code) const override {
12291230
assert(precompute_mode == 2);
12301231
float dis = this->dis0 +
1231-
distance_single_code<PQDecoder>(
1232+
PQCodeDistance::distance_single_code(
12321233
this->pq.M, this->pq.nbits, this->sim_table, code);
12331234
return dis;
12341235
}
@@ -1292,7 +1293,9 @@ struct IVFPQScanner : IVFPQScannerT<idx_t, METRIC_TYPE, PQDecoder>,
12921293
}
12931294
};
12941295

1295-
template <class PQDecoder, bool use_sel>
1296+
/** follow 3 stages of template dispatching */
1297+
1298+
template <class PQCodeDistance, bool use_sel>
12961299
InvertedListScanner* get_InvertedListScanner1(
12971300
const IndexIVFPQ& index,
12981301
bool store_pairs,
@@ -1301,32 +1304,47 @@ InvertedListScanner* get_InvertedListScanner1(
13011304
return new IVFPQScanner<
13021305
METRIC_INNER_PRODUCT,
13031306
CMin<float, idx_t>,
1304-
PQDecoder,
1307+
PQCodeDistance,
13051308
use_sel>(index, store_pairs, 2, sel);
13061309
} else if (index.metric_type == METRIC_L2) {
13071310
return new IVFPQScanner<
13081311
METRIC_L2,
13091312
CMax<float, idx_t>,
1310-
PQDecoder,
1313+
PQCodeDistance,
13111314
use_sel>(index, store_pairs, 2, sel);
13121315
}
13131316
return nullptr;
13141317
}
13151318

1316-
template <bool use_sel>
1319+
template <bool use_sel, SIMDLevel SL>
13171320
InvertedListScanner* get_InvertedListScanner2(
13181321
const IndexIVFPQ& index,
13191322
bool store_pairs,
13201323
const IDSelector* sel) {
13211324
if (index.pq.nbits == 8) {
1322-
return get_InvertedListScanner1<PQDecoder8, use_sel>(
1323-
index, store_pairs, sel);
1325+
return get_InvertedListScanner1<
1326+
PQCodeDistance<PQDecoder8, SL>,
1327+
use_sel>(index, store_pairs, sel);
13241328
} else if (index.pq.nbits == 16) {
1325-
return get_InvertedListScanner1<PQDecoder16, use_sel>(
1326-
index, store_pairs, sel);
1329+
return get_InvertedListScanner1<
1330+
PQCodeDistance<PQDecoder16, SL>,
1331+
use_sel>(index, store_pairs, sel);
1332+
} else {
1333+
return get_InvertedListScanner1<
1334+
PQCodeDistance<PQDecoderGeneric, SL>,
1335+
use_sel>(index, store_pairs, sel);
1336+
}
1337+
}
1338+
1339+
template <SIMDLevel SL>
1340+
InvertedListScanner* get_InvertedListScanner3(
1341+
const IndexIVFPQ& index,
1342+
bool store_pairs,
1343+
const IDSelector* sel) {
1344+
if (sel) {
1345+
return get_InvertedListScanner2<true, SL>(index, store_pairs, sel);
13271346
} else {
1328-
return get_InvertedListScanner1<PQDecoderGeneric, use_sel>(
1329-
index, store_pairs, sel);
1347+
return get_InvertedListScanner2<false, SL>(index, store_pairs, sel);
13301348
}
13311349
}
13321350

@@ -1336,11 +1354,7 @@ InvertedListScanner* IndexIVFPQ::get_InvertedListScanner(
13361354
bool store_pairs,
13371355
const IDSelector* sel,
13381356
const IVFSearchParameters*) const {
1339-
if (sel) {
1340-
return get_InvertedListScanner2<true>(*this, store_pairs, sel);
1341-
} else {
1342-
return get_InvertedListScanner2<false>(*this, store_pairs, sel);
1343-
}
1357+
DISPATCH_SIMDLevel(get_InvertedListScanner3, *this, store_pairs, sel);
13441358
return nullptr;
13451359
}
13461360

faiss/IndexPQ.cpp

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ void IndexPQ::train(idx_t n, const float* x) {
7272

7373
namespace {
7474

75-
template <class PQDecoder>
75+
template <class PQCodeDistance>
7676
struct PQDistanceComputer : FlatCodesDistanceComputer {
7777
size_t d;
7878
MetricType metric;
@@ -85,7 +85,7 @@ struct PQDistanceComputer : FlatCodesDistanceComputer {
8585
float distance_to_code(const uint8_t* code) final {
8686
ndis++;
8787

88-
float dis = distance_single_code<PQDecoder>(
88+
float dis = PQCodeDistance::distance_single_code(
8989
pq.M, pq.nbits, precomputed_table.data(), code);
9090
return dis;
9191
}
@@ -94,8 +94,10 @@ struct PQDistanceComputer : FlatCodesDistanceComputer {
9494
FAISS_THROW_IF_NOT(sdc);
9595
const float* sdci = sdc;
9696
float accu = 0;
97-
PQDecoder codei(codes + i * code_size, pq.nbits);
98-
PQDecoder codej(codes + j * code_size, pq.nbits);
97+
typename PQCodeDistance::PQDecoder codei(
98+
codes + i * code_size, pq.nbits);
99+
typename PQCodeDistance::PQDecoder codej(
100+
codes + j * code_size, pq.nbits);
99101

100102
for (int l = 0; l < pq.M; l++) {
101103
accu += sdci[codei.decode() + (codej.decode() << codei.nbits)];
@@ -131,16 +133,24 @@ struct PQDistanceComputer : FlatCodesDistanceComputer {
131133
}
132134
};
133135

136+
template <SIMDLevel SL>
137+
FlatCodesDistanceComputer* get_FlatCodesDistanceComputer1(
138+
const IndexPQ& index) {
139+
int nbits = index.pq.nbits;
140+
if (nbits == 8) {
141+
return new PQDistanceComputer<PQCodeDistance<PQDecoder8, SL>>(index);
142+
} else if (nbits == 16) {
143+
return new PQDistanceComputer<PQCodeDistance<PQDecoder16, SL>>(index);
144+
} else {
145+
return new PQDistanceComputer<PQCodeDistance<PQDecoderGeneric, SL>>(
146+
index);
147+
}
148+
}
149+
134150
} // namespace
135151

136152
FlatCodesDistanceComputer* IndexPQ::get_FlatCodesDistanceComputer() const {
137-
if (pq.nbits == 8) {
138-
return new PQDistanceComputer<PQDecoder8>(*this);
139-
} else if (pq.nbits == 16) {
140-
return new PQDistanceComputer<PQDecoder16>(*this);
141-
} else {
142-
return new PQDistanceComputer<PQDecoderGeneric>(*this);
143-
}
153+
DISPATCH_SIMDLevel(get_FlatCodesDistanceComputer1, *this);
144154
}
145155

146156
/*****************************************

0 commit comments

Comments
 (0)