Skip to content

Commit 01e6c49

Browse files
committed
update: add test for retriever
1 parent 9d82735 commit 01e6c49

File tree

1 file changed

+55
-0
lines changed

1 file changed

+55
-0
lines changed
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
package com.fpf.smartscansdk.extensions.embeddings
2+
3+
import com.fpf.smartscansdk.core.ml.embeddings.Embedding
4+
import kotlinx.coroutines.test.runTest
5+
import org.junit.jupiter.api.Assertions.assertEquals
6+
import org.junit.jupiter.api.Assertions.assertTrue
7+
import org.junit.jupiter.api.TestInstance
8+
import org.junit.jupiter.api.io.TempDir
9+
import org.junit.jupiter.api.Test
10+
import java.io.File
11+
12+
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
13+
class FileEmbeddingRetrieverTest {
14+
15+
@TempDir
16+
lateinit var tempDir: File
17+
18+
private val embeddingLength = 4
19+
20+
private fun embedding(id: Long, date: Long, values: FloatArray) =
21+
Embedding(id, date, values)
22+
23+
private fun createStore(fileName: String = "embeddings.bin") =
24+
FileEmbeddingStore(tempDir, fileName, embeddingLength)
25+
26+
@Test
27+
fun `query batch retrieval with start and end works`() = runTest {
28+
val store = createStore()
29+
val retriever = FileEmbeddingRetriever(store)
30+
31+
val embeddings = listOf(
32+
embedding(1, 100, floatArrayOf(0.1f, 0.2f, 0.3f, 0.4f)),
33+
embedding(2, 200, floatArrayOf(0.5f, 0.6f, 0.7f, 0.8f)),
34+
embedding(3, 300, floatArrayOf(0.9f, 1.0f, 1.1f, 1.2f))
35+
)
36+
store.add(embeddings)
37+
38+
// trigger initial query to populate cachedIds
39+
retriever.query(floatArrayOf(0.1f, 0.2f, 0.3f, 0.4f), topK = 3, threshold = 0f)
40+
41+
// fetch first two cached embeddings (order-agnostic)
42+
val batch1 = retriever.query(0, 2)
43+
assertEquals(2, batch1.size)
44+
assertTrue(batch1.map { it.id }.all { it in listOf(1L, 2L, 3L) })
45+
46+
// fetch last cached embedding
47+
val batch2 = retriever.query(2, 3)
48+
assertEquals(1, batch2.size)
49+
assertTrue(batch2[0].id in listOf(1L, 2L, 3L))
50+
51+
// out-of-bounds requests return empty
52+
val batch3 = retriever.query(3, 5)
53+
assertEquals(0, batch3.size)
54+
}
55+
}

0 commit comments

Comments
 (0)