Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 176 additions & 0 deletions faiss/utils/partitioning.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,180 @@ simd16uint16 max_func(simd16uint16 v, simd16uint16 thr16) {
}
}

#if defined(__AVX512F__) && defined(__AVX512VBMI2__)

template <class C>
void count_lt_and_eq(
const uint16_t* vals,
int n,
uint16_t thresh,
size_t& n_lt,
size_t& n_eq) {
n_lt = 0;
n_eq = 0;

size_t local_n_lt = 0;
size_t local_n_eq = 0;

int i = 0;
constexpr int VEC_SIZE = 32;

constexpr int cmp_op = C::is_max ? _MM_CMPINT_LT : _MM_CMPINT_GT;

__m512i v_thresh = _mm512_set1_epi16(thresh);

for (; i + VEC_SIZE <= n; i += VEC_SIZE) {
__m512i v_vals = _mm512_loadu_si512(vals + i);

__mmask32 k_lt = _mm512_cmp_epu16_mask(v_vals, v_thresh, cmp_op);
__mmask32 k_eq = _mm512_cmp_epu16_mask(v_vals, v_thresh, _MM_CMPINT_EQ);
__mmask32 k_eq_only = k_eq & ~k_lt;

local_n_lt += _mm_popcnt_u32(k_lt);
local_n_eq += _mm_popcnt_u32(k_eq_only);
}

for (; i < n; i++) {
uint16_t v = vals[i];
if (C::cmp(thresh, v)) {
local_n_lt++;
} else if (v == thresh) {
local_n_eq++;
}
}

n_lt = local_n_lt;
n_eq = local_n_eq;
}

template <class C>
int simd_compress_array(
uint16_t* vals,
int* ids,
size_t n,
uint16_t thresh,
int n_eq) {
constexpr int cmp_op = C::is_max ? _MM_CMPINT_LT : _MM_CMPINT_GT;

int wp = 0;
size_t i = 0;

constexpr int VEC_SIZE = 16;
__m512i v_thresh = _mm512_set1_epi32(static_cast<int32_t>(thresh));

for (; i + VEC_SIZE <= n; i += VEC_SIZE) {
__m256i v_vals_u16 = _mm256_loadu_si256((__m256i*)(vals + i));
__m512i v_ids_s32 = _mm512_loadu_si512(ids + i);
__m512i v_vals_s32 = _mm512_cvtepu16_epi32(v_vals_u16);

__mmask16 k_primary =
_mm512_cmp_epi32_mask(v_vals_s32, v_thresh, cmp_op);
__mmask16 k_equal_to_add = 0;

int num_to_take = 0;

if (n_eq > 0) {
__mmask16 k_equal =
_mm512_cmp_epi32_mask(v_vals_s32, v_thresh, _MM_CMPINT_EQ);
__mmask16 k_equal_only = k_equal & ~k_primary;

int num_eq_found = _mm_popcnt_u32(k_equal_only);
num_to_take = std::min(n_eq, num_eq_found);

k_equal_to_add =
_pdep_u32((uint32_t(1) << num_to_take) - 1, k_equal_only);
}

__mmask16 k_final = k_primary | k_equal_to_add;
_mm256_mask_compressstoreu_epi16(vals + wp, k_final, v_vals_u16);
_mm512_mask_compressstoreu_epi32(ids + wp, k_final, v_ids_s32);

wp += _mm_popcnt_u32(k_final);
n_eq -= num_to_take;
}

for (; i < n; i++) {
if (C::cmp(thresh, vals[i])) {
vals[wp] = vals[i];
ids[wp] = ids[i];
wp++;
} else if (n_eq > 0 && vals[i] == thresh) {
vals[wp] = vals[i];
ids[wp] = ids[i];
wp++;
n_eq--;
}
}

assert(n_eq == 0);
return wp;
}

template <class C>
int simd_compress_array(
uint16_t* vals,
int64_t* ids,
size_t n,
uint16_t thresh,
int n_eq) {
constexpr int cmp_op = C::is_max ? _MM_CMPINT_LT : _MM_CMPINT_GT;

int wp = 0;
size_t i = 0;

constexpr int VEC_SIZE = 8;
__m512i v_thresh = _mm512_set1_epi64(static_cast<int64_t>(thresh));

for (; i + VEC_SIZE <= n; i += VEC_SIZE) {
__m128i v_vals_u16 = _mm_loadu_si128((__m128i*)(vals + i));
__m512i v_ids_s64 = _mm512_loadu_si512(ids + i);
__m512i v_vals_s64 = _mm512_cvtepu16_epi64(v_vals_u16);

__mmask8 k_primary =
_mm512_cmp_epi64_mask(v_vals_s64, v_thresh, cmp_op);
__mmask8 k_equal_to_add = 0;

int num_to_take = 0;

if (n_eq > 0) {
__mmask8 k_equal =
_mm512_cmp_epi64_mask(v_vals_s64, v_thresh, _MM_CMPINT_EQ);
__mmask8 k_equal_only = k_equal & ~k_primary;

int num_eq_found = _mm_popcnt_u32(k_equal_only);
num_to_take = std::min(n_eq, num_eq_found);

k_equal_to_add =
_pdep_u32((uint32_t(1) << num_to_take) - 1, k_equal_only);
}

__mmask8 k_final = k_primary | k_equal_to_add;
_mm_mask_compressstoreu_epi16(vals + wp, k_final, v_vals_u16);
_mm512_mask_compressstoreu_epi64(ids + wp, k_final, v_ids_s64);

wp += _mm_popcnt_u32(k_final);
n_eq -= num_to_take;
}

for (; i < n; i++) {
if (C::cmp(thresh, vals[i])) {
vals[wp] = vals[i];
ids[wp] = ids[i];
wp++;
} else if (n_eq > 0 && vals[i] == thresh) {
vals[wp] = vals[i];
ids[wp] = ids[i];
wp++;
n_eq--;
}
}

assert(n_eq == 0);
return wp;
}

#else

template <class C>
void count_lt_and_eq(
const uint16_t* vals,
Expand Down Expand Up @@ -385,6 +559,8 @@ int simd_compress_array(
return wp;
}

#endif

// #define MICRO_BENCHMARK

static uint64_t get_cy() {
Expand Down
Loading