Skip to content

Commit c66dfb5

Browse files
alwayslove2013XuanYang-cn
authored andcommitted
fix pinecone client
Signed-off-by: min.tian <[email protected]>
1 parent aa6d4dc commit c66dfb5

File tree

2 files changed

+34
-38
lines changed

2 files changed

+34
-38
lines changed

vectordb_bench/backend/clients/pinecone/config.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,10 @@
44

55
class PineconeConfig(DBConfig):
66
api_key: SecretStr
7-
environment: SecretStr
87
index_name: str
98

109
def to_dict(self) -> dict:
1110
return {
1211
"api_key": self.api_key.get_secret_value(),
13-
"environment": self.environment.get_secret_value(),
1412
"index_name": self.index_name,
1513
}

vectordb_bench/backend/clients/pinecone/pinecone.py

Lines changed: 34 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,16 @@
33
import logging
44
from contextlib import contextmanager
55
from typing import Type
6-
6+
import pinecone
77
from ..api import VectorDB, DBConfig, DBCaseConfig, EmptyDBCaseConfig, IndexType
88
from .config import PineconeConfig
99

1010

1111
log = logging.getLogger(__name__)
1212

1313
PINECONE_MAX_NUM_PER_BATCH = 1000
14-
PINECONE_MAX_SIZE_PER_BATCH = 2 * 1024 * 1024 # 2MB
14+
PINECONE_MAX_SIZE_PER_BATCH = 2 * 1024 * 1024 # 2MB
15+
1516

1617
class Pinecone(VectorDB):
1718
def __init__(
@@ -23,30 +24,25 @@ def __init__(
2324
**kwargs,
2425
):
2526
"""Initialize wrapper around the milvus vector database."""
26-
self.index_name = db_config["index_name"]
27-
self.api_key = db_config["api_key"]
28-
self.environment = db_config["environment"]
29-
self.batch_size = int(min(PINECONE_MAX_SIZE_PER_BATCH / (dim * 5), PINECONE_MAX_NUM_PER_BATCH))
30-
# Pincone will make connections with server while import
31-
# so place the import here.
32-
import pinecone
33-
pinecone.init(
34-
api_key=self.api_key, environment=self.environment)
27+
self.index_name = db_config.get("index_name", "")
28+
self.api_key = db_config.get("api_key", "")
29+
self.batch_size = int(
30+
min(PINECONE_MAX_SIZE_PER_BATCH / (dim * 5), PINECONE_MAX_NUM_PER_BATCH)
31+
)
32+
33+
pc = pinecone.Pinecone(api_key=self.api_key)
34+
index = pc.Index(self.index_name)
35+
3536
if drop_old:
36-
list_indexes = pinecone.list_indexes()
37-
if self.index_name in list_indexes:
38-
index = pinecone.Index(self.index_name)
39-
index_dim = index.describe_index_stats()["dimension"]
40-
if (index_dim != dim):
41-
raise ValueError(
42-
f"Pinecone index {self.index_name} dimension mismatch, expected {index_dim} got {dim}")
43-
log.info(
44-
f"Pinecone client delete old index: {self.index_name}")
45-
index.delete(delete_all=True)
46-
index.close()
47-
else:
37+
index_stats = index.describe_index_stats()
38+
index_dim = index_stats["dimension"]
39+
if index_dim != dim:
4840
raise ValueError(
49-
f"Pinecone index {self.index_name} does not exist")
41+
f"Pinecone index {self.index_name} dimension mismatch, expected {index_dim} got {dim}"
42+
)
43+
for namespace in index_stats["namespaces"]:
44+
log.info(f"Pinecone index delete namespace: {namespace}")
45+
index.delete(delete_all=True, namespace=namespace)
5046

5147
self._metadata_key = "meta"
5248

@@ -59,13 +55,10 @@ def case_config_cls(cls, index_type: IndexType | None = None) -> Type[DBCaseConf
5955
return EmptyDBCaseConfig
6056

6157
@contextmanager
62-
def init(self) -> None:
63-
import pinecone
64-
pinecone.init(
65-
api_key=self.api_key, environment=self.environment)
66-
self.index = pinecone.Index(self.index_name)
58+
def init(self):
59+
pc = pinecone.Pinecone(api_key=self.api_key)
60+
self.index = pc.Index(self.index_name)
6761
yield
68-
self.index.close()
6962

7063
def ready_to_load(self):
7164
pass
@@ -83,11 +76,16 @@ def insert_embeddings(
8376
insert_count = 0
8477
try:
8578
for batch_start_offset in range(0, len(embeddings), self.batch_size):
86-
batch_end_offset = min(batch_start_offset + self.batch_size, len(embeddings))
79+
batch_end_offset = min(
80+
batch_start_offset + self.batch_size, len(embeddings)
81+
)
8782
insert_datas = []
8883
for i in range(batch_start_offset, batch_end_offset):
89-
insert_data = (str(metadata[i]), embeddings[i], {
90-
self._metadata_key: metadata[i]})
84+
insert_data = (
85+
str(metadata[i]),
86+
embeddings[i],
87+
{self._metadata_key: metadata[i]},
88+
)
9189
insert_datas.append(insert_data)
9290
self.index.upsert(insert_datas)
9391
insert_count += batch_end_offset - batch_start_offset
@@ -101,7 +99,7 @@ def search_embedding(
10199
k: int = 100,
102100
filters: dict | None = None,
103101
timeout: int | None = None,
104-
) -> list[tuple[int, float]]:
102+
) -> list[int]:
105103
if filters is None:
106104
pinecone_filters = {}
107105
else:
@@ -111,9 +109,9 @@ def search_embedding(
111109
top_k=k,
112110
vector=query,
113111
filter=pinecone_filters,
114-
)['matches']
112+
)["matches"]
115113
except Exception as e:
116114
print(f"Error querying index: {e}")
117115
raise e
118-
id_res = [int(one_res['id']) for one_res in res]
116+
id_res = [int(one_res["id"]) for one_res in res]
119117
return id_res

0 commit comments

Comments
 (0)