@@ -817,8 +817,9 @@ struct RangeSearchResults {
817817 * The scanning functions call their favorite precompute_*
818818 * function to precompute the tables they need.
819819 *****************************************************/
820- template <typename IDType, MetricType METRIC_TYPE, class PQDecoder >
820+ template <typename IDType, MetricType METRIC_TYPE, class PQCodeDistance >
821821struct IVFPQScannerT : QueryTables {
822+ using PQDecoder = typename PQCodeDistance::PQDecoder;
822823 const uint8_t * list_codes;
823824 const IDType* list_ids;
824825 size_t list_size;
@@ -894,7 +895,7 @@ struct IVFPQScannerT : QueryTables {
894895 float distance_1 = 0 ;
895896 float distance_2 = 0 ;
896897 float distance_3 = 0 ;
897- distance_four_codes<PQDecoder> (
898+ PQCodeDistance:: distance_four_codes (
898899 pq.M ,
899900 pq.nbits ,
900901 sim_table,
@@ -917,7 +918,7 @@ struct IVFPQScannerT : QueryTables {
917918
918919 if (counter >= 1 ) {
919920 float dis = dis0 +
920- distance_single_code<PQDecoder> (
921+ PQCodeDistance:: distance_single_code (
921922 pq.M ,
922923 pq.nbits ,
923924 sim_table,
@@ -926,7 +927,7 @@ struct IVFPQScannerT : QueryTables {
926927 }
927928 if (counter >= 2 ) {
928929 float dis = dis0 +
929- distance_single_code<PQDecoder> (
930+ PQCodeDistance:: distance_single_code (
930931 pq.M ,
931932 pq.nbits ,
932933 sim_table,
@@ -935,7 +936,7 @@ struct IVFPQScannerT : QueryTables {
935936 }
936937 if (counter >= 3 ) {
937938 float dis = dis0 +
938- distance_single_code<PQDecoder> (
939+ PQCodeDistance:: distance_single_code (
939940 pq.M ,
940941 pq.nbits ,
941942 sim_table,
@@ -1101,7 +1102,7 @@ struct IVFPQScannerT : QueryTables {
11011102 float distance_1 = dis0;
11021103 float distance_2 = dis0;
11031104 float distance_3 = dis0;
1104- distance_four_codes<PQDecoder> (
1105+ PQCodeDistance:: distance_four_codes (
11051106 pq.M ,
11061107 pq.nbits ,
11071108 sim_table,
@@ -1132,7 +1133,7 @@ struct IVFPQScannerT : QueryTables {
11321133 n_hamming_pass++;
11331134
11341135 float dis = dis0 +
1135- distance_single_code<PQDecoder> (
1136+ PQCodeDistance:: distance_single_code (
11361137 pq.M ,
11371138 pq.nbits ,
11381139 sim_table,
@@ -1152,7 +1153,7 @@ struct IVFPQScannerT : QueryTables {
11521153 n_hamming_pass++;
11531154
11541155 float dis = dis0 +
1155- distance_single_code<PQDecoder> (
1156+ PQCodeDistance:: distance_single_code (
11561157 pq.M ,
11571158 pq.nbits ,
11581159 sim_table,
@@ -1197,8 +1198,8 @@ struct IVFPQScannerT : QueryTables {
11971198 *
11981199 * use_sel: store or ignore the IDSelector
11991200 */
1200- template <MetricType METRIC_TYPE, class C , class PQDecoder , bool use_sel>
1201- struct IVFPQScanner : IVFPQScannerT<idx_t , METRIC_TYPE, PQDecoder >,
1201+ template <MetricType METRIC_TYPE, class C , class PQCodeDistance , bool use_sel>
1202+ struct IVFPQScanner : IVFPQScannerT<idx_t , METRIC_TYPE, PQCodeDistance >,
12021203 InvertedListScanner {
12031204 int precompute_mode;
12041205 const IDSelector* sel;
@@ -1208,7 +1209,7 @@ struct IVFPQScanner : IVFPQScannerT<idx_t, METRIC_TYPE, PQDecoder>,
12081209 bool store_pairs,
12091210 int precompute_mode,
12101211 const IDSelector* sel)
1211- : IVFPQScannerT<idx_t , METRIC_TYPE, PQDecoder >(ivfpq, nullptr ),
1212+ : IVFPQScannerT<idx_t , METRIC_TYPE, PQCodeDistance >(ivfpq, nullptr ),
12121213 precompute_mode (precompute_mode),
12131214 sel (sel) {
12141215 this ->store_pairs = store_pairs;
@@ -1228,7 +1229,7 @@ struct IVFPQScanner : IVFPQScannerT<idx_t, METRIC_TYPE, PQDecoder>,
12281229 float distance_to_code (const uint8_t * code) const override {
12291230 assert (precompute_mode == 2 );
12301231 float dis = this ->dis0 +
1231- distance_single_code<PQDecoder> (
1232+ PQCodeDistance:: distance_single_code (
12321233 this ->pq .M , this ->pq .nbits , this ->sim_table , code);
12331234 return dis;
12341235 }
@@ -1292,7 +1293,9 @@ struct IVFPQScanner : IVFPQScannerT<idx_t, METRIC_TYPE, PQDecoder>,
12921293 }
12931294};
12941295
1295- template <class PQDecoder , bool use_sel>
1296+ /* * follow 3 stages of template dispatching */
1297+
1298+ template <class PQCodeDistance , bool use_sel>
12961299InvertedListScanner* get_InvertedListScanner1 (
12971300 const IndexIVFPQ& index,
12981301 bool store_pairs,
@@ -1301,32 +1304,47 @@ InvertedListScanner* get_InvertedListScanner1(
13011304 return new IVFPQScanner<
13021305 METRIC_INNER_PRODUCT,
13031306 CMin<float , idx_t >,
1304- PQDecoder ,
1307+ PQCodeDistance ,
13051308 use_sel>(index, store_pairs, 2 , sel);
13061309 } else if (index.metric_type == METRIC_L2) {
13071310 return new IVFPQScanner<
13081311 METRIC_L2,
13091312 CMax<float , idx_t >,
1310- PQDecoder ,
1313+ PQCodeDistance ,
13111314 use_sel>(index, store_pairs, 2 , sel);
13121315 }
13131316 return nullptr ;
13141317}
13151318
1316- template <bool use_sel>
1319+ template <bool use_sel, SIMDLevel SL >
13171320InvertedListScanner* get_InvertedListScanner2 (
13181321 const IndexIVFPQ& index,
13191322 bool store_pairs,
13201323 const IDSelector* sel) {
13211324 if (index.pq .nbits == 8 ) {
1322- return get_InvertedListScanner1<PQDecoder8, use_sel>(
1323- index, store_pairs, sel);
1325+ return get_InvertedListScanner1<
1326+ PQCodeDistance<PQDecoder8, SL>,
1327+ use_sel>(index, store_pairs, sel);
13241328 } else if (index.pq .nbits == 16 ) {
1325- return get_InvertedListScanner1<PQDecoder16, use_sel>(
1326- index, store_pairs, sel);
1329+ return get_InvertedListScanner1<
1330+ PQCodeDistance<PQDecoder16, SL>,
1331+ use_sel>(index, store_pairs, sel);
1332+ } else {
1333+ return get_InvertedListScanner1<
1334+ PQCodeDistance<PQDecoderGeneric, SL>,
1335+ use_sel>(index, store_pairs, sel);
1336+ }
1337+ }
1338+
1339+ template <SIMDLevel SL>
1340+ InvertedListScanner* get_InvertedListScanner3 (
1341+ const IndexIVFPQ& index,
1342+ bool store_pairs,
1343+ const IDSelector* sel) {
1344+ if (sel) {
1345+ return get_InvertedListScanner2<true , SL>(index, store_pairs, sel);
13271346 } else {
1328- return get_InvertedListScanner1<PQDecoderGeneric, use_sel>(
1329- index, store_pairs, sel);
1347+ return get_InvertedListScanner2<false , SL>(index, store_pairs, sel);
13301348 }
13311349}
13321350
@@ -1336,11 +1354,7 @@ InvertedListScanner* IndexIVFPQ::get_InvertedListScanner(
13361354 bool store_pairs,
13371355 const IDSelector* sel,
13381356 const IVFSearchParameters*) const {
1339- if (sel) {
1340- return get_InvertedListScanner2<true >(*this , store_pairs, sel);
1341- } else {
1342- return get_InvertedListScanner2<false >(*this , store_pairs, sel);
1343- }
1357+ DISPATCH_SIMDLevel (get_InvertedListScanner3, *this , store_pairs, sel);
13441358 return nullptr ;
13451359}
13461360
0 commit comments