diff --git a/faiss/utils/partitioning.cpp b/faiss/utils/partitioning.cpp index b87f6afd53..4f49465621 100644 --- a/faiss/utils/partitioning.cpp +++ b/faiss/utils/partitioning.cpp @@ -266,6 +266,180 @@ simd16uint16 max_func(simd16uint16 v, simd16uint16 thr16) { } } +#if defined(__AVX512F__) && defined(__AVX512VBMI2__) + +template +void count_lt_and_eq( + const uint16_t* vals, + int n, + uint16_t thresh, + size_t& n_lt, + size_t& n_eq) { + n_lt = 0; + n_eq = 0; + + size_t local_n_lt = 0; + size_t local_n_eq = 0; + + int i = 0; + constexpr int VEC_SIZE = 32; + + constexpr int cmp_op = C::is_max ? _MM_CMPINT_LT : _MM_CMPINT_GT; + + __m512i v_thresh = _mm512_set1_epi16(thresh); + + for (; i + VEC_SIZE <= n; i += VEC_SIZE) { + __m512i v_vals = _mm512_loadu_si512(vals + i); + + __mmask32 k_lt = _mm512_cmp_epu16_mask(v_vals, v_thresh, cmp_op); + __mmask32 k_eq = _mm512_cmp_epu16_mask(v_vals, v_thresh, _MM_CMPINT_EQ); + __mmask32 k_eq_only = k_eq & ~k_lt; + + local_n_lt += _mm_popcnt_u32(k_lt); + local_n_eq += _mm_popcnt_u32(k_eq_only); + } + + for (; i < n; i++) { + uint16_t v = vals[i]; + if (C::cmp(thresh, v)) { + local_n_lt++; + } else if (v == thresh) { + local_n_eq++; + } + } + + n_lt = local_n_lt; + n_eq = local_n_eq; +} + +template +int simd_compress_array( + uint16_t* vals, + int* ids, + size_t n, + uint16_t thresh, + int n_eq) { + constexpr int cmp_op = C::is_max ? _MM_CMPINT_LT : _MM_CMPINT_GT; + + int wp = 0; + size_t i = 0; + + constexpr int VEC_SIZE = 16; + __m512i v_thresh = _mm512_set1_epi32(static_cast(thresh)); + + for (; i + VEC_SIZE <= n; i += VEC_SIZE) { + __m256i v_vals_u16 = _mm256_loadu_si256((__m256i*)(vals + i)); + __m512i v_ids_s32 = _mm512_loadu_si512(ids + i); + __m512i v_vals_s32 = _mm512_cvtepu16_epi32(v_vals_u16); + + __mmask16 k_primary = + _mm512_cmp_epi32_mask(v_vals_s32, v_thresh, cmp_op); + __mmask16 k_equal_to_add = 0; + + int num_to_take = 0; + + if (n_eq > 0) { + __mmask16 k_equal = + _mm512_cmp_epi32_mask(v_vals_s32, v_thresh, _MM_CMPINT_EQ); + __mmask16 k_equal_only = k_equal & ~k_primary; + + int num_eq_found = _mm_popcnt_u32(k_equal_only); + num_to_take = std::min(n_eq, num_eq_found); + + k_equal_to_add = + _pdep_u32((uint32_t(1) << num_to_take) - 1, k_equal_only); + } + + __mmask16 k_final = k_primary | k_equal_to_add; + _mm256_mask_compressstoreu_epi16(vals + wp, k_final, v_vals_u16); + _mm512_mask_compressstoreu_epi32(ids + wp, k_final, v_ids_s32); + + wp += _mm_popcnt_u32(k_final); + n_eq -= num_to_take; + } + + for (; i < n; i++) { + if (C::cmp(thresh, vals[i])) { + vals[wp] = vals[i]; + ids[wp] = ids[i]; + wp++; + } else if (n_eq > 0 && vals[i] == thresh) { + vals[wp] = vals[i]; + ids[wp] = ids[i]; + wp++; + n_eq--; + } + } + + assert(n_eq == 0); + return wp; +} + +template +int simd_compress_array( + uint16_t* vals, + int64_t* ids, + size_t n, + uint16_t thresh, + int n_eq) { + constexpr int cmp_op = C::is_max ? _MM_CMPINT_LT : _MM_CMPINT_GT; + + int wp = 0; + size_t i = 0; + + constexpr int VEC_SIZE = 8; + __m512i v_thresh = _mm512_set1_epi64(static_cast(thresh)); + + for (; i + VEC_SIZE <= n; i += VEC_SIZE) { + __m128i v_vals_u16 = _mm_loadu_si128((__m128i*)(vals + i)); + __m512i v_ids_s64 = _mm512_loadu_si512(ids + i); + __m512i v_vals_s64 = _mm512_cvtepu16_epi64(v_vals_u16); + + __mmask8 k_primary = + _mm512_cmp_epi64_mask(v_vals_s64, v_thresh, cmp_op); + __mmask8 k_equal_to_add = 0; + + int num_to_take = 0; + + if (n_eq > 0) { + __mmask8 k_equal = + _mm512_cmp_epi64_mask(v_vals_s64, v_thresh, _MM_CMPINT_EQ); + __mmask8 k_equal_only = k_equal & ~k_primary; + + int num_eq_found = _mm_popcnt_u32(k_equal_only); + num_to_take = std::min(n_eq, num_eq_found); + + k_equal_to_add = + _pdep_u32((uint32_t(1) << num_to_take) - 1, k_equal_only); + } + + __mmask8 k_final = k_primary | k_equal_to_add; + _mm_mask_compressstoreu_epi16(vals + wp, k_final, v_vals_u16); + _mm512_mask_compressstoreu_epi64(ids + wp, k_final, v_ids_s64); + + wp += _mm_popcnt_u32(k_final); + n_eq -= num_to_take; + } + + for (; i < n; i++) { + if (C::cmp(thresh, vals[i])) { + vals[wp] = vals[i]; + ids[wp] = ids[i]; + wp++; + } else if (n_eq > 0 && vals[i] == thresh) { + vals[wp] = vals[i]; + ids[wp] = ids[i]; + wp++; + n_eq--; + } + } + + assert(n_eq == 0); + return wp; +} + +#else + template void count_lt_and_eq( const uint16_t* vals, @@ -385,6 +559,8 @@ int simd_compress_array( return wp; } +#endif + // #define MICRO_BENCHMARK static uint64_t get_cy() {