Skip to content

Commit 1391d09

Browse files
committed
update: add get overloads
1 parent d33a6a6 commit 1391d09

File tree

1 file changed

+22
-2
lines changed

1 file changed

+22
-2
lines changed

extensions/src/main/java/com/fpf/smartscansdk/extensions/embeddings/FileEmbeddingStore.kt

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ class FileEmbeddingStore(
7878
}
7979

8080
// This explicitly makes clear the design constraints that requires the full index to be loaded in memory
81-
override suspend fun getAll(): List<Embedding> = withContext(Dispatchers.IO){
81+
override suspend fun get(): List<Embedding> = withContext(Dispatchers.IO){
8282
cache?.let { return@withContext it.values.toList() };
8383

8484
FileInputStream(file).channel.use { ch ->
@@ -170,7 +170,7 @@ class FileEmbeddingStore(
170170
try {
171171
val map = cache ?: run {
172172
// Load all embeddings into the map if cache is empty
173-
val all = getAll()
173+
val all = get()
174174
LinkedHashMap(all.associateBy { it.id })
175175
}
176176

@@ -188,6 +188,26 @@ class FileEmbeddingStore(
188188
}
189189
}
190190

191+
suspend fun get(ids: List<Long>): List<Embedding> = withContext(Dispatchers.IO) {
192+
val map = cache ?: run {
193+
val all = get()
194+
LinkedHashMap(all.associateBy { it.id })
195+
}
196+
val embeddings = mutableListOf<Embedding>()
197+
198+
for (id in ids) {
199+
map.get(id)?.let { embeddings.add(it) }
200+
}
201+
embeddings
202+
}
203+
204+
suspend fun get(id: Long): Embedding? = withContext(Dispatchers.IO) {
205+
val map = cache ?: run {
206+
val all = get()
207+
LinkedHashMap(all.associateBy { it.id })
208+
}
209+
map.get(id)
210+
}
191211

192212
override fun clear(){
193213
cache = null

0 commit comments

Comments
 (0)