Skip to content

Commit b3bc6e2

Browse files
mdouzefacebook-github-bot
authored andcommitted
Use simdlib abstraction in ScalarQuantizer implementation, split off training code, split quantizer code into headers, Make headers more independent
Summary: Move the interface of SIMD functions to use the simdXfloat32 API to mutualize code. Begin splitting the ScalarQuantizer.cpp Continue splitting. Purely in header files for now. Differential Revision: D72945865
1 parent d2058fc commit b3bc6e2

File tree

8 files changed

+1959
-1853
lines changed

8 files changed

+1959
-1853
lines changed

faiss/impl/ScalarQuantizer.cpp

Lines changed: 114 additions & 1851 deletions
Large diffs are not rendered by default.

faiss/impl/ScalarQuantizer.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55
* LICENSE file in the root directory of this source tree.
66
*/
77

8-
// -*- c++ -*-
9-
108
#pragma once
119

1210
#include <faiss/impl/AuxIndexStructures.h>
Lines changed: 305 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,305 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
*
4+
* This source code is licensed under the MIT license found in the
5+
* LICENSE file in the root directory of this source tree.
6+
*/
7+
8+
#pragma once
9+
10+
#include <faiss/impl/ScalarQuantizer.h>
11+
12+
namespace faiss {
13+
14+
namespace scalar_quantizer {
15+
16+
/*******************************************************************
17+
* Codec: converts between values in [0, 1] and an index in a code
18+
* array. The "i" parameter is the vector component index (not byte
19+
* index).
20+
*/
21+
22+
struct Codec8bit {
23+
static FAISS_ALWAYS_INLINE void encode_component(
24+
float x,
25+
uint8_t* code,
26+
int i) {
27+
code[i] = (int)(255 * x);
28+
}
29+
30+
static FAISS_ALWAYS_INLINE float decode_component(
31+
const uint8_t* code,
32+
int i) {
33+
return (code[i] + 0.5f) / 255.0f;
34+
}
35+
36+
#if defined(__AVX512F__)
37+
static FAISS_ALWAYS_INLINE simd16float32
38+
decode_16_components(const uint8_t* code, int i) {
39+
const __m128i c16 = _mm_loadu_si128((__m128i*)(code + i));
40+
const __m512i i32 = _mm512_cvtepu8_epi32(c16);
41+
const __m512 f16 = _mm512_cvtepi32_ps(i32);
42+
const __m512 half_one_255 = _mm512_set1_ps(0.5f / 255.f);
43+
const __m512 one_255 = _mm512_set1_ps(1.f / 255.f);
44+
return simd16float32(_mm512_fmadd_ps(f16, one_255, half_one_255));
45+
}
46+
#elif defined(__AVX2__)
47+
static FAISS_ALWAYS_INLINE simd8float32
48+
decode_8_components(const uint8_t* code, int i) {
49+
const uint64_t c8 = *(uint64_t*)(code + i);
50+
51+
const __m128i i8 = _mm_set1_epi64x(c8);
52+
const __m256i i32 = _mm256_cvtepu8_epi32(i8);
53+
const __m256 f8 = _mm256_cvtepi32_ps(i32);
54+
const __m256 half_one_255 = _mm256_set1_ps(0.5f / 255.f);
55+
const __m256 one_255 = _mm256_set1_ps(1.f / 255.f);
56+
return simd8float32(_mm256_fmadd_ps(f8, one_255, half_one_255));
57+
}
58+
#endif
59+
60+
#ifdef USE_NEON
61+
static FAISS_ALWAYS_INLINE decode_8_components(const uint8_t* code, int i) {
62+
float32_t result[8] = {};
63+
for (size_t j = 0; j < 8; j++) {
64+
result[j] = decode_component(code, i + j);
65+
}
66+
float32x4_t res1 = vld1q_f32(result);
67+
float32x4_t res2 = vld1q_f32(result + 4);
68+
return simd8float32(float32x4x2_t{res1, res2});
69+
}
70+
#endif
71+
};
72+
73+
struct Codec4bit {
74+
static FAISS_ALWAYS_INLINE void encode_component(
75+
float x,
76+
uint8_t* code,
77+
int i) {
78+
code[i / 2] |= (int)(x * 15.0) << ((i & 1) << 2);
79+
}
80+
81+
static FAISS_ALWAYS_INLINE float decode_component(
82+
const uint8_t* code,
83+
int i) {
84+
return (((code[i / 2] >> ((i & 1) << 2)) & 0xf) + 0.5f) / 15.0f;
85+
}
86+
87+
#if defined(__AVX512F__)
88+
static FAISS_ALWAYS_INLINE simd16float32
89+
decode_16_components(const uint8_t* code, int i) {
90+
uint64_t c8 = *(uint64_t*)(code + (i >> 1));
91+
uint64_t mask = 0x0f0f0f0f0f0f0f0f;
92+
uint64_t c8ev = c8 & mask;
93+
uint64_t c8od = (c8 >> 4) & mask;
94+
95+
__m128i c16 =
96+
_mm_unpacklo_epi8(_mm_set1_epi64x(c8ev), _mm_set1_epi64x(c8od));
97+
__m256i c8lo = _mm256_cvtepu8_epi32(c16);
98+
__m256i c8hi = _mm256_cvtepu8_epi32(_mm_srli_si128(c16, 8));
99+
__m512i i16 = _mm512_castsi256_si512(c8lo);
100+
i16 = _mm512_inserti32x8(i16, c8hi, 1);
101+
__m512 f16 = _mm512_cvtepi32_ps(i16);
102+
const __m512 half_one_255 = _mm512_set1_ps(0.5f / 15.f);
103+
const __m512 one_255 = _mm512_set1_ps(1.f / 15.f);
104+
return simd16float32(_mm512_fmadd_ps(f16, one_255, half_one_255));
105+
}
106+
#elif defined(__AVX2__)
107+
static FAISS_ALWAYS_INLINE simd8float32
108+
decode_8_components(const uint8_t* code, int i) {
109+
uint32_t c4 = *(uint32_t*)(code + (i >> 1));
110+
uint32_t mask = 0x0f0f0f0f;
111+
uint32_t c4ev = c4 & mask;
112+
uint32_t c4od = (c4 >> 4) & mask;
113+
114+
// the 8 lower bytes of c8 contain the values
115+
__m128i c8 =
116+
_mm_unpacklo_epi8(_mm_set1_epi32(c4ev), _mm_set1_epi32(c4od));
117+
__m128i c4lo = _mm_cvtepu8_epi32(c8);
118+
__m128i c4hi = _mm_cvtepu8_epi32(_mm_srli_si128(c8, 4));
119+
__m256i i8 = _mm256_castsi128_si256(c4lo);
120+
i8 = _mm256_insertf128_si256(i8, c4hi, 1);
121+
__m256 f8 = _mm256_cvtepi32_ps(i8);
122+
__m256 half = _mm256_set1_ps(0.5f);
123+
f8 = _mm256_add_ps(f8, half);
124+
__m256 one_255 = _mm256_set1_ps(1.f / 15.f);
125+
return simd8float32(_mm256_mul_ps(f8, one_255));
126+
}
127+
#endif
128+
129+
#ifdef USE_NEON
130+
static FAISS_ALWAYS_INLINE simd8float32
131+
decode_8_components(const uint8_t* code, int i) {
132+
float32_t result[8] = {};
133+
for (size_t j = 0; j < 8; j++) {
134+
result[j] = decode_component(code, i + j);
135+
}
136+
float32x4_t res1 = vld1q_f32(result);
137+
float32x4_t res2 = vld1q_f32(result + 4);
138+
return simd8float32({res1, res2});
139+
}
140+
#endif
141+
};
142+
143+
struct Codec6bit {
144+
static FAISS_ALWAYS_INLINE void encode_component(
145+
float x,
146+
uint8_t* code,
147+
int i) {
148+
int bits = (int)(x * 63.0);
149+
code += (i >> 2) * 3;
150+
switch (i & 3) {
151+
case 0:
152+
code[0] |= bits;
153+
break;
154+
case 1:
155+
code[0] |= bits << 6;
156+
code[1] |= bits >> 2;
157+
break;
158+
case 2:
159+
code[1] |= bits << 4;
160+
code[2] |= bits >> 4;
161+
break;
162+
case 3:
163+
code[2] |= bits << 2;
164+
break;
165+
}
166+
}
167+
168+
static FAISS_ALWAYS_INLINE float decode_component(
169+
const uint8_t* code,
170+
int i) {
171+
uint8_t bits;
172+
code += (i >> 2) * 3;
173+
switch (i & 3) {
174+
case 0:
175+
bits = code[0] & 0x3f;
176+
break;
177+
case 1:
178+
bits = code[0] >> 6;
179+
bits |= (code[1] & 0xf) << 2;
180+
break;
181+
case 2:
182+
bits = code[1] >> 4;
183+
bits |= (code[2] & 3) << 4;
184+
break;
185+
case 3:
186+
bits = code[2] >> 2;
187+
break;
188+
}
189+
return (bits + 0.5f) / 63.0f;
190+
}
191+
192+
#if defined(__AVX512F__)
193+
194+
static FAISS_ALWAYS_INLINE simd16float32
195+
decode_16_components(const uint8_t* code, int i) {
196+
// pure AVX512 implementation (not necessarily the fastest).
197+
// see:
198+
// https://github.com/zilliztech/knowhere/blob/main/thirdparty/faiss/faiss/impl/ScalarQuantizerCodec_avx512.h
199+
200+
// clang-format off
201+
202+
// 16 components, 16x6 bit=12 bytes
203+
const __m128i bit_6v =
204+
_mm_maskz_loadu_epi8(0b0000111111111111, code + (i >> 2) * 3);
205+
const __m256i bit_6v_256 = _mm256_broadcast_i32x4(bit_6v);
206+
207+
// 00 01 02 03 04 05 06 07 08 09 0A 0B 0C 0D 0E 0F
208+
// 00 01 02 03
209+
const __m256i shuffle_mask = _mm256_setr_epi16(
210+
0xFF00, 0x0100, 0x0201, 0xFF02,
211+
0xFF03, 0x0403, 0x0504, 0xFF05,
212+
0xFF06, 0x0706, 0x0807, 0xFF08,
213+
0xFF09, 0x0A09, 0x0B0A, 0xFF0B);
214+
const __m256i shuffled = _mm256_shuffle_epi8(bit_6v_256, shuffle_mask);
215+
216+
// 0: xxxxxxxx xx543210
217+
// 1: xxxx5432 10xxxxxx
218+
// 2: xxxxxx54 3210xxxx
219+
// 3: xxxxxxxx 543210xx
220+
const __m256i shift_right_v = _mm256_setr_epi16(
221+
0x0U, 0x6U, 0x4U, 0x2U,
222+
0x0U, 0x6U, 0x4U, 0x2U,
223+
0x0U, 0x6U, 0x4U, 0x2U,
224+
0x0U, 0x6U, 0x4U, 0x2U);
225+
__m256i shuffled_shifted = _mm256_srlv_epi16(shuffled, shift_right_v);
226+
227+
// remove unneeded bits
228+
shuffled_shifted =
229+
_mm256_and_si256(shuffled_shifted, _mm256_set1_epi16(0x003F));
230+
231+
// scale
232+
const __m512 f8 =
233+
_mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(shuffled_shifted));
234+
const __m512 half_one_255 = _mm512_set1_ps(0.5f / 63.f);
235+
const __m512 one_255 = _mm512_set1_ps(1.f / 63.f);
236+
return simd16float32(_mm512_fmadd_ps(f8, one_255, half_one_255));
237+
238+
// clang-format on
239+
}
240+
241+
#elif defined(__AVX2__)
242+
243+
/* Load 6 bytes that represent 8 6-bit values, return them as a
244+
* 8*32 bit vector register */
245+
static FAISS_ALWAYS_INLINE __m256i load6(const uint16_t* code16) {
246+
const __m128i perm = _mm_set_epi8(
247+
-1, 5, 5, 4, 4, 3, -1, 3, -1, 2, 2, 1, 1, 0, -1, 0);
248+
const __m256i shifts = _mm256_set_epi32(2, 4, 6, 0, 2, 4, 6, 0);
249+
250+
// load 6 bytes
251+
__m128i c1 =
252+
_mm_set_epi16(0, 0, 0, 0, 0, code16[2], code16[1], code16[0]);
253+
254+
// put in 8 * 32 bits
255+
__m128i c2 = _mm_shuffle_epi8(c1, perm);
256+
__m256i c3 = _mm256_cvtepi16_epi32(c2);
257+
258+
// shift and mask out useless bits
259+
__m256i c4 = _mm256_srlv_epi32(c3, shifts);
260+
__m256i c5 = _mm256_and_si256(_mm256_set1_epi32(63), c4);
261+
return c5;
262+
}
263+
264+
static FAISS_ALWAYS_INLINE simd8float32
265+
decode_8_components(const uint8_t* code, int i) {
266+
// // Faster code for Intel CPUs or AMD Zen3+, just keeping it here
267+
// // for the reference, maybe, it becomes used oned day.
268+
// const uint16_t* data16 = (const uint16_t*)(code + (i >> 2) * 3);
269+
// const uint32_t* data32 = (const uint32_t*)data16;
270+
// const uint64_t val = *data32 + ((uint64_t)data16[2] << 32);
271+
// const uint64_t vext = _pdep_u64(val, 0x3F3F3F3F3F3F3F3FULL);
272+
// const __m128i i8 = _mm_set1_epi64x(vext);
273+
// const __m256i i32 = _mm256_cvtepi8_epi32(i8);
274+
// const __m256 f8 = _mm256_cvtepi32_ps(i32);
275+
// const __m256 half_one_255 = _mm256_set1_ps(0.5f / 63.f);
276+
// const __m256 one_255 = _mm256_set1_ps(1.f / 63.f);
277+
// return _mm256_fmadd_ps(f8, one_255, half_one_255);
278+
279+
__m256i i8 = load6((const uint16_t*)(code + (i >> 2) * 3));
280+
__m256 f8 = _mm256_cvtepi32_ps(i8);
281+
// this could also be done with bit manipulations but it is
282+
// not obviously faster
283+
const __m256 half_one_255 = _mm256_set1_ps(0.5f / 63.f);
284+
const __m256 one_255 = _mm256_set1_ps(1.f / 63.f);
285+
return simd8float32(_mm256_fmadd_ps(f8, one_255, half_one_255));
286+
}
287+
288+
#endif
289+
290+
#ifdef USE_NEON
291+
static FAISS_ALWAYS_INLINE simd8float32
292+
decode_8_components(const uint8_t* code, int i) {
293+
float32_t result[8] = {};
294+
for (size_t j = 0; j < 8; j++) {
295+
result[j] = decode_component(code, i + j);
296+
}
297+
float32x4_t res1 = vld1q_f32(result);
298+
float32x4_t res2 = vld1q_f32(result + 4);
299+
return simd8float32(float32x4x2_t({res1, res2}));
300+
}
301+
#endif
302+
};
303+
304+
} // namespace scalar_quantizer
305+
} // namespace faiss

0 commit comments

Comments
 (0)