Skip to content

Commit 0f873cc

Browse files
Merge pull request #5 from dev-diaries41/feat/query
feat: support retrieval using start and end indices
2 parents 3983381 + 01e6c49 commit 0f873cc

File tree

5 files changed

+110
-18
lines changed

5 files changed

+110
-18
lines changed

core/src/main/java/com/fpf/smartscansdk/core/ml/embeddings/EmbeddingTypes.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ interface IEmbeddingStore {
2222
val exists: Boolean
2323
suspend fun add(newEmbeddings: List<Embedding>)
2424
suspend fun remove(ids: List<Long>)
25-
suspend fun getAll(): List<Embedding> // getAll used instead of get to make clear that loading full index in memory is required
25+
suspend fun get(): List<Embedding>
2626
fun clear()
2727
}
2828

extensions/src/main/java/com/fpf/smartscansdk/extensions/embeddings/FileEmbeddingRetriever.kt

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,26 +8,43 @@ import com.fpf.smartscansdk.core.ml.embeddings.getTopN
88
class FileEmbeddingRetriever(
99
private val store: FileEmbeddingStore
1010
): IRetriever {
11+
12+
private var cachedIds: List<Long>? = null
13+
1114
override suspend fun query(
1215
embedding: FloatArray,
1316
topK: Int,
1417
threshold: Float
1518
): List<Embedding> {
1619

17-
val storedEmbeddings = store.getAll()
20+
cachedIds = null // clear on new search
1821

19-
if (storedEmbeddings.isEmpty()) {
20-
return emptyList()
21-
}
22+
val storedEmbeddings = store.get()
23+
24+
if (storedEmbeddings.isEmpty()) return emptyList()
2225

2326
val similarities = getSimilarities(embedding, storedEmbeddings.map { it.embeddings })
24-
val results = getTopN(similarities, topK, threshold)
27+
val resultIndices = getTopN(similarities, topK, threshold)
2528

26-
if (results.isEmpty()) {
27-
return emptyList()
29+
if (resultIndices.isEmpty()) return emptyList()
30+
31+
val idsToCache = mutableListOf<Long>()
32+
val results = resultIndices.map{idx ->
33+
idsToCache.add( storedEmbeddings[idx].id)
34+
storedEmbeddings[idx]
2835
}
36+
cachedIds = idsToCache
37+
return results
38+
}
39+
40+
suspend fun query(start: Int, end: Int): List<Embedding> {
41+
val ids = cachedIds ?: return emptyList()
42+
val s = start.coerceAtLeast(0)
43+
val e = end.coerceAtMost(ids.size)
44+
if (s >= e) return emptyList()
2945

30-
return results.map{idx -> storedEmbeddings[idx]}
46+
val batch = store.get(ids.subList(s, e))
47+
return batch
3148
}
3249
}
3350

extensions/src/main/java/com/fpf/smartscansdk/extensions/embeddings/FileEmbeddingStore.kt

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ class FileEmbeddingStore(
7878
}
7979

8080
// This explicitly makes clear the design constraints that requires the full index to be loaded in memory
81-
override suspend fun getAll(): List<Embedding> = withContext(Dispatchers.IO){
81+
override suspend fun get(): List<Embedding> = withContext(Dispatchers.IO){
8282
cache?.let { return@withContext it.values.toList() };
8383

8484
FileInputStream(file).channel.use { ch ->
@@ -170,7 +170,7 @@ class FileEmbeddingStore(
170170
try {
171171
val map = cache ?: run {
172172
// Load all embeddings into the map if cache is empty
173-
val all = getAll()
173+
val all = get()
174174
LinkedHashMap(all.associateBy { it.id })
175175
}
176176

@@ -188,6 +188,26 @@ class FileEmbeddingStore(
188188
}
189189
}
190190

191+
suspend fun get(ids: List<Long>): List<Embedding> = withContext(Dispatchers.IO) {
192+
val map = cache ?: run {
193+
val all = get()
194+
LinkedHashMap(all.associateBy { it.id })
195+
}
196+
val embeddings = mutableListOf<Embedding>()
197+
198+
for (id in ids) {
199+
map.get(id)?.let { embeddings.add(it) }
200+
}
201+
embeddings
202+
}
203+
204+
suspend fun get(id: Long): Embedding? = withContext(Dispatchers.IO) {
205+
val map = cache ?: run {
206+
val all = get()
207+
LinkedHashMap(all.associateBy { it.id })
208+
}
209+
map.get(id)
210+
}
191211

192212
override fun clear(){
193213
cache = null
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
package com.fpf.smartscansdk.extensions.embeddings
2+
3+
import com.fpf.smartscansdk.core.ml.embeddings.Embedding
4+
import kotlinx.coroutines.test.runTest
5+
import org.junit.jupiter.api.Assertions.assertEquals
6+
import org.junit.jupiter.api.Assertions.assertTrue
7+
import org.junit.jupiter.api.TestInstance
8+
import org.junit.jupiter.api.io.TempDir
9+
import org.junit.jupiter.api.Test
10+
import java.io.File
11+
12+
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
13+
class FileEmbeddingRetrieverTest {
14+
15+
@TempDir
16+
lateinit var tempDir: File
17+
18+
private val embeddingLength = 4
19+
20+
private fun embedding(id: Long, date: Long, values: FloatArray) =
21+
Embedding(id, date, values)
22+
23+
private fun createStore(fileName: String = "embeddings.bin") =
24+
FileEmbeddingStore(tempDir, fileName, embeddingLength)
25+
26+
@Test
27+
fun `query batch retrieval with start and end works`() = runTest {
28+
val store = createStore()
29+
val retriever = FileEmbeddingRetriever(store)
30+
31+
val embeddings = listOf(
32+
embedding(1, 100, floatArrayOf(0.1f, 0.2f, 0.3f, 0.4f)),
33+
embedding(2, 200, floatArrayOf(0.5f, 0.6f, 0.7f, 0.8f)),
34+
embedding(3, 300, floatArrayOf(0.9f, 1.0f, 1.1f, 1.2f))
35+
)
36+
store.add(embeddings)
37+
38+
// trigger initial query to populate cachedIds
39+
retriever.query(floatArrayOf(0.1f, 0.2f, 0.3f, 0.4f), topK = 3, threshold = 0f)
40+
41+
// fetch first two cached embeddings (order-agnostic)
42+
val batch1 = retriever.query(0, 2)
43+
assertEquals(2, batch1.size)
44+
assertTrue(batch1.map { it.id }.all { it in listOf(1L, 2L, 3L) })
45+
46+
// fetch last cached embedding
47+
val batch2 = retriever.query(2, 3)
48+
assertEquals(1, batch2.size)
49+
assertTrue(batch2[0].id in listOf(1L, 2L, 3L))
50+
51+
// out-of-bounds requests return empty
52+
val batch3 = retriever.query(3, 5)
53+
assertEquals(0, batch3.size)
54+
}
55+
}

extensions/src/test/kotlin/com/fpf/smartscansdk/extensions/embeddings/FileEmbeddingStoreTest.kt

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ class FileEmbeddingStoreTest {
5151
)
5252

5353
store.add(embeddings)
54-
val loaded = store.getAll()
54+
val loaded = store.get()
5555

5656
Assertions.assertEquals(2, loaded.size)
5757
Assertions.assertEquals(embeddings[0].id, loaded[0].id)
@@ -72,7 +72,7 @@ class FileEmbeddingStoreTest {
7272
store.add(first)
7373
store.add(second)
7474

75-
val all = store.getAll()
75+
val all = store.get()
7676
Assertions.assertEquals(2, all.size)
7777
Assertions.assertEquals(1, all[0].id)
7878
Assertions.assertEquals(2, all[1].id)
@@ -90,7 +90,7 @@ class FileEmbeddingStoreTest {
9090
store.add(embeddings)
9191
store.remove(listOf(2L))
9292

93-
val remaining = store.getAll()
93+
val remaining = store.get()
9494
Assertions.assertEquals(2, remaining.size)
9595
Assertions.assertFalse(remaining.any { it.id == 2L })
9696
}
@@ -101,7 +101,7 @@ class FileEmbeddingStoreTest {
101101
val embeddings = listOf(embedding(1, 100, FloatArray(embeddingLength) { 0.1f }))
102102
store.add(embeddings)
103103

104-
val firstLoad = store.getAll()
104+
val firstLoad = store.get()
105105
if(store.useCache){
106106
Assertions.assertTrue(store.isCached)
107107
}else{
@@ -111,7 +111,7 @@ class FileEmbeddingStoreTest {
111111
store.clear()
112112
Assertions.assertFalse(store.isCached)
113113

114-
val secondLoad = store.getAll()
114+
val secondLoad = store.get()
115115
assertTrue(firstLoad.zip(secondLoad).all { (a, b) ->
116116
a.id == b.id && a.date == b.date && a.embeddings.contentEquals(b.embeddings)
117117
})
@@ -123,13 +123,13 @@ class FileEmbeddingStoreTest {
123123
val embeddings = listOf(embedding(1, 100, FloatArray(embeddingLength) { 0.1f }))
124124
store.add(embeddings)
125125

126-
val firstLoad = store.getAll()
126+
val firstLoad = store.get()
127127
Assertions.assertFalse(store.isCached)
128128

129129
store.clear()
130130
Assertions.assertFalse(store.isCached)
131131

132-
val secondLoad = store.getAll()
132+
val secondLoad = store.get()
133133
assertTrue(firstLoad.zip(secondLoad).all { (a, b) ->
134134
a.id == b.id && a.date == b.date && a.embeddings.contentEquals(b.embeddings)
135135
})

0 commit comments

Comments
 (0)