Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 98 additions & 0 deletions faiss/utils/distances.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

#include <faiss/impl/platform_macros.h>
#include <faiss/utils/Heap.h>
#include <faiss/utils/simd_levels.h>

namespace faiss {

Expand All @@ -27,15 +28,27 @@ struct IDSelector;
/// Squared L2 distance between two vectors
float fvec_L2sqr(const float* x, const float* y, size_t d);

template <SIMDLevel>
float fvec_L2sqr(const float* x, const float* y, size_t d);

/// inner product
float fvec_inner_product(const float* x, const float* y, size_t d);

template <SIMDLevel>
float fvec_inner_product(const float* x, const float* y, size_t d);

/// L1 distance
float fvec_L1(const float* x, const float* y, size_t d);

template <SIMDLevel>
float fvec_L1(const float* x, const float* y, size_t d);

/// infinity distance
float fvec_Linf(const float* x, const float* y, size_t d);

template <SIMDLevel>
float fvec_Linf(const float* x, const float* y, size_t d);

/// Special version of inner product that computes 4 distances
/// between x and yi, which is performance oriented.
void fvec_inner_product_batch_4(
Expand All @@ -50,6 +63,19 @@ void fvec_inner_product_batch_4(
float& dis2,
float& dis3);

template <SIMDLevel>
void fvec_inner_product_batch_4(
const float* x,
const float* y0,
const float* y1,
const float* y2,
const float* y3,
const size_t d,
float& dis0,
float& dis1,
float& dis2,
float& dis3);

/// Special version of L2sqr that computes 4 distances
/// between x and yi, which is performance oriented.
void fvec_L2sqr_batch_4(
Expand All @@ -64,6 +90,19 @@ void fvec_L2sqr_batch_4(
float& dis2,
float& dis3);

template <SIMDLevel>
void fvec_L2sqr_batch_4(
const float* x,
const float* y0,
const float* y1,
const float* y2,
const float* y3,
const size_t d,
float& dis0,
float& dis1,
float& dis2,
float& dis3);

/** Compute pairwise distances between sets of vectors
*
* @param d dimension of the vectors
Expand Down Expand Up @@ -93,6 +132,14 @@ void fvec_inner_products_ny(
size_t d,
size_t ny);

template <SIMDLevel>
void fvec_inner_products_ny(
float* ip, /* output inner product */
const float* x,
const float* y,
size_t d,
size_t ny);

/* compute ny square L2 distance between x and a set of contiguous y vectors */
void fvec_L2sqr_ny(
float* dis,
Expand All @@ -101,6 +148,14 @@ void fvec_L2sqr_ny(
size_t d,
size_t ny);

template <SIMDLevel>
void fvec_L2sqr_ny(
float* dis,
const float* x,
const float* y,
size_t d,
size_t ny);

/* compute ny square L2 distance between x and a set of transposed contiguous
y vectors. squared lengths of y should be provided as well */
void fvec_L2sqr_ny_transposed(
Expand All @@ -112,6 +167,16 @@ void fvec_L2sqr_ny_transposed(
size_t d_offset,
size_t ny);

template <SIMDLevel>
void fvec_L2sqr_ny_transposed(
float* dis,
const float* x,
const float* y,
const float* y_sqlen,
size_t d,
size_t d_offset,
size_t ny);

/* compute ny square L2 distance between x and a set of contiguous y vectors
and return the index of the nearest vector.
return 0 if ny == 0. */
Expand All @@ -122,6 +187,14 @@ size_t fvec_L2sqr_ny_nearest(
size_t d,
size_t ny);

template <SIMDLevel>
size_t fvec_L2sqr_ny_nearest(
float* distances_tmp_buffer,
const float* x,
const float* y,
size_t d,
size_t ny);

/* compute ny square L2 distance between x and a set of transposed contiguous
y vectors and return the index of the nearest vector.
squared lengths of y should be provided as well
Expand All @@ -135,9 +208,22 @@ size_t fvec_L2sqr_ny_nearest_y_transposed(
size_t d_offset,
size_t ny);

template <SIMDLevel>
size_t fvec_L2sqr_ny_nearest_y_transposed(
float* distances_tmp_buffer,
const float* x,
const float* y,
const float* y_sqlen,
size_t d,
size_t d_offset,
size_t ny);

/** squared norm of a vector */
float fvec_norm_L2sqr(const float* x, size_t d);

template <SIMDLevel>
float fvec_norm_L2sqr(const float* x, size_t d);

/** compute the L2 norms for a set of vectors
*
* @param norms output norms, size nx
Expand Down Expand Up @@ -473,6 +559,10 @@ void compute_PQ_dis_tables_dsub2(
*/
void fvec_madd(size_t n, const float* a, float bf, const float* b, float* c);

/* same statically */
template <SIMDLevel>
void fvec_madd(size_t n, const float* a, float bf, const float* b, float* c);

/** same as fvec_madd, also return index of the min of the result table
* @return index of the min of table c
*/
Expand All @@ -483,4 +573,12 @@ int fvec_madd_and_argmin(
const float* b,
float* c);

template <SIMDLevel>
int fvec_madd_and_argmin(
size_t n,
const float* a,
float bf,
const float* b,
float* c);

} // namespace faiss
Loading
Loading