1+ package com.fpf.smartscansdk.ml.models.providers.embeddings.dino
2+
3+ import android.app.Application
4+ import android.content.Context
5+ import android.graphics.Bitmap
6+ import androidx.core.graphics.get
7+ import com.fpf.smartscansdk.core.embeddings.ImageEmbeddingProvider
8+ import com.fpf.smartscansdk.core.embeddings.normalizeL2
9+ import com.fpf.smartscansdk.core.media.centerCrop
10+ import com.fpf.smartscansdk.core.processors.BatchProcessor
11+ import com.fpf.smartscansdk.ml.data.FilePath
12+ import com.fpf.smartscansdk.ml.data.ModelSource
13+ import com.fpf.smartscansdk.ml.data.ResourceId
14+ import com.fpf.smartscansdk.ml.data.TensorData
15+ import com.fpf.smartscansdk.ml.models.FileOnnxLoader
16+ import com.fpf.smartscansdk.ml.models.OnnxModel
17+ import com.fpf.smartscansdk.ml.models.ResourceOnnxLoader
18+ import kotlinx.coroutines.Dispatchers
19+ import kotlinx.coroutines.withContext
20+ import java.nio.ByteBuffer
21+ import java.nio.ByteOrder
22+ import java.nio.FloatBuffer
23+ import androidx.core.graphics.scale
24+
25+
26+ class DinoV2SmallImageEmbedder (
27+ private val context : Context ,
28+ modelSource : ModelSource ,
29+ ) : ImageEmbeddingProvider {
30+
31+ companion object {
32+ const val DIM_BATCH_SIZE = 1
33+ const val DIM_PIXEL_SIZE = 3
34+ const val IMAGE_SIZE_X = 224
35+ const val IMAGE_SIZE_Y = 224
36+ val MEAN = floatArrayOf(0.485f , 0.456f , 0.406f )
37+ val STD = floatArrayOf(0.229f , 0.224f , 0.225f )
38+ }
39+ private val model: OnnxModel = when (modelSource){
40+ is FilePath -> OnnxModel (FileOnnxLoader (modelSource.path))
41+ is ResourceId -> OnnxModel (ResourceOnnxLoader (context.resources, modelSource.resId))
42+ }
43+
44+ override val embeddingDim: Int = 384
45+ private var closed = false
46+
47+ override suspend fun initialize () = model.loadModel()
48+
49+ override fun isInitialized () = model.isLoaded()
50+
51+ override suspend fun embed (data : Bitmap ): FloatArray = withContext(Dispatchers .Default ) {
52+ if (! isInitialized()) throw IllegalStateException (" Model not initialized" )
53+
54+ val inputShape = longArrayOf(DIM_BATCH_SIZE .toLong(), DIM_PIXEL_SIZE .toLong(), IMAGE_SIZE_X .toLong(), IMAGE_SIZE_Y .toLong())
55+ val imgData: FloatBuffer = preProcess(data)
56+ val inputName = model.getInputNames()?.firstOrNull() ? : throw IllegalStateException (" Model inputs not available" )
57+ val output = model.run (mapOf (inputName to TensorData .FloatBufferTensor (imgData, inputShape)))
58+ normalizeL2((output.values.first() as Array <FloatArray >)[0 ])
59+ }
60+
61+ override suspend fun embedBatch (data : List <Bitmap >): List <FloatArray > {
62+ val allEmbeddings = mutableListOf<FloatArray >()
63+
64+ val processor = object : BatchProcessor <Bitmap , FloatArray >(context = context.applicationContext as Application ) {
65+ override suspend fun onProcess (context : Context , item : Bitmap ): FloatArray {
66+ return embed(item)
67+ }
68+ override suspend fun onBatchComplete (context : Context , batch : List <FloatArray >) {
69+ allEmbeddings.addAll(batch)
70+ }
71+ }
72+
73+ processor.run (data)
74+ return allEmbeddings
75+ }
76+
77+ override fun closeSession () {
78+ if (closed) return
79+ closed = true
80+ (model as ? AutoCloseable )?.close()
81+ }
82+
83+ private fun preProcess (bitmap : Bitmap ): FloatBuffer {
84+ val cropped = centerCrop(bitmap, IMAGE_SIZE_X )
85+ val numFloats = DIM_BATCH_SIZE * DIM_PIXEL_SIZE * IMAGE_SIZE_X * IMAGE_SIZE_Y
86+ val byteBuffer = ByteBuffer .allocateDirect(numFloats * 4 ).order(ByteOrder .nativeOrder())
87+ val floatBuffer = byteBuffer.asFloatBuffer()
88+
89+ for (c in 0 until DIM_PIXEL_SIZE ) { // R, G, B channels
90+ for (y in 0 until IMAGE_SIZE_X ) {
91+ for (x in 0 until IMAGE_SIZE_X ) {
92+ val px = cropped[x, y]
93+ val v = when (c) {
94+ 0 -> ((px shr 16 ) and 0xFF ) / 255f // R
95+ 1 -> ((px shr 8 ) and 0xFF ) / 255f // G
96+ else -> (px and 0xFF ) / 255f // B
97+ }
98+ val norm = (v - MEAN [c]) / STD [c]
99+ floatBuffer.put(norm)
100+ }
101+ }
102+ }
103+
104+ floatBuffer.rewind()
105+ return floatBuffer
106+ }
107+ }
0 commit comments