88#pragma once
99
1010#include < faiss/impl/ScalarQuantizer.h>
11+ #include < faiss/utils/simd_levels.h>
1112
1213namespace faiss {
1314
@@ -19,7 +20,17 @@ namespace scalar_quantizer {
1920 * index).
2021 */
2122
22- struct Codec8bit {
23+ template <SIMDLevel>
24+ struct Codec8bit {};
25+
26+ template <SIMDLevel>
27+ struct Codec4bit {};
28+
29+ template <SIMDLevel>
30+ struct Codec6bit {};
31+
32+ template <>
33+ struct Codec8bit <SIMDLevel::NONE> {
2334 static FAISS_ALWAYS_INLINE void encode_component (
2435 float x,
2536 uint8_t * code,
@@ -32,45 +43,9 @@ struct Codec8bit {
3243 int i) {
3344 return (code[i] + 0 .5f ) / 255 .0f ;
3445 }
35-
36- #if defined(__AVX512F__)
37- static FAISS_ALWAYS_INLINE simd16float32
38- decode_16_components (const uint8_t * code, int i) {
39- const __m128i c16 = _mm_loadu_si128 ((__m128i*)(code + i));
40- const __m512i i32 = _mm512_cvtepu8_epi32 (c16);
41- const __m512 f16 = _mm512_cvtepi32_ps (i32 );
42- const __m512 half_one_255 = _mm512_set1_ps (0 .5f / 255 .f );
43- const __m512 one_255 = _mm512_set1_ps (1 .f / 255 .f );
44- return simd16float32 (_mm512_fmadd_ps (f16 , one_255, half_one_255));
45- }
46- #elif defined(__AVX2__)
47- static FAISS_ALWAYS_INLINE simd8float32
48- decode_8_components (const uint8_t * code, int i) {
49- const uint64_t c8 = *(uint64_t *)(code + i);
50-
51- const __m128i i8 = _mm_set1_epi64x (c8);
52- const __m256i i32 = _mm256_cvtepu8_epi32 (i8 );
53- const __m256 f8 = _mm256_cvtepi32_ps (i32 );
54- const __m256 half_one_255 = _mm256_set1_ps (0 .5f / 255 .f );
55- const __m256 one_255 = _mm256_set1_ps (1 .f / 255 .f );
56- return simd8float32 (_mm256_fmadd_ps (f8 , one_255, half_one_255));
57- }
58- #endif
59-
60- #ifdef USE_NEON
61- static FAISS_ALWAYS_INLINE decode_8_components (const uint8_t * code, int i) {
62- float32_t result[8 ] = {};
63- for (size_t j = 0 ; j < 8 ; j++) {
64- result[j] = decode_component (code, i + j);
65- }
66- float32x4_t res1 = vld1q_f32 (result);
67- float32x4_t res2 = vld1q_f32 (result + 4 );
68- return simd8float32 (float32x4x2_t {res1, res2});
69- }
70- #endif
7146};
72-
73- struct Codec4bit {
47+ template <>
48+ struct Codec4bit <SIMDLevel::NONE> {
7449 static FAISS_ALWAYS_INLINE void encode_component (
7550 float x,
7651 uint8_t * code,
@@ -83,64 +58,10 @@ struct Codec4bit {
8358 int i) {
8459 return (((code[i / 2 ] >> ((i & 1 ) << 2 )) & 0xf ) + 0 .5f ) / 15 .0f ;
8560 }
86-
87- #if defined(__AVX512F__)
88- static FAISS_ALWAYS_INLINE simd16float32
89- decode_16_components (const uint8_t * code, int i) {
90- uint64_t c8 = *(uint64_t *)(code + (i >> 1 ));
91- uint64_t mask = 0x0f0f0f0f0f0f0f0f ;
92- uint64_t c8ev = c8 & mask;
93- uint64_t c8od = (c8 >> 4 ) & mask;
94-
95- __m128i c16 =
96- _mm_unpacklo_epi8 (_mm_set1_epi64x (c8ev), _mm_set1_epi64x (c8od));
97- __m256i c8lo = _mm256_cvtepu8_epi32 (c16);
98- __m256i c8hi = _mm256_cvtepu8_epi32 (_mm_srli_si128 (c16, 8 ));
99- __m512i i16 = _mm512_castsi256_si512 (c8lo);
100- i16 = _mm512_inserti32x8 (i16 , c8hi, 1 );
101- __m512 f16 = _mm512_cvtepi32_ps (i16 );
102- const __m512 half_one_255 = _mm512_set1_ps (0 .5f / 15 .f );
103- const __m512 one_255 = _mm512_set1_ps (1 .f / 15 .f );
104- return simd16float32 (_mm512_fmadd_ps (f16 , one_255, half_one_255));
105- }
106- #elif defined(__AVX2__)
107- static FAISS_ALWAYS_INLINE simd8float32
108- decode_8_components (const uint8_t * code, int i) {
109- uint32_t c4 = *(uint32_t *)(code + (i >> 1 ));
110- uint32_t mask = 0x0f0f0f0f ;
111- uint32_t c4ev = c4 & mask;
112- uint32_t c4od = (c4 >> 4 ) & mask;
113-
114- // the 8 lower bytes of c8 contain the values
115- __m128i c8 =
116- _mm_unpacklo_epi8 (_mm_set1_epi32 (c4ev), _mm_set1_epi32 (c4od));
117- __m128i c4lo = _mm_cvtepu8_epi32 (c8);
118- __m128i c4hi = _mm_cvtepu8_epi32 (_mm_srli_si128 (c8, 4 ));
119- __m256i i8 = _mm256_castsi128_si256 (c4lo);
120- i8 = _mm256_insertf128_si256 (i8 , c4hi, 1 );
121- __m256 f8 = _mm256_cvtepi32_ps (i8 );
122- __m256 half = _mm256_set1_ps (0 .5f );
123- f8 = _mm256_add_ps (f8 , half);
124- __m256 one_255 = _mm256_set1_ps (1 .f / 15 .f );
125- return simd8float32 (_mm256_mul_ps (f8 , one_255));
126- }
127- #endif
128-
129- #ifdef USE_NEON
130- static FAISS_ALWAYS_INLINE simd8float32
131- decode_8_components (const uint8_t * code, int i) {
132- float32_t result[8 ] = {};
133- for (size_t j = 0 ; j < 8 ; j++) {
134- result[j] = decode_component (code, i + j);
135- }
136- float32x4_t res1 = vld1q_f32 (result);
137- float32x4_t res2 = vld1q_f32 (result + 4 );
138- return simd8float32 ({res1, res2});
139- }
140- #endif
14161};
14262
143- struct Codec6bit {
63+ template <>
64+ struct Codec6bit <SIMDLevel::NONE> {
14465 static FAISS_ALWAYS_INLINE void encode_component (
14566 float x,
14667 uint8_t * code,
@@ -188,117 +109,6 @@ struct Codec6bit {
188109 }
189110 return (bits + 0 .5f ) / 63 .0f ;
190111 }
191-
192- #if defined(__AVX512F__)
193-
194- static FAISS_ALWAYS_INLINE simd16float32
195- decode_16_components (const uint8_t * code, int i) {
196- // pure AVX512 implementation (not necessarily the fastest).
197- // see:
198- // https://github.com/zilliztech/knowhere/blob/main/thirdparty/faiss/faiss/impl/ScalarQuantizerCodec_avx512.h
199-
200- // clang-format off
201-
202- // 16 components, 16x6 bit=12 bytes
203- const __m128i bit_6v =
204- _mm_maskz_loadu_epi8 (0b0000111111111111 , code + (i >> 2 ) * 3 );
205- const __m256i bit_6v_256 = _mm256_broadcast_i32x4 (bit_6v);
206-
207- // 00 01 02 03 04 05 06 07 08 09 0A 0B 0C 0D 0E 0F
208- // 00 01 02 03
209- const __m256i shuffle_mask = _mm256_setr_epi16 (
210- 0xFF00 , 0x0100 , 0x0201 , 0xFF02 ,
211- 0xFF03 , 0x0403 , 0x0504 , 0xFF05 ,
212- 0xFF06 , 0x0706 , 0x0807 , 0xFF08 ,
213- 0xFF09 , 0x0A09 , 0x0B0A , 0xFF0B );
214- const __m256i shuffled = _mm256_shuffle_epi8 (bit_6v_256, shuffle_mask);
215-
216- // 0: xxxxxxxx xx543210
217- // 1: xxxx5432 10xxxxxx
218- // 2: xxxxxx54 3210xxxx
219- // 3: xxxxxxxx 543210xx
220- const __m256i shift_right_v = _mm256_setr_epi16 (
221- 0x0U , 0x6U , 0x4U , 0x2U ,
222- 0x0U , 0x6U , 0x4U , 0x2U ,
223- 0x0U , 0x6U , 0x4U , 0x2U ,
224- 0x0U , 0x6U , 0x4U , 0x2U );
225- __m256i shuffled_shifted = _mm256_srlv_epi16 (shuffled, shift_right_v);
226-
227- // remove unneeded bits
228- shuffled_shifted =
229- _mm256_and_si256 (shuffled_shifted, _mm256_set1_epi16 (0x003F ));
230-
231- // scale
232- const __m512 f8 =
233- _mm512_cvtepi32_ps (_mm512_cvtepi16_epi32 (shuffled_shifted));
234- const __m512 half_one_255 = _mm512_set1_ps (0 .5f / 63 .f );
235- const __m512 one_255 = _mm512_set1_ps (1 .f / 63 .f );
236- return simd16float32 (_mm512_fmadd_ps (f8 , one_255, half_one_255));
237-
238- // clang-format on
239- }
240-
241- #elif defined(__AVX2__)
242-
243- /* Load 6 bytes that represent 8 6-bit values, return them as a
244- * 8*32 bit vector register */
245- static FAISS_ALWAYS_INLINE __m256i load6 (const uint16_t * code16) {
246- const __m128i perm = _mm_set_epi8 (
247- -1 , 5 , 5 , 4 , 4 , 3 , -1 , 3 , -1 , 2 , 2 , 1 , 1 , 0 , -1 , 0 );
248- const __m256i shifts = _mm256_set_epi32 (2 , 4 , 6 , 0 , 2 , 4 , 6 , 0 );
249-
250- // load 6 bytes
251- __m128i c1 =
252- _mm_set_epi16 (0 , 0 , 0 , 0 , 0 , code16[2 ], code16[1 ], code16[0 ]);
253-
254- // put in 8 * 32 bits
255- __m128i c2 = _mm_shuffle_epi8 (c1, perm);
256- __m256i c3 = _mm256_cvtepi16_epi32 (c2);
257-
258- // shift and mask out useless bits
259- __m256i c4 = _mm256_srlv_epi32 (c3, shifts);
260- __m256i c5 = _mm256_and_si256 (_mm256_set1_epi32 (63 ), c4);
261- return c5;
262- }
263-
264- static FAISS_ALWAYS_INLINE simd8float32
265- decode_8_components (const uint8_t * code, int i) {
266- // // Faster code for Intel CPUs or AMD Zen3+, just keeping it here
267- // // for the reference, maybe, it becomes used oned day.
268- // const uint16_t* data16 = (const uint16_t*)(code + (i >> 2) * 3);
269- // const uint32_t* data32 = (const uint32_t*)data16;
270- // const uint64_t val = *data32 + ((uint64_t)data16[2] << 32);
271- // const uint64_t vext = _pdep_u64(val, 0x3F3F3F3F3F3F3F3FULL);
272- // const __m128i i8 = _mm_set1_epi64x(vext);
273- // const __m256i i32 = _mm256_cvtepi8_epi32(i8);
274- // const __m256 f8 = _mm256_cvtepi32_ps(i32);
275- // const __m256 half_one_255 = _mm256_set1_ps(0.5f / 63.f);
276- // const __m256 one_255 = _mm256_set1_ps(1.f / 63.f);
277- // return _mm256_fmadd_ps(f8, one_255, half_one_255);
278-
279- __m256i i8 = load6 ((const uint16_t *)(code + (i >> 2 ) * 3 ));
280- __m256 f8 = _mm256_cvtepi32_ps (i8 );
281- // this could also be done with bit manipulations but it is
282- // not obviously faster
283- const __m256 half_one_255 = _mm256_set1_ps (0 .5f / 63 .f );
284- const __m256 one_255 = _mm256_set1_ps (1 .f / 63 .f );
285- return simd8float32 (_mm256_fmadd_ps (f8 , one_255, half_one_255));
286- }
287-
288- #endif
289-
290- #ifdef USE_NEON
291- static FAISS_ALWAYS_INLINE simd8float32
292- decode_8_components (const uint8_t * code, int i) {
293- float32_t result[8 ] = {};
294- for (size_t j = 0 ; j < 8 ; j++) {
295- result[j] = decode_component (code, i + j);
296- }
297- float32x4_t res1 = vld1q_f32 (result);
298- float32x4_t res2 = vld1q_f32 (result + 4 );
299- return simd8float32 (float32x4x2_t ({res1, res2}));
300- }
301- #endif
302112};
303113
304114} // namespace scalar_quantizer
0 commit comments