Skip to content

Commit 71d1478

Browse files
authored
Update default batch_retrieve for RAG system classes (#441)
* default batch retrieve for rag systems * coverage * more coverage * parametrize no encode * add not implemented batch retrieval async rag system test * changelog
1 parent 219c65b commit 71d1478

File tree

10 files changed

+238
-36
lines changed

10 files changed

+238
-36
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/).
1212

1313
### Added
1414

15+
- Update `batch_retrieve` for RAGSystems to use `batch_retrieve` of Knowledge Stores if implemented (#441)
1516
- Implement `batch_retrieve` for Qdrant sync/async knowledge stores (#439)
1617
- Add `batch_retrieve` to KnowledgeStore classes that raise `NotImplementedError` by default (#436)
1718
- Add batch methods for RAGSystem (#270)

src/fed_rag/core/no_encode_rag_system/_asynchronous.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -66,14 +66,18 @@ async def batch_retrieve(
6666
self, queries: list[str]
6767
) -> list[list[SourceNode]]:
6868
"""Batch retrieve from KnowledgeStore."""
69-
# TODO: move this to knowledge store batch retrieve once implemented
70-
raw_retrieval_tasks = [
71-
self.knowledge_store.retrieve(
72-
query=query, top_k=self.rag_config.top_k
69+
try:
70+
raw_retrieval_results = await self.knowledge_store.batch_retrieve(
71+
queries=queries, top_k=self.rag_config.top_k
7372
)
74-
for query in queries
75-
]
76-
raw_retrieval_results = await asyncio.gather(*raw_retrieval_tasks)
73+
except NotImplementedError:
74+
raw_retrieval_tasks = [
75+
self.knowledge_store.retrieve(
76+
query=query, top_k=self.rag_config.top_k
77+
)
78+
for query in queries
79+
]
80+
raw_retrieval_results = await asyncio.gather(*raw_retrieval_tasks)
7781
return [
7882
[SourceNode(score=el[0], node=el[1]) for el in raw_result]
7983
for raw_result in raw_retrieval_results

src/fed_rag/core/no_encode_rag_system/_synchronous.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -63,13 +63,17 @@ def retrieve(self, query: str) -> list[SourceNode]:
6363

6464
def batch_retrieve(self, queries: list[str]) -> list[list[SourceNode]]:
6565
"""Batch retrieve from NoEncodeKnowledgeStore."""
66-
# TODO: move this to knowledge store batch retrieve once implemented
67-
raw_retrieval_results = [
68-
self.knowledge_store.retrieve(
69-
query=query, top_k=self.rag_config.top_k
66+
try:
67+
raw_retrieval_results = self.knowledge_store.batch_retrieve(
68+
queries=queries, top_k=self.rag_config.top_k
7069
)
71-
for query in queries
72-
]
70+
except NotImplementedError:
71+
raw_retrieval_results = [
72+
self.knowledge_store.retrieve(
73+
query=query, top_k=self.rag_config.top_k
74+
)
75+
for query in queries
76+
]
7377
return [
7478
[SourceNode(score=el[0], node=el[1]) for el in raw_result]
7579
for raw_result in raw_retrieval_results

src/fed_rag/core/rag_system/_asynchronous.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,6 @@ async def batch_query(self, queries: list[str]) -> list[RAGResponse]:
5555
async def retrieve(self, query: str) -> list[SourceNode]:
5656
"""Retrieve from KnowledgeStore."""
5757
query_emb: list[float] = self.retriever.encode_query(query).tolist()
58-
# TODO: move this to knowledge store batch retrieve once implemented
5958
raw_retrieval_result = await self.knowledge_store.retrieve(
6059
query_emb=query_emb, top_k=self.rag_config.top_k
6160
)
@@ -70,13 +69,19 @@ async def batch_retrieve(
7069
query_embs: list[list[float]] = self.retriever.encode_query(
7170
queries
7271
).tolist()
73-
raw_retrieval_tasks = [
74-
self.knowledge_store.retrieve(
75-
query_emb=query_emb, top_k=self.rag_config.top_k
72+
try:
73+
raw_retrieval_results = await self.knowledge_store.batch_retrieve(
74+
query_embs=query_embs, top_k=self.rag_config.top_k
7675
)
77-
for query_emb in query_embs
78-
]
79-
raw_retrieval_results = await asyncio.gather(*raw_retrieval_tasks)
76+
except NotImplementedError:
77+
raw_retrieval_tasks = [
78+
self.knowledge_store.retrieve(
79+
query_emb=query_emb, top_k=self.rag_config.top_k
80+
)
81+
for query_emb in query_embs
82+
]
83+
raw_retrieval_results = await asyncio.gather(*raw_retrieval_tasks)
84+
8085
return [
8186
[SourceNode(score=el[0], node=el[1]) for el in raw_result]
8287
for raw_result in raw_retrieval_results

src/fed_rag/core/rag_system/_synchronous.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -66,13 +66,18 @@ def batch_retrieve(self, queries: list[str]) -> list[list[SourceNode]]:
6666
query_embs: list[list[float]] = self.retriever.encode_query(
6767
queries
6868
).tolist()
69-
# TODO: move this to knowledge store batch retrieve once implemented
70-
raw_retrieval_results = [
71-
self.knowledge_store.retrieve(
72-
query_emb=query_emb, top_k=self.rag_config.top_k
69+
try:
70+
raw_retrieval_results = self.knowledge_store.batch_retrieve(
71+
query_embs=query_embs, top_k=self.rag_config.top_k
7372
)
74-
for query_emb in query_embs
75-
]
73+
except NotImplementedError:
74+
raw_retrieval_results = [
75+
self.knowledge_store.retrieve(
76+
query_emb=query_emb, top_k=self.rag_config.top_k
77+
)
78+
for query_emb in query_embs
79+
]
80+
7681
return [
7782
[SourceNode(score=el[0], node=el[1]) for el in raw_result]
7883
for raw_result in raw_retrieval_results

src/fed_rag/knowledge_stores/in_memory.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@ def cosine_sim(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
4646
similarities = similarities.to("cpu")
4747
similarities = similarities.tolist()[0]
4848
zipped = list(zip(nodes, similarities))
49-
# scores.sort(key=lambda tup: tup[1], reverse=True)
5049
sorted_similarities = sorted(zipped, key=lambda row: row[1], reverse=True)
5150
return sorted_similarities[:top_k]
5251

@@ -83,7 +82,6 @@ def load_nodes(self, nodes: list[KnowledgeNode]) -> None:
8382
def retrieve(
8483
self, query_emb: list[float], top_k: int = DEFAULT_TOP_K
8584
) -> list[tuple[float, KnowledgeNode]]:
86-
# all_nodes = list(self._data.values())
8785
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8886
query_emb = torch.tensor(query_emb).to(device)
8987
if not torch.is_tensor(self._data_storage):

tests/rag_system/conftest.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,38 @@ async def load(self) -> None:
101101
pass
102102

103103

104+
class DummyAsyncNoBatchRetrievalKnowledgeStore(BaseAsyncKnowledgeStore):
105+
nodes: list[KnowledgeNode] = []
106+
107+
async def load_node(self, node: KnowledgeNode) -> None:
108+
self.nodes.append(node)
109+
110+
async def retrieve(
111+
self, query_emb: list[float], top_k: int
112+
) -> list[tuple[float, KnowledgeNode]]:
113+
return []
114+
115+
async def batch_retrieve(
116+
self, query_embs: list[list[float]], top_k: int
117+
) -> list[list[tuple[float, KnowledgeNode]]]:
118+
raise NotImplementedError
119+
120+
async def delete_node(self, node_id: str) -> bool:
121+
return True
122+
123+
async def clear(self) -> None:
124+
self.nodes.clear()
125+
126+
async def count(self) -> int:
127+
return len(self.nodes)
128+
129+
async def persist(self) -> None:
130+
pass
131+
132+
async def load(self) -> None:
133+
pass
134+
135+
104136
@pytest.fixture
105137
def mock_retriever() -> MockRetriever:
106138
return MockRetriever()

tests/rag_system/test_async_no_encode_rag_system.py

Lines changed: 66 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,47 @@ async def retrieve(
3535
async def batch_retrieve(
3636
self, queries: list[str], top_k: int
3737
) -> list[list[tuple[float, KnowledgeNode]]]:
38-
return [[]]
38+
return [
39+
[(ix, n) for ix, n in enumerate(self.nodes[:top_k])]
40+
for jx in range(len(queries))
41+
]
42+
43+
async def delete_node(self, node_id: str) -> bool:
44+
return True
45+
46+
async def clear(self) -> None:
47+
self.nodes.clear()
48+
49+
async def count(self) -> int:
50+
return len(self.nodes)
51+
52+
async def persist(self) -> None:
53+
pass
54+
55+
async def load(self) -> None:
56+
pass
57+
58+
59+
class DummyNoEncodeNoBatchRetrievalKnowledgeStore(
60+
BaseAsyncNoEncodeKnowledgeStore
61+
):
62+
nodes: list[KnowledgeNode] = []
63+
64+
async def load_node(self, node: KnowledgeNode) -> None:
65+
self.nodes.append(node)
66+
67+
async def load_nodes(self, nodes: list[KnowledgeNode]) -> None:
68+
await asyncio.gather(*(self.load_node(n) for n in nodes))
69+
70+
async def retrieve(
71+
self, query: str, top_k: int
72+
) -> list[tuple[float, KnowledgeNode]]:
73+
return [(ix, n) for ix, n in enumerate(self.nodes[:top_k])]
74+
75+
async def batch_retrieve(
76+
self, queries: list[str], top_k: int
77+
) -> list[list[tuple[float, KnowledgeNode]]]:
78+
raise NotImplementedError
3979

4080
async def delete_node(self, node_id: str) -> bool:
4181
return True
@@ -64,6 +104,24 @@ async def dummy_store() -> BaseAsyncNoEncodeKnowledgeStore:
64104
return dummy_store
65105

66106

107+
@pytest.fixture()
108+
async def dummy_store_no_batch_retrieval() -> BaseAsyncNoEncodeKnowledgeStore:
109+
dummy_store = DummyNoEncodeNoBatchRetrievalKnowledgeStore()
110+
nodes = [
111+
KnowledgeNode(node_type=NodeType.TEXT, text_content="Dummy text")
112+
for _ in range(5)
113+
]
114+
await dummy_store.load_nodes(nodes)
115+
return dummy_store
116+
117+
118+
@pytest.fixture()
119+
def knowledge_store(
120+
request: pytest.FixtureRequest,
121+
) -> BaseAsyncNoEncodeKnowledgeStore:
122+
return request.getfixturevalue(request.param)
123+
124+
67125
def test_rag_system_init(
68126
mock_generator: BaseGenerator,
69127
dummy_store: BaseAsyncNoEncodeKnowledgeStore,
@@ -307,17 +365,22 @@ async def test_rag_system_format_context(
307365

308366

309367
@pytest.mark.asyncio
368+
@pytest.mark.parametrize(
369+
"knowledge_store",
370+
["dummy_store", "dummy_store_no_batch_retrieval"],
371+
indirect=True,
372+
)
310373
async def test_rag_system_batch_retrieve(
374+
knowledge_store: BaseAsyncNoEncodeKnowledgeStore,
311375
mock_generator: BaseGenerator,
312-
dummy_store: BaseAsyncNoEncodeKnowledgeStore,
313376
) -> None:
314377
# build rag system
315378
rag_config = RAGConfig(
316379
top_k=2,
317380
)
318381
rag_system = AsyncNoEncodeRAGSystem(
319382
generator=mock_generator,
320-
knowledge_store=dummy_store,
383+
knowledge_store=knowledge_store,
321384
rag_config=rag_config,
322385
)
323386

tests/rag_system/test_async_rag_system.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,34 @@
55

66
from fed_rag import AsyncRAGSystem, RAGConfig
77
from fed_rag.base.generator import BaseGenerator
8+
from fed_rag.base.knowledge_store import BaseAsyncKnowledgeStore
89
from fed_rag.base.retriever import BaseRetriever
910
from fed_rag.data_structures import KnowledgeNode, SourceNode
1011
from fed_rag.exceptions import RAGSystemError
1112

12-
from .conftest import DummyAsyncKnowledgeStore, MockGenerator, MockRetriever
13+
from .conftest import (
14+
DummyAsyncKnowledgeStore,
15+
DummyAsyncNoBatchRetrievalKnowledgeStore,
16+
MockGenerator,
17+
MockRetriever,
18+
)
19+
20+
21+
@pytest.fixture()
22+
def dummy_store() -> BaseAsyncKnowledgeStore:
23+
return DummyAsyncKnowledgeStore()
24+
25+
26+
@pytest.fixture()
27+
def dummy_store_no_batch_retrieval() -> BaseAsyncKnowledgeStore:
28+
return DummyAsyncNoBatchRetrievalKnowledgeStore()
29+
30+
31+
@pytest.fixture()
32+
def knowledge_store(
33+
request: pytest.FixtureRequest,
34+
) -> BaseAsyncKnowledgeStore:
35+
return request.getfixturevalue(request.param)
1336

1437

1538
@pytest.mark.asyncio
@@ -208,9 +231,15 @@ async def test_rag_system_retrieve(
208231

209232

210233
@pytest.mark.asyncio
234+
@pytest.mark.parametrize(
235+
"knowledge_store",
236+
["dummy_store", "dummy_store_no_batch_retrieval"],
237+
indirect=True,
238+
)
211239
@patch.object(MockRetriever, "encode_query")
212240
async def test_rag_system_batch_retrieve(
213241
mock_encode_query: MagicMock,
242+
knowledge_store: BaseAsyncKnowledgeStore,
214243
mock_generator: BaseGenerator,
215244
mock_retriever: MockRetriever,
216245
knowledge_nodes: list[KnowledgeNode],
@@ -221,7 +250,6 @@ async def test_rag_system_batch_retrieve(
221250
)
222251

223252
# build rag system
224-
knowledge_store = DummyAsyncKnowledgeStore()
225253
await knowledge_store.load_nodes(nodes=knowledge_nodes)
226254
rag_config = RAGConfig(
227255
top_k=2,

0 commit comments

Comments
 (0)