@@ -35,7 +35,47 @@ async def retrieve(
3535 async def batch_retrieve (
3636 self , queries : list [str ], top_k : int
3737 ) -> list [list [tuple [float , KnowledgeNode ]]]:
38- return [[]]
38+ return [
39+ [(ix , n ) for ix , n in enumerate (self .nodes [:top_k ])]
40+ for jx in range (len (queries ))
41+ ]
42+
43+ async def delete_node (self , node_id : str ) -> bool :
44+ return True
45+
46+ async def clear (self ) -> None :
47+ self .nodes .clear ()
48+
49+ async def count (self ) -> int :
50+ return len (self .nodes )
51+
52+ async def persist (self ) -> None :
53+ pass
54+
55+ async def load (self ) -> None :
56+ pass
57+
58+
59+ class DummyNoEncodeNoBatchRetrievalKnowledgeStore (
60+ BaseAsyncNoEncodeKnowledgeStore
61+ ):
62+ nodes : list [KnowledgeNode ] = []
63+
64+ async def load_node (self , node : KnowledgeNode ) -> None :
65+ self .nodes .append (node )
66+
67+ async def load_nodes (self , nodes : list [KnowledgeNode ]) -> None :
68+ await asyncio .gather (* (self .load_node (n ) for n in nodes ))
69+
70+ async def retrieve (
71+ self , query : str , top_k : int
72+ ) -> list [tuple [float , KnowledgeNode ]]:
73+ return [(ix , n ) for ix , n in enumerate (self .nodes [:top_k ])]
74+
75+ async def batch_retrieve (
76+ self , queries : list [str ], top_k : int
77+ ) -> list [list [tuple [float , KnowledgeNode ]]]:
78+ raise NotImplementedError
3979
4080 async def delete_node (self , node_id : str ) -> bool :
4181 return True
@@ -64,6 +104,24 @@ async def dummy_store() -> BaseAsyncNoEncodeKnowledgeStore:
64104 return dummy_store
65105
66106
107+ @pytest .fixture ()
108+ async def dummy_store_no_batch_retrieval () -> BaseAsyncNoEncodeKnowledgeStore :
109+ dummy_store = DummyNoEncodeNoBatchRetrievalKnowledgeStore ()
110+ nodes = [
111+ KnowledgeNode (node_type = NodeType .TEXT , text_content = "Dummy text" )
112+ for _ in range (5 )
113+ ]
114+ await dummy_store .load_nodes (nodes )
115+ return dummy_store
116+
117+
118+ @pytest .fixture ()
119+ def knowledge_store (
120+ request : pytest .FixtureRequest ,
121+ ) -> BaseAsyncNoEncodeKnowledgeStore :
122+ return request .getfixturevalue (request .param )
123+
124+
67125def test_rag_system_init (
68126 mock_generator : BaseGenerator ,
69127 dummy_store : BaseAsyncNoEncodeKnowledgeStore ,
@@ -307,17 +365,22 @@ async def test_rag_system_format_context(
307365
308366
309367@pytest .mark .asyncio
368+ @pytest .mark .parametrize (
369+ "knowledge_store" ,
370+ ["dummy_store" , "dummy_store_no_batch_retrieval" ],
371+ indirect = True ,
372+ )
310373async def test_rag_system_batch_retrieve (
374+ knowledge_store : BaseAsyncNoEncodeKnowledgeStore ,
311375 mock_generator : BaseGenerator ,
312- dummy_store : BaseAsyncNoEncodeKnowledgeStore ,
313376) -> None :
314377 # build rag system
315378 rag_config = RAGConfig (
316379 top_k = 2 ,
317380 )
318381 rag_system = AsyncNoEncodeRAGSystem (
319382 generator = mock_generator ,
320- knowledge_store = dummy_store ,
383+ knowledge_store = knowledge_store ,
321384 rag_config = rag_config ,
322385 )
323386
0 commit comments