Skip to content

Commit b7a3309

Browse files
Subhadeep Karanfacebook-github-bot
authored andcommitted
Split ScalarQuantizer code into independent parts (facebookresearch#4557)
Summary: Pull Request resolved: facebookresearch#4557 Pull Request resolved: facebookresearch#4296 Splits the ScalarQuantizer code into parts so that the AVX2 and AVX512 can be compiled independently. Reviewed By: mnorris11 Differential Revision: D73037185
1 parent 6604082 commit b7a3309

File tree

12 files changed

+1901
-1731
lines changed

12 files changed

+1901
-1731
lines changed

faiss/impl/ScalarQuantizer.cpp

Lines changed: 37 additions & 557 deletions
Large diffs are not rendered by default.

faiss/impl/code_distance/code_distance-avx512.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ struct PQCodeDistance<PQDecoder8, SIMDLevel::AVX512> {
192192
};
193193

194194
// explicit template instanciations
195-
// template struct PQCodeDistance<PQDecoder8, SIMDLevel::AVX512F>;
195+
// template struct PQCodeDistance<PQDecoder8, SIMDLevel::AVX512>;
196196

197197
// these two will automatically use the generic implementation
198198
template struct PQCodeDistance<PQDecoder16, SIMDLevel::AVX512>;

faiss/impl/scalar_quantizer/codecs.h

Lines changed: 16 additions & 206 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#pragma once
99

1010
#include <faiss/impl/ScalarQuantizer.h>
11+
#include <faiss/utils/simd_levels.h>
1112

1213
namespace faiss {
1314

@@ -19,7 +20,17 @@ namespace scalar_quantizer {
1920
* index).
2021
*/
2122

22-
struct Codec8bit {
23+
template <SIMDLevel>
24+
struct Codec8bit {};
25+
26+
template <SIMDLevel>
27+
struct Codec4bit {};
28+
29+
template <SIMDLevel>
30+
struct Codec6bit {};
31+
32+
template <>
33+
struct Codec8bit<SIMDLevel::NONE> {
2334
static FAISS_ALWAYS_INLINE void encode_component(
2435
float x,
2536
uint8_t* code,
@@ -32,45 +43,9 @@ struct Codec8bit {
3243
int i) {
3344
return (code[i] + 0.5f) / 255.0f;
3445
}
35-
36-
#if defined(__AVX512F__)
37-
static FAISS_ALWAYS_INLINE simd16float32
38-
decode_16_components(const uint8_t* code, int i) {
39-
const __m128i c16 = _mm_loadu_si128((__m128i*)(code + i));
40-
const __m512i i32 = _mm512_cvtepu8_epi32(c16);
41-
const __m512 f16 = _mm512_cvtepi32_ps(i32);
42-
const __m512 half_one_255 = _mm512_set1_ps(0.5f / 255.f);
43-
const __m512 one_255 = _mm512_set1_ps(1.f / 255.f);
44-
return simd16float32(_mm512_fmadd_ps(f16, one_255, half_one_255));
45-
}
46-
#elif defined(__AVX2__)
47-
static FAISS_ALWAYS_INLINE simd8float32
48-
decode_8_components(const uint8_t* code, int i) {
49-
const uint64_t c8 = *(uint64_t*)(code + i);
50-
51-
const __m128i i8 = _mm_set1_epi64x(c8);
52-
const __m256i i32 = _mm256_cvtepu8_epi32(i8);
53-
const __m256 f8 = _mm256_cvtepi32_ps(i32);
54-
const __m256 half_one_255 = _mm256_set1_ps(0.5f / 255.f);
55-
const __m256 one_255 = _mm256_set1_ps(1.f / 255.f);
56-
return simd8float32(_mm256_fmadd_ps(f8, one_255, half_one_255));
57-
}
58-
#endif
59-
60-
#ifdef USE_NEON
61-
static FAISS_ALWAYS_INLINE decode_8_components(const uint8_t* code, int i) {
62-
float32_t result[8] = {};
63-
for (size_t j = 0; j < 8; j++) {
64-
result[j] = decode_component(code, i + j);
65-
}
66-
float32x4_t res1 = vld1q_f32(result);
67-
float32x4_t res2 = vld1q_f32(result + 4);
68-
return simd8float32(float32x4x2_t{res1, res2});
69-
}
70-
#endif
7146
};
72-
73-
struct Codec4bit {
47+
template <>
48+
struct Codec4bit<SIMDLevel::NONE> {
7449
static FAISS_ALWAYS_INLINE void encode_component(
7550
float x,
7651
uint8_t* code,
@@ -83,64 +58,10 @@ struct Codec4bit {
8358
int i) {
8459
return (((code[i / 2] >> ((i & 1) << 2)) & 0xf) + 0.5f) / 15.0f;
8560
}
86-
87-
#if defined(__AVX512F__)
88-
static FAISS_ALWAYS_INLINE simd16float32
89-
decode_16_components(const uint8_t* code, int i) {
90-
uint64_t c8 = *(uint64_t*)(code + (i >> 1));
91-
uint64_t mask = 0x0f0f0f0f0f0f0f0f;
92-
uint64_t c8ev = c8 & mask;
93-
uint64_t c8od = (c8 >> 4) & mask;
94-
95-
__m128i c16 =
96-
_mm_unpacklo_epi8(_mm_set1_epi64x(c8ev), _mm_set1_epi64x(c8od));
97-
__m256i c8lo = _mm256_cvtepu8_epi32(c16);
98-
__m256i c8hi = _mm256_cvtepu8_epi32(_mm_srli_si128(c16, 8));
99-
__m512i i16 = _mm512_castsi256_si512(c8lo);
100-
i16 = _mm512_inserti32x8(i16, c8hi, 1);
101-
__m512 f16 = _mm512_cvtepi32_ps(i16);
102-
const __m512 half_one_255 = _mm512_set1_ps(0.5f / 15.f);
103-
const __m512 one_255 = _mm512_set1_ps(1.f / 15.f);
104-
return simd16float32(_mm512_fmadd_ps(f16, one_255, half_one_255));
105-
}
106-
#elif defined(__AVX2__)
107-
static FAISS_ALWAYS_INLINE simd8float32
108-
decode_8_components(const uint8_t* code, int i) {
109-
uint32_t c4 = *(uint32_t*)(code + (i >> 1));
110-
uint32_t mask = 0x0f0f0f0f;
111-
uint32_t c4ev = c4 & mask;
112-
uint32_t c4od = (c4 >> 4) & mask;
113-
114-
// the 8 lower bytes of c8 contain the values
115-
__m128i c8 =
116-
_mm_unpacklo_epi8(_mm_set1_epi32(c4ev), _mm_set1_epi32(c4od));
117-
__m128i c4lo = _mm_cvtepu8_epi32(c8);
118-
__m128i c4hi = _mm_cvtepu8_epi32(_mm_srli_si128(c8, 4));
119-
__m256i i8 = _mm256_castsi128_si256(c4lo);
120-
i8 = _mm256_insertf128_si256(i8, c4hi, 1);
121-
__m256 f8 = _mm256_cvtepi32_ps(i8);
122-
__m256 half = _mm256_set1_ps(0.5f);
123-
f8 = _mm256_add_ps(f8, half);
124-
__m256 one_255 = _mm256_set1_ps(1.f / 15.f);
125-
return simd8float32(_mm256_mul_ps(f8, one_255));
126-
}
127-
#endif
128-
129-
#ifdef USE_NEON
130-
static FAISS_ALWAYS_INLINE simd8float32
131-
decode_8_components(const uint8_t* code, int i) {
132-
float32_t result[8] = {};
133-
for (size_t j = 0; j < 8; j++) {
134-
result[j] = decode_component(code, i + j);
135-
}
136-
float32x4_t res1 = vld1q_f32(result);
137-
float32x4_t res2 = vld1q_f32(result + 4);
138-
return simd8float32({res1, res2});
139-
}
140-
#endif
14161
};
14262

143-
struct Codec6bit {
63+
template <>
64+
struct Codec6bit<SIMDLevel::NONE> {
14465
static FAISS_ALWAYS_INLINE void encode_component(
14566
float x,
14667
uint8_t* code,
@@ -188,117 +109,6 @@ struct Codec6bit {
188109
}
189110
return (bits + 0.5f) / 63.0f;
190111
}
191-
192-
#if defined(__AVX512F__)
193-
194-
static FAISS_ALWAYS_INLINE simd16float32
195-
decode_16_components(const uint8_t* code, int i) {
196-
// pure AVX512 implementation (not necessarily the fastest).
197-
// see:
198-
// https://github.com/zilliztech/knowhere/blob/main/thirdparty/faiss/faiss/impl/ScalarQuantizerCodec_avx512.h
199-
200-
// clang-format off
201-
202-
// 16 components, 16x6 bit=12 bytes
203-
const __m128i bit_6v =
204-
_mm_maskz_loadu_epi8(0b0000111111111111, code + (i >> 2) * 3);
205-
const __m256i bit_6v_256 = _mm256_broadcast_i32x4(bit_6v);
206-
207-
// 00 01 02 03 04 05 06 07 08 09 0A 0B 0C 0D 0E 0F
208-
// 00 01 02 03
209-
const __m256i shuffle_mask = _mm256_setr_epi16(
210-
0xFF00, 0x0100, 0x0201, 0xFF02,
211-
0xFF03, 0x0403, 0x0504, 0xFF05,
212-
0xFF06, 0x0706, 0x0807, 0xFF08,
213-
0xFF09, 0x0A09, 0x0B0A, 0xFF0B);
214-
const __m256i shuffled = _mm256_shuffle_epi8(bit_6v_256, shuffle_mask);
215-
216-
// 0: xxxxxxxx xx543210
217-
// 1: xxxx5432 10xxxxxx
218-
// 2: xxxxxx54 3210xxxx
219-
// 3: xxxxxxxx 543210xx
220-
const __m256i shift_right_v = _mm256_setr_epi16(
221-
0x0U, 0x6U, 0x4U, 0x2U,
222-
0x0U, 0x6U, 0x4U, 0x2U,
223-
0x0U, 0x6U, 0x4U, 0x2U,
224-
0x0U, 0x6U, 0x4U, 0x2U);
225-
__m256i shuffled_shifted = _mm256_srlv_epi16(shuffled, shift_right_v);
226-
227-
// remove unneeded bits
228-
shuffled_shifted =
229-
_mm256_and_si256(shuffled_shifted, _mm256_set1_epi16(0x003F));
230-
231-
// scale
232-
const __m512 f8 =
233-
_mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(shuffled_shifted));
234-
const __m512 half_one_255 = _mm512_set1_ps(0.5f / 63.f);
235-
const __m512 one_255 = _mm512_set1_ps(1.f / 63.f);
236-
return simd16float32(_mm512_fmadd_ps(f8, one_255, half_one_255));
237-
238-
// clang-format on
239-
}
240-
241-
#elif defined(__AVX2__)
242-
243-
/* Load 6 bytes that represent 8 6-bit values, return them as a
244-
* 8*32 bit vector register */
245-
static FAISS_ALWAYS_INLINE __m256i load6(const uint16_t* code16) {
246-
const __m128i perm = _mm_set_epi8(
247-
-1, 5, 5, 4, 4, 3, -1, 3, -1, 2, 2, 1, 1, 0, -1, 0);
248-
const __m256i shifts = _mm256_set_epi32(2, 4, 6, 0, 2, 4, 6, 0);
249-
250-
// load 6 bytes
251-
__m128i c1 =
252-
_mm_set_epi16(0, 0, 0, 0, 0, code16[2], code16[1], code16[0]);
253-
254-
// put in 8 * 32 bits
255-
__m128i c2 = _mm_shuffle_epi8(c1, perm);
256-
__m256i c3 = _mm256_cvtepi16_epi32(c2);
257-
258-
// shift and mask out useless bits
259-
__m256i c4 = _mm256_srlv_epi32(c3, shifts);
260-
__m256i c5 = _mm256_and_si256(_mm256_set1_epi32(63), c4);
261-
return c5;
262-
}
263-
264-
static FAISS_ALWAYS_INLINE simd8float32
265-
decode_8_components(const uint8_t* code, int i) {
266-
// // Faster code for Intel CPUs or AMD Zen3+, just keeping it here
267-
// // for the reference, maybe, it becomes used oned day.
268-
// const uint16_t* data16 = (const uint16_t*)(code + (i >> 2) * 3);
269-
// const uint32_t* data32 = (const uint32_t*)data16;
270-
// const uint64_t val = *data32 + ((uint64_t)data16[2] << 32);
271-
// const uint64_t vext = _pdep_u64(val, 0x3F3F3F3F3F3F3F3FULL);
272-
// const __m128i i8 = _mm_set1_epi64x(vext);
273-
// const __m256i i32 = _mm256_cvtepi8_epi32(i8);
274-
// const __m256 f8 = _mm256_cvtepi32_ps(i32);
275-
// const __m256 half_one_255 = _mm256_set1_ps(0.5f / 63.f);
276-
// const __m256 one_255 = _mm256_set1_ps(1.f / 63.f);
277-
// return _mm256_fmadd_ps(f8, one_255, half_one_255);
278-
279-
__m256i i8 = load6((const uint16_t*)(code + (i >> 2) * 3));
280-
__m256 f8 = _mm256_cvtepi32_ps(i8);
281-
// this could also be done with bit manipulations but it is
282-
// not obviously faster
283-
const __m256 half_one_255 = _mm256_set1_ps(0.5f / 63.f);
284-
const __m256 one_255 = _mm256_set1_ps(1.f / 63.f);
285-
return simd8float32(_mm256_fmadd_ps(f8, one_255, half_one_255));
286-
}
287-
288-
#endif
289-
290-
#ifdef USE_NEON
291-
static FAISS_ALWAYS_INLINE simd8float32
292-
decode_8_components(const uint8_t* code, int i) {
293-
float32_t result[8] = {};
294-
for (size_t j = 0; j < 8; j++) {
295-
result[j] = decode_component(code, i + j);
296-
}
297-
float32x4_t res1 = vld1q_f32(result);
298-
float32x4_t res2 = vld1q_f32(result + 4);
299-
return simd8float32(float32x4x2_t({res1, res2}));
300-
}
301-
#endif
302112
};
303113

304114
} // namespace scalar_quantizer

0 commit comments

Comments
 (0)