Skip to content

Commit 0031d61

Browse files
generatedunixname89002005232357facebook-github-bot
authored andcommitted
Revert D80734790 (facebookresearch#4563)
Summary: Pull Request resolved: facebookresearch#4563 This diff reverts D80734790 (The context such as a Sandcastle job, Task, SEV, etc. was not provided.) Depends on D80734790 Reviewed By: limqiying Differential Revision: D81147721 fbshipit-source-id: 33031f34660155a82f30cc2327a932e3f47ff141
1 parent ee6b7dd commit 0031d61

File tree

10 files changed

+65
-173
lines changed

10 files changed

+65
-173
lines changed

contrib/torch_utils.py

Lines changed: 11 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -80,15 +80,6 @@ def swig_ptr_from_IndicesTensor(x):
8080
return faiss.cast_integer_to_idx_t_ptr(
8181
x.untyped_storage().data_ptr() + x.storage_offset() * 8)
8282

83-
84-
def swig_ptr_from_Int64Tensor(x):
85-
""" gets a Faiss SWIG pointer from a pytorch tensor (on CPU or GPU) """
86-
assert x.is_contiguous()
87-
assert x.dtype == torch.int64, 'dtype=%s' % x.dtype
88-
return faiss.cast_integer_to_int64_ptr(
89-
x.untyped_storage().data_ptr() + x.storage_offset() * 8)
90-
91-
9283
##################################################################
9384
# utilities
9485
##################################################################
@@ -168,7 +159,7 @@ def torch_replacement_add(self, x, numeric_type = faiss.Float32):
168159
# CPU torch
169160
self.add_ex(n, x_ptr, numeric_type)
170161

171-
def torch_replacement_add_with_ids(self, x, ids, numeric_type = faiss.Float32, ids_type = faiss.Int64):
162+
def torch_replacement_add_with_ids(self, x, ids, numeric_type = faiss.Float32):
172163
if type(x) is np.ndarray:
173164
# forward to faiss __init__.py base method
174165
return self.add_with_ids_numpy(x, ids)
@@ -185,20 +176,17 @@ def torch_replacement_add_with_ids(self, x, ids, numeric_type = faiss.Float32, i
185176

186177
assert type(ids) is torch.Tensor
187178
assert ids.shape == (n, ), 'not same number of vectors as ids'
188-
if ids_type == faiss.Int64:
189-
ids_ptr = swig_ptr_from_Int64Tensor(ids)
190-
else:
191-
raise ValueError("ids type for add_with_ids must be faiss.Int64")
179+
ids_ptr = swig_ptr_from_IndicesTensor(ids)
192180

193181
if x.is_cuda:
194182
assert hasattr(self, 'getDevice'), 'GPU tensor on CPU index not allowed'
195183

196184
# On the GPU, use proper stream ordering
197185
with using_stream(self.getResources()):
198-
self.add_with_ids_ex(n, x_ptr, numeric_type, ids_ptr, ids_type)
186+
self.add_with_ids_ex(n, x_ptr, numeric_type, ids_ptr)
199187
else:
200188
# CPU torch
201-
self.add_with_ids_ex(n, x_ptr, numeric_type, ids_ptr, ids_type)
189+
self.add_with_ids_ex(n, x_ptr, numeric_type, ids_ptr)
202190

203191
def torch_replacement_assign(self, x, k, labels=None):
204192
if type(x) is np.ndarray:
@@ -272,18 +260,12 @@ def search_methods_common(x, k, D, I, numeric_type=faiss.Float32):
272260

273261
if I is None:
274262
I = torch.empty(n, k, device=x.device, dtype=torch.int64)
275-
I_ptr = swig_ptr_from_Int64Tensor(I)
276-
I_type = faiss.Int64
277263
else:
278264
assert type(I) is torch.Tensor
279265
assert I.shape == (n, k)
280-
if I.dtype == torch.int64:
281-
I_ptr = swig_ptr_from_Int64Tensor(I)
282-
I_type = faiss.Int64
283-
else:
284-
raise ValueError("labels for search should be int64 type")
266+
I_ptr = swig_ptr_from_IndicesTensor(I)
285267

286-
return x_ptr, D_ptr, I_ptr, D, I, I_type
268+
return x_ptr, D_ptr, I_ptr, D, I
287269

288270
def torch_replacement_search(self, x, k, D=None, I=None, numeric_type=faiss.Float32):
289271
if type(x) is np.ndarray:
@@ -294,17 +276,17 @@ def torch_replacement_search(self, x, k, D=None, I=None, numeric_type=faiss.Floa
294276
n, d = x.shape
295277
assert d == self.d
296278

297-
x_ptr, D_ptr, I_ptr, D, I, I_type = search_methods_common(x, k, D, I)
279+
x_ptr, D_ptr, I_ptr, D, I = search_methods_common(x, k, D, I)
298280

299281
if x.is_cuda:
300282
assert hasattr(self, 'getDevice'), 'GPU tensor on CPU index not allowed'
301283

302284
# On the GPU, use proper stream ordering
303285
with using_stream(self.getResources()):
304-
self.search_ex(n, x_ptr, numeric_type, k, D_ptr, I_ptr, I_type)
286+
self.search_ex(n, x_ptr, numeric_type, k, D_ptr, I_ptr)
305287
else:
306288
# CPU torch
307-
self.search_ex(n, x_ptr, numeric_type, k, D_ptr, I_ptr, I_type)
289+
self.search_ex(n, x_ptr, numeric_type, k, D_ptr, I_ptr)
308290

309291
return D, I
310292

@@ -317,7 +299,7 @@ def torch_replacement_search_and_reconstruct(self, x, k, D=None, I=None, R=None)
317299
n, d = x.shape
318300
assert d == self.d
319301

320-
x_ptr, D_ptr, I_ptr, D, I, _ = search_methods_common(x, k, D, I)
302+
x_ptr, D_ptr, I_ptr, D, I = search_methods_common(x, k, D, I)
321303

322304
if R is None:
323305
R = torch.empty(n, k, d, device=x.device, dtype=torch.float32)
@@ -347,7 +329,7 @@ def torch_replacement_search_preassigned(self, x, k, Iq, Dq, *, D=None, I=None):
347329
n, d = x.shape
348330
assert d == self.d
349331

350-
x_ptr, D_ptr, I_ptr, D, I, _ = search_methods_common(x, k, D, I)
332+
x_ptr, D_ptr, I_ptr, D, I = search_methods_common(x, k, D, I)
351333

352334
assert Iq.shape == (n, self.nprobe)
353335
Iq = Iq.contiguous()

faiss/Index.h

Lines changed: 9 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
#include <faiss/MetricType.h>
1414
#include <faiss/impl/FaissAssert.h>
1515

16-
#include <cstdint>
1716
#include <cstdio>
1817
#include <sstream>
1918

@@ -60,28 +59,22 @@ struct DistanceComputer;
6059
enum NumericType {
6160
Float32,
6261
Float16,
63-
Int64,
6462
UInt8,
6563
Int8,
66-
NONE, // corresponding to nullptr
6764
};
6865

6966
inline size_t get_numeric_type_size(NumericType numeric_type) {
7067
switch (numeric_type) {
71-
case NumericType::Int64:
72-
return 8;
7368
case NumericType::Float32:
7469
return 4;
7570
case NumericType::Float16:
7671
return 2;
7772
case NumericType::UInt8:
7873
case NumericType::Int8:
7974
return 1;
80-
case NumericType::NONE:
81-
return 0;
8275
default:
8376
FAISS_THROW_MSG(
84-
"Unknown Numeric Type. Only supports Float32, Float16, Int64, UInt8, Int8");
77+
"Unknown Numeric Type. Only supports Float32, Float16");
8578
}
8679
}
8780

@@ -176,17 +169,11 @@ struct Index {
176169
idx_t n,
177170
const void* x,
178171
NumericType numeric_type,
179-
const void* xids,
180-
NumericType xids_type) {
181-
if (numeric_type == NumericType::Float32 &&
182-
xids_type == NumericType::Int64) {
183-
add_with_ids(
184-
n,
185-
static_cast<const float*>(x),
186-
static_cast<const idx_t*>(xids));
172+
const idx_t* xids) {
173+
if (numeric_type == NumericType::Float32) {
174+
add_with_ids(n, static_cast<const float*>(x), xids);
187175
} else {
188-
FAISS_THROW_MSG(
189-
"Index::add_with_ids: unsupported numeric type or xids type");
176+
FAISS_THROW_MSG("Index::add_with_ids: unsupported numeric type");
190177
}
191178
}
192179

@@ -215,20 +202,17 @@ struct Index {
215202
NumericType numeric_type,
216203
idx_t k,
217204
float* distances,
218-
void* labels,
219-
NumericType labels_type,
205+
idx_t* labels,
220206
const SearchParameters* params = nullptr) const {
221-
if (numeric_type == NumericType::Float32 &&
222-
labels_type == NumericType::Int64) {
207+
if (numeric_type == NumericType::Float32) {
223208
search(n,
224209
static_cast<const float*>(x),
225210
k,
226211
distances,
227-
static_cast<idx_t*>(labels),
212+
labels,
228213
params);
229214
} else {
230-
FAISS_THROW_MSG(
231-
"Index::search: unsupported numeric type or label type");
215+
FAISS_THROW_MSG("Index::search: unsupported numeric type");
232216
}
233217
}
234218

faiss/IndexBinary.h

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -86,17 +86,12 @@ struct IndexBinary {
8686
idx_t n,
8787
const void* x,
8888
NumericType numeric_type,
89-
const void* xids,
90-
NumericType xids_type) {
91-
if (numeric_type == NumericType::UInt8 &&
92-
xids_type == NumericType::Int64) {
93-
add_with_ids(
94-
n,
95-
static_cast<const uint8_t*>(x),
96-
static_cast<const idx_t*>(xids));
89+
const idx_t* xids) {
90+
if (numeric_type == NumericType::UInt8) {
91+
add_with_ids(n, static_cast<const uint8_t*>(x), xids);
9792
} else {
9893
FAISS_THROW_MSG(
99-
"IndexBinary::add_with_ids: unsupported numeric type or xids type");
94+
"IndexBinary::add_with_ids: unsupported numeric type");
10095
}
10196
};
10297

@@ -122,20 +117,17 @@ struct IndexBinary {
122117
NumericType numeric_type,
123118
idx_t k,
124119
int32_t* distances,
125-
void* labels,
126-
NumericType labels_type,
120+
idx_t* labels,
127121
const SearchParameters* params = nullptr) const {
128-
if (numeric_type == NumericType::UInt8 &&
129-
labels_type == NumericType::Int64) {
122+
if (numeric_type == NumericType::UInt8) {
130123
search(n,
131124
static_cast<const uint8_t*>(x),
132125
k,
133126
distances,
134-
static_cast<idx_t*>(labels),
127+
labels,
135128
params);
136129
} else {
137-
FAISS_THROW_MSG(
138-
"IndexBinary::search: unsupported numeric type or label type");
130+
FAISS_THROW_MSG("IndexBinary::search: unsupported numeric type");
139131
}
140132
};
141133

faiss/IndexIDMap.cpp

Lines changed: 10 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
#include <cinttypes>
1313
#include <cstdint>
1414
#include <cstdio>
15-
#include "Index.h"
1615
#include "faiss/Index.h"
1716

1817
#include <faiss/impl/AuxIndexStructures.h>
@@ -109,14 +108,10 @@ void IndexIDMapTemplate<IndexT>::add_with_ids_ex(
109108
idx_t n,
110109
const void* x,
111110
NumericType numeric_type,
112-
const void* xids,
113-
NumericType xids_type) {
114-
FAISS_THROW_IF_NOT_MSG(
115-
xids_type == NumericType::Int64,
116-
"IndexIDMapTemplate::add_with_ids_ex only supports int64 as xids type");
111+
const idx_t* xids) {
117112
index->add_ex(n, x, numeric_type);
118113
for (idx_t i = 0; i < n; i++) {
119-
id_map.push_back(static_cast<const idx_t*>(xids)[i]);
114+
id_map.push_back(xids[i]);
120115
}
121116
this->ntotal = index->ntotal;
122117
}
@@ -130,8 +125,7 @@ void IndexIDMapTemplate<IndexT>::add_with_ids(
130125
n,
131126
static_cast<const void*>(x),
132127
component_t_to_numeric<typename IndexT::component_t>(),
133-
static_cast<const void*>(xids),
134-
NumericType::Int64);
128+
xids);
135129
}
136130

137131
template <typename IndexT>
@@ -179,8 +173,7 @@ void IndexIDMapTemplate<IndexT>::search_ex(
179173
NumericType numeric_type,
180174
idx_t k,
181175
typename IndexT::distance_t* distances,
182-
void* labels,
183-
NumericType labels_type,
176+
idx_t* labels,
184177
const SearchParameters* params) const {
185178
IDSelectorTranslated this_idtrans(this->id_map, nullptr);
186179
ScopedSelChange sel_change;
@@ -201,13 +194,8 @@ void IndexIDMapTemplate<IndexT>::search_ex(
201194
sel_change.set(params_non_const, &this_idtrans);
202195
}
203196
}
204-
index->search_ex(
205-
n, x, numeric_type, k, distances, labels, labels_type, params);
206-
207-
FAISS_THROW_IF_NOT_MSG(
208-
labels_type == NumericType::Int64,
209-
"IndexIDMapTemplate::search_ex only supports int64 as labels type");
210-
idx_t* li = static_cast<idx_t*>(labels);
197+
index->search_ex(n, x, numeric_type, k, distances, labels, params);
198+
idx_t* li = labels;
211199
#pragma omp parallel for
212200
for (idx_t i = 0; i < n * k; i++) {
213201
li[i] = li[i] < 0 ? li[i] : id_map[li[i]];
@@ -228,8 +216,7 @@ void IndexIDMapTemplate<IndexT>::search(
228216
component_t_to_numeric<typename IndexT::component_t>(),
229217
k,
230218
distances,
231-
static_cast<void*>(labels),
232-
NumericType::Int64,
219+
labels,
233220
params);
234221
}
235222

@@ -319,11 +306,9 @@ void IndexIDMap2Template<IndexT>::add_with_ids_ex(
319306
idx_t n,
320307
const void* x,
321308
NumericType numeric_type,
322-
const void* xids,
323-
NumericType xids_type) {
309+
const idx_t* xids) {
324310
size_t prev_ntotal = this->ntotal;
325-
IndexIDMapTemplate<IndexT>::add_with_ids_ex(
326-
n, x, numeric_type, xids, xids_type);
311+
IndexIDMapTemplate<IndexT>::add_with_ids_ex(n, x, numeric_type, xids);
327312
for (size_t i = prev_ntotal; i < this->ntotal; i++) {
328313
rev_map[this->id_map[i]] = i;
329314
}
@@ -338,8 +323,7 @@ void IndexIDMap2Template<IndexT>::add_with_ids(
338323
n,
339324
static_cast<const void*>(x),
340325
component_t_to_numeric<typename IndexT::component_t>(),
341-
static_cast<const void*>(xids),
342-
NumericType::Int64);
326+
xids);
343327
}
344328

345329
template <typename IndexT>

faiss/IndexIDMap.h

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,7 @@ struct IndexIDMapTemplate : IndexT {
3535
idx_t n,
3636
const void* x,
3737
NumericType numeric_type,
38-
const void* xids,
39-
NumericType xids_type) override;
38+
const idx_t* xids) override;
4039

4140
/// this will fail. Use add_with_ids
4241
void add(idx_t n, const component_t* x) override;
@@ -55,8 +54,7 @@ struct IndexIDMapTemplate : IndexT {
5554
NumericType numeric_type,
5655
idx_t k,
5756
distance_t* distances,
58-
void* labels,
59-
NumericType labels_type,
57+
idx_t* labels,
6058
const SearchParameters* params = nullptr) const override;
6159

6260
void train(idx_t n, const component_t* x) override;
@@ -110,8 +108,7 @@ struct IndexIDMap2Template : IndexIDMapTemplate<IndexT> {
110108
idx_t n,
111109
const void* x,
112110
NumericType numeric_type,
113-
const void* xids,
114-
NumericType xids_type) override;
111+
const idx_t* xids) override;
115112

116113
size_t remove_ids(const IDSelector& sel) override;
117114

0 commit comments

Comments
 (0)