@@ -12,6 +12,7 @@ import java.io.RandomAccessFile
1212import java.nio.ByteBuffer
1313import java.nio.ByteOrder
1414import java.nio.channels.FileChannel
15+ import kotlin.collections.map
1516
1617class FileEmbeddingStore (
1718 private val file : File ,
@@ -25,13 +26,10 @@ class FileEmbeddingStore(
2526 }
2627
2728 private var cache: LinkedHashMap <Long , Embedding >? = null
29+ private var cachedIds: List <Long >? = null
2830
2931 override val exists: Boolean get() = file.exists()
3032
31- override val isCached: Boolean
32- get() = cache != null
33-
34-
3533 // prevent OOM in FileEmbeddingStore.save() by batching writes
3634 private suspend fun save (embeddingsList : List <Embedding >): Unit = withContext(Dispatchers .IO ) {
3735 if (embeddingsList.isEmpty()) return @withContext
@@ -99,6 +97,19 @@ class FileEmbeddingStore(
9997 }
10098 }
10199
100+ suspend fun get (ids : List <Long >): List <Embedding > = withContext(Dispatchers .IO ) {
101+ val map = cache ? : run {
102+ val all = get()
103+ LinkedHashMap (all.associateBy { it.id })
104+ }
105+ val embeddings = mutableListOf<Embedding >()
106+
107+ for (id in ids) {
108+ map.get(id)?.let { embeddings.add(it) }
109+ }
110+ embeddings
111+ }
112+
102113 override suspend fun add (newEmbeddings : List <Embedding >): Unit = withContext(Dispatchers .IO ) {
103114 if (newEmbeddings.isEmpty()) return @withContext
104115
@@ -185,29 +196,42 @@ class FileEmbeddingStore(
185196 }
186197 }
187198
188- suspend fun get (ids : List <Long >): List <Embedding > = withContext(Dispatchers .IO ) {
189- val map = cache ? : run {
190- val all = get()
191- LinkedHashMap (all.associateBy { it.id })
192- }
193- val embeddings = mutableListOf<Embedding >()
194199
195- for (id in ids) {
196- map.get(id)?.let { embeddings.add(it) }
197- }
198- embeddings
200+ override fun clear (){
201+ cache = null
199202 }
200203
201- suspend fun get (id : Long ): Embedding ? = withContext(Dispatchers .IO ) {
202- val map = cache ? : run {
203- val all = get()
204- LinkedHashMap (all.associateBy { it.id })
204+
205+ override suspend fun query (embedding : FloatArray , topK : Int , threshold : Float ): List <Embedding > {
206+
207+ cachedIds = null // clear on new search
208+
209+ val storedEmbeddings = get()
210+
211+ if (storedEmbeddings.isEmpty()) return emptyList()
212+
213+ val similarities = getSimilarities(embedding, storedEmbeddings.map { it.embeddings })
214+ val resultIndices = getTopN(similarities, topK, threshold)
215+
216+ if (resultIndices.isEmpty()) return emptyList()
217+
218+ val idsToCache = mutableListOf<Long >()
219+ val results = resultIndices.map{idx ->
220+ idsToCache.add( storedEmbeddings[idx].id)
221+ storedEmbeddings[idx]
205222 }
206- map.get(id)
223+ cachedIds = idsToCache
224+ return results
207225 }
208226
209- override fun clear (){
210- cache = null
227+ suspend fun query (start : Int , end : Int ): List <Embedding > {
228+ val ids = cachedIds ? : return emptyList()
229+ val s = start.coerceAtLeast(0 )
230+ val e = end.coerceAtMost(ids.size)
231+ if (s >= e) return emptyList()
232+
233+ val batch = get(ids.subList(s, e))
234+ return batch
211235 }
212236
213237}
0 commit comments