diff --git a/faiss/utils/distances.h b/faiss/utils/distances.h index 80d2cfc699..3531b10845 100644 --- a/faiss/utils/distances.h +++ b/faiss/utils/distances.h @@ -15,6 +15,7 @@ #include #include +#include namespace faiss { @@ -27,15 +28,27 @@ struct IDSelector; /// Squared L2 distance between two vectors float fvec_L2sqr(const float* x, const float* y, size_t d); +template +float fvec_L2sqr(const float* x, const float* y, size_t d); + /// inner product float fvec_inner_product(const float* x, const float* y, size_t d); +template +float fvec_inner_product(const float* x, const float* y, size_t d); + /// L1 distance float fvec_L1(const float* x, const float* y, size_t d); +template +float fvec_L1(const float* x, const float* y, size_t d); + /// infinity distance float fvec_Linf(const float* x, const float* y, size_t d); +template +float fvec_Linf(const float* x, const float* y, size_t d); + /// Special version of inner product that computes 4 distances /// between x and yi, which is performance oriented. void fvec_inner_product_batch_4( @@ -50,6 +63,19 @@ void fvec_inner_product_batch_4( float& dis2, float& dis3); +template +void fvec_inner_product_batch_4( + const float* x, + const float* y0, + const float* y1, + const float* y2, + const float* y3, + const size_t d, + float& dis0, + float& dis1, + float& dis2, + float& dis3); + /// Special version of L2sqr that computes 4 distances /// between x and yi, which is performance oriented. void fvec_L2sqr_batch_4( @@ -64,6 +90,19 @@ void fvec_L2sqr_batch_4( float& dis2, float& dis3); +template +void fvec_L2sqr_batch_4( + const float* x, + const float* y0, + const float* y1, + const float* y2, + const float* y3, + const size_t d, + float& dis0, + float& dis1, + float& dis2, + float& dis3); + /** Compute pairwise distances between sets of vectors * * @param d dimension of the vectors @@ -93,6 +132,14 @@ void fvec_inner_products_ny( size_t d, size_t ny); +template +void fvec_inner_products_ny( + float* ip, /* output inner product */ + const float* x, + const float* y, + size_t d, + size_t ny); + /* compute ny square L2 distance between x and a set of contiguous y vectors */ void fvec_L2sqr_ny( float* dis, @@ -101,6 +148,14 @@ void fvec_L2sqr_ny( size_t d, size_t ny); +template +void fvec_L2sqr_ny( + float* dis, + const float* x, + const float* y, + size_t d, + size_t ny); + /* compute ny square L2 distance between x and a set of transposed contiguous y vectors. squared lengths of y should be provided as well */ void fvec_L2sqr_ny_transposed( @@ -112,6 +167,16 @@ void fvec_L2sqr_ny_transposed( size_t d_offset, size_t ny); +template +void fvec_L2sqr_ny_transposed( + float* dis, + const float* x, + const float* y, + const float* y_sqlen, + size_t d, + size_t d_offset, + size_t ny); + /* compute ny square L2 distance between x and a set of contiguous y vectors and return the index of the nearest vector. return 0 if ny == 0. */ @@ -122,6 +187,14 @@ size_t fvec_L2sqr_ny_nearest( size_t d, size_t ny); +template +size_t fvec_L2sqr_ny_nearest( + float* distances_tmp_buffer, + const float* x, + const float* y, + size_t d, + size_t ny); + /* compute ny square L2 distance between x and a set of transposed contiguous y vectors and return the index of the nearest vector. squared lengths of y should be provided as well @@ -135,9 +208,22 @@ size_t fvec_L2sqr_ny_nearest_y_transposed( size_t d_offset, size_t ny); +template +size_t fvec_L2sqr_ny_nearest_y_transposed( + float* distances_tmp_buffer, + const float* x, + const float* y, + const float* y_sqlen, + size_t d, + size_t d_offset, + size_t ny); + /** squared norm of a vector */ float fvec_norm_L2sqr(const float* x, size_t d); +template +float fvec_norm_L2sqr(const float* x, size_t d); + /** compute the L2 norms for a set of vectors * * @param norms output norms, size nx @@ -473,6 +559,10 @@ void compute_PQ_dis_tables_dsub2( */ void fvec_madd(size_t n, const float* a, float bf, const float* b, float* c); +/* same statically */ +template +void fvec_madd(size_t n, const float* a, float bf, const float* b, float* c); + /** same as fvec_madd, also return index of the min of the result table * @return index of the min of table c */ @@ -483,4 +573,12 @@ int fvec_madd_and_argmin( const float* b, float* c); +template +int fvec_madd_and_argmin( + size_t n, + const float* a, + float bf, + const float* b, + float* c); + } // namespace faiss diff --git a/faiss/utils/distances_simd.cpp b/faiss/utils/distances_simd.cpp index c6ff8b57cb..ab174a5a54 100644 --- a/faiss/utils/distances_simd.cpp +++ b/faiss/utils/distances_simd.cpp @@ -10,7 +10,6 @@ #include #include -#include #include #include #include @@ -19,85 +18,28 @@ #include #include -#ifdef __SSE3__ -#include -#endif - -#if defined(__AVX512F__) -#include -#elif defined(__AVX2__) -#include -#endif - -#ifdef __ARM_FEATURE_SVE -#include -#endif - -#ifdef __aarch64__ -#include -#endif +#define AUTOVEC_LEVEL SIMDLevel::NONE +#include namespace faiss { -#ifdef __AVX__ -#define USE_AVX -#endif - -/********************************************************* - * Optimized distance computations - *********************************************************/ - -/* Functions to compute: - - L2 distance between 2 vectors - - inner product between 2 vectors - - L2 norm of a vector - - The functions should probably not be invoked when a large number of - vectors are be processed in batch (in which case Matrix multiply - is faster), but may be useful for comparing vectors isolated in - memory. - - Works with any vectors of any dimension, even unaligned (in which - case they are slower). - +/******* +Functions with SIMDLevel::NONE */ -/********************************************************* - * Reference implementations - */ - -float fvec_L1_ref(const float* x, const float* y, size_t d) { - size_t i; - float res = 0; - for (i = 0; i < d; i++) { - const float tmp = x[i] - y[i]; - res += fabs(tmp); - } - return res; -} - -float fvec_Linf_ref(const float* x, const float* y, size_t d) { - size_t i; - float res = 0; - for (i = 0; i < d; i++) { - res = fmax(res, fabs(x[i] - y[i])); - } - return res; -} - -void fvec_L2sqr_ny_ref( - float* dis, - const float* x, - const float* y, - size_t d, - size_t ny) { - for (size_t i = 0; i < ny; i++) { - dis[i] = fvec_L2sqr(x, y, d); - y += d; - } +template <> +void fvec_madd( + size_t n, + const float* a, + float bf, + const float* b, + float* c) { + for (size_t i = 0; i < n; i++) + c[i] = a[i] + bf * b[i]; } -void fvec_L2sqr_ny_y_transposed_ref( +template <> +void fvec_L2sqr_ny_transposed( float* dis, const float* x, const float* y, @@ -120,13 +62,50 @@ void fvec_L2sqr_ny_y_transposed_ref( } } -size_t fvec_L2sqr_ny_nearest_ref( +template <> +void fvec_inner_products_ny( + float* ip, + const float* x, + const float* y, + size_t d, + size_t ny) { +// BLAS slower for the use cases here +#if 0 +{ + FINTEGER di = d; + FINTEGER nyi = ny; + float one = 1.0, zero = 0.0; + FINTEGER onei = 1; + sgemv_ ("T", &di, &nyi, &one, y, &di, x, &onei, &zero, ip, &onei); +} +#endif + for (size_t i = 0; i < ny; i++) { + ip[i] = fvec_inner_product(x, y, d); + y += d; + } +} + +template <> +void fvec_L2sqr_ny( + float* dis, + const float* x, + const float* y, + size_t d, + size_t ny) { + for (size_t i = 0; i < ny; i++) { + dis[i] = fvec_L2sqr(x, y, d); + y += d; + } +} + +template <> +size_t fvec_L2sqr_ny_nearest( float* distances_tmp_buffer, const float* x, const float* y, size_t d, size_t ny) { - fvec_L2sqr_ny(distances_tmp_buffer, x, y, d, ny); + fvec_L2sqr_ny(distances_tmp_buffer, x, y, d, ny); size_t nearest_idx = 0; float min_dis = HUGE_VALF; @@ -141,7 +120,8 @@ size_t fvec_L2sqr_ny_nearest_ref( return nearest_idx; } -size_t fvec_L2sqr_ny_nearest_y_transposed_ref( +template <> +size_t fvec_L2sqr_ny_nearest_y_transposed( float* distances_tmp_buffer, const float* x, const float* y, @@ -149,7 +129,7 @@ size_t fvec_L2sqr_ny_nearest_y_transposed_ref( size_t d, size_t d_offset, size_t ny) { - fvec_L2sqr_ny_y_transposed_ref( + fvec_L2sqr_ny_transposed( distances_tmp_buffer, x, y, y_sqlen, d, d_offset, ny); size_t nearest_idx = 0; @@ -165,73 +145,54 @@ size_t fvec_L2sqr_ny_nearest_y_transposed_ref( return nearest_idx; } -void fvec_inner_products_ny_ref( - float* ip, - const float* x, - const float* y, - size_t d, - size_t ny) { - // BLAS slower for the use cases here -#if 0 - { - FINTEGER di = d; - FINTEGER nyi = ny; - float one = 1.0, zero = 0.0; - FINTEGER onei = 1; - sgemv_ ("T", &di, &nyi, &one, y, &di, x, &onei, &zero, ip, &onei); - } -#endif - for (size_t i = 0; i < ny; i++) { - ip[i] = fvec_inner_product(x, y, d); - y += d; +template <> +int fvec_madd_and_argmin( + size_t n, + const float* a, + float bf, + const float* b, + float* c) { + float vmin = 1e20; + int imin = -1; + + for (size_t i = 0; i < n; i++) { + c[i] = a[i] + bf * b[i]; + if (c[i] < vmin) { + vmin = c[i]; + imin = i; + } } + return imin; } /********************************************************* - * Autovectorized implementations + * dispatching functions */ -FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN -float fvec_inner_product(const float* x, const float* y, size_t d) { - float res = 0.F; - FAISS_PRAGMA_IMPRECISE_LOOP - for (size_t i = 0; i != d; ++i) { - res += x[i] * y[i]; - } - return res; +float fvec_L1(const float* x, const float* y, size_t d) { + DISPATCH_SIMDLevel(fvec_L1, x, y, d); } -FAISS_PRAGMA_IMPRECISE_FUNCTION_END -FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN -float fvec_norm_L2sqr(const float* x, size_t d) { - // the double in the _ref is suspected to be a typo. Some of the manual - // implementations this replaces used float. - float res = 0; - FAISS_PRAGMA_IMPRECISE_LOOP - for (size_t i = 0; i != d; ++i) { - res += x[i] * x[i]; - } +float fvec_Linf(const float* x, const float* y, size_t d) { + DISPATCH_SIMDLevel(fvec_Linf, x, y, d); +} - return res; +// dispatching functions + +float fvec_norm_L2sqr(const float* x, size_t d) { + DISPATCH_SIMDLevel(fvec_norm_L2sqr, x, d); } -FAISS_PRAGMA_IMPRECISE_FUNCTION_END -FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN float fvec_L2sqr(const float* x, const float* y, size_t d) { - size_t i; - float res = 0; - FAISS_PRAGMA_IMPRECISE_LOOP - for (i = 0; i < d; i++) { - const float tmp = x[i] - y[i]; - res += tmp * tmp; - } - return res; + DISPATCH_SIMDLevel(fvec_L2sqr, x, y, d); +} + +float fvec_inner_product(const float* x, const float* y, size_t d) { + DISPATCH_SIMDLevel(fvec_inner_product, x, y, d); } -FAISS_PRAGMA_IMPRECISE_FUNCTION_END /// Special version of inner product that computes 4 distances /// between x and yi -FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN void fvec_inner_product_batch_4( const float* __restrict x, const float* __restrict y0, @@ -243,28 +204,22 @@ void fvec_inner_product_batch_4( float& dis1, float& dis2, float& dis3) { - float d0 = 0; - float d1 = 0; - float d2 = 0; - float d3 = 0; - FAISS_PRAGMA_IMPRECISE_LOOP - for (size_t i = 0; i < d; ++i) { - d0 += x[i] * y0[i]; - d1 += x[i] * y1[i]; - d2 += x[i] * y2[i]; - d3 += x[i] * y3[i]; - } - - dis0 = d0; - dis1 = d1; - dis2 = d2; - dis3 = d3; + DISPATCH_SIMDLevel( + fvec_inner_product_batch_4, + x, + y0, + y1, + y2, + y3, + d, + dis0, + dis1, + dis2, + dis3); } -FAISS_PRAGMA_IMPRECISE_FUNCTION_END /// Special version of L2sqr that computes 4 distances /// between x and yi, which is performance oriented. -FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN void fvec_L2sqr_batch_4( const float* x, const float* y0, @@ -276,3326 +231,72 @@ void fvec_L2sqr_batch_4( float& dis1, float& dis2, float& dis3) { - float d0 = 0; - float d1 = 0; - float d2 = 0; - float d3 = 0; - FAISS_PRAGMA_IMPRECISE_LOOP - for (size_t i = 0; i < d; ++i) { - const float q0 = x[i] - y0[i]; - const float q1 = x[i] - y1[i]; - const float q2 = x[i] - y2[i]; - const float q3 = x[i] - y3[i]; - d0 += q0 * q0; - d1 += q1 * q1; - d2 += q2 * q2; - d3 += q3 * q3; - } - - dis0 = d0; - dis1 = d1; - dis2 = d2; - dis3 = d3; -} -FAISS_PRAGMA_IMPRECISE_FUNCTION_END - -/********************************************************* - * SSE and AVX implementations - */ - -#ifdef __SSE3__ - -// reads 0 <= d < 4 floats as __m128 -static inline __m128 masked_read(int d, const float* x) { - assert(0 <= d && d < 4); - ALIGNED(16) float buf[4] = {0, 0, 0, 0}; - switch (d) { - case 3: - buf[2] = x[2]; - [[fallthrough]]; - case 2: - buf[1] = x[1]; - [[fallthrough]]; - case 1: - buf[0] = x[0]; - } - return _mm_load_ps(buf); - // cannot use AVX2 _mm_mask_set1_epi32 -} - -namespace { - -/// helper function -inline float horizontal_sum(const __m128 v) { - // say, v is [x0, x1, x2, x3] - - // v0 is [x2, x3, ..., ...] - const __m128 v0 = _mm_shuffle_ps(v, v, _MM_SHUFFLE(0, 0, 3, 2)); - // v1 is [x0 + x2, x1 + x3, ..., ...] - const __m128 v1 = _mm_add_ps(v, v0); - // v2 is [x1 + x3, ..., .... ,...] - __m128 v2 = _mm_shuffle_ps(v1, v1, _MM_SHUFFLE(0, 0, 0, 1)); - // v3 is [x0 + x1 + x2 + x3, ..., ..., ...] - const __m128 v3 = _mm_add_ps(v1, v2); - // return v3[0] - return _mm_cvtss_f32(v3); -} - -#ifdef __AVX2__ -/// helper function for AVX2 -inline float horizontal_sum(const __m256 v) { - // add high and low parts - const __m128 v0 = - _mm_add_ps(_mm256_castps256_ps128(v), _mm256_extractf128_ps(v, 1)); - // perform horizontal sum on v0 - return horizontal_sum(v0); -} -#endif - -#ifdef __AVX512F__ -/// helper function for AVX512 -inline float horizontal_sum(const __m512 v) { - // performs better than adding the high and low parts - return _mm512_reduce_add_ps(v); -} -#endif - -/// Function that does a component-wise operation between x and y -/// to compute L2 distances. ElementOp can then be used in the fvec_op_ny -/// functions below -struct ElementOpL2 { - static float op(float x, float y) { - float tmp = x - y; - return tmp * tmp; - } - - static __m128 op(__m128 x, __m128 y) { - __m128 tmp = _mm_sub_ps(x, y); - return _mm_mul_ps(tmp, tmp); - } - -#ifdef __AVX2__ - static __m256 op(__m256 x, __m256 y) { - __m256 tmp = _mm256_sub_ps(x, y); - return _mm256_mul_ps(tmp, tmp); - } -#endif - -#ifdef __AVX512F__ - static __m512 op(__m512 x, __m512 y) { - __m512 tmp = _mm512_sub_ps(x, y); - return _mm512_mul_ps(tmp, tmp); - } -#endif -}; - -/// Function that does a component-wise operation between x and y -/// to compute inner products -struct ElementOpIP { - static float op(float x, float y) { - return x * y; - } - - static __m128 op(__m128 x, __m128 y) { - return _mm_mul_ps(x, y); - } - -#ifdef __AVX2__ - static __m256 op(__m256 x, __m256 y) { - return _mm256_mul_ps(x, y); - } -#endif - -#ifdef __AVX512F__ - static __m512 op(__m512 x, __m512 y) { - return _mm512_mul_ps(x, y); - } -#endif -}; - -template -void fvec_op_ny_D1(float* dis, const float* x, const float* y, size_t ny) { - float x0s = x[0]; - __m128 x0 = _mm_set_ps(x0s, x0s, x0s, x0s); - - size_t i; - for (i = 0; i + 3 < ny; i += 4) { - __m128 accu = ElementOp::op(x0, _mm_loadu_ps(y)); - y += 4; - dis[i] = _mm_cvtss_f32(accu); - __m128 tmp = _mm_shuffle_ps(accu, accu, 1); - dis[i + 1] = _mm_cvtss_f32(tmp); - tmp = _mm_shuffle_ps(accu, accu, 2); - dis[i + 2] = _mm_cvtss_f32(tmp); - tmp = _mm_shuffle_ps(accu, accu, 3); - dis[i + 3] = _mm_cvtss_f32(tmp); - } - while (i < ny) { // handle non-multiple-of-4 case - dis[i++] = ElementOp::op(x0s, *y++); - } -} - -template -void fvec_op_ny_D2(float* dis, const float* x, const float* y, size_t ny) { - __m128 x0 = _mm_set_ps(x[1], x[0], x[1], x[0]); - - size_t i; - for (i = 0; i + 1 < ny; i += 2) { - __m128 accu = ElementOp::op(x0, _mm_loadu_ps(y)); - y += 4; - accu = _mm_hadd_ps(accu, accu); - dis[i] = _mm_cvtss_f32(accu); - accu = _mm_shuffle_ps(accu, accu, 3); - dis[i + 1] = _mm_cvtss_f32(accu); - } - if (i < ny) { // handle odd case - dis[i] = ElementOp::op(x[0], y[0]) + ElementOp::op(x[1], y[1]); - } + DISPATCH_SIMDLevel( + fvec_L2sqr_batch_4, x, y0, y1, y2, y3, d, dis0, dis1, dis2, dis3); } -#if defined(__AVX512F__) - -template <> -void fvec_op_ny_D2( +void fvec_L2sqr_ny_transposed( float* dis, const float* x, const float* y, + const float* y_sqlen, + size_t d, + size_t d_offset, size_t ny) { - const size_t ny16 = ny / 16; - size_t i = 0; - - if (ny16 > 0) { - // process 16 D2-vectors per loop. - _mm_prefetch((const char*)y, _MM_HINT_T0); - _mm_prefetch((const char*)(y + 32), _MM_HINT_T0); - - const __m512 m0 = _mm512_set1_ps(x[0]); - const __m512 m1 = _mm512_set1_ps(x[1]); - - for (i = 0; i < ny16 * 16; i += 16) { - _mm_prefetch((const char*)(y + 64), _MM_HINT_T0); - - // load 16x2 matrix and transpose it in registers. - // the typical bottleneck is memory access, so - // let's trade instructions for the bandwidth. - - __m512 v0; - __m512 v1; - - transpose_16x2( - _mm512_loadu_ps(y + 0 * 16), - _mm512_loadu_ps(y + 1 * 16), - v0, - v1); - - // compute distances (dot product) - __m512 distances = _mm512_mul_ps(m0, v0); - distances = _mm512_fmadd_ps(m1, v1, distances); - - // store - _mm512_storeu_ps(dis + i, distances); - - y += 32; // move to the next set of 16x2 elements - } - } - - if (i < ny) { - // process leftovers - float x0 = x[0]; - float x1 = x[1]; - - for (; i < ny; i++) { - float distance = x0 * y[0] + x1 * y[1]; - y += 2; - dis[i] = distance; - } - } + DISPATCH_SIMDLevel( + fvec_L2sqr_ny_transposed, dis, x, y, y_sqlen, d, d_offset, ny); } -template <> -void fvec_op_ny_D2( - float* dis, +void fvec_inner_products_ny( + float* ip, /* output inner product */ const float* x, const float* y, + size_t d, size_t ny) { - const size_t ny16 = ny / 16; - size_t i = 0; - - if (ny16 > 0) { - // process 16 D2-vectors per loop. - _mm_prefetch((const char*)y, _MM_HINT_T0); - _mm_prefetch((const char*)(y + 32), _MM_HINT_T0); - - const __m512 m0 = _mm512_set1_ps(x[0]); - const __m512 m1 = _mm512_set1_ps(x[1]); - - for (i = 0; i < ny16 * 16; i += 16) { - _mm_prefetch((const char*)(y + 64), _MM_HINT_T0); - - // load 16x2 matrix and transpose it in registers. - // the typical bottleneck is memory access, so - // let's trade instructions for the bandwidth. - - __m512 v0; - __m512 v1; - - transpose_16x2( - _mm512_loadu_ps(y + 0 * 16), - _mm512_loadu_ps(y + 1 * 16), - v0, - v1); - - // compute differences - const __m512 d0 = _mm512_sub_ps(m0, v0); - const __m512 d1 = _mm512_sub_ps(m1, v1); - - // compute squares of differences - __m512 distances = _mm512_mul_ps(d0, d0); - distances = _mm512_fmadd_ps(d1, d1, distances); - - // store - _mm512_storeu_ps(dis + i, distances); - - y += 32; // move to the next set of 16x2 elements - } - } - - if (i < ny) { - // process leftovers - float x0 = x[0]; - float x1 = x[1]; - - for (; i < ny; i++) { - float sub0 = x0 - y[0]; - float sub1 = x1 - y[1]; - float distance = sub0 * sub0 + sub1 * sub1; - - y += 2; - dis[i] = distance; - } - } + DISPATCH_SIMDLevel(fvec_inner_products_ny, ip, x, y, d, ny); } -#elif defined(__AVX2__) - -template <> -void fvec_op_ny_D2( +void fvec_L2sqr_ny( float* dis, const float* x, const float* y, + size_t d, size_t ny) { - const size_t ny8 = ny / 8; - size_t i = 0; - - if (ny8 > 0) { - // process 8 D2-vectors per loop. - _mm_prefetch((const char*)y, _MM_HINT_T0); - _mm_prefetch((const char*)(y + 16), _MM_HINT_T0); - - const __m256 m0 = _mm256_set1_ps(x[0]); - const __m256 m1 = _mm256_set1_ps(x[1]); - - for (i = 0; i < ny8 * 8; i += 8) { - _mm_prefetch((const char*)(y + 32), _MM_HINT_T0); - - // load 8x2 matrix and transpose it in registers. - // the typical bottleneck is memory access, so - // let's trade instructions for the bandwidth. - - __m256 v0; - __m256 v1; - - transpose_8x2( - _mm256_loadu_ps(y + 0 * 8), - _mm256_loadu_ps(y + 1 * 8), - v0, - v1); - - // compute distances - __m256 distances = _mm256_mul_ps(m0, v0); - distances = _mm256_fmadd_ps(m1, v1, distances); - - // store - _mm256_storeu_ps(dis + i, distances); - - y += 16; - } - } - - if (i < ny) { - // process leftovers - float x0 = x[0]; - float x1 = x[1]; - - for (; i < ny; i++) { - float distance = x0 * y[0] + x1 * y[1]; - y += 2; - dis[i] = distance; - } - } + DISPATCH_SIMDLevel(fvec_L2sqr_ny, dis, x, y, d, ny); } -template <> -void fvec_op_ny_D2( - float* dis, +size_t fvec_L2sqr_ny_nearest( + float* distances_tmp_buffer, const float* x, const float* y, + size_t d, size_t ny) { - const size_t ny8 = ny / 8; - size_t i = 0; - - if (ny8 > 0) { - // process 8 D2-vectors per loop. - _mm_prefetch((const char*)y, _MM_HINT_T0); - _mm_prefetch((const char*)(y + 16), _MM_HINT_T0); - - const __m256 m0 = _mm256_set1_ps(x[0]); - const __m256 m1 = _mm256_set1_ps(x[1]); - - for (i = 0; i < ny8 * 8; i += 8) { - _mm_prefetch((const char*)(y + 32), _MM_HINT_T0); - - // load 8x2 matrix and transpose it in registers. - // the typical bottleneck is memory access, so - // let's trade instructions for the bandwidth. - - __m256 v0; - __m256 v1; - - transpose_8x2( - _mm256_loadu_ps(y + 0 * 8), - _mm256_loadu_ps(y + 1 * 8), - v0, - v1); + DISPATCH_SIMDLevel( + fvec_L2sqr_ny_nearest, distances_tmp_buffer, x, y, d, ny); +} - // compute differences - const __m256 d0 = _mm256_sub_ps(m0, v0); - const __m256 d1 = _mm256_sub_ps(m1, v1); +size_t fvec_L2sqr_ny_nearest_y_transposed( + float* distances_tmp_buffer, + const float* x, + const float* y, + const float* y_sqlen, + size_t d, + size_t d_offset, + size_t ny) { + DISPATCH_SIMDLevel( + fvec_L2sqr_ny_nearest_y_transposed, + distances_tmp_buffer, + x, + y, + y_sqlen, + d, + d_offset, + ny); +} - // compute squares of differences - __m256 distances = _mm256_mul_ps(d0, d0); - distances = _mm256_fmadd_ps(d1, d1, distances); - - // store - _mm256_storeu_ps(dis + i, distances); - - y += 16; - } - } - - if (i < ny) { - // process leftovers - float x0 = x[0]; - float x1 = x[1]; - - for (; i < ny; i++) { - float sub0 = x0 - y[0]; - float sub1 = x1 - y[1]; - float distance = sub0 * sub0 + sub1 * sub1; - - y += 2; - dis[i] = distance; - } - } -} - -#endif - -template -void fvec_op_ny_D4(float* dis, const float* x, const float* y, size_t ny) { - __m128 x0 = _mm_loadu_ps(x); - - for (size_t i = 0; i < ny; i++) { - __m128 accu = ElementOp::op(x0, _mm_loadu_ps(y)); - y += 4; - dis[i] = horizontal_sum(accu); - } -} - -#if defined(__AVX512F__) - -template <> -void fvec_op_ny_D4( - float* dis, - const float* x, - const float* y, - size_t ny) { - const size_t ny16 = ny / 16; - size_t i = 0; - - if (ny16 > 0) { - // process 16 D4-vectors per loop. - const __m512 m0 = _mm512_set1_ps(x[0]); - const __m512 m1 = _mm512_set1_ps(x[1]); - const __m512 m2 = _mm512_set1_ps(x[2]); - const __m512 m3 = _mm512_set1_ps(x[3]); - - for (i = 0; i < ny16 * 16; i += 16) { - // load 16x4 matrix and transpose it in registers. - // the typical bottleneck is memory access, so - // let's trade instructions for the bandwidth. - - __m512 v0; - __m512 v1; - __m512 v2; - __m512 v3; - - transpose_16x4( - _mm512_loadu_ps(y + 0 * 16), - _mm512_loadu_ps(y + 1 * 16), - _mm512_loadu_ps(y + 2 * 16), - _mm512_loadu_ps(y + 3 * 16), - v0, - v1, - v2, - v3); - - // compute distances - __m512 distances = _mm512_mul_ps(m0, v0); - distances = _mm512_fmadd_ps(m1, v1, distances); - distances = _mm512_fmadd_ps(m2, v2, distances); - distances = _mm512_fmadd_ps(m3, v3, distances); - - // store - _mm512_storeu_ps(dis + i, distances); - - y += 64; // move to the next set of 16x4 elements - } - } - - if (i < ny) { - // process leftovers - __m128 x0 = _mm_loadu_ps(x); - - for (; i < ny; i++) { - __m128 accu = ElementOpIP::op(x0, _mm_loadu_ps(y)); - y += 4; - dis[i] = horizontal_sum(accu); - } - } -} - -template <> -void fvec_op_ny_D4( - float* dis, - const float* x, - const float* y, - size_t ny) { - const size_t ny16 = ny / 16; - size_t i = 0; - - if (ny16 > 0) { - // process 16 D4-vectors per loop. - const __m512 m0 = _mm512_set1_ps(x[0]); - const __m512 m1 = _mm512_set1_ps(x[1]); - const __m512 m2 = _mm512_set1_ps(x[2]); - const __m512 m3 = _mm512_set1_ps(x[3]); - - for (i = 0; i < ny16 * 16; i += 16) { - // load 16x4 matrix and transpose it in registers. - // the typical bottleneck is memory access, so - // let's trade instructions for the bandwidth. - - __m512 v0; - __m512 v1; - __m512 v2; - __m512 v3; - - transpose_16x4( - _mm512_loadu_ps(y + 0 * 16), - _mm512_loadu_ps(y + 1 * 16), - _mm512_loadu_ps(y + 2 * 16), - _mm512_loadu_ps(y + 3 * 16), - v0, - v1, - v2, - v3); - - // compute differences - const __m512 d0 = _mm512_sub_ps(m0, v0); - const __m512 d1 = _mm512_sub_ps(m1, v1); - const __m512 d2 = _mm512_sub_ps(m2, v2); - const __m512 d3 = _mm512_sub_ps(m3, v3); - - // compute squares of differences - __m512 distances = _mm512_mul_ps(d0, d0); - distances = _mm512_fmadd_ps(d1, d1, distances); - distances = _mm512_fmadd_ps(d2, d2, distances); - distances = _mm512_fmadd_ps(d3, d3, distances); - - // store - _mm512_storeu_ps(dis + i, distances); - - y += 64; // move to the next set of 16x4 elements - } - } - - if (i < ny) { - // process leftovers - __m128 x0 = _mm_loadu_ps(x); - - for (; i < ny; i++) { - __m128 accu = ElementOpL2::op(x0, _mm_loadu_ps(y)); - y += 4; - dis[i] = horizontal_sum(accu); - } - } -} - -#elif defined(__AVX2__) - -template <> -void fvec_op_ny_D4( - float* dis, - const float* x, - const float* y, - size_t ny) { - const size_t ny8 = ny / 8; - size_t i = 0; - - if (ny8 > 0) { - // process 8 D4-vectors per loop. - const __m256 m0 = _mm256_set1_ps(x[0]); - const __m256 m1 = _mm256_set1_ps(x[1]); - const __m256 m2 = _mm256_set1_ps(x[2]); - const __m256 m3 = _mm256_set1_ps(x[3]); - - for (i = 0; i < ny8 * 8; i += 8) { - // load 8x4 matrix and transpose it in registers. - // the typical bottleneck is memory access, so - // let's trade instructions for the bandwidth. - - __m256 v0; - __m256 v1; - __m256 v2; - __m256 v3; - - transpose_8x4( - _mm256_loadu_ps(y + 0 * 8), - _mm256_loadu_ps(y + 1 * 8), - _mm256_loadu_ps(y + 2 * 8), - _mm256_loadu_ps(y + 3 * 8), - v0, - v1, - v2, - v3); - - // compute distances - __m256 distances = _mm256_mul_ps(m0, v0); - distances = _mm256_fmadd_ps(m1, v1, distances); - distances = _mm256_fmadd_ps(m2, v2, distances); - distances = _mm256_fmadd_ps(m3, v3, distances); - - // store - _mm256_storeu_ps(dis + i, distances); - - y += 32; - } - } - - if (i < ny) { - // process leftovers - __m128 x0 = _mm_loadu_ps(x); - - for (; i < ny; i++) { - __m128 accu = ElementOpIP::op(x0, _mm_loadu_ps(y)); - y += 4; - dis[i] = horizontal_sum(accu); - } - } -} - -template <> -void fvec_op_ny_D4( - float* dis, - const float* x, - const float* y, - size_t ny) { - const size_t ny8 = ny / 8; - size_t i = 0; - - if (ny8 > 0) { - // process 8 D4-vectors per loop. - const __m256 m0 = _mm256_set1_ps(x[0]); - const __m256 m1 = _mm256_set1_ps(x[1]); - const __m256 m2 = _mm256_set1_ps(x[2]); - const __m256 m3 = _mm256_set1_ps(x[3]); - - for (i = 0; i < ny8 * 8; i += 8) { - // load 8x4 matrix and transpose it in registers. - // the typical bottleneck is memory access, so - // let's trade instructions for the bandwidth. - - __m256 v0; - __m256 v1; - __m256 v2; - __m256 v3; - - transpose_8x4( - _mm256_loadu_ps(y + 0 * 8), - _mm256_loadu_ps(y + 1 * 8), - _mm256_loadu_ps(y + 2 * 8), - _mm256_loadu_ps(y + 3 * 8), - v0, - v1, - v2, - v3); - - // compute differences - const __m256 d0 = _mm256_sub_ps(m0, v0); - const __m256 d1 = _mm256_sub_ps(m1, v1); - const __m256 d2 = _mm256_sub_ps(m2, v2); - const __m256 d3 = _mm256_sub_ps(m3, v3); - - // compute squares of differences - __m256 distances = _mm256_mul_ps(d0, d0); - distances = _mm256_fmadd_ps(d1, d1, distances); - distances = _mm256_fmadd_ps(d2, d2, distances); - distances = _mm256_fmadd_ps(d3, d3, distances); - - // store - _mm256_storeu_ps(dis + i, distances); - - y += 32; - } - } - - if (i < ny) { - // process leftovers - __m128 x0 = _mm_loadu_ps(x); - - for (; i < ny; i++) { - __m128 accu = ElementOpL2::op(x0, _mm_loadu_ps(y)); - y += 4; - dis[i] = horizontal_sum(accu); - } - } -} - -#endif - -template -void fvec_op_ny_D8(float* dis, const float* x, const float* y, size_t ny) { - __m128 x0 = _mm_loadu_ps(x); - __m128 x1 = _mm_loadu_ps(x + 4); - - for (size_t i = 0; i < ny; i++) { - __m128 accu = ElementOp::op(x0, _mm_loadu_ps(y)); - y += 4; - accu = _mm_add_ps(accu, ElementOp::op(x1, _mm_loadu_ps(y))); - y += 4; - accu = _mm_hadd_ps(accu, accu); - accu = _mm_hadd_ps(accu, accu); - dis[i] = _mm_cvtss_f32(accu); - } -} - -#if defined(__AVX512F__) - -template <> -void fvec_op_ny_D8( - float* dis, - const float* x, - const float* y, - size_t ny) { - const size_t ny16 = ny / 16; - size_t i = 0; - - if (ny16 > 0) { - // process 16 D16-vectors per loop. - const __m512 m0 = _mm512_set1_ps(x[0]); - const __m512 m1 = _mm512_set1_ps(x[1]); - const __m512 m2 = _mm512_set1_ps(x[2]); - const __m512 m3 = _mm512_set1_ps(x[3]); - const __m512 m4 = _mm512_set1_ps(x[4]); - const __m512 m5 = _mm512_set1_ps(x[5]); - const __m512 m6 = _mm512_set1_ps(x[6]); - const __m512 m7 = _mm512_set1_ps(x[7]); - - for (i = 0; i < ny16 * 16; i += 16) { - // load 16x8 matrix and transpose it in registers. - // the typical bottleneck is memory access, so - // let's trade instructions for the bandwidth. - - __m512 v0; - __m512 v1; - __m512 v2; - __m512 v3; - __m512 v4; - __m512 v5; - __m512 v6; - __m512 v7; - - transpose_16x8( - _mm512_loadu_ps(y + 0 * 16), - _mm512_loadu_ps(y + 1 * 16), - _mm512_loadu_ps(y + 2 * 16), - _mm512_loadu_ps(y + 3 * 16), - _mm512_loadu_ps(y + 4 * 16), - _mm512_loadu_ps(y + 5 * 16), - _mm512_loadu_ps(y + 6 * 16), - _mm512_loadu_ps(y + 7 * 16), - v0, - v1, - v2, - v3, - v4, - v5, - v6, - v7); - - // compute distances - __m512 distances = _mm512_mul_ps(m0, v0); - distances = _mm512_fmadd_ps(m1, v1, distances); - distances = _mm512_fmadd_ps(m2, v2, distances); - distances = _mm512_fmadd_ps(m3, v3, distances); - distances = _mm512_fmadd_ps(m4, v4, distances); - distances = _mm512_fmadd_ps(m5, v5, distances); - distances = _mm512_fmadd_ps(m6, v6, distances); - distances = _mm512_fmadd_ps(m7, v7, distances); - - // store - _mm512_storeu_ps(dis + i, distances); - - y += 128; // 16 floats * 8 rows - } - } - - if (i < ny) { - // process leftovers - __m256 x0 = _mm256_loadu_ps(x); - - for (; i < ny; i++) { - __m256 accu = ElementOpIP::op(x0, _mm256_loadu_ps(y)); - y += 8; - dis[i] = horizontal_sum(accu); - } - } -} - -template <> -void fvec_op_ny_D8( - float* dis, - const float* x, - const float* y, - size_t ny) { - const size_t ny16 = ny / 16; - size_t i = 0; - - if (ny16 > 0) { - // process 16 D16-vectors per loop. - const __m512 m0 = _mm512_set1_ps(x[0]); - const __m512 m1 = _mm512_set1_ps(x[1]); - const __m512 m2 = _mm512_set1_ps(x[2]); - const __m512 m3 = _mm512_set1_ps(x[3]); - const __m512 m4 = _mm512_set1_ps(x[4]); - const __m512 m5 = _mm512_set1_ps(x[5]); - const __m512 m6 = _mm512_set1_ps(x[6]); - const __m512 m7 = _mm512_set1_ps(x[7]); - - for (i = 0; i < ny16 * 16; i += 16) { - // load 16x8 matrix and transpose it in registers. - // the typical bottleneck is memory access, so - // let's trade instructions for the bandwidth. - - __m512 v0; - __m512 v1; - __m512 v2; - __m512 v3; - __m512 v4; - __m512 v5; - __m512 v6; - __m512 v7; - - transpose_16x8( - _mm512_loadu_ps(y + 0 * 16), - _mm512_loadu_ps(y + 1 * 16), - _mm512_loadu_ps(y + 2 * 16), - _mm512_loadu_ps(y + 3 * 16), - _mm512_loadu_ps(y + 4 * 16), - _mm512_loadu_ps(y + 5 * 16), - _mm512_loadu_ps(y + 6 * 16), - _mm512_loadu_ps(y + 7 * 16), - v0, - v1, - v2, - v3, - v4, - v5, - v6, - v7); - - // compute differences - const __m512 d0 = _mm512_sub_ps(m0, v0); - const __m512 d1 = _mm512_sub_ps(m1, v1); - const __m512 d2 = _mm512_sub_ps(m2, v2); - const __m512 d3 = _mm512_sub_ps(m3, v3); - const __m512 d4 = _mm512_sub_ps(m4, v4); - const __m512 d5 = _mm512_sub_ps(m5, v5); - const __m512 d6 = _mm512_sub_ps(m6, v6); - const __m512 d7 = _mm512_sub_ps(m7, v7); - - // compute squares of differences - __m512 distances = _mm512_mul_ps(d0, d0); - distances = _mm512_fmadd_ps(d1, d1, distances); - distances = _mm512_fmadd_ps(d2, d2, distances); - distances = _mm512_fmadd_ps(d3, d3, distances); - distances = _mm512_fmadd_ps(d4, d4, distances); - distances = _mm512_fmadd_ps(d5, d5, distances); - distances = _mm512_fmadd_ps(d6, d6, distances); - distances = _mm512_fmadd_ps(d7, d7, distances); - - // store - _mm512_storeu_ps(dis + i, distances); - - y += 128; // 16 floats * 8 rows - } - } - - if (i < ny) { - // process leftovers - __m256 x0 = _mm256_loadu_ps(x); - - for (; i < ny; i++) { - __m256 accu = ElementOpL2::op(x0, _mm256_loadu_ps(y)); - y += 8; - dis[i] = horizontal_sum(accu); - } - } -} - -#elif defined(__AVX2__) - -template <> -void fvec_op_ny_D8( - float* dis, - const float* x, - const float* y, - size_t ny) { - const size_t ny8 = ny / 8; - size_t i = 0; - - if (ny8 > 0) { - // process 8 D8-vectors per loop. - const __m256 m0 = _mm256_set1_ps(x[0]); - const __m256 m1 = _mm256_set1_ps(x[1]); - const __m256 m2 = _mm256_set1_ps(x[2]); - const __m256 m3 = _mm256_set1_ps(x[3]); - const __m256 m4 = _mm256_set1_ps(x[4]); - const __m256 m5 = _mm256_set1_ps(x[5]); - const __m256 m6 = _mm256_set1_ps(x[6]); - const __m256 m7 = _mm256_set1_ps(x[7]); - - for (i = 0; i < ny8 * 8; i += 8) { - // load 8x8 matrix and transpose it in registers. - // the typical bottleneck is memory access, so - // let's trade instructions for the bandwidth. - - __m256 v0; - __m256 v1; - __m256 v2; - __m256 v3; - __m256 v4; - __m256 v5; - __m256 v6; - __m256 v7; - - transpose_8x8( - _mm256_loadu_ps(y + 0 * 8), - _mm256_loadu_ps(y + 1 * 8), - _mm256_loadu_ps(y + 2 * 8), - _mm256_loadu_ps(y + 3 * 8), - _mm256_loadu_ps(y + 4 * 8), - _mm256_loadu_ps(y + 5 * 8), - _mm256_loadu_ps(y + 6 * 8), - _mm256_loadu_ps(y + 7 * 8), - v0, - v1, - v2, - v3, - v4, - v5, - v6, - v7); - - // compute distances - __m256 distances = _mm256_mul_ps(m0, v0); - distances = _mm256_fmadd_ps(m1, v1, distances); - distances = _mm256_fmadd_ps(m2, v2, distances); - distances = _mm256_fmadd_ps(m3, v3, distances); - distances = _mm256_fmadd_ps(m4, v4, distances); - distances = _mm256_fmadd_ps(m5, v5, distances); - distances = _mm256_fmadd_ps(m6, v6, distances); - distances = _mm256_fmadd_ps(m7, v7, distances); - - // store - _mm256_storeu_ps(dis + i, distances); - - y += 64; - } - } - - if (i < ny) { - // process leftovers - __m256 x0 = _mm256_loadu_ps(x); - - for (; i < ny; i++) { - __m256 accu = ElementOpIP::op(x0, _mm256_loadu_ps(y)); - y += 8; - dis[i] = horizontal_sum(accu); - } - } -} - -template <> -void fvec_op_ny_D8( - float* dis, - const float* x, - const float* y, - size_t ny) { - const size_t ny8 = ny / 8; - size_t i = 0; - - if (ny8 > 0) { - // process 8 D8-vectors per loop. - const __m256 m0 = _mm256_set1_ps(x[0]); - const __m256 m1 = _mm256_set1_ps(x[1]); - const __m256 m2 = _mm256_set1_ps(x[2]); - const __m256 m3 = _mm256_set1_ps(x[3]); - const __m256 m4 = _mm256_set1_ps(x[4]); - const __m256 m5 = _mm256_set1_ps(x[5]); - const __m256 m6 = _mm256_set1_ps(x[6]); - const __m256 m7 = _mm256_set1_ps(x[7]); - - for (i = 0; i < ny8 * 8; i += 8) { - // load 8x8 matrix and transpose it in registers. - // the typical bottleneck is memory access, so - // let's trade instructions for the bandwidth. - - __m256 v0; - __m256 v1; - __m256 v2; - __m256 v3; - __m256 v4; - __m256 v5; - __m256 v6; - __m256 v7; - - transpose_8x8( - _mm256_loadu_ps(y + 0 * 8), - _mm256_loadu_ps(y + 1 * 8), - _mm256_loadu_ps(y + 2 * 8), - _mm256_loadu_ps(y + 3 * 8), - _mm256_loadu_ps(y + 4 * 8), - _mm256_loadu_ps(y + 5 * 8), - _mm256_loadu_ps(y + 6 * 8), - _mm256_loadu_ps(y + 7 * 8), - v0, - v1, - v2, - v3, - v4, - v5, - v6, - v7); - - // compute differences - const __m256 d0 = _mm256_sub_ps(m0, v0); - const __m256 d1 = _mm256_sub_ps(m1, v1); - const __m256 d2 = _mm256_sub_ps(m2, v2); - const __m256 d3 = _mm256_sub_ps(m3, v3); - const __m256 d4 = _mm256_sub_ps(m4, v4); - const __m256 d5 = _mm256_sub_ps(m5, v5); - const __m256 d6 = _mm256_sub_ps(m6, v6); - const __m256 d7 = _mm256_sub_ps(m7, v7); - - // compute squares of differences - __m256 distances = _mm256_mul_ps(d0, d0); - distances = _mm256_fmadd_ps(d1, d1, distances); - distances = _mm256_fmadd_ps(d2, d2, distances); - distances = _mm256_fmadd_ps(d3, d3, distances); - distances = _mm256_fmadd_ps(d4, d4, distances); - distances = _mm256_fmadd_ps(d5, d5, distances); - distances = _mm256_fmadd_ps(d6, d6, distances); - distances = _mm256_fmadd_ps(d7, d7, distances); - - // store - _mm256_storeu_ps(dis + i, distances); - - y += 64; - } - } - - if (i < ny) { - // process leftovers - __m256 x0 = _mm256_loadu_ps(x); - - for (; i < ny; i++) { - __m256 accu = ElementOpL2::op(x0, _mm256_loadu_ps(y)); - y += 8; - dis[i] = horizontal_sum(accu); - } - } -} - -#endif - -template -void fvec_op_ny_D12(float* dis, const float* x, const float* y, size_t ny) { - __m128 x0 = _mm_loadu_ps(x); - __m128 x1 = _mm_loadu_ps(x + 4); - __m128 x2 = _mm_loadu_ps(x + 8); - - for (size_t i = 0; i < ny; i++) { - __m128 accu = ElementOp::op(x0, _mm_loadu_ps(y)); - y += 4; - accu = _mm_add_ps(accu, ElementOp::op(x1, _mm_loadu_ps(y))); - y += 4; - accu = _mm_add_ps(accu, ElementOp::op(x2, _mm_loadu_ps(y))); - y += 4; - dis[i] = horizontal_sum(accu); - } -} - -} // anonymous namespace - -void fvec_L2sqr_ny( - float* dis, - const float* x, - const float* y, - size_t d, - size_t ny) { - // optimized for a few special cases - -#define DISPATCH(dval) \ - case dval: \ - fvec_op_ny_D##dval(dis, x, y, ny); \ - return; - - switch (d) { - DISPATCH(1) - DISPATCH(2) - DISPATCH(4) - DISPATCH(8) - DISPATCH(12) - default: - fvec_L2sqr_ny_ref(dis, x, y, d, ny); - return; - } -#undef DISPATCH -} - -void fvec_inner_products_ny( - float* dis, - const float* x, - const float* y, - size_t d, - size_t ny) { -#define DISPATCH(dval) \ - case dval: \ - fvec_op_ny_D##dval(dis, x, y, ny); \ - return; - - switch (d) { - DISPATCH(1) - DISPATCH(2) - DISPATCH(4) - DISPATCH(8) - DISPATCH(12) - default: - fvec_inner_products_ny_ref(dis, x, y, d, ny); - return; - } -#undef DISPATCH -} - -#if defined(__AVX512F__) - -template -void fvec_L2sqr_ny_y_transposed_D( - float* distances, - const float* x, - const float* y, - const float* y_sqlen, - const size_t d_offset, - size_t ny) { - // current index being processed - size_t i = 0; - - // squared length of x - float x_sqlen = 0; - for (size_t j = 0; j < DIM; j++) { - x_sqlen += x[j] * x[j]; - } - - // process 16 vectors per loop - const size_t ny16 = ny / 16; - - if (ny16 > 0) { - // m[i] = (2 * x[i], ... 2 * x[i]) - __m512 m[DIM]; - for (size_t j = 0; j < DIM; j++) { - m[j] = _mm512_set1_ps(x[j]); - m[j] = _mm512_add_ps(m[j], m[j]); // m[j] = 2 * x[j] - } - - __m512 x_sqlen_ymm = _mm512_set1_ps(x_sqlen); - - for (; i < ny16 * 16; i += 16) { - // Load vectors for 16 dimensions - __m512 v[DIM]; - for (size_t j = 0; j < DIM; j++) { - v[j] = _mm512_loadu_ps(y + j * d_offset); - } - - // Compute dot products - __m512 dp = _mm512_fnmadd_ps(m[0], v[0], x_sqlen_ymm); - for (size_t j = 1; j < DIM; j++) { - dp = _mm512_fnmadd_ps(m[j], v[j], dp); - } - - // Compute y^2 - (2 * x, y) + x^2 - __m512 distances_v = _mm512_add_ps(_mm512_loadu_ps(y_sqlen), dp); - - _mm512_storeu_ps(distances + i, distances_v); - - // Scroll y and y_sqlen forward - y += 16; - y_sqlen += 16; - } - } - - if (i < ny) { - // Process leftovers - for (; i < ny; i++) { - float dp = 0; - for (size_t j = 0; j < DIM; j++) { - dp += x[j] * y[j * d_offset]; - } - - // Compute y^2 - 2 * (x, y), which is sufficient for looking for the - // lowest distance. - const float distance = y_sqlen[0] - 2 * dp + x_sqlen; - distances[i] = distance; - - y += 1; - y_sqlen += 1; - } - } -} - -#elif defined(__AVX2__) - -template -void fvec_L2sqr_ny_y_transposed_D( - float* distances, - const float* x, - const float* y, - const float* y_sqlen, - const size_t d_offset, - size_t ny) { - // current index being processed - size_t i = 0; - - // squared length of x - float x_sqlen = 0; - for (size_t j = 0; j < DIM; j++) { - x_sqlen += x[j] * x[j]; - } - - // process 8 vectors per loop. - const size_t ny8 = ny / 8; - - if (ny8 > 0) { - // m[i] = (2 * x[i], ... 2 * x[i]) - __m256 m[DIM]; - for (size_t j = 0; j < DIM; j++) { - m[j] = _mm256_set1_ps(x[j]); - m[j] = _mm256_add_ps(m[j], m[j]); - } - - __m256 x_sqlen_ymm = _mm256_set1_ps(x_sqlen); - - for (; i < ny8 * 8; i += 8) { - // collect dim 0 for 8 D4-vectors. - const __m256 v0 = _mm256_loadu_ps(y + 0 * d_offset); - - // compute dot products - // this is x^2 - 2x[0]*y[0] - __m256 dp = _mm256_fnmadd_ps(m[0], v0, x_sqlen_ymm); - - for (size_t j = 1; j < DIM; j++) { - // collect dim j for 8 D4-vectors. - const __m256 vj = _mm256_loadu_ps(y + j * d_offset); - dp = _mm256_fnmadd_ps(m[j], vj, dp); - } - - // we've got x^2 - (2x, y) at this point - - // y^2 - (2x, y) + x^2 - __m256 distances_v = _mm256_add_ps(_mm256_loadu_ps(y_sqlen), dp); - - _mm256_storeu_ps(distances + i, distances_v); - - // scroll y and y_sqlen forward. - y += 8; - y_sqlen += 8; - } - } - - if (i < ny) { - // process leftovers - for (; i < ny; i++) { - float dp = 0; - for (size_t j = 0; j < DIM; j++) { - dp += x[j] * y[j * d_offset]; - } - - // compute y^2 - 2 * (x, y), which is sufficient for looking for the - // lowest distance. - const float distance = y_sqlen[0] - 2 * dp + x_sqlen; - distances[i] = distance; - - y += 1; - y_sqlen += 1; - } - } -} - -#endif - -void fvec_L2sqr_ny_transposed( - float* dis, - const float* x, - const float* y, - const float* y_sqlen, - size_t d, - size_t d_offset, - size_t ny) { - // optimized for a few special cases - -#ifdef __AVX2__ -#define DISPATCH(dval) \ - case dval: \ - return fvec_L2sqr_ny_y_transposed_D( \ - dis, x, y, y_sqlen, d_offset, ny); - - switch (d) { - DISPATCH(1) - DISPATCH(2) - DISPATCH(4) - DISPATCH(8) - default: - return fvec_L2sqr_ny_y_transposed_ref( - dis, x, y, y_sqlen, d, d_offset, ny); - } -#undef DISPATCH -#else - // non-AVX2 case - return fvec_L2sqr_ny_y_transposed_ref(dis, x, y, y_sqlen, d, d_offset, ny); -#endif -} - -#if defined(__AVX512F__) - -size_t fvec_L2sqr_ny_nearest_D2( - float* distances_tmp_buffer, - const float* x, - const float* y, - size_t ny) { - // this implementation does not use distances_tmp_buffer. - - size_t i = 0; - float current_min_distance = HUGE_VALF; - size_t current_min_index = 0; - - const size_t ny16 = ny / 16; - if (ny16 > 0) { - _mm_prefetch((const char*)y, _MM_HINT_T0); - _mm_prefetch((const char*)(y + 32), _MM_HINT_T0); - - __m512 min_distances = _mm512_set1_ps(HUGE_VALF); - __m512i min_indices = _mm512_set1_epi32(0); - - __m512i current_indices = _mm512_setr_epi32( - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); - const __m512i indices_increment = _mm512_set1_epi32(16); - - const __m512 m0 = _mm512_set1_ps(x[0]); - const __m512 m1 = _mm512_set1_ps(x[1]); - - for (; i < ny16 * 16; i += 16) { - _mm_prefetch((const char*)(y + 64), _MM_HINT_T0); - - __m512 v0; - __m512 v1; - - transpose_16x2( - _mm512_loadu_ps(y + 0 * 16), - _mm512_loadu_ps(y + 1 * 16), - v0, - v1); - - const __m512 d0 = _mm512_sub_ps(m0, v0); - const __m512 d1 = _mm512_sub_ps(m1, v1); - - __m512 distances = _mm512_mul_ps(d0, d0); - distances = _mm512_fmadd_ps(d1, d1, distances); - - __mmask16 comparison = - _mm512_cmp_ps_mask(distances, min_distances, _CMP_LT_OS); - - min_distances = _mm512_min_ps(distances, min_distances); - min_indices = _mm512_mask_blend_epi32( - comparison, min_indices, current_indices); - - current_indices = - _mm512_add_epi32(current_indices, indices_increment); - - y += 32; - } - - alignas(64) float min_distances_scalar[16]; - alignas(64) uint32_t min_indices_scalar[16]; - _mm512_store_ps(min_distances_scalar, min_distances); - _mm512_store_epi32(min_indices_scalar, min_indices); - - for (size_t j = 0; j < 16; j++) { - if (current_min_distance > min_distances_scalar[j]) { - current_min_distance = min_distances_scalar[j]; - current_min_index = min_indices_scalar[j]; - } - } - } - - if (i < ny) { - float x0 = x[0]; - float x1 = x[1]; - - for (; i < ny; i++) { - float sub0 = x0 - y[0]; - float sub1 = x1 - y[1]; - float distance = sub0 * sub0 + sub1 * sub1; - - y += 2; - - if (current_min_distance > distance) { - current_min_distance = distance; - current_min_index = i; - } - } - } - - return current_min_index; -} - -size_t fvec_L2sqr_ny_nearest_D4( - float* distances_tmp_buffer, - const float* x, - const float* y, - size_t ny) { - // this implementation does not use distances_tmp_buffer. - - size_t i = 0; - float current_min_distance = HUGE_VALF; - size_t current_min_index = 0; - - const size_t ny16 = ny / 16; - - if (ny16 > 0) { - __m512 min_distances = _mm512_set1_ps(HUGE_VALF); - __m512i min_indices = _mm512_set1_epi32(0); - - __m512i current_indices = _mm512_setr_epi32( - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); - const __m512i indices_increment = _mm512_set1_epi32(16); - - const __m512 m0 = _mm512_set1_ps(x[0]); - const __m512 m1 = _mm512_set1_ps(x[1]); - const __m512 m2 = _mm512_set1_ps(x[2]); - const __m512 m3 = _mm512_set1_ps(x[3]); - - for (; i < ny16 * 16; i += 16) { - __m512 v0; - __m512 v1; - __m512 v2; - __m512 v3; - - transpose_16x4( - _mm512_loadu_ps(y + 0 * 16), - _mm512_loadu_ps(y + 1 * 16), - _mm512_loadu_ps(y + 2 * 16), - _mm512_loadu_ps(y + 3 * 16), - v0, - v1, - v2, - v3); - - const __m512 d0 = _mm512_sub_ps(m0, v0); - const __m512 d1 = _mm512_sub_ps(m1, v1); - const __m512 d2 = _mm512_sub_ps(m2, v2); - const __m512 d3 = _mm512_sub_ps(m3, v3); - - __m512 distances = _mm512_mul_ps(d0, d0); - distances = _mm512_fmadd_ps(d1, d1, distances); - distances = _mm512_fmadd_ps(d2, d2, distances); - distances = _mm512_fmadd_ps(d3, d3, distances); - - __mmask16 comparison = - _mm512_cmp_ps_mask(distances, min_distances, _CMP_LT_OS); - - min_distances = _mm512_min_ps(distances, min_distances); - min_indices = _mm512_mask_blend_epi32( - comparison, min_indices, current_indices); - - current_indices = - _mm512_add_epi32(current_indices, indices_increment); - - y += 64; - } - - alignas(64) float min_distances_scalar[16]; - alignas(64) uint32_t min_indices_scalar[16]; - _mm512_store_ps(min_distances_scalar, min_distances); - _mm512_store_epi32(min_indices_scalar, min_indices); - - for (size_t j = 0; j < 16; j++) { - if (current_min_distance > min_distances_scalar[j]) { - current_min_distance = min_distances_scalar[j]; - current_min_index = min_indices_scalar[j]; - } - } - } - - if (i < ny) { - __m128 x0 = _mm_loadu_ps(x); - - for (; i < ny; i++) { - __m128 accu = ElementOpL2::op(x0, _mm_loadu_ps(y)); - y += 4; - const float distance = horizontal_sum(accu); - - if (current_min_distance > distance) { - current_min_distance = distance; - current_min_index = i; - } - } - } - - return current_min_index; -} - -size_t fvec_L2sqr_ny_nearest_D8( - float* distances_tmp_buffer, - const float* x, - const float* y, - size_t ny) { - // this implementation does not use distances_tmp_buffer. - - size_t i = 0; - float current_min_distance = HUGE_VALF; - size_t current_min_index = 0; - - const size_t ny16 = ny / 16; - if (ny16 > 0) { - __m512 min_distances = _mm512_set1_ps(HUGE_VALF); - __m512i min_indices = _mm512_set1_epi32(0); - - __m512i current_indices = _mm512_setr_epi32( - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); - const __m512i indices_increment = _mm512_set1_epi32(16); - - const __m512 m0 = _mm512_set1_ps(x[0]); - const __m512 m1 = _mm512_set1_ps(x[1]); - const __m512 m2 = _mm512_set1_ps(x[2]); - const __m512 m3 = _mm512_set1_ps(x[3]); - - const __m512 m4 = _mm512_set1_ps(x[4]); - const __m512 m5 = _mm512_set1_ps(x[5]); - const __m512 m6 = _mm512_set1_ps(x[6]); - const __m512 m7 = _mm512_set1_ps(x[7]); - - for (; i < ny16 * 16; i += 16) { - __m512 v0; - __m512 v1; - __m512 v2; - __m512 v3; - __m512 v4; - __m512 v5; - __m512 v6; - __m512 v7; - - transpose_16x8( - _mm512_loadu_ps(y + 0 * 16), - _mm512_loadu_ps(y + 1 * 16), - _mm512_loadu_ps(y + 2 * 16), - _mm512_loadu_ps(y + 3 * 16), - _mm512_loadu_ps(y + 4 * 16), - _mm512_loadu_ps(y + 5 * 16), - _mm512_loadu_ps(y + 6 * 16), - _mm512_loadu_ps(y + 7 * 16), - v0, - v1, - v2, - v3, - v4, - v5, - v6, - v7); - - const __m512 d0 = _mm512_sub_ps(m0, v0); - const __m512 d1 = _mm512_sub_ps(m1, v1); - const __m512 d2 = _mm512_sub_ps(m2, v2); - const __m512 d3 = _mm512_sub_ps(m3, v3); - const __m512 d4 = _mm512_sub_ps(m4, v4); - const __m512 d5 = _mm512_sub_ps(m5, v5); - const __m512 d6 = _mm512_sub_ps(m6, v6); - const __m512 d7 = _mm512_sub_ps(m7, v7); - - __m512 distances = _mm512_mul_ps(d0, d0); - distances = _mm512_fmadd_ps(d1, d1, distances); - distances = _mm512_fmadd_ps(d2, d2, distances); - distances = _mm512_fmadd_ps(d3, d3, distances); - distances = _mm512_fmadd_ps(d4, d4, distances); - distances = _mm512_fmadd_ps(d5, d5, distances); - distances = _mm512_fmadd_ps(d6, d6, distances); - distances = _mm512_fmadd_ps(d7, d7, distances); - - __mmask16 comparison = - _mm512_cmp_ps_mask(distances, min_distances, _CMP_LT_OS); - - min_distances = _mm512_min_ps(distances, min_distances); - min_indices = _mm512_mask_blend_epi32( - comparison, min_indices, current_indices); - - current_indices = - _mm512_add_epi32(current_indices, indices_increment); - - y += 128; - } - - alignas(64) float min_distances_scalar[16]; - alignas(64) uint32_t min_indices_scalar[16]; - _mm512_store_ps(min_distances_scalar, min_distances); - _mm512_store_epi32(min_indices_scalar, min_indices); - - for (size_t j = 0; j < 16; j++) { - if (current_min_distance > min_distances_scalar[j]) { - current_min_distance = min_distances_scalar[j]; - current_min_index = min_indices_scalar[j]; - } - } - } - - if (i < ny) { - __m256 x0 = _mm256_loadu_ps(x); - - for (; i < ny; i++) { - __m256 accu = ElementOpL2::op(x0, _mm256_loadu_ps(y)); - y += 8; - const float distance = horizontal_sum(accu); - - if (current_min_distance > distance) { - current_min_distance = distance; - current_min_index = i; - } - } - } - - return current_min_index; -} - -#elif defined(__AVX2__) - -size_t fvec_L2sqr_ny_nearest_D2( - float* distances_tmp_buffer, - const float* x, - const float* y, - size_t ny) { - // this implementation does not use distances_tmp_buffer. - - // current index being processed - size_t i = 0; - - // min distance and the index of the closest vector so far - float current_min_distance = HUGE_VALF; - size_t current_min_index = 0; - - // process 8 D2-vectors per loop. - const size_t ny8 = ny / 8; - if (ny8 > 0) { - _mm_prefetch((const char*)y, _MM_HINT_T0); - _mm_prefetch((const char*)(y + 16), _MM_HINT_T0); - - // track min distance and the closest vector independently - // for each of 8 AVX2 components. - __m256 min_distances = _mm256_set1_ps(HUGE_VALF); - __m256i min_indices = _mm256_set1_epi32(0); - - __m256i current_indices = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); - const __m256i indices_increment = _mm256_set1_epi32(8); - - // 1 value per register - const __m256 m0 = _mm256_set1_ps(x[0]); - const __m256 m1 = _mm256_set1_ps(x[1]); - - for (; i < ny8 * 8; i += 8) { - _mm_prefetch((const char*)(y + 32), _MM_HINT_T0); - - __m256 v0; - __m256 v1; - - transpose_8x2( - _mm256_loadu_ps(y + 0 * 8), - _mm256_loadu_ps(y + 1 * 8), - v0, - v1); - - // compute differences - const __m256 d0 = _mm256_sub_ps(m0, v0); - const __m256 d1 = _mm256_sub_ps(m1, v1); - - // compute squares of differences - __m256 distances = _mm256_mul_ps(d0, d0); - distances = _mm256_fmadd_ps(d1, d1, distances); - - // compare the new distances to the min distances - // for each of 8 AVX2 components. - __m256 comparison = - _mm256_cmp_ps(min_distances, distances, _CMP_LT_OS); - - // update min distances and indices with closest vectors if needed. - min_distances = _mm256_min_ps(distances, min_distances); - min_indices = _mm256_castps_si256(_mm256_blendv_ps( - _mm256_castsi256_ps(current_indices), - _mm256_castsi256_ps(min_indices), - comparison)); - - // update current indices values. Basically, +8 to each of the - // 8 AVX2 components. - current_indices = - _mm256_add_epi32(current_indices, indices_increment); - - // scroll y forward (8 vectors 2 DIM each). - y += 16; - } - - // dump values and find the minimum distance / minimum index - float min_distances_scalar[8]; - uint32_t min_indices_scalar[8]; - _mm256_storeu_ps(min_distances_scalar, min_distances); - _mm256_storeu_si256((__m256i*)(min_indices_scalar), min_indices); - - for (size_t j = 0; j < 8; j++) { - if (current_min_distance > min_distances_scalar[j]) { - current_min_distance = min_distances_scalar[j]; - current_min_index = min_indices_scalar[j]; - } - } - } - - if (i < ny) { - // process leftovers. - // the following code is not optimal, but it is rarely invoked. - float x0 = x[0]; - float x1 = x[1]; - - for (; i < ny; i++) { - float sub0 = x0 - y[0]; - float sub1 = x1 - y[1]; - float distance = sub0 * sub0 + sub1 * sub1; - - y += 2; - - if (current_min_distance > distance) { - current_min_distance = distance; - current_min_index = i; - } - } - } - - return current_min_index; -} - -size_t fvec_L2sqr_ny_nearest_D4( - float* distances_tmp_buffer, - const float* x, - const float* y, - size_t ny) { - // this implementation does not use distances_tmp_buffer. - - // current index being processed - size_t i = 0; - - // min distance and the index of the closest vector so far - float current_min_distance = HUGE_VALF; - size_t current_min_index = 0; - - // process 8 D4-vectors per loop. - const size_t ny8 = ny / 8; - - if (ny8 > 0) { - // track min distance and the closest vector independently - // for each of 8 AVX2 components. - __m256 min_distances = _mm256_set1_ps(HUGE_VALF); - __m256i min_indices = _mm256_set1_epi32(0); - - __m256i current_indices = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); - const __m256i indices_increment = _mm256_set1_epi32(8); - - // 1 value per register - const __m256 m0 = _mm256_set1_ps(x[0]); - const __m256 m1 = _mm256_set1_ps(x[1]); - const __m256 m2 = _mm256_set1_ps(x[2]); - const __m256 m3 = _mm256_set1_ps(x[3]); - - for (; i < ny8 * 8; i += 8) { - __m256 v0; - __m256 v1; - __m256 v2; - __m256 v3; - - transpose_8x4( - _mm256_loadu_ps(y + 0 * 8), - _mm256_loadu_ps(y + 1 * 8), - _mm256_loadu_ps(y + 2 * 8), - _mm256_loadu_ps(y + 3 * 8), - v0, - v1, - v2, - v3); - - // compute differences - const __m256 d0 = _mm256_sub_ps(m0, v0); - const __m256 d1 = _mm256_sub_ps(m1, v1); - const __m256 d2 = _mm256_sub_ps(m2, v2); - const __m256 d3 = _mm256_sub_ps(m3, v3); - - // compute squares of differences - __m256 distances = _mm256_mul_ps(d0, d0); - distances = _mm256_fmadd_ps(d1, d1, distances); - distances = _mm256_fmadd_ps(d2, d2, distances); - distances = _mm256_fmadd_ps(d3, d3, distances); - - // compare the new distances to the min distances - // for each of 8 AVX2 components. - __m256 comparison = - _mm256_cmp_ps(min_distances, distances, _CMP_LT_OS); - - // update min distances and indices with closest vectors if needed. - min_distances = _mm256_min_ps(distances, min_distances); - min_indices = _mm256_castps_si256(_mm256_blendv_ps( - _mm256_castsi256_ps(current_indices), - _mm256_castsi256_ps(min_indices), - comparison)); - - // update current indices values. Basically, +8 to each of the - // 8 AVX2 components. - current_indices = - _mm256_add_epi32(current_indices, indices_increment); - - // scroll y forward (8 vectors 4 DIM each). - y += 32; - } - - // dump values and find the minimum distance / minimum index - float min_distances_scalar[8]; - uint32_t min_indices_scalar[8]; - _mm256_storeu_ps(min_distances_scalar, min_distances); - _mm256_storeu_si256((__m256i*)(min_indices_scalar), min_indices); - - for (size_t j = 0; j < 8; j++) { - if (current_min_distance > min_distances_scalar[j]) { - current_min_distance = min_distances_scalar[j]; - current_min_index = min_indices_scalar[j]; - } - } - } - - if (i < ny) { - // process leftovers - __m128 x0 = _mm_loadu_ps(x); - - for (; i < ny; i++) { - __m128 accu = ElementOpL2::op(x0, _mm_loadu_ps(y)); - y += 4; - const float distance = horizontal_sum(accu); - - if (current_min_distance > distance) { - current_min_distance = distance; - current_min_index = i; - } - } - } - - return current_min_index; -} - -size_t fvec_L2sqr_ny_nearest_D8( - float* distances_tmp_buffer, - const float* x, - const float* y, - size_t ny) { - // this implementation does not use distances_tmp_buffer. - - // current index being processed - size_t i = 0; - - // min distance and the index of the closest vector so far - float current_min_distance = HUGE_VALF; - size_t current_min_index = 0; - - // process 8 D8-vectors per loop. - const size_t ny8 = ny / 8; - if (ny8 > 0) { - // track min distance and the closest vector independently - // for each of 8 AVX2 components. - __m256 min_distances = _mm256_set1_ps(HUGE_VALF); - __m256i min_indices = _mm256_set1_epi32(0); - - __m256i current_indices = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); - const __m256i indices_increment = _mm256_set1_epi32(8); - - // 1 value per register - const __m256 m0 = _mm256_set1_ps(x[0]); - const __m256 m1 = _mm256_set1_ps(x[1]); - const __m256 m2 = _mm256_set1_ps(x[2]); - const __m256 m3 = _mm256_set1_ps(x[3]); - - const __m256 m4 = _mm256_set1_ps(x[4]); - const __m256 m5 = _mm256_set1_ps(x[5]); - const __m256 m6 = _mm256_set1_ps(x[6]); - const __m256 m7 = _mm256_set1_ps(x[7]); - - for (; i < ny8 * 8; i += 8) { - __m256 v0; - __m256 v1; - __m256 v2; - __m256 v3; - __m256 v4; - __m256 v5; - __m256 v6; - __m256 v7; - - transpose_8x8( - _mm256_loadu_ps(y + 0 * 8), - _mm256_loadu_ps(y + 1 * 8), - _mm256_loadu_ps(y + 2 * 8), - _mm256_loadu_ps(y + 3 * 8), - _mm256_loadu_ps(y + 4 * 8), - _mm256_loadu_ps(y + 5 * 8), - _mm256_loadu_ps(y + 6 * 8), - _mm256_loadu_ps(y + 7 * 8), - v0, - v1, - v2, - v3, - v4, - v5, - v6, - v7); - - // compute differences - const __m256 d0 = _mm256_sub_ps(m0, v0); - const __m256 d1 = _mm256_sub_ps(m1, v1); - const __m256 d2 = _mm256_sub_ps(m2, v2); - const __m256 d3 = _mm256_sub_ps(m3, v3); - const __m256 d4 = _mm256_sub_ps(m4, v4); - const __m256 d5 = _mm256_sub_ps(m5, v5); - const __m256 d6 = _mm256_sub_ps(m6, v6); - const __m256 d7 = _mm256_sub_ps(m7, v7); - - // compute squares of differences - __m256 distances = _mm256_mul_ps(d0, d0); - distances = _mm256_fmadd_ps(d1, d1, distances); - distances = _mm256_fmadd_ps(d2, d2, distances); - distances = _mm256_fmadd_ps(d3, d3, distances); - distances = _mm256_fmadd_ps(d4, d4, distances); - distances = _mm256_fmadd_ps(d5, d5, distances); - distances = _mm256_fmadd_ps(d6, d6, distances); - distances = _mm256_fmadd_ps(d7, d7, distances); - - // compare the new distances to the min distances - // for each of 8 AVX2 components. - __m256 comparison = - _mm256_cmp_ps(min_distances, distances, _CMP_LT_OS); - - // update min distances and indices with closest vectors if needed. - min_distances = _mm256_min_ps(distances, min_distances); - min_indices = _mm256_castps_si256(_mm256_blendv_ps( - _mm256_castsi256_ps(current_indices), - _mm256_castsi256_ps(min_indices), - comparison)); - - // update current indices values. Basically, +8 to each of the - // 8 AVX2 components. - current_indices = - _mm256_add_epi32(current_indices, indices_increment); - - // scroll y forward (8 vectors 8 DIM each). - y += 64; - } - - // dump values and find the minimum distance / minimum index - float min_distances_scalar[8]; - uint32_t min_indices_scalar[8]; - _mm256_storeu_ps(min_distances_scalar, min_distances); - _mm256_storeu_si256((__m256i*)(min_indices_scalar), min_indices); - - for (size_t j = 0; j < 8; j++) { - if (current_min_distance > min_distances_scalar[j]) { - current_min_distance = min_distances_scalar[j]; - current_min_index = min_indices_scalar[j]; - } - } - } - - if (i < ny) { - // process leftovers - __m256 x0 = _mm256_loadu_ps(x); - - for (; i < ny; i++) { - __m256 accu = ElementOpL2::op(x0, _mm256_loadu_ps(y)); - y += 8; - const float distance = horizontal_sum(accu); - - if (current_min_distance > distance) { - current_min_distance = distance; - current_min_index = i; - } - } - } - - return current_min_index; -} - -#else -size_t fvec_L2sqr_ny_nearest_D2( - float* distances_tmp_buffer, - const float* x, - const float* y, - size_t ny) { - return fvec_L2sqr_ny_nearest_ref(distances_tmp_buffer, x, y, 2, ny); -} - -size_t fvec_L2sqr_ny_nearest_D4( - float* distances_tmp_buffer, - const float* x, - const float* y, - size_t ny) { - return fvec_L2sqr_ny_nearest_ref(distances_tmp_buffer, x, y, 4, ny); -} - -size_t fvec_L2sqr_ny_nearest_D8( - float* distances_tmp_buffer, - const float* x, - const float* y, - size_t ny) { - return fvec_L2sqr_ny_nearest_ref(distances_tmp_buffer, x, y, 8, ny); -} -#endif - -size_t fvec_L2sqr_ny_nearest( - float* distances_tmp_buffer, - const float* x, - const float* y, - size_t d, - size_t ny) { - // optimized for a few special cases -#define DISPATCH(dval) \ - case dval: \ - return fvec_L2sqr_ny_nearest_D##dval(distances_tmp_buffer, x, y, ny); - - switch (d) { - DISPATCH(2) - DISPATCH(4) - DISPATCH(8) - default: - return fvec_L2sqr_ny_nearest_ref(distances_tmp_buffer, x, y, d, ny); - } -#undef DISPATCH -} - -#if defined(__AVX512F__) - -template -size_t fvec_L2sqr_ny_nearest_y_transposed_D( - float* distances_tmp_buffer, - const float* x, - const float* y, - const float* y_sqlen, - const size_t d_offset, - size_t ny) { - // This implementation does not use distances_tmp_buffer. - - // Current index being processed - size_t i = 0; - - // Min distance and the index of the closest vector so far - float current_min_distance = HUGE_VALF; - size_t current_min_index = 0; - - // Process 16 vectors per loop - const size_t ny16 = ny / 16; - - if (ny16 > 0) { - // Track min distance and the closest vector independently - // for each of 16 AVX-512 components. - __m512 min_distances = _mm512_set1_ps(HUGE_VALF); - __m512i min_indices = _mm512_set1_epi32(0); - - __m512i current_indices = _mm512_setr_epi32( - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); - const __m512i indices_increment = _mm512_set1_epi32(16); - - // m[i] = (2 * x[i], ... 2 * x[i]) - __m512 m[DIM]; - for (size_t j = 0; j < DIM; j++) { - m[j] = _mm512_set1_ps(x[j]); - m[j] = _mm512_add_ps(m[j], m[j]); - } - - for (; i < ny16 * 16; i += 16) { - // Compute dot products - const __m512 v0 = _mm512_loadu_ps(y + 0 * d_offset); - __m512 dp = _mm512_mul_ps(m[0], v0); - for (size_t j = 1; j < DIM; j++) { - const __m512 vj = _mm512_loadu_ps(y + j * d_offset); - dp = _mm512_fmadd_ps(m[j], vj, dp); - } - - // Compute y^2 - (2 * x, y), which is sufficient for looking for the - // lowest distance. - // x^2 is the constant that can be avoided. - const __m512 distances = - _mm512_sub_ps(_mm512_loadu_ps(y_sqlen), dp); - - // Compare the new distances to the min distances - __mmask16 comparison = - _mm512_cmp_ps_mask(min_distances, distances, _CMP_LT_OS); - - // Update min distances and indices with closest vectors if needed - min_distances = - _mm512_mask_blend_ps(comparison, distances, min_distances); - min_indices = _mm512_castps_si512(_mm512_mask_blend_ps( - comparison, - _mm512_castsi512_ps(current_indices), - _mm512_castsi512_ps(min_indices))); - - // Update current indices values. Basically, +16 to each of the 16 - // AVX-512 components. - current_indices = - _mm512_add_epi32(current_indices, indices_increment); - - // Scroll y and y_sqlen forward. - y += 16; - y_sqlen += 16; - } - - // Dump values and find the minimum distance / minimum index - float min_distances_scalar[16]; - uint32_t min_indices_scalar[16]; - _mm512_storeu_ps(min_distances_scalar, min_distances); - _mm512_storeu_si512((__m512i*)(min_indices_scalar), min_indices); - - for (size_t j = 0; j < 16; j++) { - if (current_min_distance > min_distances_scalar[j]) { - current_min_distance = min_distances_scalar[j]; - current_min_index = min_indices_scalar[j]; - } - } - } - - if (i < ny) { - // Process leftovers - for (; i < ny; i++) { - float dp = 0; - for (size_t j = 0; j < DIM; j++) { - dp += x[j] * y[j * d_offset]; - } - - // Compute y^2 - 2 * (x, y), which is sufficient for looking for the - // lowest distance. - const float distance = y_sqlen[0] - 2 * dp; - - if (current_min_distance > distance) { - current_min_distance = distance; - current_min_index = i; - } - - y += 1; - y_sqlen += 1; - } - } - - return current_min_index; -} - -#elif defined(__AVX2__) - -template -size_t fvec_L2sqr_ny_nearest_y_transposed_D( - float* distances_tmp_buffer, - const float* x, - const float* y, - const float* y_sqlen, - const size_t d_offset, - size_t ny) { - // this implementation does not use distances_tmp_buffer. - - // current index being processed - size_t i = 0; - - // min distance and the index of the closest vector so far - float current_min_distance = HUGE_VALF; - size_t current_min_index = 0; - - // process 8 vectors per loop. - const size_t ny8 = ny / 8; - - if (ny8 > 0) { - // track min distance and the closest vector independently - // for each of 8 AVX2 components. - __m256 min_distances = _mm256_set1_ps(HUGE_VALF); - __m256i min_indices = _mm256_set1_epi32(0); - - __m256i current_indices = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); - const __m256i indices_increment = _mm256_set1_epi32(8); - - // m[i] = (2 * x[i], ... 2 * x[i]) - __m256 m[DIM]; - for (size_t j = 0; j < DIM; j++) { - m[j] = _mm256_set1_ps(x[j]); - m[j] = _mm256_add_ps(m[j], m[j]); - } - - for (; i < ny8 * 8; i += 8) { - // collect dim 0 for 8 D4-vectors. - const __m256 v0 = _mm256_loadu_ps(y + 0 * d_offset); - // compute dot products - __m256 dp = _mm256_mul_ps(m[0], v0); - - for (size_t j = 1; j < DIM; j++) { - // collect dim j for 8 D4-vectors. - const __m256 vj = _mm256_loadu_ps(y + j * d_offset); - dp = _mm256_fmadd_ps(m[j], vj, dp); - } - - // compute y^2 - (2 * x, y), which is sufficient for looking for the - // lowest distance. - // x^2 is the constant that can be avoided. - const __m256 distances = - _mm256_sub_ps(_mm256_loadu_ps(y_sqlen), dp); - - // compare the new distances to the min distances - // for each of 8 AVX2 components. - const __m256 comparison = - _mm256_cmp_ps(min_distances, distances, _CMP_LT_OS); - - // update min distances and indices with closest vectors if needed. - min_distances = - _mm256_blendv_ps(distances, min_distances, comparison); - min_indices = _mm256_castps_si256(_mm256_blendv_ps( - _mm256_castsi256_ps(current_indices), - _mm256_castsi256_ps(min_indices), - comparison)); - - // update current indices values. Basically, +8 to each of the - // 8 AVX2 components. - current_indices = - _mm256_add_epi32(current_indices, indices_increment); - - // scroll y and y_sqlen forward. - y += 8; - y_sqlen += 8; - } - - // dump values and find the minimum distance / minimum index - float min_distances_scalar[8]; - uint32_t min_indices_scalar[8]; - _mm256_storeu_ps(min_distances_scalar, min_distances); - _mm256_storeu_si256((__m256i*)(min_indices_scalar), min_indices); - - for (size_t j = 0; j < 8; j++) { - if (current_min_distance > min_distances_scalar[j]) { - current_min_distance = min_distances_scalar[j]; - current_min_index = min_indices_scalar[j]; - } - } - } - - if (i < ny) { - // process leftovers - for (; i < ny; i++) { - float dp = 0; - for (size_t j = 0; j < DIM; j++) { - dp += x[j] * y[j * d_offset]; - } - - // compute y^2 - 2 * (x, y), which is sufficient for looking for the - // lowest distance. - const float distance = y_sqlen[0] - 2 * dp; - - if (current_min_distance > distance) { - current_min_distance = distance; - current_min_index = i; - } - - y += 1; - y_sqlen += 1; - } - } - - return current_min_index; -} - -#endif - -size_t fvec_L2sqr_ny_nearest_y_transposed( - float* distances_tmp_buffer, - const float* x, - const float* y, - const float* y_sqlen, - size_t d, - size_t d_offset, - size_t ny) { - // optimized for a few special cases -#ifdef __AVX2__ -#define DISPATCH(dval) \ - case dval: \ - return fvec_L2sqr_ny_nearest_y_transposed_D( \ - distances_tmp_buffer, x, y, y_sqlen, d_offset, ny); - - switch (d) { - DISPATCH(1) - DISPATCH(2) - DISPATCH(4) - DISPATCH(8) - default: - return fvec_L2sqr_ny_nearest_y_transposed_ref( - distances_tmp_buffer, x, y, y_sqlen, d, d_offset, ny); - } -#undef DISPATCH -#else - // non-AVX2 case - return fvec_L2sqr_ny_nearest_y_transposed_ref( - distances_tmp_buffer, x, y, y_sqlen, d, d_offset, ny); -#endif -} - -#endif - -#ifdef USE_AVX - -float fvec_L1(const float* x, const float* y, size_t d) { - __m256 msum1 = _mm256_setzero_ps(); - // signmask used for absolute value - __m256 signmask = _mm256_castsi256_ps(_mm256_set1_epi32(0x7fffffffUL)); - - while (d >= 8) { - __m256 mx = _mm256_loadu_ps(x); - x += 8; - __m256 my = _mm256_loadu_ps(y); - y += 8; - // subtract - const __m256 a_m_b = _mm256_sub_ps(mx, my); - // find sum of absolute value of distances (manhattan distance) - msum1 = _mm256_add_ps(msum1, _mm256_and_ps(signmask, a_m_b)); - d -= 8; - } - - __m128 msum2 = _mm256_extractf128_ps(msum1, 1); - msum2 = _mm_add_ps(msum2, _mm256_extractf128_ps(msum1, 0)); - __m128 signmask2 = _mm_castsi128_ps(_mm_set1_epi32(0x7fffffffUL)); - - if (d >= 4) { - __m128 mx = _mm_loadu_ps(x); - x += 4; - __m128 my = _mm_loadu_ps(y); - y += 4; - const __m128 a_m_b = _mm_sub_ps(mx, my); - msum2 = _mm_add_ps(msum2, _mm_and_ps(signmask2, a_m_b)); - d -= 4; - } - - if (d > 0) { - __m128 mx = masked_read(d, x); - __m128 my = masked_read(d, y); - __m128 a_m_b = _mm_sub_ps(mx, my); - msum2 = _mm_add_ps(msum2, _mm_and_ps(signmask2, a_m_b)); - } - - msum2 = _mm_hadd_ps(msum2, msum2); - msum2 = _mm_hadd_ps(msum2, msum2); - return _mm_cvtss_f32(msum2); -} - -float fvec_Linf(const float* x, const float* y, size_t d) { - __m256 msum1 = _mm256_setzero_ps(); - // signmask used for absolute value - __m256 signmask = _mm256_castsi256_ps(_mm256_set1_epi32(0x7fffffffUL)); - - while (d >= 8) { - __m256 mx = _mm256_loadu_ps(x); - x += 8; - __m256 my = _mm256_loadu_ps(y); - y += 8; - // subtract - const __m256 a_m_b = _mm256_sub_ps(mx, my); - // find max of absolute value of distances (chebyshev distance) - msum1 = _mm256_max_ps(msum1, _mm256_and_ps(signmask, a_m_b)); - d -= 8; - } - - __m128 msum2 = _mm256_extractf128_ps(msum1, 1); - msum2 = _mm_max_ps(msum2, _mm256_extractf128_ps(msum1, 0)); - __m128 signmask2 = _mm_castsi128_ps(_mm_set1_epi32(0x7fffffffUL)); - - if (d >= 4) { - __m128 mx = _mm_loadu_ps(x); - x += 4; - __m128 my = _mm_loadu_ps(y); - y += 4; - const __m128 a_m_b = _mm_sub_ps(mx, my); - msum2 = _mm_max_ps(msum2, _mm_and_ps(signmask2, a_m_b)); - d -= 4; - } - - if (d > 0) { - __m128 mx = masked_read(d, x); - __m128 my = masked_read(d, y); - __m128 a_m_b = _mm_sub_ps(mx, my); - msum2 = _mm_max_ps(msum2, _mm_and_ps(signmask2, a_m_b)); - } - - msum2 = _mm_max_ps(_mm_movehl_ps(msum2, msum2), msum2); - msum2 = _mm_max_ps(msum2, _mm_shuffle_ps(msum2, msum2, 1)); - return _mm_cvtss_f32(msum2); -} - -#elif defined(__SSE3__) // But not AVX - -float fvec_L1(const float* x, const float* y, size_t d) { - return fvec_L1_ref(x, y, d); -} - -float fvec_Linf(const float* x, const float* y, size_t d) { - return fvec_Linf_ref(x, y, d); -} - -#elif defined(__ARM_FEATURE_SVE) - -struct ElementOpIP { - static svfloat32_t op(svbool_t pg, svfloat32_t x, svfloat32_t y) { - return svmul_f32_x(pg, x, y); - } - static svfloat32_t merge( - svbool_t pg, - svfloat32_t z, - svfloat32_t x, - svfloat32_t y) { - return svmla_f32_x(pg, z, x, y); - } -}; - -template -void fvec_op_ny_sve_d1(float* dis, const float* x, const float* y, size_t ny) { - const size_t lanes = svcntw(); - const size_t lanes2 = lanes * 2; - const size_t lanes3 = lanes * 3; - const size_t lanes4 = lanes * 4; - const svbool_t pg = svptrue_b32(); - const svfloat32_t x0 = svdup_n_f32(x[0]); - size_t i = 0; - for (; i + lanes4 < ny; i += lanes4) { - svfloat32_t y0 = svld1_f32(pg, y); - svfloat32_t y1 = svld1_f32(pg, y + lanes); - svfloat32_t y2 = svld1_f32(pg, y + lanes2); - svfloat32_t y3 = svld1_f32(pg, y + lanes3); - y0 = ElementOp::op(pg, x0, y0); - y1 = ElementOp::op(pg, x0, y1); - y2 = ElementOp::op(pg, x0, y2); - y3 = ElementOp::op(pg, x0, y3); - svst1_f32(pg, dis, y0); - svst1_f32(pg, dis + lanes, y1); - svst1_f32(pg, dis + lanes2, y2); - svst1_f32(pg, dis + lanes3, y3); - y += lanes4; - dis += lanes4; - } - const svbool_t pg0 = svwhilelt_b32_u64(i, ny); - const svbool_t pg1 = svwhilelt_b32_u64(i + lanes, ny); - const svbool_t pg2 = svwhilelt_b32_u64(i + lanes2, ny); - const svbool_t pg3 = svwhilelt_b32_u64(i + lanes3, ny); - svfloat32_t y0 = svld1_f32(pg0, y); - svfloat32_t y1 = svld1_f32(pg1, y + lanes); - svfloat32_t y2 = svld1_f32(pg2, y + lanes2); - svfloat32_t y3 = svld1_f32(pg3, y + lanes3); - y0 = ElementOp::op(pg0, x0, y0); - y1 = ElementOp::op(pg1, x0, y1); - y2 = ElementOp::op(pg2, x0, y2); - y3 = ElementOp::op(pg3, x0, y3); - svst1_f32(pg0, dis, y0); - svst1_f32(pg1, dis + lanes, y1); - svst1_f32(pg2, dis + lanes2, y2); - svst1_f32(pg3, dis + lanes3, y3); -} - -template -void fvec_op_ny_sve_d2(float* dis, const float* x, const float* y, size_t ny) { - const size_t lanes = svcntw(); - const size_t lanes2 = lanes * 2; - const size_t lanes4 = lanes * 4; - const svbool_t pg = svptrue_b32(); - const svfloat32_t x0 = svdup_n_f32(x[0]); - const svfloat32_t x1 = svdup_n_f32(x[1]); - size_t i = 0; - for (; i + lanes2 < ny; i += lanes2) { - const svfloat32x2_t y0 = svld2_f32(pg, y); - const svfloat32x2_t y1 = svld2_f32(pg, y + lanes2); - svfloat32_t y00 = svget2_f32(y0, 0); - const svfloat32_t y01 = svget2_f32(y0, 1); - svfloat32_t y10 = svget2_f32(y1, 0); - const svfloat32_t y11 = svget2_f32(y1, 1); - y00 = ElementOp::op(pg, x0, y00); - y10 = ElementOp::op(pg, x0, y10); - y00 = ElementOp::merge(pg, y00, x1, y01); - y10 = ElementOp::merge(pg, y10, x1, y11); - svst1_f32(pg, dis, y00); - svst1_f32(pg, dis + lanes, y10); - y += lanes4; - dis += lanes2; - } - const svbool_t pg0 = svwhilelt_b32_u64(i, ny); - const svbool_t pg1 = svwhilelt_b32_u64(i + lanes, ny); - const svfloat32x2_t y0 = svld2_f32(pg0, y); - const svfloat32x2_t y1 = svld2_f32(pg1, y + lanes2); - svfloat32_t y00 = svget2_f32(y0, 0); - const svfloat32_t y01 = svget2_f32(y0, 1); - svfloat32_t y10 = svget2_f32(y1, 0); - const svfloat32_t y11 = svget2_f32(y1, 1); - y00 = ElementOp::op(pg0, x0, y00); - y10 = ElementOp::op(pg1, x0, y10); - y00 = ElementOp::merge(pg0, y00, x1, y01); - y10 = ElementOp::merge(pg1, y10, x1, y11); - svst1_f32(pg0, dis, y00); - svst1_f32(pg1, dis + lanes, y10); -} - -template -void fvec_op_ny_sve_d4(float* dis, const float* x, const float* y, size_t ny) { - const size_t lanes = svcntw(); - const size_t lanes4 = lanes * 4; - const svbool_t pg = svptrue_b32(); - const svfloat32_t x0 = svdup_n_f32(x[0]); - const svfloat32_t x1 = svdup_n_f32(x[1]); - const svfloat32_t x2 = svdup_n_f32(x[2]); - const svfloat32_t x3 = svdup_n_f32(x[3]); - size_t i = 0; - for (; i + lanes < ny; i += lanes) { - const svfloat32x4_t y0 = svld4_f32(pg, y); - svfloat32_t y00 = svget4_f32(y0, 0); - const svfloat32_t y01 = svget4_f32(y0, 1); - svfloat32_t y02 = svget4_f32(y0, 2); - const svfloat32_t y03 = svget4_f32(y0, 3); - y00 = ElementOp::op(pg, x0, y00); - y02 = ElementOp::op(pg, x2, y02); - y00 = ElementOp::merge(pg, y00, x1, y01); - y02 = ElementOp::merge(pg, y02, x3, y03); - y00 = svadd_f32_x(pg, y00, y02); - svst1_f32(pg, dis, y00); - y += lanes4; - dis += lanes; - } - const svbool_t pg0 = svwhilelt_b32_u64(i, ny); - const svfloat32x4_t y0 = svld4_f32(pg0, y); - svfloat32_t y00 = svget4_f32(y0, 0); - const svfloat32_t y01 = svget4_f32(y0, 1); - svfloat32_t y02 = svget4_f32(y0, 2); - const svfloat32_t y03 = svget4_f32(y0, 3); - y00 = ElementOp::op(pg0, x0, y00); - y02 = ElementOp::op(pg0, x2, y02); - y00 = ElementOp::merge(pg0, y00, x1, y01); - y02 = ElementOp::merge(pg0, y02, x3, y03); - y00 = svadd_f32_x(pg0, y00, y02); - svst1_f32(pg0, dis, y00); -} - -template -void fvec_op_ny_sve_d8(float* dis, const float* x, const float* y, size_t ny) { - const size_t lanes = svcntw(); - const size_t lanes4 = lanes * 4; - const size_t lanes8 = lanes * 8; - const svbool_t pg = svptrue_b32(); - const svfloat32_t x0 = svdup_n_f32(x[0]); - const svfloat32_t x1 = svdup_n_f32(x[1]); - const svfloat32_t x2 = svdup_n_f32(x[2]); - const svfloat32_t x3 = svdup_n_f32(x[3]); - const svfloat32_t x4 = svdup_n_f32(x[4]); - const svfloat32_t x5 = svdup_n_f32(x[5]); - const svfloat32_t x6 = svdup_n_f32(x[6]); - const svfloat32_t x7 = svdup_n_f32(x[7]); - size_t i = 0; - for (; i + lanes < ny; i += lanes) { - const svfloat32x4_t ya = svld4_f32(pg, y); - const svfloat32x4_t yb = svld4_f32(pg, y + lanes4); - const svfloat32_t ya0 = svget4_f32(ya, 0); - const svfloat32_t ya1 = svget4_f32(ya, 1); - const svfloat32_t ya2 = svget4_f32(ya, 2); - const svfloat32_t ya3 = svget4_f32(ya, 3); - const svfloat32_t yb0 = svget4_f32(yb, 0); - const svfloat32_t yb1 = svget4_f32(yb, 1); - const svfloat32_t yb2 = svget4_f32(yb, 2); - const svfloat32_t yb3 = svget4_f32(yb, 3); - svfloat32_t y0 = svuzp1(ya0, yb0); - const svfloat32_t y1 = svuzp1(ya1, yb1); - svfloat32_t y2 = svuzp1(ya2, yb2); - const svfloat32_t y3 = svuzp1(ya3, yb3); - svfloat32_t y4 = svuzp2(ya0, yb0); - const svfloat32_t y5 = svuzp2(ya1, yb1); - svfloat32_t y6 = svuzp2(ya2, yb2); - const svfloat32_t y7 = svuzp2(ya3, yb3); - y0 = ElementOp::op(pg, x0, y0); - y2 = ElementOp::op(pg, x2, y2); - y4 = ElementOp::op(pg, x4, y4); - y6 = ElementOp::op(pg, x6, y6); - y0 = ElementOp::merge(pg, y0, x1, y1); - y2 = ElementOp::merge(pg, y2, x3, y3); - y4 = ElementOp::merge(pg, y4, x5, y5); - y6 = ElementOp::merge(pg, y6, x7, y7); - y0 = svadd_f32_x(pg, y0, y2); - y4 = svadd_f32_x(pg, y4, y6); - y0 = svadd_f32_x(pg, y0, y4); - svst1_f32(pg, dis, y0); - y += lanes8; - dis += lanes; - } - const svbool_t pg0 = svwhilelt_b32_u64(i, ny); - const svbool_t pga = svwhilelt_b32_u64(i * 2, ny * 2); - const svbool_t pgb = svwhilelt_b32_u64(i * 2 + lanes, ny * 2); - const svfloat32x4_t ya = svld4_f32(pga, y); - const svfloat32x4_t yb = svld4_f32(pgb, y + lanes4); - const svfloat32_t ya0 = svget4_f32(ya, 0); - const svfloat32_t ya1 = svget4_f32(ya, 1); - const svfloat32_t ya2 = svget4_f32(ya, 2); - const svfloat32_t ya3 = svget4_f32(ya, 3); - const svfloat32_t yb0 = svget4_f32(yb, 0); - const svfloat32_t yb1 = svget4_f32(yb, 1); - const svfloat32_t yb2 = svget4_f32(yb, 2); - const svfloat32_t yb3 = svget4_f32(yb, 3); - svfloat32_t y0 = svuzp1(ya0, yb0); - const svfloat32_t y1 = svuzp1(ya1, yb1); - svfloat32_t y2 = svuzp1(ya2, yb2); - const svfloat32_t y3 = svuzp1(ya3, yb3); - svfloat32_t y4 = svuzp2(ya0, yb0); - const svfloat32_t y5 = svuzp2(ya1, yb1); - svfloat32_t y6 = svuzp2(ya2, yb2); - const svfloat32_t y7 = svuzp2(ya3, yb3); - y0 = ElementOp::op(pg0, x0, y0); - y2 = ElementOp::op(pg0, x2, y2); - y4 = ElementOp::op(pg0, x4, y4); - y6 = ElementOp::op(pg0, x6, y6); - y0 = ElementOp::merge(pg0, y0, x1, y1); - y2 = ElementOp::merge(pg0, y2, x3, y3); - y4 = ElementOp::merge(pg0, y4, x5, y5); - y6 = ElementOp::merge(pg0, y6, x7, y7); - y0 = svadd_f32_x(pg0, y0, y2); - y4 = svadd_f32_x(pg0, y4, y6); - y0 = svadd_f32_x(pg0, y0, y4); - svst1_f32(pg0, dis, y0); - y += lanes8; - dis += lanes; -} - -template -void fvec_op_ny_sve_lanes1( - float* dis, - const float* x, - const float* y, - size_t ny) { - const size_t lanes = svcntw(); - const size_t lanes2 = lanes * 2; - const size_t lanes3 = lanes * 3; - const size_t lanes4 = lanes * 4; - const svbool_t pg = svptrue_b32(); - const svfloat32_t x0 = svld1_f32(pg, x); - size_t i = 0; - for (; i + 3 < ny; i += 4) { - svfloat32_t y0 = svld1_f32(pg, y); - svfloat32_t y1 = svld1_f32(pg, y + lanes); - svfloat32_t y2 = svld1_f32(pg, y + lanes2); - svfloat32_t y3 = svld1_f32(pg, y + lanes3); - y += lanes4; - y0 = ElementOp::op(pg, x0, y0); - y1 = ElementOp::op(pg, x0, y1); - y2 = ElementOp::op(pg, x0, y2); - y3 = ElementOp::op(pg, x0, y3); - dis[i] = svaddv_f32(pg, y0); - dis[i + 1] = svaddv_f32(pg, y1); - dis[i + 2] = svaddv_f32(pg, y2); - dis[i + 3] = svaddv_f32(pg, y3); - } - for (; i < ny; ++i) { - svfloat32_t y0 = svld1_f32(pg, y); - y += lanes; - y0 = ElementOp::op(pg, x0, y0); - dis[i] = svaddv_f32(pg, y0); - } -} - -template -void fvec_op_ny_sve_lanes2( - float* dis, - const float* x, - const float* y, - size_t ny) { - const size_t lanes = svcntw(); - const size_t lanes2 = lanes * 2; - const size_t lanes3 = lanes * 3; - const size_t lanes4 = lanes * 4; - const svbool_t pg = svptrue_b32(); - const svfloat32_t x0 = svld1_f32(pg, x); - const svfloat32_t x1 = svld1_f32(pg, x + lanes); - size_t i = 0; - for (; i + 1 < ny; i += 2) { - svfloat32_t y00 = svld1_f32(pg, y); - const svfloat32_t y01 = svld1_f32(pg, y + lanes); - svfloat32_t y10 = svld1_f32(pg, y + lanes2); - const svfloat32_t y11 = svld1_f32(pg, y + lanes3); - y += lanes4; - y00 = ElementOp::op(pg, x0, y00); - y10 = ElementOp::op(pg, x0, y10); - y00 = ElementOp::merge(pg, y00, x1, y01); - y10 = ElementOp::merge(pg, y10, x1, y11); - dis[i] = svaddv_f32(pg, y00); - dis[i + 1] = svaddv_f32(pg, y10); - } - if (i < ny) { - svfloat32_t y0 = svld1_f32(pg, y); - const svfloat32_t y1 = svld1_f32(pg, y + lanes); - y0 = ElementOp::op(pg, x0, y0); - y0 = ElementOp::merge(pg, y0, x1, y1); - dis[i] = svaddv_f32(pg, y0); - } -} - -template -void fvec_op_ny_sve_lanes3( - float* dis, - const float* x, - const float* y, - size_t ny) { - const size_t lanes = svcntw(); - const size_t lanes2 = lanes * 2; - const size_t lanes3 = lanes * 3; - const svbool_t pg = svptrue_b32(); - const svfloat32_t x0 = svld1_f32(pg, x); - const svfloat32_t x1 = svld1_f32(pg, x + lanes); - const svfloat32_t x2 = svld1_f32(pg, x + lanes2); - for (size_t i = 0; i < ny; ++i) { - svfloat32_t y0 = svld1_f32(pg, y); - const svfloat32_t y1 = svld1_f32(pg, y + lanes); - svfloat32_t y2 = svld1_f32(pg, y + lanes2); - y += lanes3; - y0 = ElementOp::op(pg, x0, y0); - y0 = ElementOp::merge(pg, y0, x1, y1); - y0 = ElementOp::merge(pg, y0, x2, y2); - dis[i] = svaddv_f32(pg, y0); - } -} - -template -void fvec_op_ny_sve_lanes4( - float* dis, - const float* x, - const float* y, - size_t ny) { - const size_t lanes = svcntw(); - const size_t lanes2 = lanes * 2; - const size_t lanes3 = lanes * 3; - const size_t lanes4 = lanes * 4; - const svbool_t pg = svptrue_b32(); - const svfloat32_t x0 = svld1_f32(pg, x); - const svfloat32_t x1 = svld1_f32(pg, x + lanes); - const svfloat32_t x2 = svld1_f32(pg, x + lanes2); - const svfloat32_t x3 = svld1_f32(pg, x + lanes3); - for (size_t i = 0; i < ny; ++i) { - svfloat32_t y0 = svld1_f32(pg, y); - const svfloat32_t y1 = svld1_f32(pg, y + lanes); - svfloat32_t y2 = svld1_f32(pg, y + lanes2); - const svfloat32_t y3 = svld1_f32(pg, y + lanes3); - y += lanes4; - y0 = ElementOp::op(pg, x0, y0); - y2 = ElementOp::op(pg, x2, y2); - y0 = ElementOp::merge(pg, y0, x1, y1); - y2 = ElementOp::merge(pg, y2, x3, y3); - y0 = svadd_f32_x(pg, y0, y2); - dis[i] = svaddv_f32(pg, y0); - } -} - -void fvec_L2sqr_ny( - float* dis, - const float* x, - const float* y, - size_t d, - size_t ny) { - fvec_L2sqr_ny_ref(dis, x, y, d, ny); -} - -void fvec_L2sqr_ny_transposed( - float* dis, - const float* x, - const float* y, - const float* y_sqlen, - size_t d, - size_t d_offset, - size_t ny) { - return fvec_L2sqr_ny_y_transposed_ref(dis, x, y, y_sqlen, d, d_offset, ny); -} - -size_t fvec_L2sqr_ny_nearest( - float* distances_tmp_buffer, - const float* x, - const float* y, - size_t d, - size_t ny) { - return fvec_L2sqr_ny_nearest_ref(distances_tmp_buffer, x, y, d, ny); -} - -size_t fvec_L2sqr_ny_nearest_y_transposed( - float* distances_tmp_buffer, - const float* x, - const float* y, - const float* y_sqlen, - size_t d, - size_t d_offset, - size_t ny) { - return fvec_L2sqr_ny_nearest_y_transposed_ref( - distances_tmp_buffer, x, y, y_sqlen, d, d_offset, ny); -} - -float fvec_L1(const float* x, const float* y, size_t d) { - return fvec_L1_ref(x, y, d); -} - -float fvec_Linf(const float* x, const float* y, size_t d) { - return fvec_Linf_ref(x, y, d); -} - -void fvec_inner_products_ny( - float* dis, - const float* x, - const float* y, - size_t d, - size_t ny) { - const size_t lanes = svcntw(); - switch (d) { - case 1: - fvec_op_ny_sve_d1(dis, x, y, ny); - break; - case 2: - fvec_op_ny_sve_d2(dis, x, y, ny); - break; - case 4: - fvec_op_ny_sve_d4(dis, x, y, ny); - break; - case 8: - fvec_op_ny_sve_d8(dis, x, y, ny); - break; - default: - if (d == lanes) - fvec_op_ny_sve_lanes1(dis, x, y, ny); - else if (d == lanes * 2) - fvec_op_ny_sve_lanes2(dis, x, y, ny); - else if (d == lanes * 3) - fvec_op_ny_sve_lanes3(dis, x, y, ny); - else if (d == lanes * 4) - fvec_op_ny_sve_lanes4(dis, x, y, ny); - else - fvec_inner_products_ny_ref(dis, x, y, d, ny); - break; - } -} - -#elif defined(__aarch64__) - -// not optimized for ARM -void fvec_L2sqr_ny( - float* dis, - const float* x, - const float* y, - size_t d, - size_t ny) { - fvec_L2sqr_ny_ref(dis, x, y, d, ny); -} - -void fvec_L2sqr_ny_transposed( - float* dis, - const float* x, - const float* y, - const float* y_sqlen, - size_t d, - size_t d_offset, - size_t ny) { - return fvec_L2sqr_ny_y_transposed_ref(dis, x, y, y_sqlen, d, d_offset, ny); -} - -size_t fvec_L2sqr_ny_nearest( - float* distances_tmp_buffer, - const float* x, - const float* y, - size_t d, - size_t ny) { - return fvec_L2sqr_ny_nearest_ref(distances_tmp_buffer, x, y, d, ny); -} - -size_t fvec_L2sqr_ny_nearest_y_transposed( - float* distances_tmp_buffer, - const float* x, - const float* y, - const float* y_sqlen, - size_t d, - size_t d_offset, - size_t ny) { - return fvec_L2sqr_ny_nearest_y_transposed_ref( - distances_tmp_buffer, x, y, y_sqlen, d, d_offset, ny); -} - -float fvec_L1(const float* x, const float* y, size_t d) { - return fvec_L1_ref(x, y, d); -} - -float fvec_Linf(const float* x, const float* y, size_t d) { - return fvec_Linf_ref(x, y, d); -} - -void fvec_inner_products_ny( - float* dis, - const float* x, - const float* y, - size_t d, - size_t ny) { - fvec_inner_products_ny_ref(dis, x, y, d, ny); -} - -#else -// scalar implementation - -float fvec_L1(const float* x, const float* y, size_t d) { - return fvec_L1_ref(x, y, d); -} - -float fvec_Linf(const float* x, const float* y, size_t d) { - return fvec_Linf_ref(x, y, d); -} - -void fvec_L2sqr_ny( - float* dis, - const float* x, - const float* y, - size_t d, - size_t ny) { - fvec_L2sqr_ny_ref(dis, x, y, d, ny); -} - -void fvec_L2sqr_ny_transposed( - float* dis, - const float* x, - const float* y, - const float* y_sqlen, - size_t d, - size_t d_offset, - size_t ny) { - return fvec_L2sqr_ny_y_transposed_ref(dis, x, y, y_sqlen, d, d_offset, ny); -} - -size_t fvec_L2sqr_ny_nearest( - float* distances_tmp_buffer, - const float* x, - const float* y, - size_t d, - size_t ny) { - return fvec_L2sqr_ny_nearest_ref(distances_tmp_buffer, x, y, d, ny); -} - -size_t fvec_L2sqr_ny_nearest_y_transposed( - float* distances_tmp_buffer, - const float* x, - const float* y, - const float* y_sqlen, - size_t d, - size_t d_offset, - size_t ny) { - return fvec_L2sqr_ny_nearest_y_transposed_ref( - distances_tmp_buffer, x, y, y_sqlen, d, d_offset, ny); -} - -void fvec_inner_products_ny( - float* dis, - const float* x, - const float* y, - size_t d, - size_t ny) { - fvec_inner_products_ny_ref(dis, x, y, d, ny); -} - -#endif - -/*************************************************************************** - * heavily optimized table computations - ***************************************************************************/ - -[[maybe_unused]] static inline void fvec_madd_ref( - size_t n, - const float* a, - float bf, - const float* b, - float* c) { - for (size_t i = 0; i < n; i++) { - c[i] = a[i] + bf * b[i]; - } -} - -#if defined(__AVX512F__) - -static inline void fvec_madd_avx512( - const size_t n, - const float* __restrict a, - const float bf, - const float* __restrict b, - float* __restrict c) { - const size_t n16 = n / 16; - const size_t n_for_masking = n % 16; - - const __m512 bfmm = _mm512_set1_ps(bf); - - size_t idx = 0; - for (idx = 0; idx < n16 * 16; idx += 16) { - const __m512 ax = _mm512_loadu_ps(a + idx); - const __m512 bx = _mm512_loadu_ps(b + idx); - const __m512 abmul = _mm512_fmadd_ps(bfmm, bx, ax); - _mm512_storeu_ps(c + idx, abmul); - } - - if (n_for_masking > 0) { - const __mmask16 mask = (1 << n_for_masking) - 1; - - const __m512 ax = _mm512_maskz_loadu_ps(mask, a + idx); - const __m512 bx = _mm512_maskz_loadu_ps(mask, b + idx); - const __m512 abmul = _mm512_fmadd_ps(bfmm, bx, ax); - _mm512_mask_storeu_ps(c + idx, mask, abmul); - } -} - -#elif defined(__AVX2__) - -static inline void fvec_madd_avx2( - const size_t n, - const float* __restrict a, - const float bf, - const float* __restrict b, - float* __restrict c) { - // - const size_t n8 = n / 8; - const size_t n_for_masking = n % 8; - - const __m256 bfmm = _mm256_set1_ps(bf); - - size_t idx = 0; - for (idx = 0; idx < n8 * 8; idx += 8) { - const __m256 ax = _mm256_loadu_ps(a + idx); - const __m256 bx = _mm256_loadu_ps(b + idx); - const __m256 abmul = _mm256_fmadd_ps(bfmm, bx, ax); - _mm256_storeu_ps(c + idx, abmul); - } - - if (n_for_masking > 0) { - __m256i mask; - switch (n_for_masking) { - case 1: - mask = _mm256_set_epi32(0, 0, 0, 0, 0, 0, 0, -1); - break; - case 2: - mask = _mm256_set_epi32(0, 0, 0, 0, 0, 0, -1, -1); - break; - case 3: - mask = _mm256_set_epi32(0, 0, 0, 0, 0, -1, -1, -1); - break; - case 4: - mask = _mm256_set_epi32(0, 0, 0, 0, -1, -1, -1, -1); - break; - case 5: - mask = _mm256_set_epi32(0, 0, 0, -1, -1, -1, -1, -1); - break; - case 6: - mask = _mm256_set_epi32(0, 0, -1, -1, -1, -1, -1, -1); - break; - case 7: - mask = _mm256_set_epi32(0, -1, -1, -1, -1, -1, -1, -1); - break; - } - - const __m256 ax = _mm256_maskload_ps(a + idx, mask); - const __m256 bx = _mm256_maskload_ps(b + idx, mask); - const __m256 abmul = _mm256_fmadd_ps(bfmm, bx, ax); - _mm256_maskstore_ps(c + idx, mask, abmul); - } -} - -#endif - -#ifdef __SSE3__ - -[[maybe_unused]] static inline void fvec_madd_sse( - size_t n, - const float* a, - float bf, - const float* b, - float* c) { - n >>= 2; - __m128 bf4 = _mm_set_ps1(bf); - __m128* a4 = (__m128*)a; - __m128* b4 = (__m128*)b; - __m128* c4 = (__m128*)c; - - while (n--) { - *c4 = _mm_add_ps(*a4, _mm_mul_ps(bf4, *b4)); - b4++; - a4++; - c4++; - } -} - -void fvec_madd(size_t n, const float* a, float bf, const float* b, float* c) { -#ifdef __AVX512F__ - fvec_madd_avx512(n, a, bf, b, c); -#elif __AVX2__ - fvec_madd_avx2(n, a, bf, b, c); -#else - if ((n & 3) == 0 && ((((long)a) | ((long)b) | ((long)c)) & 15) == 0) - fvec_madd_sse(n, a, bf, b, c); - else - fvec_madd_ref(n, a, bf, b, c); -#endif -} - -#elif defined(__ARM_FEATURE_SVE) - -void fvec_madd( - const size_t n, - const float* __restrict a, - const float bf, - const float* __restrict b, - float* __restrict c) { - const size_t lanes = static_cast(svcntw()); - const size_t lanes2 = lanes * 2; - const size_t lanes3 = lanes * 3; - const size_t lanes4 = lanes * 4; - size_t i = 0; - for (; i + lanes4 < n; i += lanes4) { - const auto mask = svptrue_b32(); - const auto ai0 = svld1_f32(mask, a + i); - const auto ai1 = svld1_f32(mask, a + i + lanes); - const auto ai2 = svld1_f32(mask, a + i + lanes2); - const auto ai3 = svld1_f32(mask, a + i + lanes3); - const auto bi0 = svld1_f32(mask, b + i); - const auto bi1 = svld1_f32(mask, b + i + lanes); - const auto bi2 = svld1_f32(mask, b + i + lanes2); - const auto bi3 = svld1_f32(mask, b + i + lanes3); - const auto ci0 = svmla_n_f32_x(mask, ai0, bi0, bf); - const auto ci1 = svmla_n_f32_x(mask, ai1, bi1, bf); - const auto ci2 = svmla_n_f32_x(mask, ai2, bi2, bf); - const auto ci3 = svmla_n_f32_x(mask, ai3, bi3, bf); - svst1_f32(mask, c + i, ci0); - svst1_f32(mask, c + i + lanes, ci1); - svst1_f32(mask, c + i + lanes2, ci2); - svst1_f32(mask, c + i + lanes3, ci3); - } - const auto mask0 = svwhilelt_b32_u64(i, n); - const auto mask1 = svwhilelt_b32_u64(i + lanes, n); - const auto mask2 = svwhilelt_b32_u64(i + lanes2, n); - const auto mask3 = svwhilelt_b32_u64(i + lanes3, n); - const auto ai0 = svld1_f32(mask0, a + i); - const auto ai1 = svld1_f32(mask1, a + i + lanes); - const auto ai2 = svld1_f32(mask2, a + i + lanes2); - const auto ai3 = svld1_f32(mask3, a + i + lanes3); - const auto bi0 = svld1_f32(mask0, b + i); - const auto bi1 = svld1_f32(mask1, b + i + lanes); - const auto bi2 = svld1_f32(mask2, b + i + lanes2); - const auto bi3 = svld1_f32(mask3, b + i + lanes3); - const auto ci0 = svmla_n_f32_x(mask0, ai0, bi0, bf); - const auto ci1 = svmla_n_f32_x(mask1, ai1, bi1, bf); - const auto ci2 = svmla_n_f32_x(mask2, ai2, bi2, bf); - const auto ci3 = svmla_n_f32_x(mask3, ai3, bi3, bf); - svst1_f32(mask0, c + i, ci0); - svst1_f32(mask1, c + i + lanes, ci1); - svst1_f32(mask2, c + i + lanes2, ci2); - svst1_f32(mask3, c + i + lanes3, ci3); -} - -#elif defined(__aarch64__) - -void fvec_madd(size_t n, const float* a, float bf, const float* b, float* c) { - const size_t n_simd = n - (n & 3); - const float32x4_t bfv = vdupq_n_f32(bf); - size_t i; - for (i = 0; i < n_simd; i += 4) { - const float32x4_t ai = vld1q_f32(a + i); - const float32x4_t bi = vld1q_f32(b + i); - const float32x4_t ci = vfmaq_f32(ai, bfv, bi); - vst1q_f32(c + i, ci); - } - for (; i < n; ++i) - c[i] = a[i] + bf * b[i]; -} - -#else - -void fvec_madd(size_t n, const float* a, float bf, const float* b, float* c) { - fvec_madd_ref(n, a, bf, b, c); -} - -#endif - -static inline int fvec_madd_and_argmin_ref( - size_t n, - const float* a, - float bf, - const float* b, - float* c) { - float vmin = 1e20; - int imin = -1; - - for (size_t i = 0; i < n; i++) { - c[i] = a[i] + bf * b[i]; - if (c[i] < vmin) { - vmin = c[i]; - imin = i; - } - } - return imin; -} - -#ifdef __SSE3__ - -static inline int fvec_madd_and_argmin_sse( - size_t n, - const float* a, - float bf, - const float* b, - float* c) { - n >>= 2; - __m128 bf4 = _mm_set_ps1(bf); - __m128 vmin4 = _mm_set_ps1(1e20); - __m128i imin4 = _mm_set1_epi32(-1); - __m128i idx4 = _mm_set_epi32(3, 2, 1, 0); - __m128i inc4 = _mm_set1_epi32(4); - __m128* a4 = (__m128*)a; - __m128* b4 = (__m128*)b; - __m128* c4 = (__m128*)c; - - while (n--) { - __m128 vc4 = _mm_add_ps(*a4, _mm_mul_ps(bf4, *b4)); - *c4 = vc4; - __m128i mask = _mm_castps_si128(_mm_cmpgt_ps(vmin4, vc4)); - // imin4 = _mm_blendv_epi8 (imin4, idx4, mask); // slower! - - imin4 = _mm_or_si128( - _mm_and_si128(mask, idx4), _mm_andnot_si128(mask, imin4)); - vmin4 = _mm_min_ps(vmin4, vc4); - b4++; - a4++; - c4++; - idx4 = _mm_add_epi32(idx4, inc4); - } - - // 4 values -> 2 - { - idx4 = _mm_shuffle_epi32(imin4, 3 << 2 | 2); - __m128 vc4 = _mm_shuffle_ps(vmin4, vmin4, 3 << 2 | 2); - __m128i mask = _mm_castps_si128(_mm_cmpgt_ps(vmin4, vc4)); - imin4 = _mm_or_si128( - _mm_and_si128(mask, idx4), _mm_andnot_si128(mask, imin4)); - vmin4 = _mm_min_ps(vmin4, vc4); - } - // 2 values -> 1 - { - idx4 = _mm_shuffle_epi32(imin4, 1); - __m128 vc4 = _mm_shuffle_ps(vmin4, vmin4, 1); - __m128i mask = _mm_castps_si128(_mm_cmpgt_ps(vmin4, vc4)); - imin4 = _mm_or_si128( - _mm_and_si128(mask, idx4), _mm_andnot_si128(mask, imin4)); - // vmin4 = _mm_min_ps (vmin4, vc4); - } - return _mm_cvtsi128_si32(imin4); -} - -int fvec_madd_and_argmin( - size_t n, - const float* a, - float bf, - const float* b, - float* c) { - if ((n & 3) == 0 && ((((long)a) | ((long)b) | ((long)c)) & 15) == 0) { - return fvec_madd_and_argmin_sse(n, a, bf, b, c); - } else { - return fvec_madd_and_argmin_ref(n, a, bf, b, c); - } -} - -#elif defined(__aarch64__) - -int fvec_madd_and_argmin( - size_t n, - const float* a, - float bf, - const float* b, - float* c) { - float32x4_t vminv = vdupq_n_f32(1e20); - uint32x4_t iminv = vdupq_n_u32(static_cast(-1)); - size_t i; - { - const size_t n_simd = n - (n & 3); - const uint32_t iota[] = {0, 1, 2, 3}; - uint32x4_t iv = vld1q_u32(iota); - const uint32x4_t incv = vdupq_n_u32(4); - const float32x4_t bfv = vdupq_n_f32(bf); - for (i = 0; i < n_simd; i += 4) { - const float32x4_t ai = vld1q_f32(a + i); - const float32x4_t bi = vld1q_f32(b + i); - const float32x4_t ci = vfmaq_f32(ai, bfv, bi); - vst1q_f32(c + i, ci); - const uint32x4_t less_than = vcltq_f32(ci, vminv); - vminv = vminq_f32(ci, vminv); - iminv = vorrq_u32( - vandq_u32(less_than, iv), - vandq_u32(vmvnq_u32(less_than), iminv)); - iv = vaddq_u32(iv, incv); - } - } - float vmin = vminvq_f32(vminv); - uint32_t imin; - { - const float32x4_t vminy = vdupq_n_f32(vmin); - const uint32x4_t equals = vceqq_f32(vminv, vminy); - imin = vminvq_u32(vorrq_u32( - vandq_u32(equals, iminv), - vandq_u32( - vmvnq_u32(equals), - vdupq_n_u32(std::numeric_limits::max())))); - } - for (; i < n; ++i) { - c[i] = a[i] + bf * b[i]; - if (c[i] < vmin) { - vmin = c[i]; - imin = static_cast(i); - } - } - return static_cast(imin); -} - -#else +void fvec_madd(size_t n, const float* a, float bf, const float* b, float* c) { + DISPATCH_SIMDLevel(fvec_madd, n, a, bf, b, c); +} int fvec_madd_and_argmin( size_t n, @@ -3603,11 +304,9 @@ int fvec_madd_and_argmin( float bf, const float* b, float* c) { - return fvec_madd_and_argmin_ref(n, a, bf, b, c); + DISPATCH_SIMDLevel(fvec_madd_and_argmin, n, a, bf, b, c); } -#endif - /*************************************************************************** * PQ tables computations ***************************************************************************/ diff --git a/faiss/utils/extra_distances-inl.h b/faiss/utils/extra_distances-inl.h index 6a374ed518..066ba55590 100644 --- a/faiss/utils/extra_distances-inl.h +++ b/faiss/utils/extra_distances-inl.h @@ -59,13 +59,6 @@ inline float VectorDistance::operator()( const float* x, const float* y) const { return fvec_Linf(x, y, d); - /* - float vmax = 0; - for (size_t i = 0; i < d; i++) { - float diff = fabs (x[i] - y[i]); - if (diff > vmax) vmax = diff; - } - return vmax;*/ } template <> diff --git a/faiss/utils/simd_impl/distances_aarch64.cpp b/faiss/utils/simd_impl/distances_aarch64.cpp new file mode 100644 index 0000000000..33ad9bbc4f --- /dev/null +++ b/faiss/utils/simd_impl/distances_aarch64.cpp @@ -0,0 +1,137 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#define AUTOVEC_LEVEL SIMDLevel::ARM_NEON +#include + +namespace faiss { + +template <> +void fvec_madd( + size_t n, + const float* a, + float bf, + const float* b, + float* c) { + const size_t n_simd = n - (n & 3); + const float32x4_t bfv = vdupq_n_f32(bf); + size_t i; + for (i = 0; i < n_simd; i += 4) { + const float32x4_t ai = vld1q_f32(a + i); + const float32x4_t bi = vld1q_f32(b + i); + const float32x4_t ci = vfmaq_f32(ai, bfv, bi); + vst1q_f32(c + i, ci); + } + for (; i < n; ++i) + c[i] = a[i] + bf * b[i]; +} + +template <> +void fvec_L2sqr_ny_transposed( + float* dis, + const float* x, + const float* y, + const float* y_sqlen, + size_t d, + size_t d_offset, + size_t ny); + +template <> +void fvec_inner_products_ny( + float* dis, + const float* x, + const float* y, + size_t d, + size_t ny) { + fvec_inner_products_ny(dis, x, y, d, ny); +} + +template <> +void fvec_L2sqr_ny( + float* dis, + const float* x, + const float* y, + size_t d, + size_t ny) { + fvec_L2sqr_ny(dis, x, y, d, ny); +} + +template <> +size_t fvec_L2sqr_ny_nearest( + float* distances_tmp_buffer, + const float* x, + const float* y, + size_t d, + size_t ny) { + fvec_L2sqr_ny_nearest(distances_tmp_buffer, x, y, d, ny); +} + +size_t fvec_L2sqr_ny_nearest_y_transposed( + float* distances_tmp_buffer, + const float* x, + const float* y, + const float* y_sqlen, + size_t d, + size_t d_offset, + size_t ny) { + return fvec_L2sqr_ny_nearest_y_transposed_ref( + distances_tmp_buffer, x, y, y_sqlen, d, d_offset, ny); +} + +template <> +int fvec_madd_and_argmin( + size_t n, + const float* a, + float bf, + const float* b, + float* c) { + float32x4_t vminv = vdupq_n_f32(1e20); + uint32x4_t iminv = vdupq_n_u32(static_cast(-1)); + size_t i; + { + const size_t n_simd = n - (n & 3); + const uint32_t iota[] = {0, 1, 2, 3}; + uint32x4_t iv = vld1q_u32(iota); + const uint32x4_t incv = vdupq_n_u32(4); + const float32x4_t bfv = vdupq_n_f32(bf); + for (i = 0; i < n_simd; i += 4) { + const float32x4_t ai = vld1q_f32(a + i); + const float32x4_t bi = vld1q_f32(b + i); + const float32x4_t ci = vfmaq_f32(ai, bfv, bi); + vst1q_f32(c + i, ci); + const uint32x4_t less_than = vcltq_f32(ci, vminv); + vminv = vminq_f32(ci, vminv); + iminv = vorrq_u32( + vandq_u32(less_than, iv), + vandq_u32(vmvnq_u32(less_than), iminv)); + iv = vaddq_u32(iv, incv); + } + } + float vmin = vminvq_f32(vminv); + uint32_t imin; + { + const float32x4_t vminy = vdupq_n_f32(vmin); + const uint32x4_t equals = vceqq_f32(vminv, vminy); + imin = vminvq_u32(vorrq_u32( + vandq_u32(equals, iminv), + vandq_u32( + vmvnq_u32(equals), + vdupq_n_u32(std::numeric_limits::max())))); + } + for (; i < n; ++i) { + c[i] = a[i] + bf * b[i]; + if (c[i] < vmin) { + vmin = c[i]; + imin = static_cast(i); + } + } + return static_cast(imin); +} + +} // namespace faiss diff --git a/faiss/utils/simd_impl/distances_arm_sve.cpp b/faiss/utils/simd_impl/distances_arm_sve.cpp new file mode 100644 index 0000000000..3bd4227da0 --- /dev/null +++ b/faiss/utils/simd_impl/distances_arm_sve.cpp @@ -0,0 +1,496 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#define AUTOVEC_LEVEL SIMDLevel::ARM_SVE +#include + +namespace faiss { + +template <> +void fvec_madd( + const size_t n, + const float* __restrict a, + const float bf, + const float* __restrict b, + float* __restrict c) { + const size_t lanes = static_cast(svcntw()); + const size_t lanes2 = lanes * 2; + const size_t lanes3 = lanes * 3; + const size_t lanes4 = lanes * 4; + size_t i = 0; + for (; i + lanes4 < n; i += lanes4) { + const auto mask = svptrue_b32(); + const auto ai0 = svld1_f32(mask, a + i); + const auto ai1 = svld1_f32(mask, a + i + lanes); + const auto ai2 = svld1_f32(mask, a + i + lanes2); + const auto ai3 = svld1_f32(mask, a + i + lanes3); + const auto bi0 = svld1_f32(mask, b + i); + const auto bi1 = svld1_f32(mask, b + i + lanes); + const auto bi2 = svld1_f32(mask, b + i + lanes2); + const auto bi3 = svld1_f32(mask, b + i + lanes3); + const auto ci0 = svmla_n_f32_x(mask, ai0, bi0, bf); + const auto ci1 = svmla_n_f32_x(mask, ai1, bi1, bf); + const auto ci2 = svmla_n_f32_x(mask, ai2, bi2, bf); + const auto ci3 = svmla_n_f32_x(mask, ai3, bi3, bf); + svst1_f32(mask, c + i, ci0); + svst1_f32(mask, c + i + lanes, ci1); + svst1_f32(mask, c + i + lanes2, ci2); + svst1_f32(mask, c + i + lanes3, ci3); + } + const auto mask0 = svwhilelt_b32_u64(i, n); + const auto mask1 = svwhilelt_b32_u64(i + lanes, n); + const auto mask2 = svwhilelt_b32_u64(i + lanes2, n); + const auto mask3 = svwhilelt_b32_u64(i + lanes3, n); + const auto ai0 = svld1_f32(mask0, a + i); + const auto ai1 = svld1_f32(mask1, a + i + lanes); + const auto ai2 = svld1_f32(mask2, a + i + lanes2); + const auto ai3 = svld1_f32(mask3, a + i + lanes3); + const auto bi0 = svld1_f32(mask0, b + i); + const auto bi1 = svld1_f32(mask1, b + i + lanes); + const auto bi2 = svld1_f32(mask2, b + i + lanes2); + const auto bi3 = svld1_f32(mask3, b + i + lanes3); + const auto ci0 = svmla_n_f32_x(mask0, ai0, bi0, bf); + const auto ci1 = svmla_n_f32_x(mask1, ai1, bi1, bf); + const auto ci2 = svmla_n_f32_x(mask2, ai2, bi2, bf); + const auto ci3 = svmla_n_f32_x(mask3, ai3, bi3, bf); + svst1_f32(mask0, c + i, ci0); + svst1_f32(mask1, c + i + lanes, ci1); + svst1_f32(mask2, c + i + lanes2, ci2); + svst1_f32(mask3, c + i + lanes3, ci3); +} + +template <> +void fvec_L2sqr_ny_transposed( + float* dis, + const float* x, + const float* y, + const float* y_sqlen, + size_t d, + size_t d_offset, + size_t ny); + +struct ElementOpIP { + static svfloat32_t op(svbool_t pg, svfloat32_t x, svfloat32_t y) { + return svmul_f32_x(pg, x, y); + } + static svfloat32_t merge( + svbool_t pg, + svfloat32_t z, + svfloat32_t x, + svfloat32_t y) { + return svmla_f32_x(pg, z, x, y); + } +}; + +template +void fvec_op_ny_sve_d1(float* dis, const float* x, const float* y, size_t ny) { + const size_t lanes = svcntw(); + const size_t lanes2 = lanes * 2; + const size_t lanes3 = lanes * 3; + const size_t lanes4 = lanes * 4; + const svbool_t pg = svptrue_b32(); + const svfloat32_t x0 = svdup_n_f32(x[0]); + size_t i = 0; + for (; i + lanes4 < ny; i += lanes4) { + svfloat32_t y0 = svld1_f32(pg, y); + svfloat32_t y1 = svld1_f32(pg, y + lanes); + svfloat32_t y2 = svld1_f32(pg, y + lanes2); + svfloat32_t y3 = svld1_f32(pg, y + lanes3); + y0 = ElementOp::op(pg, x0, y0); + y1 = ElementOp::op(pg, x0, y1); + y2 = ElementOp::op(pg, x0, y2); + y3 = ElementOp::op(pg, x0, y3); + svst1_f32(pg, dis, y0); + svst1_f32(pg, dis + lanes, y1); + svst1_f32(pg, dis + lanes2, y2); + svst1_f32(pg, dis + lanes3, y3); + y += lanes4; + dis += lanes4; + } + const svbool_t pg0 = svwhilelt_b32_u64(i, ny); + const svbool_t pg1 = svwhilelt_b32_u64(i + lanes, ny); + const svbool_t pg2 = svwhilelt_b32_u64(i + lanes2, ny); + const svbool_t pg3 = svwhilelt_b32_u64(i + lanes3, ny); + svfloat32_t y0 = svld1_f32(pg0, y); + svfloat32_t y1 = svld1_f32(pg1, y + lanes); + svfloat32_t y2 = svld1_f32(pg2, y + lanes2); + svfloat32_t y3 = svld1_f32(pg3, y + lanes3); + y0 = ElementOp::op(pg0, x0, y0); + y1 = ElementOp::op(pg1, x0, y1); + y2 = ElementOp::op(pg2, x0, y2); + y3 = ElementOp::op(pg3, x0, y3); + svst1_f32(pg0, dis, y0); + svst1_f32(pg1, dis + lanes, y1); + svst1_f32(pg2, dis + lanes2, y2); + svst1_f32(pg3, dis + lanes3, y3); +} + +template +void fvec_op_ny_sve_d2(float* dis, const float* x, const float* y, size_t ny) { + const size_t lanes = svcntw(); + const size_t lanes2 = lanes * 2; + const size_t lanes4 = lanes * 4; + const svbool_t pg = svptrue_b32(); + const svfloat32_t x0 = svdup_n_f32(x[0]); + const svfloat32_t x1 = svdup_n_f32(x[1]); + size_t i = 0; + for (; i + lanes2 < ny; i += lanes2) { + const svfloat32x2_t y0 = svld2_f32(pg, y); + const svfloat32x2_t y1 = svld2_f32(pg, y + lanes2); + svfloat32_t y00 = svget2_f32(y0, 0); + const svfloat32_t y01 = svget2_f32(y0, 1); + svfloat32_t y10 = svget2_f32(y1, 0); + const svfloat32_t y11 = svget2_f32(y1, 1); + y00 = ElementOp::op(pg, x0, y00); + y10 = ElementOp::op(pg, x0, y10); + y00 = ElementOp::merge(pg, y00, x1, y01); + y10 = ElementOp::merge(pg, y10, x1, y11); + svst1_f32(pg, dis, y00); + svst1_f32(pg, dis + lanes, y10); + y += lanes4; + dis += lanes2; + } + const svbool_t pg0 = svwhilelt_b32_u64(i, ny); + const svbool_t pg1 = svwhilelt_b32_u64(i + lanes, ny); + const svfloat32x2_t y0 = svld2_f32(pg0, y); + const svfloat32x2_t y1 = svld2_f32(pg1, y + lanes2); + svfloat32_t y00 = svget2_f32(y0, 0); + const svfloat32_t y01 = svget2_f32(y0, 1); + svfloat32_t y10 = svget2_f32(y1, 0); + const svfloat32_t y11 = svget2_f32(y1, 1); + y00 = ElementOp::op(pg0, x0, y00); + y10 = ElementOp::op(pg1, x0, y10); + y00 = ElementOp::merge(pg0, y00, x1, y01); + y10 = ElementOp::merge(pg1, y10, x1, y11); + svst1_f32(pg0, dis, y00); + svst1_f32(pg1, dis + lanes, y10); +} + +template +void fvec_op_ny_sve_d4(float* dis, const float* x, const float* y, size_t ny) { + const size_t lanes = svcntw(); + const size_t lanes4 = lanes * 4; + const svbool_t pg = svptrue_b32(); + const svfloat32_t x0 = svdup_n_f32(x[0]); + const svfloat32_t x1 = svdup_n_f32(x[1]); + const svfloat32_t x2 = svdup_n_f32(x[2]); + const svfloat32_t x3 = svdup_n_f32(x[3]); + size_t i = 0; + for (; i + lanes < ny; i += lanes) { + const svfloat32x4_t y0 = svld4_f32(pg, y); + svfloat32_t y00 = svget4_f32(y0, 0); + const svfloat32_t y01 = svget4_f32(y0, 1); + svfloat32_t y02 = svget4_f32(y0, 2); + const svfloat32_t y03 = svget4_f32(y0, 3); + y00 = ElementOp::op(pg, x0, y00); + y02 = ElementOp::op(pg, x2, y02); + y00 = ElementOp::merge(pg, y00, x1, y01); + y02 = ElementOp::merge(pg, y02, x3, y03); + y00 = svadd_f32_x(pg, y00, y02); + svst1_f32(pg, dis, y00); + y += lanes4; + dis += lanes; + } + const svbool_t pg0 = svwhilelt_b32_u64(i, ny); + const svfloat32x4_t y0 = svld4_f32(pg0, y); + svfloat32_t y00 = svget4_f32(y0, 0); + const svfloat32_t y01 = svget4_f32(y0, 1); + svfloat32_t y02 = svget4_f32(y0, 2); + const svfloat32_t y03 = svget4_f32(y0, 3); + y00 = ElementOp::op(pg0, x0, y00); + y02 = ElementOp::op(pg0, x2, y02); + y00 = ElementOp::merge(pg0, y00, x1, y01); + y02 = ElementOp::merge(pg0, y02, x3, y03); + y00 = svadd_f32_x(pg0, y00, y02); + svst1_f32(pg0, dis, y00); +} + +template +void fvec_op_ny_sve_d8(float* dis, const float* x, const float* y, size_t ny) { + const size_t lanes = svcntw(); + const size_t lanes4 = lanes * 4; + const size_t lanes8 = lanes * 8; + const svbool_t pg = svptrue_b32(); + const svfloat32_t x0 = svdup_n_f32(x[0]); + const svfloat32_t x1 = svdup_n_f32(x[1]); + const svfloat32_t x2 = svdup_n_f32(x[2]); + const svfloat32_t x3 = svdup_n_f32(x[3]); + const svfloat32_t x4 = svdup_n_f32(x[4]); + const svfloat32_t x5 = svdup_n_f32(x[5]); + const svfloat32_t x6 = svdup_n_f32(x[6]); + const svfloat32_t x7 = svdup_n_f32(x[7]); + size_t i = 0; + for (; i + lanes < ny; i += lanes) { + const svfloat32x4_t ya = svld4_f32(pg, y); + const svfloat32x4_t yb = svld4_f32(pg, y + lanes4); + const svfloat32_t ya0 = svget4_f32(ya, 0); + const svfloat32_t ya1 = svget4_f32(ya, 1); + const svfloat32_t ya2 = svget4_f32(ya, 2); + const svfloat32_t ya3 = svget4_f32(ya, 3); + const svfloat32_t yb0 = svget4_f32(yb, 0); + const svfloat32_t yb1 = svget4_f32(yb, 1); + const svfloat32_t yb2 = svget4_f32(yb, 2); + const svfloat32_t yb3 = svget4_f32(yb, 3); + svfloat32_t y0 = svuzp1(ya0, yb0); + const svfloat32_t y1 = svuzp1(ya1, yb1); + svfloat32_t y2 = svuzp1(ya2, yb2); + const svfloat32_t y3 = svuzp1(ya3, yb3); + svfloat32_t y4 = svuzp2(ya0, yb0); + const svfloat32_t y5 = svuzp2(ya1, yb1); + svfloat32_t y6 = svuzp2(ya2, yb2); + const svfloat32_t y7 = svuzp2(ya3, yb3); + y0 = ElementOp::op(pg, x0, y0); + y2 = ElementOp::op(pg, x2, y2); + y4 = ElementOp::op(pg, x4, y4); + y6 = ElementOp::op(pg, x6, y6); + y0 = ElementOp::merge(pg, y0, x1, y1); + y2 = ElementOp::merge(pg, y2, x3, y3); + y4 = ElementOp::merge(pg, y4, x5, y5); + y6 = ElementOp::merge(pg, y6, x7, y7); + y0 = svadd_f32_x(pg, y0, y2); + y4 = svadd_f32_x(pg, y4, y6); + y0 = svadd_f32_x(pg, y0, y4); + svst1_f32(pg, dis, y0); + y += lanes8; + dis += lanes; + } + const svbool_t pg0 = svwhilelt_b32_u64(i, ny); + const svbool_t pga = svwhilelt_b32_u64(i * 2, ny * 2); + const svbool_t pgb = svwhilelt_b32_u64(i * 2 + lanes, ny * 2); + const svfloat32x4_t ya = svld4_f32(pga, y); + const svfloat32x4_t yb = svld4_f32(pgb, y + lanes4); + const svfloat32_t ya0 = svget4_f32(ya, 0); + const svfloat32_t ya1 = svget4_f32(ya, 1); + const svfloat32_t ya2 = svget4_f32(ya, 2); + const svfloat32_t ya3 = svget4_f32(ya, 3); + const svfloat32_t yb0 = svget4_f32(yb, 0); + const svfloat32_t yb1 = svget4_f32(yb, 1); + const svfloat32_t yb2 = svget4_f32(yb, 2); + const svfloat32_t yb3 = svget4_f32(yb, 3); + svfloat32_t y0 = svuzp1(ya0, yb0); + const svfloat32_t y1 = svuzp1(ya1, yb1); + svfloat32_t y2 = svuzp1(ya2, yb2); + const svfloat32_t y3 = svuzp1(ya3, yb3); + svfloat32_t y4 = svuzp2(ya0, yb0); + const svfloat32_t y5 = svuzp2(ya1, yb1); + svfloat32_t y6 = svuzp2(ya2, yb2); + const svfloat32_t y7 = svuzp2(ya3, yb3); + y0 = ElementOp::op(pg0, x0, y0); + y2 = ElementOp::op(pg0, x2, y2); + y4 = ElementOp::op(pg0, x4, y4); + y6 = ElementOp::op(pg0, x6, y6); + y0 = ElementOp::merge(pg0, y0, x1, y1); + y2 = ElementOp::merge(pg0, y2, x3, y3); + y4 = ElementOp::merge(pg0, y4, x5, y5); + y6 = ElementOp::merge(pg0, y6, x7, y7); + y0 = svadd_f32_x(pg0, y0, y2); + y4 = svadd_f32_x(pg0, y4, y6); + y0 = svadd_f32_x(pg0, y0, y4); + svst1_f32(pg0, dis, y0); + y += lanes8; + dis += lanes; +} + +template +void fvec_op_ny_sve_lanes1( + float* dis, + const float* x, + const float* y, + size_t ny) { + const size_t lanes = svcntw(); + const size_t lanes2 = lanes * 2; + const size_t lanes3 = lanes * 3; + const size_t lanes4 = lanes * 4; + const svbool_t pg = svptrue_b32(); + const svfloat32_t x0 = svld1_f32(pg, x); + size_t i = 0; + for (; i + 3 < ny; i += 4) { + svfloat32_t y0 = svld1_f32(pg, y); + svfloat32_t y1 = svld1_f32(pg, y + lanes); + svfloat32_t y2 = svld1_f32(pg, y + lanes2); + svfloat32_t y3 = svld1_f32(pg, y + lanes3); + y += lanes4; + y0 = ElementOp::op(pg, x0, y0); + y1 = ElementOp::op(pg, x0, y1); + y2 = ElementOp::op(pg, x0, y2); + y3 = ElementOp::op(pg, x0, y3); + dis[i] = svaddv_f32(pg, y0); + dis[i + 1] = svaddv_f32(pg, y1); + dis[i + 2] = svaddv_f32(pg, y2); + dis[i + 3] = svaddv_f32(pg, y3); + } + for (; i < ny; ++i) { + svfloat32_t y0 = svld1_f32(pg, y); + y += lanes; + y0 = ElementOp::op(pg, x0, y0); + dis[i] = svaddv_f32(pg, y0); + } +} + +template +void fvec_op_ny_sve_lanes2( + float* dis, + const float* x, + const float* y, + size_t ny) { + const size_t lanes = svcntw(); + const size_t lanes2 = lanes * 2; + const size_t lanes3 = lanes * 3; + const size_t lanes4 = lanes * 4; + const svbool_t pg = svptrue_b32(); + const svfloat32_t x0 = svld1_f32(pg, x); + const svfloat32_t x1 = svld1_f32(pg, x + lanes); + size_t i = 0; + for (; i + 1 < ny; i += 2) { + svfloat32_t y00 = svld1_f32(pg, y); + const svfloat32_t y01 = svld1_f32(pg, y + lanes); + svfloat32_t y10 = svld1_f32(pg, y + lanes2); + const svfloat32_t y11 = svld1_f32(pg, y + lanes3); + y += lanes4; + y00 = ElementOp::op(pg, x0, y00); + y10 = ElementOp::op(pg, x0, y10); + y00 = ElementOp::merge(pg, y00, x1, y01); + y10 = ElementOp::merge(pg, y10, x1, y11); + dis[i] = svaddv_f32(pg, y00); + dis[i + 1] = svaddv_f32(pg, y10); + } + if (i < ny) { + svfloat32_t y0 = svld1_f32(pg, y); + const svfloat32_t y1 = svld1_f32(pg, y + lanes); + y0 = ElementOp::op(pg, x0, y0); + y0 = ElementOp::merge(pg, y0, x1, y1); + dis[i] = svaddv_f32(pg, y0); + } +} + +template +void fvec_op_ny_sve_lanes3( + float* dis, + const float* x, + const float* y, + size_t ny) { + const size_t lanes = svcntw(); + const size_t lanes2 = lanes * 2; + const size_t lanes3 = lanes * 3; + const svbool_t pg = svptrue_b32(); + const svfloat32_t x0 = svld1_f32(pg, x); + const svfloat32_t x1 = svld1_f32(pg, x + lanes); + const svfloat32_t x2 = svld1_f32(pg, x + lanes2); + for (size_t i = 0; i < ny; ++i) { + svfloat32_t y0 = svld1_f32(pg, y); + const svfloat32_t y1 = svld1_f32(pg, y + lanes); + svfloat32_t y2 = svld1_f32(pg, y + lanes2); + y += lanes3; + y0 = ElementOp::op(pg, x0, y0); + y0 = ElementOp::merge(pg, y0, x1, y1); + y0 = ElementOp::merge(pg, y0, x2, y2); + dis[i] = svaddv_f32(pg, y0); + } +} + +template +void fvec_op_ny_sve_lanes4( + float* dis, + const float* x, + const float* y, + size_t ny) { + const size_t lanes = svcntw(); + const size_t lanes2 = lanes * 2; + const size_t lanes3 = lanes * 3; + const size_t lanes4 = lanes * 4; + const svbool_t pg = svptrue_b32(); + const svfloat32_t x0 = svld1_f32(pg, x); + const svfloat32_t x1 = svld1_f32(pg, x + lanes); + const svfloat32_t x2 = svld1_f32(pg, x + lanes2); + const svfloat32_t x3 = svld1_f32(pg, x + lanes3); + for (size_t i = 0; i < ny; ++i) { + svfloat32_t y0 = svld1_f32(pg, y); + const svfloat32_t y1 = svld1_f32(pg, y + lanes); + svfloat32_t y2 = svld1_f32(pg, y + lanes2); + const svfloat32_t y3 = svld1_f32(pg, y + lanes3); + y += lanes4; + y0 = ElementOp::op(pg, x0, y0); + y2 = ElementOp::op(pg, x2, y2); + y0 = ElementOp::merge(pg, y0, x1, y1); + y2 = ElementOp::merge(pg, y2, x3, y3); + y0 = svadd_f32_x(pg, y0, y2); + dis[i] = svaddv_f32(pg, y0); + } +} + +template <> +void fvec_inner_products_ny( + float* dis, + const float* x, + const float* y, + size_t d, + size_t ny) { + const size_t lanes = svcntw(); + switch (d) { + case 1: + fvec_op_ny_sve_d1(dis, x, y, ny); + break; + case 2: + fvec_op_ny_sve_d2(dis, x, y, ny); + break; + case 4: + fvec_op_ny_sve_d4(dis, x, y, ny); + break; + case 8: + fvec_op_ny_sve_d8(dis, x, y, ny); + break; + default: + if (d == lanes) + fvec_op_ny_sve_lanes1(dis, x, y, ny); + else if (d == lanes * 2) + fvec_op_ny_sve_lanes2(dis, x, y, ny); + else if (d == lanes * 3) + fvec_op_ny_sve_lanes3(dis, x, y, ny); + else if (d == lanes * 4) + fvec_op_ny_sve_lanes4(dis, x, y, ny); + else + fvec_inner_products_ny(dis, x, y, d, ny); + break; + } +} + +template <> +void fvec_L2sqr_ny( + float* dis, + const float* x, + const float* y, + size_t d, + size_t ny) { + fvec_L2sqr_ny(dis, x, y, d, ny); +} + +template <> +size_t fvec_L2sqr_ny_nearest( + float* distances_tmp_buffer, + const float* x, + const float* y, + size_t d, + size_t ny) { + fvec_L2sqr_ny_nearest( + distances_tmp_buffer, x, y, d, ny); +} + +size_t fvec_L2sqr_ny_nearest_y_transposed( + float* distances_tmp_buffer, + const float* x, + const float* y, + const float* y_sqlen, + size_t d, + size_t d_offset, + size_t ny) { + return fvec_L2sqr_ny_nearest_y_transposed( + distances_tmp_buffer, x, y, y_sqlen, d, d_offset, ny); +} + +} // namespace faiss diff --git a/faiss/utils/simd_impl/distances_autovec-inl.h b/faiss/utils/simd_impl/distances_autovec-inl.h new file mode 100644 index 0000000000..62d13eb38e --- /dev/null +++ b/faiss/utils/simd_impl/distances_autovec-inl.h @@ -0,0 +1,153 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace faiss { + +FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN +template <> +float fvec_norm_L2sqr(const float* x, size_t d) { + // the double in the _ref is suspected to be a typo. Some of the manual + // implementations this replaces used float. + float res = 0; + FAISS_PRAGMA_IMPRECISE_LOOP + for (size_t i = 0; i != d; ++i) { + res += x[i] * x[i]; + } + + return res; +} +FAISS_PRAGMA_IMPRECISE_FUNCTION_END + +FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN +template <> +float fvec_L2sqr(const float* x, const float* y, size_t d) { + size_t i; + float res = 0; + FAISS_PRAGMA_IMPRECISE_LOOP + for (i = 0; i < d; i++) { + const float tmp = x[i] - y[i]; + res += tmp * tmp; + } + return res; +} +FAISS_PRAGMA_IMPRECISE_FUNCTION_END + +FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN +template <> +float fvec_inner_product( + const float* x, + const float* y, + size_t d) { + float res = 0.F; + FAISS_PRAGMA_IMPRECISE_LOOP + for (size_t i = 0; i != d; ++i) { + res += x[i] * y[i]; + } + return res; +} +FAISS_PRAGMA_IMPRECISE_FUNCTION_END + +FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN +template <> +float fvec_L1(const float* x, const float* y, size_t d) { + size_t i; + float res = 0; + FAISS_PRAGMA_IMPRECISE_LOOP + for (i = 0; i < d; i++) { + const float tmp = x[i] - y[i]; + res += fabs(tmp); + } + return res; +} +FAISS_PRAGMA_IMPRECISE_FUNCTION_END + +FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN +template <> +float fvec_Linf(const float* x, const float* y, size_t d) { + float res = 0; + FAISS_PRAGMA_IMPRECISE_LOOP + for (size_t i = 0; i < d; i++) { + res = fmax(res, fabs(x[i] - y[i])); + } + return res; +} +FAISS_PRAGMA_IMPRECISE_FUNCTION_END + +FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN +template <> +void fvec_inner_product_batch_4( + const float* __restrict x, + const float* __restrict y0, + const float* __restrict y1, + const float* __restrict y2, + const float* __restrict y3, + const size_t d, + float& dis0, + float& dis1, + float& dis2, + float& dis3) { + float d0 = 0; + float d1 = 0; + float d2 = 0; + float d3 = 0; + FAISS_PRAGMA_IMPRECISE_LOOP + for (size_t i = 0; i < d; ++i) { + d0 += x[i] * y0[i]; + d1 += x[i] * y1[i]; + d2 += x[i] * y2[i]; + d3 += x[i] * y3[i]; + } + + dis0 = d0; + dis1 = d1; + dis2 = d2; + dis3 = d3; +} +FAISS_PRAGMA_IMPRECISE_FUNCTION_END + +FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN +template <> +void fvec_L2sqr_batch_4( + const float* x, + const float* y0, + const float* y1, + const float* y2, + const float* y3, + const size_t d, + float& dis0, + float& dis1, + float& dis2, + float& dis3) { + float d0 = 0; + float d1 = 0; + float d2 = 0; + float d3 = 0; + FAISS_PRAGMA_IMPRECISE_LOOP + for (size_t i = 0; i < d; ++i) { + const float q0 = x[i] - y0[i]; + const float q1 = x[i] - y1[i]; + const float q2 = x[i] - y2[i]; + const float q3 = x[i] - y3[i]; + d0 += q0 * q0; + d1 += q1 * q1; + d2 += q2 * q2; + d3 += q3 * q3; + } + + dis0 = d0; + dis1 = d1; + dis2 = d2; + dis3 = d3; +} +FAISS_PRAGMA_IMPRECISE_FUNCTION_END + +} // namespace faiss diff --git a/faiss/utils/simd_impl/distances_avx.cpp b/faiss/utils/simd_impl/distances_avx.cpp new file mode 100644 index 0000000000..c29e64c91f --- /dev/null +++ b/faiss/utils/simd_impl/distances_avx.cpp @@ -0,0 +1,99 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#ifdef __AVX__ + +float fvec_L1(const float* x, const float* y, size_t d) { + __m256 msum1 = _mm256_setzero_ps(); + // signmask used for absolute value + __m256 signmask = _mm256_castsi256_ps(_mm256_set1_epi32(0x7fffffffUL)); + + while (d >= 8) { + __m256 mx = _mm256_loadu_ps(x); + x += 8; + __m256 my = _mm256_loadu_ps(y); + y += 8; + // subtract + const __m256 a_m_b = _mm256_sub_ps(mx, my); + // find sum of absolute value of distances (manhattan distance) + msum1 = _mm256_add_ps(msum1, _mm256_and_ps(signmask, a_m_b)); + d -= 8; + } + + __m128 msum2 = _mm256_extractf128_ps(msum1, 1); + msum2 = _mm_add_ps(msum2, _mm256_extractf128_ps(msum1, 0)); + __m128 signmask2 = _mm_castsi128_ps(_mm_set1_epi32(0x7fffffffUL)); + + if (d >= 4) { + __m128 mx = _mm_loadu_ps(x); + x += 4; + __m128 my = _mm_loadu_ps(y); + y += 4; + const __m128 a_m_b = _mm_sub_ps(mx, my); + msum2 = _mm_add_ps(msum2, _mm_and_ps(signmask2, a_m_b)); + d -= 4; + } + + if (d > 0) { + __m128 mx = masked_read(d, x); + __m128 my = masked_read(d, y); + __m128 a_m_b = _mm_sub_ps(mx, my); + msum2 = _mm_add_ps(msum2, _mm_and_ps(signmask2, a_m_b)); + } + + msum2 = _mm_hadd_ps(msum2, msum2); + msum2 = _mm_hadd_ps(msum2, msum2); + return _mm_cvtss_f32(msum2); +} + +float fvec_Linf(const float* x, const float* y, size_t d) { + __m256 msum1 = _mm256_setzero_ps(); + // signmask used for absolute value + __m256 signmask = _mm256_castsi256_ps(_mm256_set1_epi32(0x7fffffffUL)); + + while (d >= 8) { + __m256 mx = _mm256_loadu_ps(x); + x += 8; + __m256 my = _mm256_loadu_ps(y); + y += 8; + // subtract + const __m256 a_m_b = _mm256_sub_ps(mx, my); + // find max of absolute value of distances (chebyshev distance) + msum1 = _mm256_max_ps(msum1, _mm256_and_ps(signmask, a_m_b)); + d -= 8; + } + + __m128 msum2 = _mm256_extractf128_ps(msum1, 1); + msum2 = _mm_max_ps(msum2, _mm256_extractf128_ps(msum1, 0)); + __m128 signmask2 = _mm_castsi128_ps(_mm_set1_epi32(0x7fffffffUL)); + + if (d >= 4) { + __m128 mx = _mm_loadu_ps(x); + x += 4; + __m128 my = _mm_loadu_ps(y); + y += 4; + const __m128 a_m_b = _mm_sub_ps(mx, my); + msum2 = _mm_max_ps(msum2, _mm_and_ps(signmask2, a_m_b)); + d -= 4; + } + + if (d > 0) { + __m128 mx = masked_read(d, x); + __m128 my = masked_read(d, y); + __m128 a_m_b = _mm_sub_ps(mx, my); + msum2 = _mm_max_ps(msum2, _mm_and_ps(signmask2, a_m_b)); + } + + msum2 = _mm_max_ps(_mm_movehl_ps(msum2, msum2), msum2); + msum2 = _mm_max_ps(msum2, _mm_shuffle_ps(msum2, msum2, 1)); + return _mm_cvtss_f32(msum2); +} + +#endif diff --git a/faiss/utils/simd_impl/distances_avx2.cpp b/faiss/utils/simd_impl/distances_avx2.cpp new file mode 100644 index 0000000000..acfcbabe17 --- /dev/null +++ b/faiss/utils/simd_impl/distances_avx2.cpp @@ -0,0 +1,1178 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#define AUTOVEC_LEVEL SIMDLevel::AVX2 +#include + +#include +#include + +namespace faiss { + +template <> +void fvec_madd( + const size_t n, + const float* __restrict a, + const float bf, + const float* __restrict b, + float* __restrict c) { + // + const size_t n8 = n / 8; + const size_t n_for_masking = n % 8; + + const __m256 bfmm = _mm256_set1_ps(bf); + + size_t idx = 0; + for (idx = 0; idx < n8 * 8; idx += 8) { + const __m256 ax = _mm256_loadu_ps(a + idx); + const __m256 bx = _mm256_loadu_ps(b + idx); + const __m256 abmul = _mm256_fmadd_ps(bfmm, bx, ax); + _mm256_storeu_ps(c + idx, abmul); + } + + if (n_for_masking > 0) { + __m256i mask; + switch (n_for_masking) { + case 1: + mask = _mm256_set_epi32(0, 0, 0, 0, 0, 0, 0, -1); + break; + case 2: + mask = _mm256_set_epi32(0, 0, 0, 0, 0, 0, -1, -1); + break; + case 3: + mask = _mm256_set_epi32(0, 0, 0, 0, 0, -1, -1, -1); + break; + case 4: + mask = _mm256_set_epi32(0, 0, 0, 0, -1, -1, -1, -1); + break; + case 5: + mask = _mm256_set_epi32(0, 0, 0, -1, -1, -1, -1, -1); + break; + case 6: + mask = _mm256_set_epi32(0, 0, -1, -1, -1, -1, -1, -1); + break; + case 7: + mask = _mm256_set_epi32(0, -1, -1, -1, -1, -1, -1, -1); + break; + } + + const __m256 ax = _mm256_maskload_ps(a + idx, mask); + const __m256 bx = _mm256_maskload_ps(b + idx, mask); + const __m256 abmul = _mm256_fmadd_ps(bfmm, bx, ax); + _mm256_maskstore_ps(c + idx, mask, abmul); + } +} + +template +void fvec_L2sqr_ny_y_transposed_D( + float* distances, + const float* x, + const float* y, + const float* y_sqlen, + const size_t d_offset, + size_t ny) { + // current index being processed + size_t i = 0; + + // squared length of x + float x_sqlen = 0; + for (size_t j = 0; j < DIM; j++) { + x_sqlen += x[j] * x[j]; + } + + // process 8 vectors per loop. + const size_t ny8 = ny / 8; + + if (ny8 > 0) { + // m[i] = (2 * x[i], ... 2 * x[i]) + __m256 m[DIM]; + for (size_t j = 0; j < DIM; j++) { + m[j] = _mm256_set1_ps(x[j]); + m[j] = _mm256_add_ps(m[j], m[j]); + } + + __m256 x_sqlen_ymm = _mm256_set1_ps(x_sqlen); + + for (; i < ny8 * 8; i += 8) { + // collect dim 0 for 8 D4-vectors. + const __m256 v0 = _mm256_loadu_ps(y + 0 * d_offset); + + // compute dot products + // this is x^2 - 2x[0]*y[0] + __m256 dp = _mm256_fnmadd_ps(m[0], v0, x_sqlen_ymm); + + for (size_t j = 1; j < DIM; j++) { + // collect dim j for 8 D4-vectors. + const __m256 vj = _mm256_loadu_ps(y + j * d_offset); + dp = _mm256_fnmadd_ps(m[j], vj, dp); + } + + // we've got x^2 - (2x, y) at this point + + // y^2 - (2x, y) + x^2 + __m256 distances_v = _mm256_add_ps(_mm256_loadu_ps(y_sqlen), dp); + + _mm256_storeu_ps(distances + i, distances_v); + + // scroll y and y_sqlen forward. + y += 8; + y_sqlen += 8; + } + } + + if (i < ny) { + // process leftovers + for (; i < ny; i++) { + float dp = 0; + for (size_t j = 0; j < DIM; j++) { + dp += x[j] * y[j * d_offset]; + } + + // compute y^2 - 2 * (x, y), which is sufficient for looking for the + // lowest distance. + const float distance = y_sqlen[0] - 2 * dp + x_sqlen; + distances[i] = distance; + + y += 1; + y_sqlen += 1; + } + } +} + +template <> +void fvec_L2sqr_ny_transposed( + float* dis, + const float* x, + const float* y, + const float* y_sqlen, + size_t d, + size_t d_offset, + size_t ny) { + // optimized for a few special cases +#define DISPATCH(dval) \ + case dval: \ + return fvec_L2sqr_ny_y_transposed_D( \ + dis, x, y, y_sqlen, d_offset, ny); + + switch (d) { + DISPATCH(1) + DISPATCH(2) + DISPATCH(4) + DISPATCH(8) + default: + return fvec_L2sqr_ny_transposed( + dis, x, y, y_sqlen, d, d_offset, ny); + } +#undef DISPATCH +} + +struct AVX2ElementOpIP : public ElementOpIP { + using ElementOpIP::op; + static __m256 op(__m256 x, __m256 y) { + return _mm256_mul_ps(x, y); + } +}; + +struct AVX2ElementOpL2 : public ElementOpL2 { + using ElementOpL2::op; + + static __m256 op(__m256 x, __m256 y) { + __m256 tmp = _mm256_sub_ps(x, y); + return _mm256_mul_ps(tmp, tmp); + } +}; + +/// helper function for AVX2 +inline float horizontal_sum(const __m256 v) { + // add high and low parts + const __m128 v0 = + _mm_add_ps(_mm256_castps256_ps128(v), _mm256_extractf128_ps(v, 1)); + // perform horizontal sum on v0 + return horizontal_sum(v0); +} + +template <> +void fvec_op_ny_D2( + float* dis, + const float* x, + const float* y, + size_t ny) { + const size_t ny8 = ny / 8; + size_t i = 0; + + if (ny8 > 0) { + // process 8 D2-vectors per loop. + _mm_prefetch((const char*)y, _MM_HINT_T0); + _mm_prefetch((const char*)(y + 16), _MM_HINT_T0); + + const __m256 m0 = _mm256_set1_ps(x[0]); + const __m256 m1 = _mm256_set1_ps(x[1]); + + for (i = 0; i < ny8 * 8; i += 8) { + _mm_prefetch((const char*)(y + 32), _MM_HINT_T0); + + // load 8x2 matrix and transpose it in registers. + // the typical bottleneck is memory access, so + // let's trade instructions for the bandwidth. + + __m256 v0; + __m256 v1; + + transpose_8x2( + _mm256_loadu_ps(y + 0 * 8), + _mm256_loadu_ps(y + 1 * 8), + v0, + v1); + + // compute distances + __m256 distances = _mm256_mul_ps(m0, v0); + distances = _mm256_fmadd_ps(m1, v1, distances); + + // store + _mm256_storeu_ps(dis + i, distances); + + y += 16; + } + } + + if (i < ny) { + // process leftovers + float x0 = x[0]; + float x1 = x[1]; + + for (; i < ny; i++) { + float distance = x0 * y[0] + x1 * y[1]; + y += 2; + dis[i] = distance; + } + } +} + +template <> +void fvec_op_ny_D2( + float* dis, + const float* x, + const float* y, + size_t ny) { + const size_t ny8 = ny / 8; + size_t i = 0; + + if (ny8 > 0) { + // process 8 D2-vectors per loop. + _mm_prefetch((const char*)y, _MM_HINT_T0); + _mm_prefetch((const char*)(y + 16), _MM_HINT_T0); + + const __m256 m0 = _mm256_set1_ps(x[0]); + const __m256 m1 = _mm256_set1_ps(x[1]); + + for (i = 0; i < ny8 * 8; i += 8) { + _mm_prefetch((const char*)(y + 32), _MM_HINT_T0); + + // load 8x2 matrix and transpose it in registers. + // the typical bottleneck is memory access, so + // let's trade instructions for the bandwidth. + + __m256 v0; + __m256 v1; + + transpose_8x2( + _mm256_loadu_ps(y + 0 * 8), + _mm256_loadu_ps(y + 1 * 8), + v0, + v1); + + // compute differences + const __m256 d0 = _mm256_sub_ps(m0, v0); + const __m256 d1 = _mm256_sub_ps(m1, v1); + + // compute squares of differences + __m256 distances = _mm256_mul_ps(d0, d0); + distances = _mm256_fmadd_ps(d1, d1, distances); + + // store + _mm256_storeu_ps(dis + i, distances); + + y += 16; + } + } + + if (i < ny) { + // process leftovers + float x0 = x[0]; + float x1 = x[1]; + + for (; i < ny; i++) { + float sub0 = x0 - y[0]; + float sub1 = x1 - y[1]; + float distance = sub0 * sub0 + sub1 * sub1; + + y += 2; + dis[i] = distance; + } + } +} + +template <> +void fvec_op_ny_D4( + float* dis, + const float* x, + const float* y, + size_t ny) { + const size_t ny8 = ny / 8; + size_t i = 0; + + if (ny8 > 0) { + // process 8 D4-vectors per loop. + const __m256 m0 = _mm256_set1_ps(x[0]); + const __m256 m1 = _mm256_set1_ps(x[1]); + const __m256 m2 = _mm256_set1_ps(x[2]); + const __m256 m3 = _mm256_set1_ps(x[3]); + + for (i = 0; i < ny8 * 8; i += 8) { + // load 8x4 matrix and transpose it in registers. + // the typical bottleneck is memory access, so + // let's trade instructions for the bandwidth. + + __m256 v0; + __m256 v1; + __m256 v2; + __m256 v3; + + transpose_8x4( + _mm256_loadu_ps(y + 0 * 8), + _mm256_loadu_ps(y + 1 * 8), + _mm256_loadu_ps(y + 2 * 8), + _mm256_loadu_ps(y + 3 * 8), + v0, + v1, + v2, + v3); + + // compute distances + __m256 distances = _mm256_mul_ps(m0, v0); + distances = _mm256_fmadd_ps(m1, v1, distances); + distances = _mm256_fmadd_ps(m2, v2, distances); + distances = _mm256_fmadd_ps(m3, v3, distances); + + // store + _mm256_storeu_ps(dis + i, distances); + + y += 32; + } + } + + if (i < ny) { + // process leftovers + __m128 x0 = _mm_loadu_ps(x); + + for (; i < ny; i++) { + __m128 accu = AVX2ElementOpIP::op(x0, _mm_loadu_ps(y)); + y += 4; + dis[i] = horizontal_sum(accu); + } + } +} + +template <> +void fvec_op_ny_D4( + float* dis, + const float* x, + const float* y, + size_t ny) { + const size_t ny8 = ny / 8; + size_t i = 0; + + if (ny8 > 0) { + // process 8 D4-vectors per loop. + const __m256 m0 = _mm256_set1_ps(x[0]); + const __m256 m1 = _mm256_set1_ps(x[1]); + const __m256 m2 = _mm256_set1_ps(x[2]); + const __m256 m3 = _mm256_set1_ps(x[3]); + + for (i = 0; i < ny8 * 8; i += 8) { + // load 8x4 matrix and transpose it in registers. + // the typical bottleneck is memory access, so + // let's trade instructions for the bandwidth. + + __m256 v0; + __m256 v1; + __m256 v2; + __m256 v3; + + transpose_8x4( + _mm256_loadu_ps(y + 0 * 8), + _mm256_loadu_ps(y + 1 * 8), + _mm256_loadu_ps(y + 2 * 8), + _mm256_loadu_ps(y + 3 * 8), + v0, + v1, + v2, + v3); + + // compute differences + const __m256 d0 = _mm256_sub_ps(m0, v0); + const __m256 d1 = _mm256_sub_ps(m1, v1); + const __m256 d2 = _mm256_sub_ps(m2, v2); + const __m256 d3 = _mm256_sub_ps(m3, v3); + + // compute squares of differences + __m256 distances = _mm256_mul_ps(d0, d0); + distances = _mm256_fmadd_ps(d1, d1, distances); + distances = _mm256_fmadd_ps(d2, d2, distances); + distances = _mm256_fmadd_ps(d3, d3, distances); + + // store + _mm256_storeu_ps(dis + i, distances); + + y += 32; + } + } + + if (i < ny) { + // process leftovers + __m128 x0 = _mm_loadu_ps(x); + + for (; i < ny; i++) { + __m128 accu = AVX2ElementOpL2::op(x0, _mm_loadu_ps(y)); + y += 4; + dis[i] = horizontal_sum(accu); + } + } +} + +template <> +void fvec_op_ny_D8( + float* dis, + const float* x, + const float* y, + size_t ny) { + const size_t ny8 = ny / 8; + size_t i = 0; + + if (ny8 > 0) { + // process 8 D8-vectors per loop. + const __m256 m0 = _mm256_set1_ps(x[0]); + const __m256 m1 = _mm256_set1_ps(x[1]); + const __m256 m2 = _mm256_set1_ps(x[2]); + const __m256 m3 = _mm256_set1_ps(x[3]); + const __m256 m4 = _mm256_set1_ps(x[4]); + const __m256 m5 = _mm256_set1_ps(x[5]); + const __m256 m6 = _mm256_set1_ps(x[6]); + const __m256 m7 = _mm256_set1_ps(x[7]); + + for (i = 0; i < ny8 * 8; i += 8) { + // load 8x8 matrix and transpose it in registers. + // the typical bottleneck is memory access, so + // let's trade instructions for the bandwidth. + + __m256 v0; + __m256 v1; + __m256 v2; + __m256 v3; + __m256 v4; + __m256 v5; + __m256 v6; + __m256 v7; + + transpose_8x8( + _mm256_loadu_ps(y + 0 * 8), + _mm256_loadu_ps(y + 1 * 8), + _mm256_loadu_ps(y + 2 * 8), + _mm256_loadu_ps(y + 3 * 8), + _mm256_loadu_ps(y + 4 * 8), + _mm256_loadu_ps(y + 5 * 8), + _mm256_loadu_ps(y + 6 * 8), + _mm256_loadu_ps(y + 7 * 8), + v0, + v1, + v2, + v3, + v4, + v5, + v6, + v7); + + // compute distances + __m256 distances = _mm256_mul_ps(m0, v0); + distances = _mm256_fmadd_ps(m1, v1, distances); + distances = _mm256_fmadd_ps(m2, v2, distances); + distances = _mm256_fmadd_ps(m3, v3, distances); + distances = _mm256_fmadd_ps(m4, v4, distances); + distances = _mm256_fmadd_ps(m5, v5, distances); + distances = _mm256_fmadd_ps(m6, v6, distances); + distances = _mm256_fmadd_ps(m7, v7, distances); + + // store + _mm256_storeu_ps(dis + i, distances); + + y += 64; + } + } + + if (i < ny) { + // process leftovers + __m256 x0 = _mm256_loadu_ps(x); + + for (; i < ny; i++) { + __m256 accu = AVX2ElementOpIP::op(x0, _mm256_loadu_ps(y)); + y += 8; + dis[i] = horizontal_sum(accu); + } + } +} + +template <> +void fvec_op_ny_D8( + float* dis, + const float* x, + const float* y, + size_t ny) { + const size_t ny8 = ny / 8; + size_t i = 0; + + if (ny8 > 0) { + // process 8 D8-vectors per loop. + const __m256 m0 = _mm256_set1_ps(x[0]); + const __m256 m1 = _mm256_set1_ps(x[1]); + const __m256 m2 = _mm256_set1_ps(x[2]); + const __m256 m3 = _mm256_set1_ps(x[3]); + const __m256 m4 = _mm256_set1_ps(x[4]); + const __m256 m5 = _mm256_set1_ps(x[5]); + const __m256 m6 = _mm256_set1_ps(x[6]); + const __m256 m7 = _mm256_set1_ps(x[7]); + + for (i = 0; i < ny8 * 8; i += 8) { + // load 8x8 matrix and transpose it in registers. + // the typical bottleneck is memory access, so + // let's trade instructions for the bandwidth. + + __m256 v0; + __m256 v1; + __m256 v2; + __m256 v3; + __m256 v4; + __m256 v5; + __m256 v6; + __m256 v7; + + transpose_8x8( + _mm256_loadu_ps(y + 0 * 8), + _mm256_loadu_ps(y + 1 * 8), + _mm256_loadu_ps(y + 2 * 8), + _mm256_loadu_ps(y + 3 * 8), + _mm256_loadu_ps(y + 4 * 8), + _mm256_loadu_ps(y + 5 * 8), + _mm256_loadu_ps(y + 6 * 8), + _mm256_loadu_ps(y + 7 * 8), + v0, + v1, + v2, + v3, + v4, + v5, + v6, + v7); + + // compute differences + const __m256 d0 = _mm256_sub_ps(m0, v0); + const __m256 d1 = _mm256_sub_ps(m1, v1); + const __m256 d2 = _mm256_sub_ps(m2, v2); + const __m256 d3 = _mm256_sub_ps(m3, v3); + const __m256 d4 = _mm256_sub_ps(m4, v4); + const __m256 d5 = _mm256_sub_ps(m5, v5); + const __m256 d6 = _mm256_sub_ps(m6, v6); + const __m256 d7 = _mm256_sub_ps(m7, v7); + + // compute squares of differences + __m256 distances = _mm256_mul_ps(d0, d0); + distances = _mm256_fmadd_ps(d1, d1, distances); + distances = _mm256_fmadd_ps(d2, d2, distances); + distances = _mm256_fmadd_ps(d3, d3, distances); + distances = _mm256_fmadd_ps(d4, d4, distances); + distances = _mm256_fmadd_ps(d5, d5, distances); + distances = _mm256_fmadd_ps(d6, d6, distances); + distances = _mm256_fmadd_ps(d7, d7, distances); + + // store + _mm256_storeu_ps(dis + i, distances); + + y += 64; + } + } + + if (i < ny) { + // process leftovers + __m256 x0 = _mm256_loadu_ps(x); + + for (; i < ny; i++) { + __m256 accu = AVX2ElementOpL2::op(x0, _mm256_loadu_ps(y)); + y += 8; + dis[i] = horizontal_sum(accu); + } + } +} + +template <> +void fvec_inner_products_ny( + float* ip, /* output inner product */ + const float* x, + const float* y, + size_t d, + size_t ny) { + fvec_inner_products_ny_ref(ip, x, y, d, ny); +} + +template <> +void fvec_L2sqr_ny( + float* dis, + const float* x, + const float* y, + size_t d, + size_t ny) { + fvec_L2sqr_ny_ref(dis, x, y, d, ny); +} + +template <> +size_t fvec_L2sqr_ny_nearest_D2( + float* distances_tmp_buffer, + const float* x, + const float* y, + size_t ny) { + // this implementation does not use distances_tmp_buffer. + // current index being processed + size_t i = 0; + + // min distance and the index of the closest vector so far + float current_min_distance = HUGE_VALF; + size_t current_min_index = 0; + + // process 8 D2-vectors per loop. + const size_t ny8 = ny / 8; + if (ny8 > 0) { + _mm_prefetch((const char*)y, _MM_HINT_T0); + _mm_prefetch((const char*)(y + 16), _MM_HINT_T0); + + // track min distance and the closest vector independently + // for each of 8 AVX2 components. + __m256 min_distances = _mm256_set1_ps(HUGE_VALF); + __m256i min_indices = _mm256_set1_epi32(0); + + __m256i current_indices = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); + const __m256i indices_increment = _mm256_set1_epi32(8); + + // 1 value per register + const __m256 m0 = _mm256_set1_ps(x[0]); + const __m256 m1 = _mm256_set1_ps(x[1]); + + for (; i < ny8 * 8; i += 8) { + _mm_prefetch((const char*)(y + 32), _MM_HINT_T0); + + __m256 v0; + __m256 v1; + + transpose_8x2( + _mm256_loadu_ps(y + 0 * 8), + _mm256_loadu_ps(y + 1 * 8), + v0, + v1); + + // compute differences + const __m256 d0 = _mm256_sub_ps(m0, v0); + const __m256 d1 = _mm256_sub_ps(m1, v1); + + // compute squares of differences + __m256 distances = _mm256_mul_ps(d0, d0); + distances = _mm256_fmadd_ps(d1, d1, distances); + + // compare the new distances to the min distances + // for each of 8 AVX2 components. + __m256 comparison = + _mm256_cmp_ps(min_distances, distances, _CMP_LT_OS); + + // update min distances and indices with closest vectors if needed. + min_distances = _mm256_min_ps(distances, min_distances); + min_indices = _mm256_castps_si256(_mm256_blendv_ps( + _mm256_castsi256_ps(current_indices), + _mm256_castsi256_ps(min_indices), + comparison)); + + // update current indices values. Basically, +8 to each of the + // 8 AVX2 components. + current_indices = + _mm256_add_epi32(current_indices, indices_increment); + + // scroll y forward (8 vectors 2 DIM each). + y += 16; + } + + // dump values and find the minimum distance / minimum index + float min_distances_scalar[8]; + uint32_t min_indices_scalar[8]; + _mm256_storeu_ps(min_distances_scalar, min_distances); + _mm256_storeu_si256((__m256i*)(min_indices_scalar), min_indices); + + for (size_t j = 0; j < 8; j++) { + if (current_min_distance > min_distances_scalar[j]) { + current_min_distance = min_distances_scalar[j]; + current_min_index = min_indices_scalar[j]; + } + } + } + + if (i < ny) { + // process leftovers. + // the following code is not optimal, but it is rarely invoked. + float x0 = x[0]; + float x1 = x[1]; + + for (; i < ny; i++) { + float sub0 = x0 - y[0]; + float sub1 = x1 - y[1]; + float distance = sub0 * sub0 + sub1 * sub1; + + y += 2; + + if (current_min_distance > distance) { + current_min_distance = distance; + current_min_index = i; + } + } + } + + return current_min_index; +} + +template <> +size_t fvec_L2sqr_ny_nearest_D4( + float* distances_tmp_buffer, + const float* x, + const float* y, + size_t ny) { + // this implementation does not use distances_tmp_buffer. + + // current index being processed + size_t i = 0; + + // min distance and the index of the closest vector so far + float current_min_distance = HUGE_VALF; + size_t current_min_index = 0; + + // process 8 D4-vectors per loop. + const size_t ny8 = ny / 8; + + if (ny8 > 0) { + // track min distance and the closest vector independently + // for each of 8 AVX2 components. + __m256 min_distances = _mm256_set1_ps(HUGE_VALF); + __m256i min_indices = _mm256_set1_epi32(0); + + __m256i current_indices = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); + const __m256i indices_increment = _mm256_set1_epi32(8); + + // 1 value per register + const __m256 m0 = _mm256_set1_ps(x[0]); + const __m256 m1 = _mm256_set1_ps(x[1]); + const __m256 m2 = _mm256_set1_ps(x[2]); + const __m256 m3 = _mm256_set1_ps(x[3]); + + for (; i < ny8 * 8; i += 8) { + __m256 v0; + __m256 v1; + __m256 v2; + __m256 v3; + + transpose_8x4( + _mm256_loadu_ps(y + 0 * 8), + _mm256_loadu_ps(y + 1 * 8), + _mm256_loadu_ps(y + 2 * 8), + _mm256_loadu_ps(y + 3 * 8), + v0, + v1, + v2, + v3); + + // compute differences + const __m256 d0 = _mm256_sub_ps(m0, v0); + const __m256 d1 = _mm256_sub_ps(m1, v1); + const __m256 d2 = _mm256_sub_ps(m2, v2); + const __m256 d3 = _mm256_sub_ps(m3, v3); + + // compute squares of differences + __m256 distances = _mm256_mul_ps(d0, d0); + distances = _mm256_fmadd_ps(d1, d1, distances); + distances = _mm256_fmadd_ps(d2, d2, distances); + distances = _mm256_fmadd_ps(d3, d3, distances); + + // compare the new distances to the min distances + // for each of 8 AVX2 components. + __m256 comparison = + _mm256_cmp_ps(min_distances, distances, _CMP_LT_OS); + + // update min distances and indices with closest vectors if needed. + min_distances = _mm256_min_ps(distances, min_distances); + min_indices = _mm256_castps_si256(_mm256_blendv_ps( + _mm256_castsi256_ps(current_indices), + _mm256_castsi256_ps(min_indices), + comparison)); + + // update current indices values. Basically, +8 to each of the + // 8 AVX2 components. + current_indices = + _mm256_add_epi32(current_indices, indices_increment); + + // scroll y forward (8 vectors 4 DIM each). + y += 32; + } + + // dump values and find the minimum distance / minimum index + float min_distances_scalar[8]; + uint32_t min_indices_scalar[8]; + _mm256_storeu_ps(min_distances_scalar, min_distances); + _mm256_storeu_si256((__m256i*)(min_indices_scalar), min_indices); + + for (size_t j = 0; j < 8; j++) { + if (current_min_distance > min_distances_scalar[j]) { + current_min_distance = min_distances_scalar[j]; + current_min_index = min_indices_scalar[j]; + } + } + } + + if (i < ny) { + // process leftovers + __m128 x0 = _mm_loadu_ps(x); + + for (; i < ny; i++) { + __m128 accu = ElementOpL2::op(x0, _mm_loadu_ps(y)); + y += 4; + const float distance = horizontal_sum(accu); + + if (current_min_distance > distance) { + current_min_distance = distance; + current_min_index = i; + } + } + } + + return current_min_index; +} + +template <> +size_t fvec_L2sqr_ny_nearest_D8( + float* distances_tmp_buffer, + const float* x, + const float* y, + size_t ny) { + // this implementation does not use distances_tmp_buffer. + + // current index being processed + size_t i = 0; + + // min distance and the index of the closest vector so far + float current_min_distance = HUGE_VALF; + size_t current_min_index = 0; + + // process 8 D8-vectors per loop. + const size_t ny8 = ny / 8; + if (ny8 > 0) { + // track min distance and the closest vector independently + // for each of 8 AVX2 components. + __m256 min_distances = _mm256_set1_ps(HUGE_VALF); + __m256i min_indices = _mm256_set1_epi32(0); + + __m256i current_indices = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); + const __m256i indices_increment = _mm256_set1_epi32(8); + + // 1 value per register + const __m256 m0 = _mm256_set1_ps(x[0]); + const __m256 m1 = _mm256_set1_ps(x[1]); + const __m256 m2 = _mm256_set1_ps(x[2]); + const __m256 m3 = _mm256_set1_ps(x[3]); + + const __m256 m4 = _mm256_set1_ps(x[4]); + const __m256 m5 = _mm256_set1_ps(x[5]); + const __m256 m6 = _mm256_set1_ps(x[6]); + const __m256 m7 = _mm256_set1_ps(x[7]); + + for (; i < ny8 * 8; i += 8) { + __m256 v0; + __m256 v1; + __m256 v2; + __m256 v3; + __m256 v4; + __m256 v5; + __m256 v6; + __m256 v7; + + transpose_8x8( + _mm256_loadu_ps(y + 0 * 8), + _mm256_loadu_ps(y + 1 * 8), + _mm256_loadu_ps(y + 2 * 8), + _mm256_loadu_ps(y + 3 * 8), + _mm256_loadu_ps(y + 4 * 8), + _mm256_loadu_ps(y + 5 * 8), + _mm256_loadu_ps(y + 6 * 8), + _mm256_loadu_ps(y + 7 * 8), + v0, + v1, + v2, + v3, + v4, + v5, + v6, + v7); + + // compute differences + const __m256 d0 = _mm256_sub_ps(m0, v0); + const __m256 d1 = _mm256_sub_ps(m1, v1); + const __m256 d2 = _mm256_sub_ps(m2, v2); + const __m256 d3 = _mm256_sub_ps(m3, v3); + const __m256 d4 = _mm256_sub_ps(m4, v4); + const __m256 d5 = _mm256_sub_ps(m5, v5); + const __m256 d6 = _mm256_sub_ps(m6, v6); + const __m256 d7 = _mm256_sub_ps(m7, v7); + + // compute squares of differences + __m256 distances = _mm256_mul_ps(d0, d0); + distances = _mm256_fmadd_ps(d1, d1, distances); + distances = _mm256_fmadd_ps(d2, d2, distances); + distances = _mm256_fmadd_ps(d3, d3, distances); + distances = _mm256_fmadd_ps(d4, d4, distances); + distances = _mm256_fmadd_ps(d5, d5, distances); + distances = _mm256_fmadd_ps(d6, d6, distances); + distances = _mm256_fmadd_ps(d7, d7, distances); + + // compare the new distances to the min distances + // for each of 8 AVX2 components. + __m256 comparison = + _mm256_cmp_ps(min_distances, distances, _CMP_LT_OS); + + // update min distances and indices with closest vectors if needed. + min_distances = _mm256_min_ps(distances, min_distances); + min_indices = _mm256_castps_si256(_mm256_blendv_ps( + _mm256_castsi256_ps(current_indices), + _mm256_castsi256_ps(min_indices), + comparison)); + + // update current indices values. Basically, +8 to each of the + // 8 AVX2 components. + current_indices = + _mm256_add_epi32(current_indices, indices_increment); + + // scroll y forward (8 vectors 8 DIM each). + y += 64; + } + + // dump values and find the minimum distance / minimum index + float min_distances_scalar[8]; + uint32_t min_indices_scalar[8]; + _mm256_storeu_ps(min_distances_scalar, min_distances); + _mm256_storeu_si256((__m256i*)(min_indices_scalar), min_indices); + + for (size_t j = 0; j < 8; j++) { + if (current_min_distance > min_distances_scalar[j]) { + current_min_distance = min_distances_scalar[j]; + current_min_index = min_indices_scalar[j]; + } + } + } + + if (i < ny) { + // process leftovers + __m256 x0 = _mm256_loadu_ps(x); + + for (; i < ny; i++) { + __m256 accu = AVX2ElementOpL2::op(x0, _mm256_loadu_ps(y)); + y += 8; + const float distance = horizontal_sum(accu); + + if (current_min_distance > distance) { + current_min_distance = distance; + current_min_index = i; + } + } + } + + return current_min_index; +} + +template <> +size_t fvec_L2sqr_ny_nearest( + float* distances_tmp_buffer, + const float* x, + const float* y, + size_t d, + size_t ny) { + return fvec_L2sqr_ny_nearest_x86( + distances_tmp_buffer, + x, + y, + d, + ny, + &fvec_L2sqr_ny_nearest_D2, + &fvec_L2sqr_ny_nearest_D4, + &fvec_L2sqr_ny_nearest_D8); +} + +template +size_t fvec_L2sqr_ny_nearest_y_transposed_D( + float* distances_tmp_buffer, + const float* x, + const float* y, + const float* y_sqlen, + const size_t d_offset, + size_t ny) { + // this implementation does not use distances_tmp_buffer. + + // current index being processed + size_t i = 0; + + // min distance and the index of the closest vector so far + float current_min_distance = HUGE_VALF; + size_t current_min_index = 0; + + // process 8 vectors per loop. + const size_t ny8 = ny / 8; + + if (ny8 > 0) { + // track min distance and the closest vector independently + // for each of 8 AVX2 components. + __m256 min_distances = _mm256_set1_ps(HUGE_VALF); + __m256i min_indices = _mm256_set1_epi32(0); + + __m256i current_indices = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); + const __m256i indices_increment = _mm256_set1_epi32(8); + + // m[i] = (2 * x[i], ... 2 * x[i]) + __m256 m[DIM]; + for (size_t j = 0; j < DIM; j++) { + m[j] = _mm256_set1_ps(x[j]); + m[j] = _mm256_add_ps(m[j], m[j]); + } + + for (; i < ny8 * 8; i += 8) { + // collect dim 0 for 8 D4-vectors. + const __m256 v0 = _mm256_loadu_ps(y + 0 * d_offset); + // compute dot products + __m256 dp = _mm256_mul_ps(m[0], v0); + + for (size_t j = 1; j < DIM; j++) { + // collect dim j for 8 D4-vectors. + const __m256 vj = _mm256_loadu_ps(y + j * d_offset); + dp = _mm256_fmadd_ps(m[j], vj, dp); + } + + // compute y^2 - (2 * x, y), which is sufficient for looking for the + // lowest distance. + // x^2 is the constant that can be avoided. + const __m256 distances = + _mm256_sub_ps(_mm256_loadu_ps(y_sqlen), dp); + + // compare the new distances to the min distances + // for each of 8 AVX2 components. + const __m256 comparison = + _mm256_cmp_ps(min_distances, distances, _CMP_LT_OS); + + // update min distances and indices with closest vectors if needed. + min_distances = + _mm256_blendv_ps(distances, min_distances, comparison); + min_indices = _mm256_castps_si256(_mm256_blendv_ps( + _mm256_castsi256_ps(current_indices), + _mm256_castsi256_ps(min_indices), + comparison)); + + // update current indices values. Basically, +8 to each of the + // 8 AVX2 components. + current_indices = + _mm256_add_epi32(current_indices, indices_increment); + + // scroll y and y_sqlen forward. + y += 8; + y_sqlen += 8; + } + + // dump values and find the minimum distance / minimum index + float min_distances_scalar[8]; + uint32_t min_indices_scalar[8]; + _mm256_storeu_ps(min_distances_scalar, min_distances); + _mm256_storeu_si256((__m256i*)(min_indices_scalar), min_indices); + + for (size_t j = 0; j < 8; j++) { + if (current_min_distance > min_distances_scalar[j]) { + current_min_distance = min_distances_scalar[j]; + current_min_index = min_indices_scalar[j]; + } + } + } + + if (i < ny) { + // process leftovers + for (; i < ny; i++) { + float dp = 0; + for (size_t j = 0; j < DIM; j++) { + dp += x[j] * y[j * d_offset]; + } + + // compute y^2 - 2 * (x, y), which is sufficient for looking for the + // lowest distance. + const float distance = y_sqlen[0] - 2 * dp; + + if (current_min_distance > distance) { + current_min_distance = distance; + current_min_index = i; + } + + y += 1; + y_sqlen += 1; + } + } + + return current_min_index; +} + +template <> +size_t fvec_L2sqr_ny_nearest_y_transposed( + float* distances_tmp_buffer, + const float* x, + const float* y, + const float* y_sqlen, + size_t d, + size_t d_offset, + size_t ny) { +// optimized for a few special cases +#define DISPATCH(dval) \ + case dval: \ + return fvec_L2sqr_ny_nearest_y_transposed_D( \ + distances_tmp_buffer, x, y, y_sqlen, d_offset, ny); + + switch (d) { + DISPATCH(1) + DISPATCH(2) + DISPATCH(4) + DISPATCH(8) + default: + return fvec_L2sqr_ny_nearest_y_transposed( + distances_tmp_buffer, x, y, y_sqlen, d, d_offset, ny); + } +#undef DISPATCH +} + +template <> +int fvec_madd_and_argmin( + size_t n, + const float* a, + float bf, + const float* b, + float* c) { + return fvec_madd_and_argmin_sse(n, a, bf, b, c); +} + +} // namespace faiss diff --git a/faiss/utils/simd_impl/distances_avx512.cpp b/faiss/utils/simd_impl/distances_avx512.cpp new file mode 100644 index 0000000000..06d5b399f4 --- /dev/null +++ b/faiss/utils/simd_impl/distances_avx512.cpp @@ -0,0 +1,1092 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#define AUTOVEC_LEVEL SIMDLevel::AVX512 +#include +#include +#include + +namespace faiss { + +template <> +void fvec_madd( + const size_t n, + const float* __restrict a, + const float bf, + const float* __restrict b, + float* __restrict c) { + const size_t n16 = n / 16; + const size_t n_for_masking = n % 16; + + const __m512 bfmm = _mm512_set1_ps(bf); + + size_t idx = 0; + for (idx = 0; idx < n16 * 16; idx += 16) { + const __m512 ax = _mm512_loadu_ps(a + idx); + const __m512 bx = _mm512_loadu_ps(b + idx); + const __m512 abmul = _mm512_fmadd_ps(bfmm, bx, ax); + _mm512_storeu_ps(c + idx, abmul); + } + + if (n_for_masking > 0) { + const __mmask16 mask = (1 << n_for_masking) - 1; + + const __m512 ax = _mm512_maskz_loadu_ps(mask, a + idx); + const __m512 bx = _mm512_maskz_loadu_ps(mask, b + idx); + const __m512 abmul = _mm512_fmadd_ps(bfmm, bx, ax); + _mm512_mask_storeu_ps(c + idx, mask, abmul); + } +} + +template +void fvec_L2sqr_ny_y_transposed_D( + float* distances, + const float* x, + const float* y, + const float* y_sqlen, + const size_t d_offset, + size_t ny) { + // current index being processed + size_t i = 0; + + // squared length of x + float x_sqlen = 0; + for (size_t j = 0; j < DIM; j++) { + x_sqlen += x[j] * x[j]; + } + + // process 16 vectors per loop + const size_t ny16 = ny / 16; + + if (ny16 > 0) { + // m[i] = (2 * x[i], ... 2 * x[i]) + __m512 m[DIM]; + for (size_t j = 0; j < DIM; j++) { + m[j] = _mm512_set1_ps(x[j]); + m[j] = _mm512_add_ps(m[j], m[j]); // m[j] = 2 * x[j] + } + + __m512 x_sqlen_ymm = _mm512_set1_ps(x_sqlen); + + for (; i < ny16 * 16; i += 16) { + // Load vectors for 16 dimensions + __m512 v[DIM]; + for (size_t j = 0; j < DIM; j++) { + v[j] = _mm512_loadu_ps(y + j * d_offset); + } + + // Compute dot products + __m512 dp = _mm512_fnmadd_ps(m[0], v[0], x_sqlen_ymm); + for (size_t j = 1; j < DIM; j++) { + dp = _mm512_fnmadd_ps(m[j], v[j], dp); + } + + // Compute y^2 - (2 * x, y) + x^2 + __m512 distances_v = _mm512_add_ps(_mm512_loadu_ps(y_sqlen), dp); + + _mm512_storeu_ps(distances + i, distances_v); + + // Scroll y and y_sqlen forward + y += 16; + y_sqlen += 16; + } + } + + if (i < ny) { + // Process leftovers + for (; i < ny; i++) { + float dp = 0; + for (size_t j = 0; j < DIM; j++) { + dp += x[j] * y[j * d_offset]; + } + + // Compute y^2 - 2 * (x, y), which is sufficient for looking for the + // lowest distance. + const float distance = y_sqlen[0] - 2 * dp + x_sqlen; + distances[i] = distance; + + y += 1; + y_sqlen += 1; + } + } +} + +template <> +void fvec_L2sqr_ny_transposed( + float* dis, + const float* x, + const float* y, + const float* y_sqlen, + size_t d, + size_t d_offset, + size_t ny) { + // optimized for a few special cases +#define DISPATCH(dval) \ + case dval: \ + return fvec_L2sqr_ny_y_transposed_D( \ + dis, x, y, y_sqlen, d_offset, ny); + + switch (d) { + DISPATCH(1) + DISPATCH(2) + DISPATCH(4) + DISPATCH(8) + default: + return fvec_L2sqr_ny_transposed( + dis, x, y, y_sqlen, d, d_offset, ny); + } +#undef DISPATCH +} + +struct AVX512ElementOpIP : public ElementOpIP { + using ElementOpIP::op; + static __m512 op(__m512 x, __m512 y) { + return _mm512_mul_ps(x, y); + } + static __m256 op(__m256 x, __m256 y) { + return _mm256_mul_ps(x, y); + } +}; + +struct AVX512ElementOpL2 : public ElementOpL2 { + using ElementOpL2::op; + static __m512 op(__m512 x, __m512 y) { + __m512 tmp = _mm512_sub_ps(x, y); + return _mm512_mul_ps(tmp, tmp); + } + static __m256 op(__m256 x, __m256 y) { + __m256 tmp = _mm256_sub_ps(x, y); + return _mm256_mul_ps(tmp, tmp); + } +}; + +/// helper function for AVX512 +inline float horizontal_sum(const __m512 v) { + // performs better than adding the high and low parts + return _mm512_reduce_add_ps(v); +} + +inline float horizontal_sum(const __m256 v) { + // add high and low parts + const __m128 v0 = + _mm_add_ps(_mm256_castps256_ps128(v), _mm256_extractf128_ps(v, 1)); + // perform horizontal sum on v0 + return horizontal_sum(v0); +} + +template <> +void fvec_op_ny_D2( + float* dis, + const float* x, + const float* y, + size_t ny) { + const size_t ny16 = ny / 16; + size_t i = 0; + + if (ny16 > 0) { + // process 16 D2-vectors per loop. + _mm_prefetch((const char*)y, _MM_HINT_T0); + _mm_prefetch((const char*)(y + 32), _MM_HINT_T0); + + const __m512 m0 = _mm512_set1_ps(x[0]); + const __m512 m1 = _mm512_set1_ps(x[1]); + + for (i = 0; i < ny16 * 16; i += 16) { + _mm_prefetch((const char*)(y + 64), _MM_HINT_T0); + + // load 16x2 matrix and transpose it in registers. + // the typical bottleneck is memory access, so + // let's trade instructions for the bandwidth. + + __m512 v0; + __m512 v1; + + transpose_16x2( + _mm512_loadu_ps(y + 0 * 16), + _mm512_loadu_ps(y + 1 * 16), + v0, + v1); + + // compute distances (dot product) + __m512 distances = _mm512_mul_ps(m0, v0); + distances = _mm512_fmadd_ps(m1, v1, distances); + + // store + _mm512_storeu_ps(dis + i, distances); + + y += 32; // move to the next set of 16x2 elements + } + } + + if (i < ny) { + // process leftovers + float x0 = x[0]; + float x1 = x[1]; + + for (; i < ny; i++) { + float distance = x0 * y[0] + x1 * y[1]; + y += 2; + dis[i] = distance; + } + } +} + +template <> +void fvec_op_ny_D2( + float* dis, + const float* x, + const float* y, + size_t ny) { + const size_t ny16 = ny / 16; + size_t i = 0; + + if (ny16 > 0) { + // process 16 D2-vectors per loop. + _mm_prefetch((const char*)y, _MM_HINT_T0); + _mm_prefetch((const char*)(y + 32), _MM_HINT_T0); + + const __m512 m0 = _mm512_set1_ps(x[0]); + const __m512 m1 = _mm512_set1_ps(x[1]); + + for (i = 0; i < ny16 * 16; i += 16) { + _mm_prefetch((const char*)(y + 64), _MM_HINT_T0); + + // load 16x2 matrix and transpose it in registers. + // the typical bottleneck is memory access, so + // let's trade instructions for the bandwidth. + + __m512 v0; + __m512 v1; + + transpose_16x2( + _mm512_loadu_ps(y + 0 * 16), + _mm512_loadu_ps(y + 1 * 16), + v0, + v1); + + // compute differences + const __m512 d0 = _mm512_sub_ps(m0, v0); + const __m512 d1 = _mm512_sub_ps(m1, v1); + + // compute squares of differences + __m512 distances = _mm512_mul_ps(d0, d0); + distances = _mm512_fmadd_ps(d1, d1, distances); + + // store + _mm512_storeu_ps(dis + i, distances); + + y += 32; // move to the next set of 16x2 elements + } + } + + if (i < ny) { + // process leftovers + float x0 = x[0]; + float x1 = x[1]; + + for (; i < ny; i++) { + float sub0 = x0 - y[0]; + float sub1 = x1 - y[1]; + float distance = sub0 * sub0 + sub1 * sub1; + + y += 2; + dis[i] = distance; + } + } +} + +template <> +void fvec_op_ny_D4( + float* dis, + const float* x, + const float* y, + size_t ny) { + const size_t ny16 = ny / 16; + size_t i = 0; + + if (ny16 > 0) { + // process 16 D4-vectors per loop. + const __m512 m0 = _mm512_set1_ps(x[0]); + const __m512 m1 = _mm512_set1_ps(x[1]); + const __m512 m2 = _mm512_set1_ps(x[2]); + const __m512 m3 = _mm512_set1_ps(x[3]); + + for (i = 0; i < ny16 * 16; i += 16) { + // load 16x4 matrix and transpose it in registers. + // the typical bottleneck is memory access, so + // let's trade instructions for the bandwidth. + + __m512 v0; + __m512 v1; + __m512 v2; + __m512 v3; + + transpose_16x4( + _mm512_loadu_ps(y + 0 * 16), + _mm512_loadu_ps(y + 1 * 16), + _mm512_loadu_ps(y + 2 * 16), + _mm512_loadu_ps(y + 3 * 16), + v0, + v1, + v2, + v3); + + // compute distances + __m512 distances = _mm512_mul_ps(m0, v0); + distances = _mm512_fmadd_ps(m1, v1, distances); + distances = _mm512_fmadd_ps(m2, v2, distances); + distances = _mm512_fmadd_ps(m3, v3, distances); + + // store + _mm512_storeu_ps(dis + i, distances); + + y += 64; // move to the next set of 16x4 elements + } + } + + if (i < ny) { + // process leftovers + __m128 x0 = _mm_loadu_ps(x); + + for (; i < ny; i++) { + __m128 accu = AVX512ElementOpIP::op(x0, _mm_loadu_ps(y)); + y += 4; + dis[i] = horizontal_sum(accu); + } + } +} + +template <> +void fvec_op_ny_D4( + float* dis, + const float* x, + const float* y, + size_t ny) { + const size_t ny16 = ny / 16; + size_t i = 0; + + if (ny16 > 0) { + // process 16 D4-vectors per loop. + const __m512 m0 = _mm512_set1_ps(x[0]); + const __m512 m1 = _mm512_set1_ps(x[1]); + const __m512 m2 = _mm512_set1_ps(x[2]); + const __m512 m3 = _mm512_set1_ps(x[3]); + + for (i = 0; i < ny16 * 16; i += 16) { + // load 16x4 matrix and transpose it in registers. + // the typical bottleneck is memory access, so + // let's trade instructions for the bandwidth. + + __m512 v0; + __m512 v1; + __m512 v2; + __m512 v3; + + transpose_16x4( + _mm512_loadu_ps(y + 0 * 16), + _mm512_loadu_ps(y + 1 * 16), + _mm512_loadu_ps(y + 2 * 16), + _mm512_loadu_ps(y + 3 * 16), + v0, + v1, + v2, + v3); + + // compute differences + const __m512 d0 = _mm512_sub_ps(m0, v0); + const __m512 d1 = _mm512_sub_ps(m1, v1); + const __m512 d2 = _mm512_sub_ps(m2, v2); + const __m512 d3 = _mm512_sub_ps(m3, v3); + + // compute squares of differences + __m512 distances = _mm512_mul_ps(d0, d0); + distances = _mm512_fmadd_ps(d1, d1, distances); + distances = _mm512_fmadd_ps(d2, d2, distances); + distances = _mm512_fmadd_ps(d3, d3, distances); + + // store + _mm512_storeu_ps(dis + i, distances); + + y += 64; // move to the next set of 16x4 elements + } + } + + if (i < ny) { + // process leftovers + __m128 x0 = _mm_loadu_ps(x); + + for (; i < ny; i++) { + __m128 accu = AVX512ElementOpL2::op(x0, _mm_loadu_ps(y)); + y += 4; + dis[i] = horizontal_sum(accu); + } + } +} + +template <> +void fvec_op_ny_D8( + float* dis, + const float* x, + const float* y, + size_t ny) { + const size_t ny16 = ny / 16; + size_t i = 0; + + if (ny16 > 0) { + // process 16 D16-vectors per loop. + const __m512 m0 = _mm512_set1_ps(x[0]); + const __m512 m1 = _mm512_set1_ps(x[1]); + const __m512 m2 = _mm512_set1_ps(x[2]); + const __m512 m3 = _mm512_set1_ps(x[3]); + const __m512 m4 = _mm512_set1_ps(x[4]); + const __m512 m5 = _mm512_set1_ps(x[5]); + const __m512 m6 = _mm512_set1_ps(x[6]); + const __m512 m7 = _mm512_set1_ps(x[7]); + + for (i = 0; i < ny16 * 16; i += 16) { + // load 16x8 matrix and transpose it in registers. + // the typical bottleneck is memory access, so + // let's trade instructions for the bandwidth. + + __m512 v0; + __m512 v1; + __m512 v2; + __m512 v3; + __m512 v4; + __m512 v5; + __m512 v6; + __m512 v7; + + transpose_16x8( + _mm512_loadu_ps(y + 0 * 16), + _mm512_loadu_ps(y + 1 * 16), + _mm512_loadu_ps(y + 2 * 16), + _mm512_loadu_ps(y + 3 * 16), + _mm512_loadu_ps(y + 4 * 16), + _mm512_loadu_ps(y + 5 * 16), + _mm512_loadu_ps(y + 6 * 16), + _mm512_loadu_ps(y + 7 * 16), + v0, + v1, + v2, + v3, + v4, + v5, + v6, + v7); + + // compute distances + __m512 distances = _mm512_mul_ps(m0, v0); + distances = _mm512_fmadd_ps(m1, v1, distances); + distances = _mm512_fmadd_ps(m2, v2, distances); + distances = _mm512_fmadd_ps(m3, v3, distances); + distances = _mm512_fmadd_ps(m4, v4, distances); + distances = _mm512_fmadd_ps(m5, v5, distances); + distances = _mm512_fmadd_ps(m6, v6, distances); + distances = _mm512_fmadd_ps(m7, v7, distances); + + // store + _mm512_storeu_ps(dis + i, distances); + + y += 128; // 16 floats * 8 rows + } + } + + if (i < ny) { + // process leftovers + __m256 x0 = _mm256_loadu_ps(x); + + for (; i < ny; i++) { + __m256 accu = AVX512ElementOpIP::op(x0, _mm256_loadu_ps(y)); + y += 8; + dis[i] = horizontal_sum(accu); + } + } +} + +template <> +void fvec_op_ny_D8( + float* dis, + const float* x, + const float* y, + size_t ny) { + const size_t ny16 = ny / 16; + size_t i = 0; + + if (ny16 > 0) { + // process 16 D16-vectors per loop. + const __m512 m0 = _mm512_set1_ps(x[0]); + const __m512 m1 = _mm512_set1_ps(x[1]); + const __m512 m2 = _mm512_set1_ps(x[2]); + const __m512 m3 = _mm512_set1_ps(x[3]); + const __m512 m4 = _mm512_set1_ps(x[4]); + const __m512 m5 = _mm512_set1_ps(x[5]); + const __m512 m6 = _mm512_set1_ps(x[6]); + const __m512 m7 = _mm512_set1_ps(x[7]); + + for (i = 0; i < ny16 * 16; i += 16) { + // load 16x8 matrix and transpose it in registers. + // the typical bottleneck is memory access, so + // let's trade instructions for the bandwidth. + + __m512 v0; + __m512 v1; + __m512 v2; + __m512 v3; + __m512 v4; + __m512 v5; + __m512 v6; + __m512 v7; + + transpose_16x8( + _mm512_loadu_ps(y + 0 * 16), + _mm512_loadu_ps(y + 1 * 16), + _mm512_loadu_ps(y + 2 * 16), + _mm512_loadu_ps(y + 3 * 16), + _mm512_loadu_ps(y + 4 * 16), + _mm512_loadu_ps(y + 5 * 16), + _mm512_loadu_ps(y + 6 * 16), + _mm512_loadu_ps(y + 7 * 16), + v0, + v1, + v2, + v3, + v4, + v5, + v6, + v7); + + // compute differences + const __m512 d0 = _mm512_sub_ps(m0, v0); + const __m512 d1 = _mm512_sub_ps(m1, v1); + const __m512 d2 = _mm512_sub_ps(m2, v2); + const __m512 d3 = _mm512_sub_ps(m3, v3); + const __m512 d4 = _mm512_sub_ps(m4, v4); + const __m512 d5 = _mm512_sub_ps(m5, v5); + const __m512 d6 = _mm512_sub_ps(m6, v6); + const __m512 d7 = _mm512_sub_ps(m7, v7); + + // compute squares of differences + __m512 distances = _mm512_mul_ps(d0, d0); + distances = _mm512_fmadd_ps(d1, d1, distances); + distances = _mm512_fmadd_ps(d2, d2, distances); + distances = _mm512_fmadd_ps(d3, d3, distances); + distances = _mm512_fmadd_ps(d4, d4, distances); + distances = _mm512_fmadd_ps(d5, d5, distances); + distances = _mm512_fmadd_ps(d6, d6, distances); + distances = _mm512_fmadd_ps(d7, d7, distances); + + // store + _mm512_storeu_ps(dis + i, distances); + + y += 128; // 16 floats * 8 rows + } + } + + if (i < ny) { + // process leftovers + __m256 x0 = _mm256_loadu_ps(x); + + for (; i < ny; i++) { + __m256 accu = AVX512ElementOpL2::op(x0, _mm256_loadu_ps(y)); + y += 8; + dis[i] = horizontal_sum(accu); + } + } +} + +template <> +void fvec_inner_products_ny( + float* ip, /* output inner product */ + const float* x, + const float* y, + size_t d, + size_t ny) { + fvec_inner_products_ny_ref(ip, x, y, d, ny); +} + +template <> +void fvec_L2sqr_ny( + float* dis, + const float* x, + const float* y, + size_t d, + size_t ny) { + fvec_L2sqr_ny_ref(dis, x, y, d, ny); +} + +template <> +size_t fvec_L2sqr_ny_nearest_D2( + float* distances_tmp_buffer, + const float* x, + const float* y, + size_t ny) { + // this implementation does not use distances_tmp_buffer. + + size_t i = 0; + float current_min_distance = HUGE_VALF; + size_t current_min_index = 0; + + const size_t ny16 = ny / 16; + if (ny16 > 0) { + _mm_prefetch((const char*)y, _MM_HINT_T0); + _mm_prefetch((const char*)(y + 32), _MM_HINT_T0); + + __m512 min_distances = _mm512_set1_ps(HUGE_VALF); + __m512i min_indices = _mm512_set1_epi32(0); + + __m512i current_indices = _mm512_setr_epi32( + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); + const __m512i indices_increment = _mm512_set1_epi32(16); + + const __m512 m0 = _mm512_set1_ps(x[0]); + const __m512 m1 = _mm512_set1_ps(x[1]); + + for (; i < ny16 * 16; i += 16) { + _mm_prefetch((const char*)(y + 64), _MM_HINT_T0); + + __m512 v0; + __m512 v1; + + transpose_16x2( + _mm512_loadu_ps(y + 0 * 16), + _mm512_loadu_ps(y + 1 * 16), + v0, + v1); + + const __m512 d0 = _mm512_sub_ps(m0, v0); + const __m512 d1 = _mm512_sub_ps(m1, v1); + + __m512 distances = _mm512_mul_ps(d0, d0); + distances = _mm512_fmadd_ps(d1, d1, distances); + + __mmask16 comparison = + _mm512_cmp_ps_mask(distances, min_distances, _CMP_LT_OS); + + min_distances = _mm512_min_ps(distances, min_distances); + min_indices = _mm512_mask_blend_epi32( + comparison, min_indices, current_indices); + + current_indices = + _mm512_add_epi32(current_indices, indices_increment); + + y += 32; + } + + alignas(64) float min_distances_scalar[16]; + alignas(64) uint32_t min_indices_scalar[16]; + _mm512_store_ps(min_distances_scalar, min_distances); + _mm512_store_epi32(min_indices_scalar, min_indices); + + for (size_t j = 0; j < 16; j++) { + if (current_min_distance > min_distances_scalar[j]) { + current_min_distance = min_distances_scalar[j]; + current_min_index = min_indices_scalar[j]; + } + } + } + + if (i < ny) { + float x0 = x[0]; + float x1 = x[1]; + + for (; i < ny; i++) { + float sub0 = x0 - y[0]; + float sub1 = x1 - y[1]; + float distance = sub0 * sub0 + sub1 * sub1; + + y += 2; + + if (current_min_distance > distance) { + current_min_distance = distance; + current_min_index = i; + } + } + } + + return current_min_index; +} + +template <> +size_t fvec_L2sqr_ny_nearest_D4( + float* distances_tmp_buffer, + const float* x, + const float* y, + size_t ny) { + // this implementation does not use distances_tmp_buffer. + + size_t i = 0; + float current_min_distance = HUGE_VALF; + size_t current_min_index = 0; + + const size_t ny16 = ny / 16; + + if (ny16 > 0) { + __m512 min_distances = _mm512_set1_ps(HUGE_VALF); + __m512i min_indices = _mm512_set1_epi32(0); + + __m512i current_indices = _mm512_setr_epi32( + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); + const __m512i indices_increment = _mm512_set1_epi32(16); + + const __m512 m0 = _mm512_set1_ps(x[0]); + const __m512 m1 = _mm512_set1_ps(x[1]); + const __m512 m2 = _mm512_set1_ps(x[2]); + const __m512 m3 = _mm512_set1_ps(x[3]); + + for (; i < ny16 * 16; i += 16) { + __m512 v0; + __m512 v1; + __m512 v2; + __m512 v3; + + transpose_16x4( + _mm512_loadu_ps(y + 0 * 16), + _mm512_loadu_ps(y + 1 * 16), + _mm512_loadu_ps(y + 2 * 16), + _mm512_loadu_ps(y + 3 * 16), + v0, + v1, + v2, + v3); + + const __m512 d0 = _mm512_sub_ps(m0, v0); + const __m512 d1 = _mm512_sub_ps(m1, v1); + const __m512 d2 = _mm512_sub_ps(m2, v2); + const __m512 d3 = _mm512_sub_ps(m3, v3); + + __m512 distances = _mm512_mul_ps(d0, d0); + distances = _mm512_fmadd_ps(d1, d1, distances); + distances = _mm512_fmadd_ps(d2, d2, distances); + distances = _mm512_fmadd_ps(d3, d3, distances); + + __mmask16 comparison = + _mm512_cmp_ps_mask(distances, min_distances, _CMP_LT_OS); + + min_distances = _mm512_min_ps(distances, min_distances); + min_indices = _mm512_mask_blend_epi32( + comparison, min_indices, current_indices); + + current_indices = + _mm512_add_epi32(current_indices, indices_increment); + + y += 64; + } + + alignas(64) float min_distances_scalar[16]; + alignas(64) uint32_t min_indices_scalar[16]; + _mm512_store_ps(min_distances_scalar, min_distances); + _mm512_store_epi32(min_indices_scalar, min_indices); + + for (size_t j = 0; j < 16; j++) { + if (current_min_distance > min_distances_scalar[j]) { + current_min_distance = min_distances_scalar[j]; + current_min_index = min_indices_scalar[j]; + } + } + } + + if (i < ny) { + __m128 x0 = _mm_loadu_ps(x); + + for (; i < ny; i++) { + __m128 accu = ElementOpL2::op(x0, _mm_loadu_ps(y)); + y += 4; + const float distance = horizontal_sum(accu); + + if (current_min_distance > distance) { + current_min_distance = distance; + current_min_index = i; + } + } + } + + return current_min_index; +} + +template <> +size_t fvec_L2sqr_ny_nearest_D8( + float* distances_tmp_buffer, + const float* x, + const float* y, + size_t ny) { + // this implementation does not use distances_tmp_buffer. + + size_t i = 0; + float current_min_distance = HUGE_VALF; + size_t current_min_index = 0; + + const size_t ny16 = ny / 16; + if (ny16 > 0) { + __m512 min_distances = _mm512_set1_ps(HUGE_VALF); + __m512i min_indices = _mm512_set1_epi32(0); + + __m512i current_indices = _mm512_setr_epi32( + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); + const __m512i indices_increment = _mm512_set1_epi32(16); + + const __m512 m0 = _mm512_set1_ps(x[0]); + const __m512 m1 = _mm512_set1_ps(x[1]); + const __m512 m2 = _mm512_set1_ps(x[2]); + const __m512 m3 = _mm512_set1_ps(x[3]); + + const __m512 m4 = _mm512_set1_ps(x[4]); + const __m512 m5 = _mm512_set1_ps(x[5]); + const __m512 m6 = _mm512_set1_ps(x[6]); + const __m512 m7 = _mm512_set1_ps(x[7]); + + for (; i < ny16 * 16; i += 16) { + __m512 v0; + __m512 v1; + __m512 v2; + __m512 v3; + __m512 v4; + __m512 v5; + __m512 v6; + __m512 v7; + + transpose_16x8( + _mm512_loadu_ps(y + 0 * 16), + _mm512_loadu_ps(y + 1 * 16), + _mm512_loadu_ps(y + 2 * 16), + _mm512_loadu_ps(y + 3 * 16), + _mm512_loadu_ps(y + 4 * 16), + _mm512_loadu_ps(y + 5 * 16), + _mm512_loadu_ps(y + 6 * 16), + _mm512_loadu_ps(y + 7 * 16), + v0, + v1, + v2, + v3, + v4, + v5, + v6, + v7); + + const __m512 d0 = _mm512_sub_ps(m0, v0); + const __m512 d1 = _mm512_sub_ps(m1, v1); + const __m512 d2 = _mm512_sub_ps(m2, v2); + const __m512 d3 = _mm512_sub_ps(m3, v3); + const __m512 d4 = _mm512_sub_ps(m4, v4); + const __m512 d5 = _mm512_sub_ps(m5, v5); + const __m512 d6 = _mm512_sub_ps(m6, v6); + const __m512 d7 = _mm512_sub_ps(m7, v7); + + __m512 distances = _mm512_mul_ps(d0, d0); + distances = _mm512_fmadd_ps(d1, d1, distances); + distances = _mm512_fmadd_ps(d2, d2, distances); + distances = _mm512_fmadd_ps(d3, d3, distances); + distances = _mm512_fmadd_ps(d4, d4, distances); + distances = _mm512_fmadd_ps(d5, d5, distances); + distances = _mm512_fmadd_ps(d6, d6, distances); + distances = _mm512_fmadd_ps(d7, d7, distances); + + __mmask16 comparison = + _mm512_cmp_ps_mask(distances, min_distances, _CMP_LT_OS); + + min_distances = _mm512_min_ps(distances, min_distances); + min_indices = _mm512_mask_blend_epi32( + comparison, min_indices, current_indices); + + current_indices = + _mm512_add_epi32(current_indices, indices_increment); + + y += 128; + } + + alignas(64) float min_distances_scalar[16]; + alignas(64) uint32_t min_indices_scalar[16]; + _mm512_store_ps(min_distances_scalar, min_distances); + _mm512_store_epi32(min_indices_scalar, min_indices); + + for (size_t j = 0; j < 16; j++) { + if (current_min_distance > min_distances_scalar[j]) { + current_min_distance = min_distances_scalar[j]; + current_min_index = min_indices_scalar[j]; + } + } + } + + if (i < ny) { + __m256 x0 = _mm256_loadu_ps(x); + + for (; i < ny; i++) { + __m256 accu = AVX512ElementOpL2::op(x0, _mm256_loadu_ps(y)); + y += 8; + const float distance = horizontal_sum(accu); + + if (current_min_distance > distance) { + current_min_distance = distance; + current_min_index = i; + } + } + } + + return current_min_index; +} + +template <> +size_t fvec_L2sqr_ny_nearest( + float* distances_tmp_buffer, + const float* x, + const float* y, + size_t d, + size_t ny) { + return fvec_L2sqr_ny_nearest_x86( + distances_tmp_buffer, + x, + y, + d, + ny, + &fvec_L2sqr_ny_nearest_D2, + &fvec_L2sqr_ny_nearest_D4, + &fvec_L2sqr_ny_nearest_D8); +} + +template <> +size_t fvec_L2sqr_ny_nearest_y_transposed( + float* distances_tmp_buffer, + const float* x, + const float* y, + const float* y_sqlen, + size_t d, + size_t d_offset, + size_t ny) { + return fvec_L2sqr_ny_nearest_y_transposed( + distances_tmp_buffer, x, y, y_sqlen, d, d_offset, ny); +} + +// TODO: Following functions are not used in the current codebase. Check AVX2 , +// respective implementation has been used +template +size_t fvec_L2sqr_ny_nearest_y_transposed_D( + float* distances_tmp_buffer, + const float* x, + const float* y, + const float* y_sqlen, + const size_t d_offset, + size_t ny) { + // This implementation does not use distances_tmp_buffer. + + // Current index being processed + size_t i = 0; + + // Min distance and the index of the closest vector so far + float current_min_distance = HUGE_VALF; + size_t current_min_index = 0; + + // Process 16 vectors per loop + const size_t ny16 = ny / 16; + + if (ny16 > 0) { + // Track min distance and the closest vector independently + // for each of 16 AVX-512 components. + __m512 min_distances = _mm512_set1_ps(HUGE_VALF); + __m512i min_indices = _mm512_set1_epi32(0); + + __m512i current_indices = _mm512_setr_epi32( + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); + const __m512i indices_increment = _mm512_set1_epi32(16); + + // m[i] = (2 * x[i], ... 2 * x[i]) + __m512 m[DIM]; + for (size_t j = 0; j < DIM; j++) { + m[j] = _mm512_set1_ps(x[j]); + m[j] = _mm512_add_ps(m[j], m[j]); + } + + for (; i < ny16 * 16; i += 16) { + // Compute dot products + const __m512 v0 = _mm512_loadu_ps(y + 0 * d_offset); + __m512 dp = _mm512_mul_ps(m[0], v0); + for (size_t j = 1; j < DIM; j++) { + const __m512 vj = _mm512_loadu_ps(y + j * d_offset); + dp = _mm512_fmadd_ps(m[j], vj, dp); + } + + // Compute y^2 - (2 * x, y), which is sufficient for looking for the + // lowest distance. + // x^2 is the constant that can be avoided. + const __m512 distances = + _mm512_sub_ps(_mm512_loadu_ps(y_sqlen), dp); + + // Compare the new distances to the min distances + __mmask16 comparison = + _mm512_cmp_ps_mask(min_distances, distances, _CMP_LT_OS); + + // Update min distances and indices with closest vectors if needed + min_distances = + _mm512_mask_blend_ps(comparison, distances, min_distances); + min_indices = _mm512_castps_si512(_mm512_mask_blend_ps( + comparison, + _mm512_castsi512_ps(current_indices), + _mm512_castsi512_ps(min_indices))); + + // Update current indices values. Basically, +16 to each of the 16 + // AVX-512 components. + current_indices = + _mm512_add_epi32(current_indices, indices_increment); + + // Scroll y and y_sqlen forward. + y += 16; + y_sqlen += 16; + } + + // Dump values and find the minimum distance / minimum index + float min_distances_scalar[16]; + uint32_t min_indices_scalar[16]; + _mm512_storeu_ps(min_distances_scalar, min_distances); + _mm512_storeu_si512((__m512i*)(min_indices_scalar), min_indices); + + for (size_t j = 0; j < 16; j++) { + if (current_min_distance > min_distances_scalar[j]) { + current_min_distance = min_distances_scalar[j]; + current_min_index = min_indices_scalar[j]; + } + } + } + + if (i < ny) { + // Process leftovers + for (; i < ny; i++) { + float dp = 0; + for (size_t j = 0; j < DIM; j++) { + dp += x[j] * y[j * d_offset]; + } + + // Compute y^2 - 2 * (x, y), which is sufficient for looking for the + // lowest distance. + const float distance = y_sqlen[0] - 2 * dp; + + if (current_min_distance > distance) { + current_min_distance = distance; + current_min_index = i; + } + + y += 1; + y_sqlen += 1; + } + } + + return current_min_index; +} + +template <> +int fvec_madd_and_argmin( + size_t n, + const float* a, + float bf, + const float* b, + float* c) { + return fvec_madd_and_argmin_sse(n, a, bf, b, c); +} + +} // namespace faiss diff --git a/faiss/utils/simd_impl/distances_sse-inl.h b/faiss/utils/simd_impl/distances_sse-inl.h new file mode 100644 index 0000000000..a5151750cb --- /dev/null +++ b/faiss/utils/simd_impl/distances_sse-inl.h @@ -0,0 +1,385 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +namespace faiss { + +[[maybe_unused]] static inline void fvec_madd_sse( + size_t n, + const float* a, + float bf, + const float* b, + float* c) { + n >>= 2; + __m128 bf4 = _mm_set_ps1(bf); + __m128* a4 = (__m128*)a; + __m128* b4 = (__m128*)b; + __m128* c4 = (__m128*)c; + + while (n--) { + *c4 = _mm_add_ps(*a4, _mm_mul_ps(bf4, *b4)); + b4++; + a4++; + c4++; + } +} + +/// helper function +inline float horizontal_sum(const __m128 v) { + // say, v is [x0, x1, x2, x3] + + // v0 is [x2, x3, ..., ...] + const __m128 v0 = _mm_shuffle_ps(v, v, _MM_SHUFFLE(0, 0, 3, 2)); + // v1 is [x0 + x2, x1 + x3, ..., ...] + const __m128 v1 = _mm_add_ps(v, v0); + // v2 is [x1 + x3, ..., .... ,...] + __m128 v2 = _mm_shuffle_ps(v1, v1, _MM_SHUFFLE(0, 0, 0, 1)); + // v3 is [x0 + x1 + x2 + x3, ..., ..., ...] + const __m128 v3 = _mm_add_ps(v1, v2); + // return v3[0] + return _mm_cvtss_f32(v3); +} + +/// Function that does a component-wise operation between x and y +/// to compute inner products +struct ElementOpIP { + static float op(float x, float y) { + return x * y; + } + + static __m128 op(__m128 x, __m128 y) { + return _mm_mul_ps(x, y); + } +}; + +/// Function that does a component-wise operation between x and y +/// to compute L2 distances. ElementOp can then be used in the fvec_op_ny +/// functions below +struct ElementOpL2 { + static float op(float x, float y) { + float tmp = x - y; + return tmp * tmp; + } + + static __m128 op(__m128 x, __m128 y) { + __m128 tmp = _mm_sub_ps(x, y); + return _mm_mul_ps(tmp, tmp); + } +}; + +template +void fvec_op_ny_D1(float* dis, const float* x, const float* y, size_t ny) { + float x0s = x[0]; + __m128 x0 = _mm_set_ps(x0s, x0s, x0s, x0s); + + size_t i; + for (i = 0; i + 3 < ny; i += 4) { + __m128 accu = ElementOp::op(x0, _mm_loadu_ps(y)); + y += 4; + dis[i] = _mm_cvtss_f32(accu); + __m128 tmp = _mm_shuffle_ps(accu, accu, 1); + dis[i + 1] = _mm_cvtss_f32(tmp); + tmp = _mm_shuffle_ps(accu, accu, 2); + dis[i + 2] = _mm_cvtss_f32(tmp); + tmp = _mm_shuffle_ps(accu, accu, 3); + dis[i + 3] = _mm_cvtss_f32(tmp); + } + while (i < ny) { // handle non-multiple-of-4 case + dis[i++] = ElementOp::op(x0s, *y++); + } +} + +template +void fvec_op_ny_D2(float* dis, const float* x, const float* y, size_t ny) { + __m128 x0 = _mm_set_ps(x[1], x[0], x[1], x[0]); + + size_t i; + for (i = 0; i + 1 < ny; i += 2) { + __m128 accu = ElementOp::op(x0, _mm_loadu_ps(y)); + y += 4; + accu = _mm_hadd_ps(accu, accu); + dis[i] = _mm_cvtss_f32(accu); + accu = _mm_shuffle_ps(accu, accu, 3); + dis[i + 1] = _mm_cvtss_f32(accu); + } + if (i < ny) { // handle odd case + dis[i] = ElementOp::op(x[0], y[0]) + ElementOp::op(x[1], y[1]); + } +} + +template +void fvec_op_ny_D4(float* dis, const float* x, const float* y, size_t ny) { + __m128 x0 = _mm_loadu_ps(x); + + for (size_t i = 0; i < ny; i++) { + __m128 accu = ElementOp::op(x0, _mm_loadu_ps(y)); + y += 4; + dis[i] = horizontal_sum(accu); + } +} + +template +void fvec_op_ny_D8(float* dis, const float* x, const float* y, size_t ny) { + __m128 x0 = _mm_loadu_ps(x); + __m128 x1 = _mm_loadu_ps(x + 4); + + for (size_t i = 0; i < ny; i++) { + __m128 accu = ElementOp::op(x0, _mm_loadu_ps(y)); + y += 4; + accu = _mm_add_ps(accu, ElementOp::op(x1, _mm_loadu_ps(y))); + y += 4; + accu = _mm_hadd_ps(accu, accu); + accu = _mm_hadd_ps(accu, accu); + dis[i] = _mm_cvtss_f32(accu); + } +} + +template +void fvec_op_ny_D12(float* dis, const float* x, const float* y, size_t ny) { + __m128 x0 = _mm_loadu_ps(x); + __m128 x1 = _mm_loadu_ps(x + 4); + __m128 x2 = _mm_loadu_ps(x + 8); + + for (size_t i = 0; i < ny; i++) { + __m128 accu = ElementOp::op(x0, _mm_loadu_ps(y)); + y += 4; + accu = _mm_add_ps(accu, ElementOp::op(x1, _mm_loadu_ps(y))); + y += 4; + accu = _mm_add_ps(accu, ElementOp::op(x2, _mm_loadu_ps(y))); + y += 4; + dis[i] = horizontal_sum(accu); + } +} + +template +void fvec_inner_products_ny_ref( + float* dis, + const float* x, + const float* y, + size_t d, + size_t ny) { +#define DISPATCH(dval) \ + case dval: \ + fvec_op_ny_D##dval(dis, x, y, ny); \ + return; + + switch (d) { + DISPATCH(1) + DISPATCH(2) + DISPATCH(4) + DISPATCH(8) + DISPATCH(12) + default: + fvec_inner_products_ny(dis, x, y, d, ny); + return; + } +#undef DISPATCH +} + +template +void fvec_L2sqr_ny_ref( + float* dis, + const float* x, + const float* y, + size_t d, + size_t ny) { + // optimized for a few special cases + +#define DISPATCH(dval) \ + case dval: \ + fvec_op_ny_D##dval(dis, x, y, ny); \ + return; + + switch (d) { + DISPATCH(1) + DISPATCH(2) + DISPATCH(4) + DISPATCH(8) + DISPATCH(12) + default: + fvec_L2sqr_ny(dis, x, y, d, ny); + return; + } +#undef DISPATCH +} + +template +size_t fvec_L2sqr_ny_nearest_D2( + float* distances_tmp_buffer, + const float* x, + const float* y, + size_t ny); + +template +size_t fvec_L2sqr_ny_nearest_D4( + float* distances_tmp_buffer, + const float* x, + const float* y, + size_t ny); + +template +size_t fvec_L2sqr_ny_nearest_D8( + float* distances_tmp_buffer, + const float* x, + const float* y, + size_t ny); + +template +size_t fvec_L2sqr_ny_nearest_x86( + float* distances_tmp_buffer, + const float* x, + const float* y, + size_t d, + size_t ny, + size_t (*fvec_L2sqr_ny_nearest_D2_func)( + float*, + const float*, + const float*, + size_t) = &fvec_L2sqr_ny_nearest_D2, + size_t (*fvec_L2sqr_ny_nearest_D4_func)( + float*, + const float*, + const float*, + size_t) = &fvec_L2sqr_ny_nearest_D4, + size_t (*fvec_L2sqr_ny_nearest_D8_func)( + float*, + const float*, + const float*, + size_t) = &fvec_L2sqr_ny_nearest_D8); + +template +size_t fvec_L2sqr_ny_nearest_x86( + float* distances_tmp_buffer, + const float* x, + const float* y, + size_t d, + size_t ny, + size_t (*fvec_L2sqr_ny_nearest_D2_func)( + float*, + const float*, + const float*, + size_t), + size_t (*fvec_L2sqr_ny_nearest_D4_func)( + float*, + const float*, + const float*, + size_t), + size_t (*fvec_L2sqr_ny_nearest_D8_func)( + float*, + const float*, + const float*, + size_t)) { + switch (d) { + case 2: + return fvec_L2sqr_ny_nearest_D2_func( + distances_tmp_buffer, x, y, ny); + case 4: + return fvec_L2sqr_ny_nearest_D4_func( + distances_tmp_buffer, x, y, ny); + case 8: + return fvec_L2sqr_ny_nearest_D8_func( + distances_tmp_buffer, x, y, ny); + } + + return fvec_L2sqr_ny_nearest( + distances_tmp_buffer, x, y, d, ny); +} + +template +inline size_t fvec_L2sqr_ny_nearest( + float* distances_tmp_buffer, + const float* x, + const float* y, + size_t d, + size_t ny); + +static inline int fvec_madd_and_argmin_sse_ref( + size_t n, + const float* a, + float bf, + const float* b, + float* c) { + n >>= 2; + __m128 bf4 = _mm_set_ps1(bf); + __m128 vmin4 = _mm_set_ps1(1e20); + __m128i imin4 = _mm_set1_epi32(-1); + __m128i idx4 = _mm_set_epi32(3, 2, 1, 0); + __m128i inc4 = _mm_set1_epi32(4); + __m128* a4 = (__m128*)a; + __m128* b4 = (__m128*)b; + __m128* c4 = (__m128*)c; + + while (n--) { + __m128 vc4 = _mm_add_ps(*a4, _mm_mul_ps(bf4, *b4)); + *c4 = vc4; + __m128i mask = _mm_castps_si128(_mm_cmpgt_ps(vmin4, vc4)); + // imin4 = _mm_blendv_epi8 (imin4, idx4, mask); // slower! + + imin4 = _mm_or_si128( + _mm_and_si128(mask, idx4), _mm_andnot_si128(mask, imin4)); + vmin4 = _mm_min_ps(vmin4, vc4); + b4++; + a4++; + c4++; + idx4 = _mm_add_epi32(idx4, inc4); + } + + // 4 values -> 2 + { + idx4 = _mm_shuffle_epi32(imin4, 3 << 2 | 2); + __m128 vc4 = _mm_shuffle_ps(vmin4, vmin4, 3 << 2 | 2); + __m128i mask = _mm_castps_si128(_mm_cmpgt_ps(vmin4, vc4)); + imin4 = _mm_or_si128( + _mm_and_si128(mask, idx4), _mm_andnot_si128(mask, imin4)); + vmin4 = _mm_min_ps(vmin4, vc4); + } + // 2 values -> 1 + { + idx4 = _mm_shuffle_epi32(imin4, 1); + __m128 vc4 = _mm_shuffle_ps(vmin4, vmin4, 1); + __m128i mask = _mm_castps_si128(_mm_cmpgt_ps(vmin4, vc4)); + imin4 = _mm_or_si128( + _mm_and_si128(mask, idx4), _mm_andnot_si128(mask, imin4)); + // vmin4 = _mm_min_ps (vmin4, vc4); + } + return _mm_cvtsi128_si32(imin4); +} + +static inline int fvec_madd_and_argmin_sse( + size_t n, + const float* a, + float bf, + const float* b, + float* c) { + if ((n & 3) == 0 && ((((long)a) | ((long)b) | ((long)c)) & 15) == 0) + return fvec_madd_and_argmin_sse_ref(n, a, bf, b, c); + + return fvec_madd_and_argmin(n, a, bf, b, c); +} + +// reads 0 <= d < 4 floats as __m128 +static inline __m128 masked_read(int d, const float* x) { + assert(0 <= d && d < 4); + ALIGNED(16) float buf[4] = {0, 0, 0, 0}; + switch (d) { + case 3: + buf[2] = x[2]; + [[fallthrough]]; + case 2: + buf[1] = x[1]; + [[fallthrough]]; + case 1: + buf[0] = x[0]; + } + return _mm_load_ps(buf); + // cannot use AVX2 _mm_mask_set1_epi32 +} + +} // namespace faiss diff --git a/faiss/utils/simd_levels.cpp b/faiss/utils/simd_levels.cpp new file mode 100644 index 0000000000..3f1769b289 --- /dev/null +++ b/faiss/utils/simd_levels.cpp @@ -0,0 +1,171 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include + +namespace faiss { + +SIMDLevel SIMDConfig::level = SIMDLevel::NONE; +std::unordered_set& SIMDConfig::supported_simd_levels() { + static std::unordered_set levels; + return levels; +} + +// it is there to make sure the constructor runs +static SIMDConfig dummy_config; + +SIMDConfig::SIMDConfig(const char** faiss_simd_level_env) { + // added to support dependency injection + const char* env_var = faiss_simd_level_env ? *faiss_simd_level_env + : getenv("FAISS_SIMD_LEVEL"); + + // check environment variable for SIMD level is explicitly set + if (!env_var) { + level = auto_detect_simd_level(); + } else { + auto matched_level = to_simd_level(env_var); + if (matched_level.has_value()) { + set_level(matched_level.value()); + supported_simd_levels().clear(); + supported_simd_levels().insert(matched_level.value()); + } else { + fprintf(stderr, + "FAISS_SIMD_LEVEL is set to %s, which is unknown\n", + env_var); + exit(1); + } + } + supported_simd_levels().insert(SIMDLevel::NONE); +} + +void SIMDConfig::set_level(SIMDLevel l) { + level = l; +} + +SIMDLevel SIMDConfig::get_level() { + return level; +} + +std::string SIMDConfig::get_level_name() { + return to_string(level).value_or(""); +} + +bool SIMDConfig::is_simd_level_available(SIMDLevel l) { + return supported_simd_levels().find(l) != supported_simd_levels().end(); +} + +SIMDLevel SIMDConfig::auto_detect_simd_level() { + SIMDLevel level = SIMDLevel::NONE; + +#if defined(__x86_64__) && \ + (defined(COMPILE_SIMD_AVX2) || defined(COMPILE_SIMD_AVX512)) + unsigned int eax, ebx, ecx, edx; + + eax = 1; + ecx = 0; + asm volatile("cpuid" + : "=a"(eax), "=b"(ebx), "=c"(ecx), "=d"(edx) + : "a"(eax), "c"(ecx)); + + bool has_avx = (ecx & (1 << 28)) != 0; + + bool has_xsave_osxsave = + (ecx & ((1 << 26) | (1 << 27))) == ((1 << 26) | (1 << 27)); + + bool avx_supported = false; + if (has_avx && has_xsave_osxsave) { + unsigned int xcr0; + asm volatile("xgetbv" : "=a"(xcr0), "=d"(edx) : "c"(0)); + avx_supported = (xcr0 & 6) == 6; + } + + if (avx_supported) { + eax = 7; + ecx = 0; + asm volatile("cpuid" + : "=a"(eax), "=b"(ebx), "=c"(ecx), "=d"(edx) + : "a"(eax), "c"(ecx)); + + unsigned int xcr0; + asm volatile("xgetbv" : "=a"(xcr0), "=d"(edx) : "c"(0)); + +#if defined(COMPILE_SIMD_AVX2) || defined(COMPILE_SIMD_AVX512) + bool has_avx2 = (ebx & (1 << 5)) != 0; + if (has_avx2) { + SIMDConfig::supported_simd_levels().insert(SIMDLevel::AVX2); + level = SIMDLevel::AVX2; + } + +#if defined(COMPILE_SIMD_AVX512) + bool cpu_has_avx512f = (ebx & (1 << 16)) != 0; + bool os_supports_avx512 = (xcr0 & 0xE0) == 0xE0; + bool has_avx512f = cpu_has_avx512f && os_supports_avx512; + if (has_avx512f) { + bool has_avx512cd = (ebx & (1 << 28)) != 0; + bool has_avx512vl = (ebx & (1 << 31)) != 0; + bool has_avx512dq = (ebx & (1 << 17)) != 0; + bool has_avx512bw = (ebx & (1 << 30)) != 0; + if (has_avx512bw && has_avx512cd && has_avx512vl && has_avx512dq) { + level = SIMDLevel::AVX512; + supported_simd_levels().insert(SIMDLevel::AVX512); + } + } +#endif // defined(COMPILE_SIMD_AVX512) +#endif // defined(COMPILE_SIMD_AVX2)|| defined(COMPILE_SIMD_AVX512) + } +#endif // defined(__x86_64__) && (defined(COMPILE_SIMD_AVX2) || + // defined(COMPILE_SIMD_AVX512)) + +#if defined(__aarch64__) && defined(__ARM_NEON) && \ + defined(COMPILE_SIMD_ARM_NEON) + // ARM NEON is standard on aarch64 + supported_simd_levels().insert(SIMDLevel::ARM_NEON); + level = SIMDLevel::ARM_NEON; + // TODO: Add ARM SVE detection when needed + // For now, we default to ARM_NEON as it's universally supported on aarch64 +#endif + + return level; +} + +std::optional to_string(SIMDLevel level) { + switch (level) { + case SIMDLevel::NONE: + return "NONE"; + case SIMDLevel::AVX2: + return "AVX2"; + case SIMDLevel::AVX512: + return "AVX512"; + case SIMDLevel::ARM_NEON: + return "ARM_NEON"; + default: + return std::nullopt; + } + return std::nullopt; +} + +std::optional to_simd_level(const std::string& level_str) { + if (level_str == "NONE") { + return SIMDLevel::NONE; + } + if (level_str == "AVX2") { + return SIMDLevel::AVX2; + } + if (level_str == "AVX512") { + return SIMDLevel::AVX512; + } + if (level_str == "ARM_NEON") { + return SIMDLevel::ARM_NEON; + } + + return std::nullopt; +} + +} // namespace faiss diff --git a/faiss/utils/simd_levels.h b/faiss/utils/simd_levels.h new file mode 100644 index 0000000000..95b2decc0b --- /dev/null +++ b/faiss/utils/simd_levels.h @@ -0,0 +1,82 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include + +namespace faiss { + +#define COMPILE_SIMD_NONE + +enum class SIMDLevel { + NONE, + // x86 + AVX2, + AVX512, + // arm & aarch64 + ARM_NEON, + + COUNT +}; + +std::optional to_string(SIMDLevel level); + +std::optional to_simd_level(const std::string& level_str); + +/* Current SIMD configuration. This static class manages the current SIMD level + * and intializes it from the cpuid and the FAISS_SIMD_LEVEL + * environment variable */ +struct SIMDConfig { + static SIMDLevel level; + static std::unordered_set& supported_simd_levels(); + + typedef SIMDLevel (*DetectSIMDLevelFunc)(); + static SIMDLevel auto_detect_simd_level(); + + SIMDConfig(const char** faiss_simd_level_env = nullptr); + + static void set_level(SIMDLevel level); + static SIMDLevel get_level(); + static std::string get_level_name(); + + static bool is_simd_level_available(SIMDLevel level); +}; + +/*********************** x86 SIMD */ + +#ifdef COMPILE_SIMD_AVX2 +#define DISPATCH_SIMDLevel_AVX2(f, ...) \ + case SIMDLevel::AVX2: \ + return f(__VA_ARGS__) +#else +#define DISPATCH_SIMDLevel_AVX2(f, ...) +#endif + +#ifdef COMPILE_SIMD_AVX512 +#define DISPATCH_SIMDLevel_AVX512(f, ...) \ + case SIMDLevel::AVX512: \ + return f(__VA_ARGS__) +#else +#define DISPATCH_SIMDLevel_AVX512(f, ...) +#endif + +/* dispatch function f to f */ + +#define DISPATCH_SIMDLevel(f, ...) \ + switch (SIMDConfig::level) { \ + case SIMDLevel::NONE: \ + return f(__VA_ARGS__); \ + DISPATCH_SIMDLevel_AVX2(f, __VA_ARGS__); \ + DISPATCH_SIMDLevel_AVX512(f, __VA_ARGS__); \ + default: \ + FAISS_ASSERT(!"Invalid SIMD level"); \ + } + +} // namespace faiss diff --git a/tests/test_distances_simd.cpp b/tests/test_distances_simd.cpp index 539fe2a419..dda33c3e72 100644 --- a/tests/test_distances_simd.cpp +++ b/tests/test_distances_simd.cpp @@ -39,104 +39,352 @@ void fvec_L2sqr_ny_ref( } } -// test templated versions of fvec_L2sqr_ny -TEST(TestFvecL2sqrNy, D2) { - // we're using int values in order to get 100% accurate - // results with floats. - std::default_random_engine rng(123); - std::uniform_int_distribution u(0, 32); +void remove_simd_level_if_exists( + std::unordered_set& levels, + faiss::SIMDLevel level) { + std::erase_if( + levels, [level](faiss::SIMDLevel elem) { return elem == level; }); +} - for (const auto dim : {2, 4, 8, 12}) { - std::vector x(dim, 0); - for (size_t i = 0; i < x.size(); i++) { - x[i] = u(rng); +class DistancesSIMDTest : public ::testing::TestWithParam { + protected: + void SetUp() override { + original_simd_level = faiss::SIMDConfig::get_level(); + std::iota(dims.begin(), dims.end(), 1); + + ntests = 4; + + simd_level = GetParam(); + faiss::SIMDConfig::set_level(simd_level); + + EXPECT_EQ(faiss::SIMDConfig::get_level(), simd_level); + + rng = std::default_random_engine(123); + uniform = std::uniform_int_distribution(0, 32); + } + + void TearDown() override { + faiss::SIMDConfig::set_level(original_simd_level); + } + + std::tuple, std::vector>> + SetupTestData(int dims, int ny) { + std::vector x(dims); + std::vector> y(ny, std::vector(dims)); + + for (size_t i = 0; i < dims; i++) { + x[i] = uniform(rng); + for (size_t j = 0; j < ny; j++) { + y[j][i] = uniform(rng); + } + } + return std::make_tuple(x, y); + } + + std::vector flatten_2d_vector( + const std::vector>& v) { + std::vector flat_v; + for (const auto& vec : v) { + flat_v.insert(flat_v.end(), vec.begin(), vec.end()); + } + return flat_v; + } + + faiss::SIMDLevel simd_level = faiss::SIMDLevel::NONE; + faiss::SIMDLevel original_simd_level = faiss::SIMDLevel::NONE; + std::default_random_engine rng; + std::uniform_int_distribution uniform; + + std::vector dims = {128}; + int ntests = 1; +}; + +TEST_P(DistancesSIMDTest, LinfDistance_chebyshev_distance) { + for (int i = 0; i < ntests; ++i) { // repeat tests + for (const auto dim : dims) { // test different dimensions + int ny = 1; + auto [x, y] = SetupTestData(dim, ny); + for (int k = 0; k < ny; ++k) { // test different vectors + float distance = faiss::fvec_Linf(x.data(), y[k].data(), dim); + float ref_distance = 0; + + for (int j = 0; j < dim; ++j) { + ref_distance = + std::max(ref_distance, std::abs(x[j] - y[k][j])); + } + ASSERT_EQ(distance, ref_distance); + } } + } +} - for (const auto nrows : {1, 2, 5, 10, 15, 20, 25}) { - std::vector y(nrows * dim); - for (size_t i = 0; i < y.size(); i++) { - y[i] = u(rng); +TEST_P(DistancesSIMDTest, inner_product_batch_4) { + for (int i = 0; i < ntests; ++i) { + int dim = 128; + int ny = 4; + auto [x, y] = SetupTestData(dim, ny); + + std::vector true_distances(ny, 0.F); + for (int j = 0; j < ny; ++j) { + for (int k = 0; k < dim; ++k) { + true_distances[j] += x[k] * y[j][k]; } + } - std::vector distances(nrows, 0); - faiss::fvec_L2sqr_ny( - distances.data(), x.data(), y.data(), dim, nrows); + std::vector actual_distances(ny, 0.F); + faiss::fvec_inner_product_batch_4( + x.data(), + y[0].data(), + y[1].data(), + y[2].data(), + y[3].data(), + dim, + actual_distances[0], + actual_distances[1], + actual_distances[2], + actual_distances[3]); + + ASSERT_EQ(actual_distances, true_distances) + << "Mismatching fvec_inner_product_batch4 results for test = " + << i; + } +} + +TEST_P(DistancesSIMDTest, fvec_L2sqr) { + for (int i = 0; i < ntests; ++i) { + int ny = 1; + for (const auto dim : dims) { + auto [x, y] = SetupTestData(dim, ny); + float true_distance = 0.F; + for (int k = 0; k < dim; ++k) { + const float tmp = x[k] - y[0][k]; + true_distance += tmp * tmp; + } - std::vector distances_ref(nrows, 0); - fvec_L2sqr_ny_ref( - distances_ref.data(), x.data(), y.data(), dim, nrows); + float actual_distance = + faiss::fvec_L2sqr(x.data(), y[0].data(), dim); - ASSERT_EQ(distances, distances_ref) - << "Mismatching results for dim = " << dim - << ", nrows = " << nrows; + ASSERT_EQ(actual_distance, true_distance) + << "Mismatching fvec_L2sqr results for test = " << i; } } } -// fvec_inner_products_ny -TEST(TestFvecInnerProductsNy, D2) { - // we're using int values in order to get 100% accurate - // results with floats. - std::default_random_engine rng(123); - std::uniform_int_distribution u(0, 32); +TEST_P(DistancesSIMDTest, L2sqr_batch_4) { + for (int i = 0; i < ntests; ++i) { + int dim = 128; + int ny = 4; + auto [x, y] = SetupTestData(dim, ny); + + std::vector true_distances(ny, 0.F); + for (int j = 0; j < ny; ++j) { + for (int k = 0; k < dim; ++k) { + const float tmp = x[k] - y[j][k]; + true_distances[j] += tmp * tmp; + } + } + + std::vector actual_distances(ny, 0.F); + faiss::fvec_L2sqr_batch_4( + x.data(), + y[0].data(), + y[1].data(), + y[2].data(), + y[3].data(), + dim, + actual_distances[0], + actual_distances[1], + actual_distances[2], + actual_distances[3]); + + ASSERT_EQ(actual_distances, true_distances) + << "Mismatching fvec_L2sqr_batch_4 results for test = " << i; + } +} +TEST_P(DistancesSIMDTest, fvec_L2sqr_ny) { for (const auto dim : {2, 4, 8, 12}) { - std::vector x(dim, 0); - for (size_t i = 0; i < x.size(); i++) { - x[i] = u(rng); - } + for (const auto ny : {1, 2, 5, 10, 15, 20, 25}) { + auto [x, y] = SetupTestData(dim, ny); + + std::vector actual_distances(ny, 0.F); - for (const auto nrows : {1, 2, 5, 10, 15, 20, 25}) { - std::vector y(nrows * dim); - for (size_t i = 0; i < y.size(); i++) { - y[i] = u(rng); + std::vector flat_y; + for (auto y_ : y) { + flat_y.insert(flat_y.end(), y_.begin(), y_.end()); } - std::vector distances(nrows, 0); + std::vector true_distances(ny, 0.F); + for (int i = 0; i < ny; ++i) { + for (int k = 0; k < dim; ++k) { + const float tmp = x[k] - y[i][k]; + true_distances[i] += tmp * tmp; + } + } + + faiss::fvec_L2sqr_ny( + actual_distances.data(), x.data(), flat_y.data(), dim, ny); + + ASSERT_EQ(actual_distances, true_distances) + << "Mismatching fvec_L2sqr_ny results for dim = " << dim + << ", ny = " << ny; + } + } +} + +TEST_P(DistancesSIMDTest, fvec_inner_products_ny) { + for (const auto dim : {2, 4, 8, 12}) { + for (const auto ny : {1, 2, 5, 10, 15, 20, 25}) { + auto [x, y] = SetupTestData(dim, ny); + auto flat_y = flatten_2d_vector(y); + + std::vector actual_distances(ny, 0.F); faiss::fvec_inner_products_ny( - distances.data(), x.data(), y.data(), dim, nrows); + actual_distances.data(), x.data(), flat_y.data(), dim, ny); - std::vector distances_ref(nrows, 0); - fvec_inner_products_ny_ref( - distances_ref.data(), x.data(), y.data(), dim, nrows); + std::vector true_distances(ny, 0.F); + for (int i = 0; i < ny; ++i) { + for (int k = 0; k < dim; ++k) { + true_distances[i] += x[k] * y[i][k]; + } + } - ASSERT_EQ(distances, distances_ref) - << "Mismatching results for dim = " << dim - << ", nrows = " << nrows; + ASSERT_EQ(actual_distances, true_distances) + << "Mismatching fvec_inner_products_ny results for dim = " + << dim << ", ny = " << ny; } } } -TEST(TestFvecL2sqr, distances_L2_squared_y_transposed) { - // ints instead of floats for 100% accuracy +TEST_P(DistancesSIMDTest, L2SqrNYNearest) { std::default_random_engine rng(123); std::uniform_int_distribution uniform(0, 32); + int dim = 128; + int ny = 11; + + auto [x, y] = SetupTestData(dim, ny); + auto flat_y = flatten_2d_vector(y); + + std::vector true_tmp_buffer_distances(ny, 0.F); + for (int i = 0; i < ny; ++i) { + for (int k = 0; k < dim; ++k) { + const float tmp = x[k] - y[i][k]; + true_tmp_buffer_distances[i] += tmp * tmp; + } + } + + size_t true_nearest_idx = 0; + float min_dis = HUGE_VALF; + + for (size_t i = 0; i < ny; i++) { + if (true_tmp_buffer_distances[i] < min_dis) { + min_dis = true_tmp_buffer_distances[i]; + true_nearest_idx = i; + } + } + + std::vector actual_distances(ny); + auto actual_nearest_index = faiss::fvec_L2sqr_ny_nearest( + actual_distances.data(), x.data(), flat_y.data(), dim, ny); + + EXPECT_EQ(actual_nearest_index, true_nearest_idx); +} + +TEST_P(DistancesSIMDTest, multiple_add) { + // modulo 8 results - 16 is to repeat the while loop in the function + for (const auto dim : {8, 9, 10, 11, 12, 13, 14, 15, 16}) { + auto [x, y] = SetupTestData(dim, 1); + const float bf = uniform(rng); + std::vector true_distances(dim); + for (size_t i = 0; i < x.size(); i++) { + true_distances[i] = x[i] + bf * y[0][i]; + } + + std::vector actual_distances(dim); + faiss::fvec_madd( + x.size(), x.data(), bf, y[0].data(), actual_distances.data()); + + ASSERT_EQ(actual_distances, true_distances) + << "Mismatching fvec_madd results for nrows = " << dim; + } +} + +TEST_P(DistancesSIMDTest, manhattan_distance) { + // modulo 8 results - 16 is to repeat the while loop in the function + for (const auto dim : {8, 9, 10, 11, 12, 13, 14, 15, 16}) { + auto [x, y] = SetupTestData(dim, 1); + float true_distance = 0; + for (size_t i = 0; i < x.size(); i++) { + true_distance += std::abs(x[i] - y[0][i]); + } + + auto actual_distances = faiss::fvec_L1(x.data(), y[0].data(), x.size()); + + ASSERT_EQ(actual_distances, true_distance) + << "Mismatching fvec_Linf results for nrows = " << dim; + } +} + +TEST_P(DistancesSIMDTest, add_value) { + for (const auto dim : {1, 2, 5, 10, 15, 20, 25}) { + auto [x, y] = SetupTestData(dim, 1); + const float b = uniform(rng); // value to add + std::vector true_distances(dim); + for (size_t i = 0; i < x.size(); i++) { + true_distances[i] = x[i] + b; + } + + std::vector actual_distances(dim); + faiss::fvec_add(x.size(), x.data(), b, actual_distances.data()); + + ASSERT_EQ(actual_distances, true_distances) + << "Mismatching array-value fvec_add results for nrows = " + << dim; + } +} + +TEST_P(DistancesSIMDTest, add_array) { + for (const auto dim : {1, 2, 5, 10, 15, 20, 25}) { + auto [x, y] = SetupTestData(dim, 1); + std::vector true_distances(dim); + for (size_t i = 0; i < x.size(); i++) { + true_distances[i] = x[i] + y[0][i]; + } + + std::vector actual_distances(dim); + faiss::fvec_add( + x.size(), x.data(), y[0].data(), actual_distances.data()); + + ASSERT_EQ(actual_distances, true_distances) + << "Mismatching array-array fvec_add results for nrows = " + << dim; + } +} + +TEST_P(DistancesSIMDTest, distances_L2_squared_y_transposed) { // modulo 8 results - 16 is to repeat the loop in the function int ny = 11; // this value will hit all the codepaths for (const auto d : {1, 2, 3, 4, 5, 6, 7, 8, 16}) { - // initialize inputs - std::vector x(d); + auto [x, y] = SetupTestData(d, ny); float x_sqlen = 0; - for (size_t i = 0; i < x.size(); i++) { - x[i] = uniform(rng); + for (size_t i = 0; i < d; ++i) { x_sqlen += x[i] * x[i]; } - std::vector y(d * ny); + auto flat_y = flatten_2d_vector(y); std::vector y_sqlens(ny, 0); - for (size_t i = 0; i < ny; i++) { - for (size_t j = 0; j < y.size(); j++) { - y[j] = uniform(rng); - y_sqlens[i] += y[j] * y[j]; + for (size_t i = 0; i < ny; ++i) { + for (size_t j = 0; j < d; ++j) { + y_sqlens[i] += flat_y[j] * flat_y[j]; } } // perform function std::vector true_distances(ny, 0); - for (size_t i = 0; i < ny; i++) { + for (size_t i = 0; i < ny; ++i) { float dp = 0; - for (size_t j = 0; j < d; j++) { - dp += x[j] * y[i + j * ny]; + for (size_t j = 0; j < d; ++j) { + dp += x[j] * flat_y[i + j * ny]; } true_distances[i] = x_sqlen + y_sqlens[i] - 2 * dp; } @@ -145,7 +393,7 @@ TEST(TestFvecL2sqr, distances_L2_squared_y_transposed) { faiss::fvec_L2sqr_ny_transposed( distances.data(), x.data(), - y.data(), + flat_y.data(), y_sqlens.data(), d, ny, // no need for special offset to test all lines of code @@ -156,39 +404,34 @@ TEST(TestFvecL2sqr, distances_L2_squared_y_transposed) { } } -TEST(TestFvecL2sqr, nearest_L2_squared_y_transposed) { - // ints instead of floats for 100% accuracy - std::default_random_engine rng(123); - std::uniform_int_distribution uniform(0, 32); - +TEST_P(DistancesSIMDTest, nearest_L2_squared_y_transposed) { // modulo 8 results - 16 is to repeat the loop in the function int ny = 11; // this value will hit all the codepaths - for (const auto d : {1, 2, 3, 4, 5, 6, 7, 8, 16}) { - // initialize inputs - std::vector x(d); - float x_sqlen = 0; - for (size_t i = 0; i < x.size(); i++) { - x[i] = uniform(rng); + for (const auto dim : {1, 2, 3, 4, 5, 6, 7, 8, 16}) { + auto [x, y] = SetupTestData(dim, ny); + float x_sqlen = 0.F; + for (size_t i = 0; i < dim; i++) { x_sqlen += x[i] * x[i]; } - std::vector y(d * ny); + + auto flat_y = flatten_2d_vector(y); std::vector y_sqlens(ny, 0); + for (size_t i = 0; i < ny; i++) { - for (size_t j = 0; j < y.size(); j++) { - y[j] = uniform(rng); - y_sqlens[i] += y[j] * y[j]; + for (size_t j = 0; j < dim; j++) { + y_sqlens[i] += y[i][j] * y[i][j]; } } - // get distances std::vector distances(ny, 0); for (size_t i = 0; i < ny; i++) { float dp = 0; - for (size_t j = 0; j < d; j++) { - dp += x[j] * y[i + j * ny]; + for (size_t j = 0; j < dim; j++) { + dp += x[j] * flat_y[i + j * ny]; } distances[i] = x_sqlen + y_sqlens[i] - 2 * dp; } + // find nearest size_t true_nearest_idx = 0; float min_dis = HUGE_VALF; @@ -200,135 +443,42 @@ TEST(TestFvecL2sqr, nearest_L2_squared_y_transposed) { } std::vector buffer(ny); - size_t nearest_idx = faiss::fvec_L2sqr_ny_nearest_y_transposed( + size_t actual_nearest_idx = faiss::fvec_L2sqr_ny_nearest_y_transposed( buffer.data(), x.data(), - y.data(), + flat_y.data(), y_sqlens.data(), - d, + dim, ny, // no need for special offset to test all lines of code ny); - ASSERT_EQ(nearest_idx, true_nearest_idx) + ASSERT_EQ(actual_nearest_idx, true_nearest_idx) << "Mismatching fvec_L2sqr_ny_nearest_y_transposed results for d = " - << d; + << dim; } } -TEST(TestFvecL1, manhattan_distance) { - // ints instead of floats for 100% accuracy - std::default_random_engine rng(123); - std::uniform_int_distribution uniform(0, 32); +std::vector GetSupportedSIMDLevels() { + std::vector supported_levels = {faiss::SIMDLevel::NONE}; - // modulo 8 results - 16 is to repeat the while loop in the function - for (const auto nrows : {8, 9, 10, 11, 12, 13, 14, 15, 16}) { - std::vector x(nrows); - std::vector y(nrows); - float true_distance = 0; - for (size_t i = 0; i < x.size(); i++) { - x[i] = uniform(rng); - y[i] = uniform(rng); - true_distance += std::abs(x[i] - y[i]); + for (int level = static_cast(faiss::SIMDLevel::NONE) + 1; + level < static_cast(faiss::SIMDLevel::COUNT); + level++) { + faiss::SIMDLevel simd_level = static_cast(level); + if (faiss::SIMDConfig::is_simd_level_available(simd_level)) { + supported_levels.push_back(simd_level); } - - auto distance = faiss::fvec_L1(x.data(), y.data(), x.size()); - - ASSERT_EQ(distance, true_distance) - << "Mismatching fvec_Linf results for nrows = " << nrows; } -} -TEST(TestFvecLinf, chebyshev_distance) { - // ints instead of floats for 100% accuracy - std::default_random_engine rng(123); - std::uniform_int_distribution uniform(0, 32); + EXPECT_TRUE(supported_levels.size() > 0); - // modulo 8 results - 16 is to repeat the while loop in the function - for (const auto nrows : {8, 9, 10, 11, 12, 13, 14, 15, 16}) { - std::vector x(nrows); - std::vector y(nrows); - float true_distance = 0; - for (size_t i = 0; i < x.size(); i++) { - x[i] = uniform(rng); - y[i] = uniform(rng); - true_distance = std::max(true_distance, std::abs(x[i] - y[i])); - } - - auto distance = faiss::fvec_Linf(x.data(), y.data(), x.size()); - - ASSERT_EQ(distance, true_distance) - << "Mismatching fvec_Linf results for nrows = " << nrows; - } + return std::vector( + supported_levels.begin(), supported_levels.end()); } -TEST(TestFvecMadd, multiple_add) { - // ints instead of floats for 100% accuracy - std::default_random_engine rng(123); - std::uniform_int_distribution uniform(0, 32); - - // modulo 8 results - 16 is to repeat the while loop in the function - for (const auto nrows : {8, 9, 10, 11, 12, 13, 14, 15, 16}) { - std::vector a(nrows); - std::vector b(nrows); - const float bf = uniform(rng); - std::vector true_distances(nrows); - for (size_t i = 0; i < a.size(); i++) { - a[i] = uniform(rng); - b[i] = uniform(rng); - true_distances[i] = a[i] + bf * b[i]; - } - - std::vector distances(nrows); - faiss::fvec_madd(a.size(), a.data(), bf, b.data(), distances.data()); - - ASSERT_EQ(distances, true_distances) - << "Mismatching fvec_madd results for nrows = " << nrows; - } +::testing::internal::ParamGenerator SupportedSIMDLevels() { + std::vector levels = GetSupportedSIMDLevels(); + return ::testing::ValuesIn(levels); } -TEST(TestFvecAdd, add_array) { - // ints instead of floats for 100% accuracy - std::default_random_engine rng(123); - std::uniform_int_distribution uniform(0, 32); - - for (const auto nrows : {1, 2, 5, 10, 15, 20, 25}) { - std::vector a(nrows); - std::vector b(nrows); - std::vector true_distances(nrows); - for (size_t i = 0; i < a.size(); i++) { - a[i] = uniform(rng); - b[i] = uniform(rng); - true_distances[i] = a[i] + b[i]; - } - - std::vector distances(nrows); - faiss::fvec_add(a.size(), a.data(), b.data(), distances.data()); - - ASSERT_EQ(distances, true_distances) - << "Mismatching array-array fvec_add results for nrows = " - << nrows; - } -} - -TEST(TestFvecAdd, add_value) { - // ints instead of floats for 100% accuracy - std::default_random_engine rng(123); - std::uniform_int_distribution uniform(0, 32); - - for (const auto nrows : {1, 2, 5, 10, 15, 20, 25}) { - std::vector a(nrows); - const float b = uniform(rng); // value to add - std::vector true_distances(nrows); - for (size_t i = 0; i < a.size(); i++) { - a[i] = uniform(rng); - true_distances[i] = a[i] + b; - } - - std::vector distances(nrows); - faiss::fvec_add(a.size(), a.data(), b, distances.data()); - - ASSERT_EQ(distances, true_distances) - << "Mismatching array-value fvec_add results for nrows = " - << nrows; - } -} +INSTANTIATE_TEST_SUITE_P(SIMDLevels, DistancesSIMDTest, SupportedSIMDLevels()); diff --git a/tests/test_simd_levels.cpp b/tests/test_simd_levels.cpp new file mode 100644 index 0000000000..64da6e77b9 --- /dev/null +++ b/tests/test_simd_levels.cpp @@ -0,0 +1,268 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#ifdef __x86_64__ +#include +#endif + +#include + +#ifdef __x86_64__ +bool run_avx2_computation() { +#if defined(__AVX2__) + alignas(32) int result[8]; + alignas(32) int input1[8] = {1, 2, 3, 4, 5, 6, 7, 8}; + alignas(32) int input2[8] = {8, 7, 6, 5, 4, 3, 2, 1}; + + __m256i vec1 = _mm256_load_si256(reinterpret_cast<__m256i*>(input1)); + __m256i vec2 = _mm256_load_si256(reinterpret_cast<__m256i*>(input2)); + __m256i vec_result = _mm256_add_epi32(vec1, vec2); + _mm256_store_si256(reinterpret_cast<__m256i*>(result), vec_result); + + return true; +#else + return false; +#endif // __AVX2__ +} + +bool run_avx512f_computation() { +#ifdef __AVX512F__ + alignas(64) long long result[8]; + alignas(64) long long input1[8] = {1, 2, 3, 4, 5, 6, 7, 8}; + alignas(64) long long input2[8] = {8, 7, 6, 5, 4, 3, 2, 1}; + + __m512i vec1 = _mm512_load_si512(reinterpret_cast(input1)); + __m512i vec2 = _mm512_load_si512(reinterpret_cast(input2)); + __m512i vec_result = _mm512_add_epi64(vec1, vec2); + _mm512_store_si512(reinterpret_cast<__m512i*>(result), vec_result); + + return true; +#else + return false; +#endif // __AVX512F__ +} + +bool run_avx512cd_computation() { + EXPECT_TRUE(run_avx512f_computation()); +#ifdef __AVX512CD__ + + __m512i indices = _mm512_set_epi32( + 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0); + __m512i conflict_mask = _mm512_conflict_epi32(indices); + + alignas(64) int mask_array[16]; + _mm512_store_epi32(mask_array, conflict_mask); + return true; +#else + return false; +#endif // __AVX512CD__ +} + +bool run_avx512vl_computation() { + EXPECT_TRUE(run_avx512f_computation()); + +#ifdef __AVX512VL__ + __m256i vec1 = _mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0); + __m256i vec2 = _mm256_set_epi32(0, 1, 2, 3, 4, 5, 6, 7); + __m256i result = _mm256_add_epi32(vec1, vec2); + alignas(32) int result_array[8]; + _mm256_store_si256(reinterpret_cast<__m256i*>(result_array), result); + return true; +#else + return false; +#endif // __AVX512VL__ +} + +bool run_avx512dq_computation() { + EXPECT_TRUE(run_avx512f_computation()); + +#ifdef __AVX512DQ__ + __m512i vec1 = _mm512_set_epi64(7, 6, 5, 4, 3, 2, 1, 0); + __m512i vec2 = _mm512_set_epi64(0, 1, 2, 3, 4, 5, 6, 7); + __m512i result = _mm512_add_epi64(vec1, vec2); + + alignas(64) long long result_array[8]; + _mm512_store_si512(result_array, result); + return true; +#else + return false; +#endif // __AVX512DQ__ +} + +bool run_avx512bw_computation() { + EXPECT_TRUE(run_avx512f_computation()); + +#ifdef __AVX512BW__ + std::vector input1(64, 0); + __m512i vec1 = + _mm512_loadu_si512(reinterpret_cast(input1.data())); + std::vector input2(64, 7); + __m512i vec2 = + _mm512_loadu_si512(reinterpret_cast(input2.data())); + __m512i result = _mm512_add_epi8(vec1, vec2); + + alignas(64) int8_t result_array[64]; + _mm512_storeu_si512(reinterpret_cast<__m512i*>(result_array), result); + + return true; +#else + return false; +#endif // __AVX512BW__ +} +#endif // __x86_64__ + +TEST(SIMDConfig, simd_level_auto_detect_architecture_only) { + faiss::SIMDLevel detected_level = + faiss::SIMDConfig::auto_detect_simd_level(); + +#if defined(__x86_64__) && \ + (defined(__AVX2__) || \ + (defined(__AVX512F__) && defined(__AVX512CD__) && \ + defined(__AVX512VL__) && defined(__AVX512BW__) && \ + defined(__AVX512DQ__))) + EXPECT_TRUE( + detected_level == faiss::SIMDLevel::AVX2 || + detected_level == faiss::SIMDLevel::AVX512); +#elif defined(__aarch64__) && defined(__ARM_NEON) + // Uncomment following line when dynamic dispatch is enabled for ARM_NEON + // EXPECT_TRUE(detected_level == faiss::SIMDLevel::ARM_NEON); +#else + EXPECT_EQ(detected_level, faiss::SIMDLevel::NONE); +#endif + EXPECT_TRUE(detected_level != faiss::SIMDLevel::COUNT); +} + +#ifdef __x86_64__ +TEST(SIMDConfig, successful_avx2_execution_on_x86arch) { + faiss::SIMDConfig simd_config(nullptr); + + if (simd_config.is_simd_level_available(faiss::SIMDLevel::AVX2)) { + auto actual_result = run_avx2_computation(); + EXPECT_TRUE(actual_result); + } +} + +TEST(SIMDConfig, on_avx512f_supported_we_should_avx2_support_as_well) { + faiss::SIMDConfig simd_config(nullptr); + + if (simd_config.is_simd_level_available(faiss::SIMDLevel::AVX512)) { + EXPECT_TRUE( + simd_config.is_simd_level_available(faiss::SIMDLevel::AVX2)); + } +} + +TEST(SIMDConfig, successful_avx512f_execution_on_x86arch) { + faiss::SIMDConfig simd_config(nullptr); + + if (simd_config.is_simd_level_available(faiss::SIMDLevel::AVX512)) { + auto actual_result = run_avx512f_computation(); + EXPECT_TRUE(actual_result); + } +} + +TEST(SIMDConfig, successful_avx512cd_execution_on_x86arch) { + faiss::SIMDConfig simd_config(nullptr); + + if (simd_config.is_simd_level_available(faiss::SIMDLevel::AVX512)) { + auto actual = run_avx512cd_computation(); + EXPECT_TRUE(actual); + } +} + +TEST(SIMDConfig, successful_avx512vl_execution_on_x86arch) { + faiss::SIMDConfig simd_config(nullptr); + + if (simd_config.is_simd_level_available(faiss::SIMDLevel::AVX512)) { + auto actual = run_avx512vl_computation(); + EXPECT_TRUE(actual); + } +} + +TEST(SIMDConfig, successful_avx512dq_execution_on_x86arch) { + faiss::SIMDConfig simd_config(nullptr); + + if (simd_config.is_simd_level_available(faiss::SIMDLevel::AVX512)) { + EXPECT_TRUE( + simd_config.is_simd_level_available(faiss::SIMDLevel::AVX512)); + auto actual = run_avx512dq_computation(); + EXPECT_TRUE(actual); + } +} + +TEST(SIMDConfig, successful_avx512bw_execution_on_x86arch) { + faiss::SIMDConfig simd_config(nullptr); + + if (simd_config.is_simd_level_available(faiss::SIMDLevel::AVX512)) { + EXPECT_TRUE( + simd_config.is_simd_level_available(faiss::SIMDLevel::AVX512)); + auto actual = run_avx512bw_computation(); + EXPECT_TRUE(actual); + // EXPECT_TRUE(actual.first); + // EXPECT_EQ(actual.second, std::vector(64, 7)); + } +} +#endif // __x86_64__ + +TEST(SIMDConfig, override_simd_level) { + // const char* faiss_env_var_neon = "ARM_NEON"; + // faiss::SIMDConfig simd_neon_config(&faiss_env_var_neon); + // EXPECT_EQ(simd_neon_config.level, faiss::SIMDLevel::ARM_NEON); + + // EXPECT_EQ(simd_neon_config.supported_simd_levels().size(), 2); + // EXPECT_TRUE(simd_neon_config.is_simd_level_available( + // faiss::SIMDLevel::ARM_NEON)); + + const char* faiss_env_var_avx512 = "AVX512"; + faiss::SIMDConfig simd_avx512_config(&faiss_env_var_avx512); + EXPECT_EQ(simd_avx512_config.level, faiss::SIMDLevel::AVX512); + EXPECT_EQ(simd_avx512_config.supported_simd_levels().size(), 2); + EXPECT_TRUE(simd_avx512_config.is_simd_level_available( + faiss::SIMDLevel::AVX512)); +} + +TEST(SIMDConfig, simd_config_get_level_name) { + // const char* faiss_env_var_neon = "ARM_NEON"; + // faiss::SIMDConfig simd_neon_config(&faiss_env_var_neon); + // EXPECT_EQ(simd_neon_config.level, faiss::SIMDLevel::ARM_NEON); + // EXPECT_TRUE(simd_neon_config.is_simd_level_available( + // faiss::SIMDLevel::ARM_NEON)); + // EXPECT_EQ(faiss_env_var_neon, simd_neon_config.get_level_name()); + + const char* faiss_env_var_avx512 = "AVX512"; + faiss::SIMDConfig simd_avx512_config(&faiss_env_var_avx512); + EXPECT_EQ(simd_avx512_config.level, faiss::SIMDLevel::AVX512); + EXPECT_TRUE(simd_avx512_config.is_simd_level_available( + faiss::SIMDLevel::AVX512)); + EXPECT_EQ(faiss_env_var_avx512, simd_avx512_config.get_level_name()); +} + +TEST(SIMDLevel, get_level_name_from_enum) { + EXPECT_EQ("NONE", to_string(faiss::SIMDLevel::NONE).value_or("")); + EXPECT_EQ("AVX2", to_string(faiss::SIMDLevel::AVX2).value_or("")); + EXPECT_EQ("AVX512", to_string(faiss::SIMDLevel::AVX512).value_or("")); + // EXPECT_EQ("ARM_NEON", + // to_string(faiss::SIMDLevel::ARM_NEON).value_or("")); + + int actual_num_simd_levels = static_cast(faiss::SIMDLevel::COUNT); + EXPECT_EQ(4, actual_num_simd_levels); + // Check that all SIMD levels have a name (except for COUNT which is not a + // real SIMD level) + for (int i = 0; i < actual_num_simd_levels - 1; ++i) { + faiss::SIMDLevel simd_level = static_cast(i); + EXPECT_TRUE(faiss::to_string(simd_level).has_value()); + } +} + +TEST(SIMDLevel, to_simd_level_from_string) { + EXPECT_EQ(faiss::SIMDLevel::NONE, faiss::to_simd_level("NONE")); + EXPECT_EQ(faiss::SIMDLevel::AVX2, faiss::to_simd_level("AVX2")); + EXPECT_EQ(faiss::SIMDLevel::AVX512, faiss::to_simd_level("AVX512")); + // EXPECT_EQ(faiss::SIMDLevel::ARM_NEON, faiss::to_simd_level("ARM_NEON")); + EXPECT_FALSE(faiss::to_simd_level("INVALID").has_value()); +}