Skip to content

Commit 88c8d2f

Browse files
silas.jiangXuanYang-cnzhuwenxingthangTangjac0626
committed
[Cherry-pick] Backport 6 commits to 2.6.4
- enhance: Migrate six APIs to MilvusClient (#3045) - fix: Add support for numpy ndarrays in Array fields (#3069) (#3070) - Fix: MilvusClient.insert() does not pass **kwargs to underlying insert_rows() call (#3076) - feat: add detailed traceback to error_handler decorator - fix(async): include event loop ID in connection alias to prevent reusing closed connections (#3086) - fix: fixed the key names for retrieving segment info in the MilvusClient (#3098) Co-authored-by: XuanYang-cn <[email protected]> Co-authored-by: zhuwenxing <[email protected]> Co-authored-by: tianhang <[email protected]> Co-authored-by: jac <[email protected]> Co-authored-by: wt <[email protected]>
1 parent aac25f3 commit 88c8d2f

File tree

13 files changed

+383
-26
lines changed

13 files changed

+383
-26
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,3 +42,4 @@ uv.lock
4242

4343
# AI rules
4444
WARP.md
45+
CLAUDE.md

pymilvus/bulk_writer/constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@
6262
DataType.BFLOAT16_VECTOR.name: lambda x, dim: float16_vector_validator(x, dim, True),
6363
DataType.SPARSE_FLOAT_VECTOR.name: lambda x: sparse_vector_validator(x),
6464
DataType.INT8_VECTOR.name: lambda x, dim: int8_vector_validator(x, dim),
65-
DataType.ARRAY.name: lambda x, cap: isinstance(x, list) and len(x) <= cap,
65+
DataType.ARRAY.name: lambda x, cap: (isinstance(x, (list, np.ndarray)) and len(x) <= cap),
6666
}
6767

6868
NUMPY_TYPE_CREATOR = {

pymilvus/client/entity_helper.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,10 @@ def convert_to_array_arr(objs: List[Any], field_info: Any):
246246

247247

248248
def convert_to_array(obj: List[Any], field_info: Any):
249+
# Convert numpy ndarray to list if needed
250+
if isinstance(obj, np.ndarray):
251+
obj = obj.tolist()
252+
249253
field_data = schema_types.ScalarField()
250254
element_type = field_info.get("element_type", None)
251255
if element_type == DataType.BOOL:

pymilvus/client/grpc_handler.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1098,14 +1098,15 @@ def hybrid_search(
10981098
)
10991099

11001100
@retry_on_rpc_failure()
1101-
def get_query_segment_info(self, collection_name: str, timeout: float = 30, **kwargs):
1101+
def get_query_segment_info(
1102+
self, collection_name: str, timeout: float = 30, **kwargs
1103+
) -> List[milvus_types.QuerySegmentInfo]:
11021104
req = Prepare.get_query_segment_info_request(collection_name)
11031105
response = self._stub.GetQuerySegmentInfo(
11041106
req, timeout=timeout, metadata=_api_level_md(**kwargs)
11051107
)
1106-
status = response.status
1107-
check_status(status)
1108-
return response.infos # todo: A wrapper class of QuerySegmentInfo
1108+
check_status(response.status)
1109+
return response.infos
11091110

11101111
@retry_on_rpc_failure()
11111112
def create_alias(
@@ -1665,18 +1666,18 @@ def get_flush_state(
16651666
response = self._stub.GetFlushState(req, timeout=timeout, metadata=_api_level_md(**kwargs))
16661667
status = response.status
16671668
check_status(status)
1668-
return response.flushed # todo: A wrapper class of PersistentSegmentInfo
1669+
return response.flushed
16691670

16701671
@retry_on_rpc_failure()
16711672
def get_persistent_segment_infos(
16721673
self, collection_name: str, timeout: Optional[float] = None, **kwargs
1673-
):
1674+
) -> List[milvus_types.PersistentSegmentInfo]:
16741675
req = Prepare.get_persistent_segment_info_request(collection_name)
16751676
response = self._stub.GetPersistentSegmentInfo(
16761677
req, timeout=timeout, metadata=_api_level_md(**kwargs)
16771678
)
16781679
check_status(response.status)
1679-
return response.infos # todo: A wrapper class of PersistentSegmentInfo
1680+
return response.infos
16801681

16811682
def _wait_for_flushed(
16821683
self,

pymilvus/client/prepare.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -543,6 +543,10 @@ def _process_struct_field(
543543
struct_sub_field_info: Two-level dict [struct_name][field_name] -> field info
544544
struct_sub_fields_data: Two-level dict [struct_name][field_name] -> FieldData
545545
"""
546+
# Convert numpy ndarray to list if needed
547+
if isinstance(values, np.ndarray):
548+
values = values.tolist()
549+
546550
if not isinstance(values, list):
547551
msg = f"Field '{field_name}': Expected list, got {type(values).__name__}"
548552
raise TypeError(msg)

pymilvus/client/types.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import logging
22
import time
3+
from dataclasses import dataclass
34
from enum import IntEnum
45
from typing import Any, ClassVar, Dict, List, Optional, TypeVar, Union
56

@@ -1372,3 +1373,20 @@ def __str__(self) -> str:
13721373
return str(self.tokens)
13731374

13741375
__repr__ = __str__
1376+
1377+
1378+
@dataclass
1379+
class SegmentInfo:
1380+
segment_id: int
1381+
collection_id: int
1382+
collection_name: str
1383+
num_rows: int
1384+
is_sorted: bool
1385+
state: common_pb2.SegmentState
1386+
level: common_pb2.SegmentLevel
1387+
storage_version: int
1388+
1389+
1390+
@dataclass
1391+
class LoadedSegmentInfo(SegmentInfo):
1392+
mem_size: int

pymilvus/decorators.py

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import functools
44
import logging
55
import time
6+
import traceback
67
from typing import Any, Callable, Optional
78

89
import grpc
@@ -216,25 +217,37 @@ async def async_handler(*args, **kwargs):
216217
return await func(*args, **kwargs)
217218
except MilvusException as e:
218219
record_dict["RPC error"] = str(datetime.datetime.now())
219-
LOGGER.error(f"RPC error: [{inner_name}], {e}, <Time:{record_dict}>")
220+
tb_str = traceback.format_exc()
221+
LOGGER.error(
222+
f"RPC error: [{inner_name}], {e}, <Time:{record_dict}>\n"
223+
f"Traceback:\n{tb_str}"
224+
)
220225
raise e from e
221226
except grpc.FutureTimeoutError as e:
222227
record_dict["gRPC timeout"] = str(datetime.datetime.now())
228+
tb_str = traceback.format_exc()
223229
LOGGER.error(
224230
f"grpc Timeout: [{inner_name}], <{e.__class__.__name__}: "
225-
f"{e.code()}, {e.details()}>, <Time:{record_dict}>"
231+
f"{e.code()}, {e.details()}>, <Time:{record_dict}>\n"
232+
f"Traceback:\n{tb_str}"
226233
)
227234
raise e from e
228235
except grpc.RpcError as e:
229236
record_dict["gRPC error"] = str(datetime.datetime.now())
237+
tb_str = traceback.format_exc()
230238
LOGGER.error(
231239
f"grpc RpcError: [{inner_name}], <{e.__class__.__name__}: "
232-
f"{e.code()}, {e.details()}>, <Time:{record_dict}>"
240+
f"{e.code()}, {e.details()}>, <Time:{record_dict}>\n"
241+
f"Traceback:\n{tb_str}"
233242
)
234243
raise e from e
235244
except Exception as e:
236245
record_dict["Exception"] = str(datetime.datetime.now())
237-
LOGGER.error(f"Unexpected error: [{inner_name}], {e}, <Time: {record_dict}>")
246+
tb_str = traceback.format_exc()
247+
LOGGER.error(
248+
f"Unexpected error: [{inner_name}], {e}, <Time: {record_dict}>\n"
249+
f"Traceback:\n{tb_str}"
250+
)
238251
raise MilvusException(message=f"Unexpected error, message=<{e!s}>") from e
239252

240253
return async_handler
@@ -250,25 +263,37 @@ def handler(*args, **kwargs):
250263
return func(*args, **kwargs)
251264
except MilvusException as e:
252265
record_dict["RPC error"] = str(datetime.datetime.now())
253-
LOGGER.error(f"RPC error: [{inner_name}], {e}, <Time:{record_dict}>")
266+
tb_str = traceback.format_exc()
267+
LOGGER.error(
268+
f"RPC error: [{inner_name}], {e}, <Time:{record_dict}>\n"
269+
f"Traceback:\n{tb_str}"
270+
)
254271
raise e from e
255272
except grpc.FutureTimeoutError as e:
256273
record_dict["gRPC timeout"] = str(datetime.datetime.now())
274+
tb_str = traceback.format_exc()
257275
LOGGER.error(
258276
f"grpc Timeout: [{inner_name}], <{e.__class__.__name__}: "
259-
f"{e.code()}, {e.details()}>, <Time:{record_dict}>"
277+
f"{e.code()}, {e.details()}>, <Time:{record_dict}>\n"
278+
f"Traceback:\n{tb_str}"
260279
)
261280
raise e from e
262281
except grpc.RpcError as e:
263282
record_dict["gRPC error"] = str(datetime.datetime.now())
283+
tb_str = traceback.format_exc()
264284
LOGGER.error(
265285
f"grpc RpcError: [{inner_name}], <{e.__class__.__name__}: "
266-
f"{e.code()}, {e.details()}>, <Time:{record_dict}>"
286+
f"{e.code()}, {e.details()}>, <Time:{record_dict}>\n"
287+
f"Traceback:\n{tb_str}"
267288
)
268289
raise e from e
269290
except Exception as e:
270291
record_dict["Exception"] = str(datetime.datetime.now())
271-
LOGGER.error(f"Unexpected error: [{inner_name}], {e}, <Time: {record_dict}>")
292+
tb_str = traceback.format_exc()
293+
LOGGER.error(
294+
f"Unexpected error: [{inner_name}], {e}, <Time: {record_dict}>\n"
295+
f"Traceback:\n{tb_str}"
296+
)
272297
raise MilvusException(message=f"Unexpected error, message=<{e!s}>") from e
273298

274299
return handler

pymilvus/milvus_client/_utils.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
import hashlib
23
import logging
34

@@ -33,8 +34,15 @@ def create_connection(
3334
md5.update(token.encode())
3435
auth_fmt = f"{md5.hexdigest()}"
3536

36-
# different uri, auth, db_name cannot share the same connection
37-
not_empty = [v for v in [use_async_fmt, uri, db_name, auth_fmt] if v]
37+
# For async connections, include event loop ID in alias to prevent
38+
# reusing connections from closed event loops
39+
loop_id_fmt = ""
40+
if use_async:
41+
loop = asyncio.get_running_loop()
42+
loop_id_fmt = f"loop{id(loop)}"
43+
44+
# different uri, auth, db_name, and event loop (for async) cannot share the same connection
45+
not_empty = [v for v in [use_async_fmt, uri, db_name, auth_fmt, loop_id_fmt] if v]
3846
using = "-".join(not_empty)
3947

4048
if connections.has_connection(using):

pymilvus/milvus_client/async_milvus_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ async def insert(
287287
conn = self._get_connection()
288288
# Insert into the collection.
289289
res = await conn.insert_rows(
290-
collection_name, data, partition_name=partition_name, timeout=timeout
290+
collection_name, data, partition_name=partition_name, timeout=timeout, **kwargs
291291
)
292292
return OmitZeroDict(
293293
{

pymilvus/milvus_client/milvus_client.py

Lines changed: 119 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,14 @@
66
from pymilvus.client.embedding_list import EmbeddingList
77
from pymilvus.client.search_iterator import SearchIteratorV2
88
from pymilvus.client.types import (
9+
CompactionPlans,
910
ExceptionsMessage,
11+
LoadedSegmentInfo,
1012
LoadState,
1113
OmitZeroDict,
1214
ReplicaInfo,
1315
ResourceGroupConfig,
16+
SegmentInfo,
1417
)
1518
from pymilvus.client.utils import convert_struct_fields_to_user_format, get_params, is_vector_type
1619
from pymilvus.exceptions import (
@@ -21,7 +24,6 @@
2124
PrimaryKeyException,
2225
ServerVersionIncompatibleException,
2326
)
24-
from pymilvus.orm import utility
2527
from pymilvus.orm.collection import CollectionSchema, FieldSchema, Function, FunctionScore
2628
from pymilvus.orm.connections import connections
2729
from pymilvus.orm.constants import FIELDS, METRIC_TYPE, TYPE, UNLIMITED
@@ -67,7 +69,7 @@ def __init__(
6769
self._using = create_connection(
6870
uri, token, db_name, user=user, password=password, timeout=timeout, **kwargs
6971
)
70-
self.is_self_hosted = bool(utility.get_server_type(using=self._using) == "milvus")
72+
self.is_self_hosted = bool(self.get_server_type() == "milvus")
7173

7274
def create_collection(
7375
self,
@@ -219,7 +221,7 @@ def insert(
219221
# Insert into the collection.
220222
try:
221223
res = conn.insert_rows(
222-
collection_name, data, partition_name=partition_name, timeout=timeout
224+
collection_name, data, partition_name=partition_name, timeout=timeout, **kwargs
223225
)
224226
except Exception as ex:
225227
raise ex from ex
@@ -1808,3 +1810,117 @@ def update_replicate_configuration(
18081810
timeout=timeout,
18091811
**kwargs,
18101812
)
1813+
1814+
def flush_all(self, timeout: Optional[float] = None, **kwargs) -> None:
1815+
"""Flush all collections.
1816+
1817+
Args:
1818+
timeout (Optional[float]): An optional duration of time in seconds to allow for the RPC.
1819+
**kwargs: Additional arguments.
1820+
"""
1821+
self._get_connection().flush_all(timeout=timeout, **kwargs)
1822+
1823+
def get_flush_all_state(self, timeout: Optional[float] = None, **kwargs) -> bool:
1824+
"""Get the flush all state.
1825+
1826+
Args:
1827+
timeout (Optional[float]): An optional duration of time in seconds to allow for the RPC.
1828+
**kwargs: Additional arguments.
1829+
1830+
Returns:
1831+
bool: True if flush all operation is completed, False otherwise.
1832+
"""
1833+
return self._get_connection().get_flush_all_state(timeout=timeout, **kwargs)
1834+
1835+
def list_loaded_segments(
1836+
self,
1837+
collection_name: str,
1838+
timeout: Optional[float] = None,
1839+
**kwargs,
1840+
) -> List[LoadedSegmentInfo]:
1841+
"""List loaded segments for a collection.
1842+
1843+
Args:
1844+
collection_name (str): The name of the collection.
1845+
timeout (Optional[float]): An optional duration of time in seconds to allow for the RPC.
1846+
**kwargs: Additional arguments.
1847+
1848+
Returns:
1849+
List[LoadedSegmentInfo]: A list of loaded segment information.
1850+
"""
1851+
infos = self._get_connection().get_query_segment_info(
1852+
collection_name, timeout=timeout, **kwargs
1853+
)
1854+
return [
1855+
LoadedSegmentInfo(
1856+
info.segmentID,
1857+
info.collectionID,
1858+
collection_name,
1859+
info.num_rows,
1860+
info.is_sorted,
1861+
info.state,
1862+
info.level,
1863+
info.storage_version,
1864+
info.mem_size,
1865+
)
1866+
for info in infos
1867+
]
1868+
1869+
def list_persistent_segments(
1870+
self,
1871+
collection_name: str,
1872+
timeout: Optional[float] = None,
1873+
**kwargs,
1874+
) -> List[SegmentInfo]:
1875+
"""List persistent segments for a collection.
1876+
1877+
Args:
1878+
collection_name (str): The name of the collection.
1879+
timeout (Optional[float]): An optional duration of time in seconds to allow for the RPC.
1880+
**kwargs: Additional arguments.
1881+
1882+
Returns:
1883+
List[SegmentInfo]: A list of persistent segment information.
1884+
"""
1885+
infos = self._get_connection().get_persistent_segment_infos(
1886+
collection_name, timeout=timeout, **kwargs
1887+
)
1888+
return [
1889+
SegmentInfo(
1890+
info.segmentID,
1891+
info.collectionID,
1892+
collection_name,
1893+
info.num_rows,
1894+
info.is_sorted,
1895+
info.state,
1896+
info.level,
1897+
info.storage_version,
1898+
)
1899+
for info in infos
1900+
]
1901+
1902+
def get_server_type(self):
1903+
"""Get the server type.
1904+
1905+
Returns:
1906+
str: The server type (e.g., "milvus", "zilliz").
1907+
"""
1908+
return self._get_connection().get_server_type()
1909+
1910+
def get_compaction_plans(
1911+
self,
1912+
job_id: int,
1913+
timeout: Optional[float] = None,
1914+
**kwargs,
1915+
) -> CompactionPlans:
1916+
"""Get compaction plans for a specific job.
1917+
1918+
Args:
1919+
job_id (int): The ID of the compaction job.
1920+
timeout (Optional[float]): An optional duration of time in seconds to allow for the RPC.
1921+
**kwargs: Additional arguments.
1922+
1923+
Returns:
1924+
CompactionPlans: The compaction plans for the specified job.
1925+
"""
1926+
return self._get_connection().get_compaction_plans(job_id, timeout=timeout, **kwargs)

0 commit comments

Comments
 (0)