33import logging
44from contextlib import contextmanager
55from typing import Type
6-
6+ import pinecone
77from ..api import VectorDB , DBConfig , DBCaseConfig , EmptyDBCaseConfig , IndexType
88from .config import PineconeConfig
99
1010
1111log = logging .getLogger (__name__ )
1212
1313PINECONE_MAX_NUM_PER_BATCH = 1000
14- PINECONE_MAX_SIZE_PER_BATCH = 2 * 1024 * 1024 # 2MB
14+ PINECONE_MAX_SIZE_PER_BATCH = 2 * 1024 * 1024 # 2MB
15+
1516
1617class Pinecone (VectorDB ):
1718 def __init__ (
@@ -23,30 +24,25 @@ def __init__(
2324 ** kwargs ,
2425 ):
2526 """Initialize wrapper around the milvus vector database."""
26- self .index_name = db_config [ "index_name" ]
27- self .api_key = db_config [ "api_key" ]
28- self .environment = db_config [ "environment" ]
29- self . batch_size = int ( min (PINECONE_MAX_SIZE_PER_BATCH / (dim * 5 ), PINECONE_MAX_NUM_PER_BATCH ) )
30- # Pincone will make connections with server while import
31- # so place the import here.
32- import pinecone
33- pinecone . init (
34- api_key = self . api_key , environment = self . environment )
27+ self .index_name = db_config . get ( "index_name" , "" )
28+ self .api_key = db_config . get ( "api_key" , "" )
29+ self .batch_size = int (
30+ min (PINECONE_MAX_SIZE_PER_BATCH / (dim * 5 ), PINECONE_MAX_NUM_PER_BATCH )
31+ )
32+
33+ pc = pinecone . Pinecone ( api_key = self . api_key )
34+ index = pc . Index ( self . index_name )
35+
3536 if drop_old :
36- list_indexes = pinecone .list_indexes ()
37- if self .index_name in list_indexes :
38- index = pinecone .Index (self .index_name )
39- index_dim = index .describe_index_stats ()["dimension" ]
40- if (index_dim != dim ):
41- raise ValueError (
42- f"Pinecone index { self .index_name } dimension mismatch, expected { index_dim } got { dim } " )
43- log .info (
44- f"Pinecone client delete old index: { self .index_name } " )
45- index .delete (delete_all = True )
46- index .close ()
47- else :
37+ index_stats = index .describe_index_stats ()
38+ index_dim = index_stats ["dimension" ]
39+ if index_dim != dim :
4840 raise ValueError (
49- f"Pinecone index { self .index_name } does not exist" )
41+ f"Pinecone index { self .index_name } dimension mismatch, expected { index_dim } got { dim } "
42+ )
43+ for namespace in index_stats ["namespaces" ]:
44+ log .info (f"Pinecone index delete namespace: { namespace } " )
45+ index .delete (delete_all = True , namespace = namespace )
5046
5147 self ._metadata_key = "meta"
5248
@@ -59,13 +55,10 @@ def case_config_cls(cls, index_type: IndexType | None = None) -> Type[DBCaseConf
5955 return EmptyDBCaseConfig
6056
6157 @contextmanager
62- def init (self ) -> None :
63- import pinecone
64- pinecone .init (
65- api_key = self .api_key , environment = self .environment )
66- self .index = pinecone .Index (self .index_name )
58+ def init (self ):
59+ pc = pinecone .Pinecone (api_key = self .api_key )
60+ self .index = pc .Index (self .index_name )
6761 yield
68- self .index .close ()
6962
7063 def ready_to_load (self ):
7164 pass
@@ -83,11 +76,16 @@ def insert_embeddings(
8376 insert_count = 0
8477 try :
8578 for batch_start_offset in range (0 , len (embeddings ), self .batch_size ):
86- batch_end_offset = min (batch_start_offset + self .batch_size , len (embeddings ))
79+ batch_end_offset = min (
80+ batch_start_offset + self .batch_size , len (embeddings )
81+ )
8782 insert_datas = []
8883 for i in range (batch_start_offset , batch_end_offset ):
89- insert_data = (str (metadata [i ]), embeddings [i ], {
90- self ._metadata_key : metadata [i ]})
84+ insert_data = (
85+ str (metadata [i ]),
86+ embeddings [i ],
87+ {self ._metadata_key : metadata [i ]},
88+ )
9189 insert_datas .append (insert_data )
9290 self .index .upsert (insert_datas )
9391 insert_count += batch_end_offset - batch_start_offset
@@ -101,7 +99,7 @@ def search_embedding(
10199 k : int = 100 ,
102100 filters : dict | None = None ,
103101 timeout : int | None = None ,
104- ) -> list [tuple [ int , float ] ]:
102+ ) -> list [int ]:
105103 if filters is None :
106104 pinecone_filters = {}
107105 else :
@@ -111,9 +109,9 @@ def search_embedding(
111109 top_k = k ,
112110 vector = query ,
113111 filter = pinecone_filters ,
114- )[' matches' ]
112+ )[" matches" ]
115113 except Exception as e :
116114 print (f"Error querying index: { e } " )
117115 raise e
118- id_res = [int (one_res ['id' ]) for one_res in res ]
116+ id_res = [int (one_res ["id" ]) for one_res in res ]
119117 return id_res
0 commit comments