Skip to content

Commit 316ab19

Browse files
Merge pull request #20 from smartscanapp/update/store
Update store and removed retriever (breaking changes)
2 parents 7d23c73 + fe193c3 commit 316ab19

File tree

4 files changed

+51
-77
lines changed

4 files changed

+51
-77
lines changed

core/src/main/java/com/fpf/smartscansdk/core/embeddings/FileEmbeddingRetriever.kt

Lines changed: 0 additions & 43 deletions
This file was deleted.

core/src/main/java/com/fpf/smartscansdk/core/embeddings/FileEmbeddingStore.kt

Lines changed: 45 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import java.io.RandomAccessFile
1212
import java.nio.ByteBuffer
1313
import java.nio.ByteOrder
1414
import java.nio.channels.FileChannel
15+
import kotlin.collections.map
1516

1617
class FileEmbeddingStore(
1718
private val file: File,
@@ -25,13 +26,10 @@ class FileEmbeddingStore(
2526
}
2627

2728
private var cache: LinkedHashMap<Long, Embedding>? = null
29+
private var cachedIds: List<Long>? = null
2830

2931
override val exists: Boolean get() = file.exists()
3032

31-
override val isCached: Boolean
32-
get() = cache != null
33-
34-
3533
// prevent OOM in FileEmbeddingStore.save() by batching writes
3634
private suspend fun save(embeddingsList: List<Embedding>): Unit = withContext(Dispatchers.IO) {
3735
if (embeddingsList.isEmpty()) return@withContext
@@ -99,6 +97,19 @@ class FileEmbeddingStore(
9997
}
10098
}
10199

100+
suspend fun get(ids: List<Long>): List<Embedding> = withContext(Dispatchers.IO) {
101+
val map = cache ?: run {
102+
val all = get()
103+
LinkedHashMap(all.associateBy { it.id })
104+
}
105+
val embeddings = mutableListOf<Embedding>()
106+
107+
for (id in ids) {
108+
map.get(id)?.let { embeddings.add(it) }
109+
}
110+
embeddings
111+
}
112+
102113
override suspend fun add(newEmbeddings: List<Embedding>): Unit = withContext(Dispatchers.IO) {
103114
if (newEmbeddings.isEmpty()) return@withContext
104115

@@ -185,29 +196,42 @@ class FileEmbeddingStore(
185196
}
186197
}
187198

188-
suspend fun get(ids: List<Long>): List<Embedding> = withContext(Dispatchers.IO) {
189-
val map = cache ?: run {
190-
val all = get()
191-
LinkedHashMap(all.associateBy { it.id })
192-
}
193-
val embeddings = mutableListOf<Embedding>()
194199

195-
for (id in ids) {
196-
map.get(id)?.let { embeddings.add(it) }
197-
}
198-
embeddings
200+
override fun clear(){
201+
cache = null
199202
}
200203

201-
suspend fun get(id: Long): Embedding? = withContext(Dispatchers.IO) {
202-
val map = cache ?: run {
203-
val all = get()
204-
LinkedHashMap(all.associateBy { it.id })
204+
205+
override suspend fun query(embedding: FloatArray, topK: Int, threshold: Float): List<Embedding> {
206+
207+
cachedIds = null // clear on new search
208+
209+
val storedEmbeddings = get()
210+
211+
if (storedEmbeddings.isEmpty()) return emptyList()
212+
213+
val similarities = getSimilarities(embedding, storedEmbeddings.map { it.embeddings })
214+
val resultIndices = getTopN(similarities, topK, threshold)
215+
216+
if (resultIndices.isEmpty()) return emptyList()
217+
218+
val idsToCache = mutableListOf<Long>()
219+
val results = resultIndices.map{idx ->
220+
idsToCache.add( storedEmbeddings[idx].id)
221+
storedEmbeddings[idx]
205222
}
206-
map.get(id)
223+
cachedIds = idsToCache
224+
return results
207225
}
208226

209-
override fun clear(){
210-
cache = null
227+
suspend fun query(start: Int, end: Int): List<Embedding> {
228+
val ids = cachedIds ?: return emptyList()
229+
val s = start.coerceAtLeast(0)
230+
val e = end.coerceAtMost(ids.size)
231+
if (s >= e) return emptyList()
232+
233+
val batch = get(ids.subList(s, e))
234+
return batch
211235
}
212236

213237
}

core/src/main/java/com/fpf/smartscansdk/core/embeddings/IEmbeddingStore.kt

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,15 @@ package com.fpf.smartscansdk.core.embeddings
33
import com.fpf.smartscansdk.core.data.Embedding
44

55
interface IEmbeddingStore {
6-
val isCached: Boolean
76
val exists: Boolean
87
suspend fun add(newEmbeddings: List<Embedding>)
98
suspend fun remove(ids: List<Long>)
109
suspend fun get(): List<Embedding>
1110
fun clear()
11+
12+
suspend fun query(
13+
embedding: FloatArray,
14+
topK: Int,
15+
threshold: Float
16+
): List<Embedding>
1217
}

core/src/main/java/com/fpf/smartscansdk/core/embeddings/IRetriever.kt

Lines changed: 0 additions & 12 deletions
This file was deleted.

0 commit comments

Comments
 (0)