Skip to content

Commit 1c4dd04

Browse files
Merge pull request #21 from smartscanapp/update/dino
Added dino image embedding provider
2 parents 316ab19 + 1c94d1f commit 1c4dd04

File tree

6 files changed

+148
-57
lines changed

6 files changed

+148
-57
lines changed

ml/src/main/java/com/fpf/smartscansdk/ml/models/providers/embeddings/clip/ClipImageEmbedder.kt

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@ package com.fpf.smartscansdk.ml.models.providers.embeddings.clip
33
import android.app.Application
44
import android.content.Context
55
import android.graphics.Bitmap
6+
import androidx.core.graphics.get
67
import com.fpf.smartscansdk.core.embeddings.ImageEmbeddingProvider
78
import com.fpf.smartscansdk.core.embeddings.normalizeL2
9+
import com.fpf.smartscansdk.core.media.centerCrop
810
import com.fpf.smartscansdk.core.processors.BatchProcessor
911
import com.fpf.smartscansdk.ml.data.FilePath
1012
import com.fpf.smartscansdk.ml.data.ModelSource
@@ -14,13 +16,20 @@ import com.fpf.smartscansdk.ml.models.OnnxModel
1416
import com.fpf.smartscansdk.ml.models.FileOnnxLoader
1517
import com.fpf.smartscansdk.ml.models.ResourceOnnxLoader
1618
import kotlinx.coroutines.*
19+
import java.nio.ByteBuffer
20+
import java.nio.ByteOrder
1721
import java.nio.FloatBuffer
1822

1923
// Using ModelSource enables using with bundle model or local model which has been downloaded
20-
class ClipImageEmbedder(
21-
private val context: Context,
22-
modelSource: ModelSource,
23-
) : ImageEmbeddingProvider {
24+
class ClipImageEmbedder(private val context: Context, modelSource: ModelSource, ) : ImageEmbeddingProvider {
25+
companion object {
26+
const val DIM_BATCH_SIZE = 1
27+
const val DIM_PIXEL_SIZE = 3
28+
const val IMAGE_SIZE_X = 224
29+
const val IMAGE_SIZE_Y = 224
30+
val MEAN = floatArrayOf(0.48145467f, 0.4578275f, 0.40821072f)
31+
val STD = floatArrayOf(0.26862955f, 0.2613026f, 0.2757771f)
32+
}
2433
private val model: OnnxModel = when(modelSource){
2534
is FilePath -> OnnxModel(FileOnnxLoader(modelSource.path))
2635
is ResourceId -> OnnxModel(ResourceOnnxLoader(context.resources, modelSource.resId))
@@ -36,7 +45,7 @@ class ClipImageEmbedder(
3645
override suspend fun embed(data: Bitmap): FloatArray = withContext(Dispatchers.Default) {
3746
if (!isInitialized()) throw IllegalStateException("Model not initialized")
3847

39-
val inputShape = longArrayOf(ClipConfig.DIM_BATCH_SIZE.toLong(), ClipConfig.DIM_PIXEL_SIZE.toLong(), ClipConfig.IMAGE_SIZE_X.toLong(), ClipConfig.IMAGE_SIZE_Y.toLong())
48+
val inputShape = longArrayOf(DIM_BATCH_SIZE.toLong(), DIM_PIXEL_SIZE.toLong(), IMAGE_SIZE_X.toLong(), IMAGE_SIZE_Y.toLong())
4049
val imgData: FloatBuffer = preProcess(data)
4150
val inputName = model.getInputNames()?.firstOrNull() ?: throw IllegalStateException("Model inputs not available")
4251
val output = model.run(mapOf(inputName to TensorData.FloatBufferTensor(imgData, inputShape)))
@@ -64,4 +73,27 @@ class ClipImageEmbedder(
6473
closed = true
6574
(model as? AutoCloseable)?.close()
6675
}
76+
77+
private fun preProcess(bitmap: Bitmap): FloatBuffer {
78+
val cropped = centerCrop(bitmap, IMAGE_SIZE_X)
79+
val numFloats = DIM_BATCH_SIZE * DIM_PIXEL_SIZE * IMAGE_SIZE_Y * IMAGE_SIZE_X
80+
val byteBuffer = ByteBuffer.allocateDirect(numFloats * 4).order(ByteOrder.nativeOrder())
81+
val floatBuffer = byteBuffer.asFloatBuffer()
82+
for (c in 0 until DIM_PIXEL_SIZE) {
83+
for (y in 0 until IMAGE_SIZE_Y) {
84+
for (x in 0 until IMAGE_SIZE_X) {
85+
val px = cropped[x, y]
86+
val v = when (c) {
87+
0 -> (px shr 16 and 0xFF) / 255f // R
88+
1 -> (px shr 8 and 0xFF) / 255f // G
89+
else -> (px and 0xFF) / 255f // B
90+
}
91+
val norm = (v - MEAN[c]) / STD[c]
92+
floatBuffer.put(norm)
93+
}
94+
}
95+
}
96+
floatBuffer.rewind()
97+
return floatBuffer
98+
}
6799
}

ml/src/main/java/com/fpf/smartscansdk/ml/models/providers/embeddings/clip/Constants.kt

Lines changed: 0 additions & 10 deletions
This file was deleted.

ml/src/main/java/com/fpf/smartscansdk/ml/models/providers/embeddings/clip/PreProcess.kt

Lines changed: 0 additions & 38 deletions
This file was deleted.
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
package com.fpf.smartscansdk.ml.models.providers.embeddings.dino
2+
3+
import android.app.Application
4+
import android.content.Context
5+
import android.graphics.Bitmap
6+
import androidx.core.graphics.get
7+
import com.fpf.smartscansdk.core.embeddings.ImageEmbeddingProvider
8+
import com.fpf.smartscansdk.core.embeddings.normalizeL2
9+
import com.fpf.smartscansdk.core.media.centerCrop
10+
import com.fpf.smartscansdk.core.processors.BatchProcessor
11+
import com.fpf.smartscansdk.ml.data.FilePath
12+
import com.fpf.smartscansdk.ml.data.ModelSource
13+
import com.fpf.smartscansdk.ml.data.ResourceId
14+
import com.fpf.smartscansdk.ml.data.TensorData
15+
import com.fpf.smartscansdk.ml.models.FileOnnxLoader
16+
import com.fpf.smartscansdk.ml.models.OnnxModel
17+
import com.fpf.smartscansdk.ml.models.ResourceOnnxLoader
18+
import kotlinx.coroutines.Dispatchers
19+
import kotlinx.coroutines.withContext
20+
import java.nio.ByteBuffer
21+
import java.nio.ByteOrder
22+
import java.nio.FloatBuffer
23+
import androidx.core.graphics.scale
24+
25+
26+
class DinoV2SmallImageEmbedder(
27+
private val context: Context,
28+
modelSource: ModelSource,
29+
) : ImageEmbeddingProvider {
30+
31+
companion object {
32+
const val DIM_BATCH_SIZE = 1
33+
const val DIM_PIXEL_SIZE = 3
34+
const val IMAGE_SIZE_X = 224
35+
const val IMAGE_SIZE_Y = 224
36+
val MEAN= floatArrayOf(0.485f, 0.456f, 0.406f)
37+
val STD=floatArrayOf(0.229f, 0.224f, 0.225f)
38+
}
39+
private val model: OnnxModel = when(modelSource){
40+
is FilePath -> OnnxModel(FileOnnxLoader(modelSource.path))
41+
is ResourceId -> OnnxModel(ResourceOnnxLoader(context.resources, modelSource.resId))
42+
}
43+
44+
override val embeddingDim: Int = 384
45+
private var closed = false
46+
47+
override suspend fun initialize() = model.loadModel()
48+
49+
override fun isInitialized() = model.isLoaded()
50+
51+
override suspend fun embed(data: Bitmap): FloatArray = withContext(Dispatchers.Default) {
52+
if (!isInitialized()) throw IllegalStateException("Model not initialized")
53+
54+
val inputShape = longArrayOf(DIM_BATCH_SIZE.toLong(), DIM_PIXEL_SIZE.toLong(), IMAGE_SIZE_X.toLong(), IMAGE_SIZE_Y.toLong())
55+
val imgData: FloatBuffer = preProcess(data)
56+
val inputName = model.getInputNames()?.firstOrNull() ?: throw IllegalStateException("Model inputs not available")
57+
val output = model.run(mapOf(inputName to TensorData.FloatBufferTensor(imgData, inputShape)))
58+
normalizeL2((output.values.first() as Array<FloatArray>)[0])
59+
}
60+
61+
override suspend fun embedBatch(data: List<Bitmap>): List<FloatArray> {
62+
val allEmbeddings = mutableListOf<FloatArray>()
63+
64+
val processor = object : BatchProcessor<Bitmap, FloatArray>(context = context.applicationContext as Application) {
65+
override suspend fun onProcess(context: Context, item: Bitmap): FloatArray {
66+
return embed(item)
67+
}
68+
override suspend fun onBatchComplete(context: Context, batch: List<FloatArray>) {
69+
allEmbeddings.addAll(batch)
70+
}
71+
}
72+
73+
processor.run(data)
74+
return allEmbeddings
75+
}
76+
77+
override fun closeSession() {
78+
if (closed) return
79+
closed = true
80+
(model as? AutoCloseable)?.close()
81+
}
82+
83+
private fun preProcess(bitmap: Bitmap): FloatBuffer {
84+
val cropped = centerCrop(bitmap, IMAGE_SIZE_X)
85+
val numFloats = DIM_BATCH_SIZE * DIM_PIXEL_SIZE * IMAGE_SIZE_X * IMAGE_SIZE_Y
86+
val byteBuffer = ByteBuffer.allocateDirect(numFloats * 4).order(ByteOrder.nativeOrder())
87+
val floatBuffer = byteBuffer.asFloatBuffer()
88+
89+
for (c in 0 until DIM_PIXEL_SIZE) { // R, G, B channels
90+
for (y in 0 until IMAGE_SIZE_X) {
91+
for (x in 0 until IMAGE_SIZE_X) {
92+
val px = cropped[x, y]
93+
val v = when (c) {
94+
0 -> ((px shr 16) and 0xFF) / 255f // R
95+
1 -> ((px shr 8) and 0xFF) / 255f // G
96+
else -> (px and 0xFF) / 255f // B
97+
}
98+
val norm = (v - MEAN[c]) / STD[c]
99+
floatBuffer.put(norm)
100+
}
101+
}
102+
}
103+
104+
floatBuffer.rewind()
105+
return floatBuffer
106+
}
107+
}

ml/src/main/java/com/fpf/smartscansdk/ml/models/providers/embeddings/minilm/MiniLmTextEmbedder.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ class MiniLMTextEmbedder(
2525
is ResourceId -> OnnxModel(ResourceOnnxLoader(context.resources, modelSource.resId))
2626
}
2727

28-
private var tokenizer = SimpleTokenizer.fromRawResources(context, R.raw.minilm_vocab, R.raw.minilm_tokenizer_config)
28+
private var tokenizer = MiniLmTokenizer.fromRawResources(context, R.raw.minilm_vocab, R.raw.minilm_tokenizer_config)
2929
private var closed = false
3030
override val embeddingDim: Int = 384 // MiniLM-L6-v2 dimension
3131

ml/src/main/java/com/fpf/smartscansdk/ml/models/providers/embeddings/minilm/Tokenizer.kt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ import org.json.JSONObject
77
import java.io.InputStreamReader
88
import kotlin.collections.toLongArray
99

10-
class SimpleTokenizer(
10+
class MiniLmTokenizer(
1111
private val vocab: Map<String, Int>,
1212
private val maxLen: Int,
1313
private val doLowerCase: Boolean,
@@ -18,7 +18,7 @@ class SimpleTokenizer(
1818
) {
1919

2020
companion object {
21-
fun fromRawResources(context: Context, vocabResId: Int, configResId: Int): SimpleTokenizer {
21+
fun fromRawResources(context: Context, vocabResId: Int, configResId: Int): MiniLmTokenizer {
2222
// Load vocab
2323
val vocabMap: Map<String, Int> = context.resources.openRawResource(vocabResId)
2424
.bufferedReader()
@@ -31,7 +31,7 @@ class SimpleTokenizer(
3131
}
3232
val configJson = JSONObject(configText)
3333

34-
return SimpleTokenizer(
34+
return MiniLmTokenizer(
3535
vocab = vocabMap,
3636
maxLen = configJson.optInt("max_length", 128),
3737
doLowerCase = configJson.optBoolean("do_lower_case", true),

0 commit comments

Comments
 (0)