Skip to content

Commit fcb6e8c

Browse files
alwayslove2013XuanYang-cn
authored andcommitted
feat: support turbopuffer client
Signed-off-by: min.tian <[email protected]>
1 parent fb0d120 commit fcb6e8c

File tree

6 files changed

+175
-3
lines changed

6 files changed

+175
-3
lines changed

vectordb_bench/backend/clients/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ class DB(Enum):
5555
TencentElasticsearch = "TencentElasticsearch"
5656
AliSQL = "AlibabaCloudRDSMySQL"
5757
Doris = "Doris"
58+
TurboPuffer = "TurpoBuffer"
5859

5960
@property
6061
def init_cls(self) -> type[VectorDB]: # noqa: PLR0911, PLR0912, C901, PLR0915
@@ -187,6 +188,10 @@ def init_cls(self) -> type[VectorDB]: # noqa: PLR0911, PLR0912, C901, PLR0915
187188
from .doris.doris import Doris
188189

189190
return Doris
191+
if self == DB.TurboPuffer:
192+
from .turbopuffer.turbopuffer import TurboPuffer
193+
194+
return TurboPuffer
190195

191196
if self == DB.Test:
192197
from .test.test import Test
@@ -357,6 +362,10 @@ def config_cls(self) -> type[DBConfig]: # noqa: PLR0911, PLR0912, C901, PLR0915
357362
from .doris.config import DorisConfig
358363

359364
return DorisConfig
365+
if self == DB.TurboPuffer:
366+
from .turbopuffer.config import TurboPufferConfig
367+
368+
return TurboPufferConfig
360369

361370
if self == DB.Test:
362371
from .test.config import TestConfig
@@ -537,6 +546,10 @@ def case_config_cls( # noqa: C901, PLR0911, PLR0912, PLR0915
537546
from .doris.config import DorisCaseConfig
538547

539548
return DorisCaseConfig
549+
if self == DB.TurboPuffer:
550+
from .turbopuffer.config import TurboPufferIndexConfig
551+
552+
return TurboPufferIndexConfig
540553

541554
# DB.Pinecone, DB.Chroma, DB.Redis
542555
return EmptyDBCaseConfig

vectordb_bench/backend/clients/cockroachdb/cockroachdb.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ class CockroachDB(VectorDB):
3535
FilterOp.StrEqual,
3636
]
3737

38-
def __init__(
38+
def __init__( # noqa: PLR0915
3939
self,
4040
dim: int,
4141
db_config: dict,
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
from pydantic import BaseModel, SecretStr
2+
3+
from ..api import DBCaseConfig, DBConfig, MetricType
4+
5+
6+
class TurboPufferConfig(DBConfig):
7+
api_key: SecretStr
8+
api_base_url: str
9+
namespace: str = "vdbbench_test"
10+
11+
def to_dict(self) -> dict:
12+
return {
13+
"api_key": self.api_key.get_secret_value(),
14+
"api_base_url": self.api_base_url,
15+
"namespace": self.namespace,
16+
}
17+
18+
19+
class TurboPufferIndexConfig(BaseModel, DBCaseConfig):
20+
metric_type: MetricType | None = None
21+
use_multi_ns_for_filter: bool = False
22+
time_wait_warmup: int = 60 * 1 # 1min
23+
24+
def parse_metric(self) -> str:
25+
if self.metric_type == MetricType.COSINE:
26+
return "cosine_distance"
27+
if self.metric_type == MetricType.L2:
28+
return "euclidean_squared"
29+
30+
msg = f"Not Support {self.metric_type}"
31+
raise ValueError(msg)
32+
33+
def index_param(self) -> dict:
34+
return {}
35+
36+
def search_param(self) -> dict:
37+
return {}
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
"""Wrapper around the Pinecone vector database over VectorDB"""
2+
3+
import logging
4+
import time
5+
from contextlib import contextmanager
6+
7+
import turbopuffer as tpuf
8+
9+
from vectordb_bench.backend.clients.turbopuffer.config import TurboPufferIndexConfig
10+
from vectordb_bench.backend.filter import Filter, FilterOp
11+
12+
from ..api import VectorDB
13+
14+
log = logging.getLogger(__name__)
15+
16+
17+
class TurboPuffer(VectorDB):
18+
supported_filter_types: list[FilterOp] = [
19+
FilterOp.NonFilter,
20+
FilterOp.NumGE,
21+
FilterOp.StrEqual,
22+
]
23+
24+
def __init__(
25+
self,
26+
dim: int,
27+
db_config: dict,
28+
db_case_config: TurboPufferIndexConfig,
29+
drop_old: bool = False,
30+
with_scalar_labels: bool = False,
31+
**kwargs,
32+
):
33+
"""Initialize wrapper around the milvus vector database."""
34+
self.api_key = db_config.get("api_key", "")
35+
self.api_base_url = db_config.get("api_base_url", "")
36+
self.namespace = db_config.get("namespace", "")
37+
self.db_case_config = db_case_config
38+
self.metric = db_case_config.parse_metric()
39+
40+
self._vector_field = "vector"
41+
self._scalar_id_field = "id"
42+
self._scalar_label_field = "label"
43+
44+
self.with_scalar_labels = with_scalar_labels
45+
if drop_old:
46+
log.info(f"Drop old. delete the namespace: {self.namespace}")
47+
tpuf.api_key = self.api_key
48+
tpuf.api_base_url = self.api_base_url
49+
ns = tpuf.Namespace(self.namespace)
50+
try:
51+
ns.delete_all()
52+
except Exception as e:
53+
log.warning(f"Failed to delete all. Error: {e}")
54+
55+
@contextmanager
56+
def init(self):
57+
tpuf.api_key = self.api_key
58+
tpuf.api_base_url = self.api_base_url
59+
self.ns = tpuf.Namespace(self.namespace)
60+
yield
61+
62+
def optimize(self, data_size: int | None = None):
63+
# turbopuffer responds to the request
64+
# once the cache warming operation has been started.
65+
# It does not wait for the operation to complete,
66+
# which can take multiple minutes for large namespaces.
67+
self.ns.hint_cache_warm()
68+
log.info(f"warming up but no api waiting for complete. just sleep {self.db_case_config.time_wait_warmup}s")
69+
time.sleep(self.db_case_config.time_wait_warmup)
70+
71+
def insert_embeddings(
72+
self,
73+
embeddings: list[list[float]],
74+
metadata: list[int],
75+
labels_data: list[str] | None = None,
76+
**kwargs,
77+
) -> tuple[int, Exception]:
78+
try:
79+
if self.with_scalar_labels:
80+
self.ns.write(
81+
upsert_columns={
82+
self._scalar_id_field: metadata,
83+
self._vector_field: embeddings,
84+
self._scalar_label_field: labels_data,
85+
},
86+
distance_metric=self.metric,
87+
)
88+
else:
89+
self.ns.write(
90+
upsert_columns={
91+
self._scalar_id_field: metadata,
92+
self._vector_field: embeddings,
93+
},
94+
distance_metric=self.metric,
95+
)
96+
except Exception as e:
97+
log.warning(f"Failed to insert. Error: {e}")
98+
return len(embeddings), None
99+
100+
def search_embedding(
101+
self,
102+
query: list[float],
103+
k: int = 100,
104+
timeout: int | None = None,
105+
) -> list[int]:
106+
res = self.ns.query(
107+
rank_by=["vector", "ANN", query],
108+
top_k=k,
109+
filters=self.expr,
110+
)
111+
return [row.id for row in res.rows]
112+
113+
def prepare_filter(self, filters: Filter):
114+
if filters.type == FilterOp.NonFilter:
115+
self.expr = None
116+
elif filters.type == FilterOp.NumGE:
117+
self.expr = [self._scalar_id_field, "Gte", filters.int_value]
118+
elif filters.type == FilterOp.StrEqual:
119+
self.expr = [self._scalar_label_field, "Eq", filters.label_value]
120+
else:
121+
msg = f"Not support Filter for TurboPuffer - {filters}"
122+
raise ValueError(msg)

vectordb_bench/backend/runner/rate_runner.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
from vectordb_bench import config
99
from vectordb_bench.backend.clients import api
10-
from vectordb_bench.backend.clients.doris.doris import Doris
1110
from vectordb_bench.backend.dataset import DataSetIterator
1211
from vectordb_bench.backend.utils import time_it
1312

@@ -54,7 +53,7 @@ def _insert_embeddings(db: api.VectorDB, emb: list[list[float]], metadata: list[
5453
db_copy = deepcopy(db)
5554
with db_copy.init():
5655
_insert_embeddings(db_copy, emb, metadata, retry_idx=0)
57-
elif isinstance(db, Doris):
56+
elif db.name == "Doris":
5857
# DorisVectorClient is not thread-safe. Similar to pgvector, create a per-thread client
5958
# by deep-copying the wrapper and forcing lazy re-init inside the thread.
6059
db_copy = deepcopy(db)

vectordb_bench/frontend/config/styles.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def getPatternShape(i):
6969
DB.S3Vectors: "https://assets.zilliz.com/s3_vectors_daf370b4e5.png",
7070
DB.Hologres: "https://img.alicdn.com/imgextra/i3/O1CN01d9qrry1i6lTNa2BRa_!!6000000004364-2-tps-218-200.png",
7171
DB.Doris: "https://doris.apache.org/images/logo.svg",
72+
DB.TurboPuffer: "https://turbopuffer.com/logo2.png",
7273
}
7374

7475
# RedisCloud color: #0D6EFD

0 commit comments

Comments
 (0)