Skip to content

Commit 4d728e9

Browse files
committed
feat: refactor embedding handling and add UTF-8 byte conversion utility
1 parent 579a37a commit 4d728e9

9 files changed

Lines changed: 137 additions & 94 deletions

File tree

app/src/main/java/me/grey/picquery/common/AppModules.kt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,6 @@ private val domainModules = module {
7979
context = androidContext(),
8080
imageEncoder = get(),
8181
textEncoder = get(),
82-
embeddingRepository = get(),
8382
objectBoxEmbeddingRepository = get(),
8483
dispatcher = get()
8584
)

app/src/main/java/me/grey/picquery/common/Constants.kt

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,7 @@ object Constants {
1717
PERMISSION_OLD
1818
}
1919

20-
private const val useMobileClip = false
21-
val DIM = if (useMobileClip) 256 else 224
20+
val DIM = 256
2221

2322
const val PRIVACY_URL = "https://grey030.gitee.io/pages/picquery/privacy.html"
2423
const val SOURCE_REPO_URL = "https://github.com/greyovo/PicQuery"

app/src/main/java/me/grey/picquery/data/dao/ObjectBoxEmbeddingDao.kt

Lines changed: 53 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -114,16 +114,13 @@ class ObjectBoxEmbeddingDao(private val embeddingBox: Box<ObjectBoxEmbedding>) {
114114
similarityThreshold: Float = 0.7f,
115115
albumIds: List<Long>? = null
116116
): List<ObjectWithScore<ObjectBoxEmbedding>> {
117-
val query =
118-
embeddingBox
119-
.query()
120-
.nearestNeighbors(ObjectBoxEmbedding_.data, queryVector, topK)
121-
.build()
122-
123-
val results = query.findWithScores().filter { result ->
124-
val cosineSimilarity = 1.0 - result.score
125-
cosineSimilarity > similarityThreshold
126-
}
117+
val results = searchNearestVectorsByScope(
118+
queryVector = queryVector,
119+
topK = topK,
120+
similarityThreshold = similarityThreshold,
121+
albumIds = albumIds,
122+
includeThreshold = false
123+
)
127124

128125
results.forEachIndexed { index, result ->
129126
Timber.d("Result $index:")
@@ -141,27 +138,57 @@ class ObjectBoxEmbeddingDao(private val embeddingBox: Box<ObjectBoxEmbedding>) {
141138
similarityThreshold: Float = 0.95f,
142139
albumIds: List<Long>? = null
143140
): List<ObjectWithScore<ObjectBoxEmbedding>> {
144-
val query =
145-
embeddingBox
146-
.query()
147-
.nearestNeighbors(ObjectBoxEmbedding_.data, queryVector, topK)
148-
.build()
141+
val results = searchNearestVectorsByScope(
142+
queryVector = queryVector,
143+
topK = topK,
144+
similarityThreshold = similarityThreshold,
145+
albumIds = albumIds,
146+
includeThreshold = true
147+
).onEach { result ->
148+
val cosineSimilarity = 1.0 - result.score
149149

150-
val results = query.findWithScores()
151-
.filter { result ->
150+
Timber.d("Photo ID: ${result.get().photoId}")
151+
Timber.d("Score: ${result.score}")
152+
Timber.d("Cosine Similarity: $cosineSimilarity")
153+
Timber.d("Similarity Condition: ${cosineSimilarity >= similarityThreshold}")
154+
}
152155

153-
val cosineSimilarity = 1.0 - result.score
156+
Timber.d("Filtered Results Count: ${results.size}")
157+
158+
return results
159+
}
154160

155-
Timber.d("Photo ID: ${result.get().photoId}")
156-
Timber.d("Score: ${result.score}")
157-
Timber.d("Cosine Similarity: $cosineSimilarity")
158-
Timber.d("Similarity Condition: ${cosineSimilarity >= similarityThreshold}")
161+
private fun searchNearestVectorsByScope(
162+
queryVector: FloatArray,
163+
topK: Int,
164+
similarityThreshold: Float,
165+
albumIds: List<Long>?,
166+
includeThreshold: Boolean
167+
): List<ObjectWithScore<ObjectBoxEmbedding>> {
168+
if (albumIds != null && albumIds.isEmpty()) {
169+
return emptyList()
170+
}
159171

160-
cosineSimilarity >= similarityThreshold
161-
}
172+
val queryBuilder = embeddingBox.query()
173+
if (albumIds != null) {
174+
queryBuilder.`in`(ObjectBoxEmbedding_.albumId, albumIds.toLongArray())
175+
}
162176

163-
Timber.d("Filtered Results Count: ${results.size}")
177+
val query = queryBuilder
178+
.nearestNeighbors(ObjectBoxEmbedding_.data, queryVector, topK)
179+
.build()
164180

165-
return results
181+
return try {
182+
query.findWithScores().filter { result ->
183+
val cosineSimilarity = 1.0 - result.score
184+
if (includeThreshold) {
185+
cosineSimilarity >= similarityThreshold
186+
} else {
187+
cosineSimilarity > similarityThreshold
188+
}
189+
}
190+
} finally {
191+
query.close()
192+
}
166193
}
167194
}

app/src/main/java/me/grey/picquery/domain/EmbeddingService.kt

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@ import kotlinx.coroutines.flow.onEach
1616
import kotlinx.coroutines.withContext
1717
import me.grey.picquery.common.encodeProgressCallback
1818
import me.grey.picquery.common.loadThumbnail
19-
import me.grey.picquery.common.preprocess
20-
import me.grey.picquery.data.data_source.EmbeddingRepository
2119
import me.grey.picquery.data.data_source.ObjectBoxEmbeddingRepository
2220
import me.grey.picquery.data.model.ObjectBoxEmbedding
2321
import me.grey.picquery.data.model.Photo
@@ -40,7 +38,6 @@ class EmbeddingService(
4038
private val context: Context,
4139
private val imageEncoder: ImageEncoder,
4240
private val textEncoder: TextEncoder,
43-
private val embeddingRepository: EmbeddingRepository,
4441
private val objectBoxEmbeddingRepository: ObjectBoxEmbeddingRepository,
4542
private val dispatcher: CoroutineDispatcher
4643
) {
@@ -65,7 +62,7 @@ class EmbeddingService(
6562
*/
6663
suspend fun hasEmbedding(): Boolean {
6764
return withContext(dispatcher) {
68-
val total = embeddingRepository.getTotalCount()
65+
val total = objectBoxEmbeddingRepository.getTotalCount()
6966
Timber.tag(TAG).d("Total embedding count $total")
7067
total > 0
7168
}
@@ -123,7 +120,6 @@ class EmbeddingService(
123120
.chunked(CHUNK_SIZE)
124121
.onEach { Timber.tag(TAG).d("Processing batch: ${it.size}") }
125122
.onCompletion {
126-
embeddingRepository.updateCache()
127123
encodingLock = false
128124
Timber.tag(TAG).i("Encoding completed")
129125
}
@@ -153,8 +149,7 @@ class EmbeddingService(
153149
Timber.tag(TAG).w("Unsupported file: '${photo.path}', skip encoding")
154150
return null
155151
}
156-
val prepBitmap = preprocess(thumbnailBitmap)
157-
return PhotoBitmap(photo, prepBitmap)
152+
return PhotoBitmap(photo, thumbnailBitmap)
158153
}
159154

160155
/**
@@ -163,15 +158,15 @@ class EmbeddingService(
163158
private suspend fun saveBatchToEmbedding(items: List<PhotoBitmap>) {
164159
val embeddings = imageEncoder.encodeBatch(items.map { it.bitmap })
165160

166-
embeddings.forEachIndexed { index, feat ->
167-
objectBoxEmbeddingRepository.update(
161+
objectBoxEmbeddingRepository.updateAll(
162+
embeddings.mapIndexed { index, feat ->
168163
ObjectBoxEmbedding(
169164
photoId = items[index].photo.id,
170165
albumId = items[index].photo.albumID,
171166
data = feat
172167
)
173-
)
174-
}
168+
}
169+
)
175170
}
176171

177172
private fun Any.runtimeName(): String = this::class.java.name

app/src/main/java/me/grey/picquery/feature/BPETokenizer.kt

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,12 @@ private fun whitespaceClean(text: String): String {
5151
return cleanedText
5252
}
5353

54+
internal fun utf8ByteValues(token: String): IntArray {
55+
return token.toByteArray(Charsets.UTF_8)
56+
.map { it.toInt() and 0xFF }
57+
.toIntArray()
58+
}
59+
5460
class BPETokenizer(context: Context, bpePath: String = "bpe_vocab_gz") : Tokenizer() {
5561
companion object {
5662
private const val START_TOKEN = "<|startoftext|>"
@@ -149,7 +155,7 @@ class BPETokenizer(context: Context, bpePath: String = "bpe_vocab_gz") : Tokeniz
149155

150156
// return bpe_tokens
151157
for (token in matches) {
152-
val encodedToken = token.toByteArray().map { byteEncoder[it.toInt()] }.joinToString("")
158+
val encodedToken = utf8ByteValues(token).map { byteEncoder.getValue(it) }.joinToString("")
153159
for (bpeToken in bpe(encodedToken).split(" ")) {
154160
bpeTokens.add(encoder.getValue(bpeToken))
155161
}

app/src/main/java/me/grey/picquery/feature/ImageEncoderONNX.kt

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -52,31 +52,30 @@ open class ImageEncoderONNX(
5252

5353
val floatBuffer = preprocessor.preprocessBatch(bitmaps) as FloatBuffer
5454

55-
val inputName = ortSession?.inputNames?.iterator()?.next()
55+
val session = checkNotNull(ortSession) { "ONNX image encoder session is closed." }
56+
val inputName = session.inputNames.iterator().next()
5657
val shape: LongArray = longArrayOf(bitmaps.size.toLong(), 3, dim, dim)
57-
ortEnv.use { env ->
58-
val tensor = OnnxTensor.createTensor(env, floatBuffer, shape)
59-
val output: OrtSession.Result? =
60-
ortSession?.run(Collections.singletonMap(inputName, tensor))
61-
val resultBuffer = output?.get(0) as OnnxTensor
62-
Log.d(TAG, "Finish encoding image!")
58+
OnnxTensor.createTensor(ortEnv, floatBuffer, shape).use { tensor ->
59+
session.run(Collections.singletonMap(inputName, tensor)).use { output ->
60+
val resultBuffer = output.get(0) as OnnxTensor
61+
val feat = resultBuffer.floatBuffer
62+
val embeddingSize = 512
63+
val numEmbeddings = feat.capacity() / embeddingSize
64+
val embeddings = mutableListOf<FloatArray>()
6365

64-
val feat = resultBuffer.floatBuffer
65-
val embeddingSize = 512
66-
val numEmbeddings = feat.capacity() / embeddingSize
67-
val embeddings = mutableListOf<FloatArray>()
68-
69-
for (i in 0 until numEmbeddings) {
70-
val start = i * embeddingSize
71-
val embeddingArray = FloatArray(embeddingSize)
72-
feat.position(start)
73-
for (j in 0 until embeddingSize) {
74-
embeddingArray[j] = feat[start + j]
66+
for (i in 0 until numEmbeddings) {
67+
val start = i * embeddingSize
68+
val embeddingArray = FloatArray(embeddingSize)
69+
feat.position(start)
70+
for (j in 0 until embeddingSize) {
71+
embeddingArray[j] = feat[start + j]
72+
}
73+
embeddings.add(embeddingArray)
7574
}
76-
embeddings.add(embeddingArray)
77-
}
7875

79-
return@withContext embeddings
76+
Log.d(TAG, "Finish encoding image!")
77+
return@withContext embeddings
78+
}
8079
}
8180
}
8281
}

app/src/main/java/me/grey/picquery/feature/TextEncoderONNX.kt

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -39,26 +39,30 @@ abstract class TextEncoderONNX(private val context: Context) : TextEncoder {
3939
ortSession = ortEnv.createSession(AssetUtil.assetFilePath(context, modelPath), options)
4040
}
4141

42-
val inputName = ortSession?.inputNames?.iterator()?.next()
43-
ortEnv.use { env ->
44-
45-
val tensor = when (modelType) {
46-
0 -> OnnxTensor.createTensor(env, intBuffer, shape)
47-
1 -> {
48-
val longBuffer = LongBuffer.allocate(intBuffer.capacity()).apply {
49-
while (intBuffer.hasRemaining()) {
50-
put(intBuffer.get().toLong())
51-
}
52-
flip()
42+
val session = checkNotNull(ortSession) { "ONNX text encoder session is closed." }
43+
val inputName = session.inputNames.iterator().next()
44+
val tensor = when (modelType) {
45+
0 -> OnnxTensor.createTensor(ortEnv, intBuffer, shape)
46+
1 -> {
47+
val longBuffer = LongBuffer.allocate(intBuffer.capacity()).apply {
48+
while (intBuffer.hasRemaining()) {
49+
put(intBuffer.get().toLong())
5350
}
54-
OnnxTensor.createTensor(env, longBuffer, shape)
51+
flip()
5552
}
53+
OnnxTensor.createTensor(ortEnv, longBuffer, shape)
54+
}
5655

57-
else -> throw IllegalArgumentException("Unknown buffer type")
56+
else -> throw IllegalArgumentException("Unknown buffer type")
57+
}
58+
tensor.use {
59+
session.run(mapOf(Pair(inputName, tensor))).use { output ->
60+
val resultBuffer = output.get(0) as OnnxTensor
61+
val floatBuffer = resultBuffer.floatBuffer
62+
val result = FloatArray(floatBuffer.remaining())
63+
floatBuffer.get(result)
64+
return result
5865
}
59-
val output = ortSession?.run(mapOf(Pair(inputName!!, tensor)))
60-
val resultBuffer = output?.get(0) as OnnxTensor
61-
return (resultBuffer.floatBuffer).array()
6266
}
6367
}
6468
}

app/src/main/java/me/grey/picquery/feature/clip/ImageEncoderCLIP.kt

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -26,27 +26,28 @@ class ImageEncoderCLIP(
2626
}
2727

2828
override suspend fun encodeBatch(bitmaps: List<Bitmap>): List<FloatArray> {
29-
val inputName = ortSession?.inputNames?.iterator()?.next()
29+
val session = checkNotNull(ortSession) { "ONNX image encoder session is closed." }
30+
val inputName = session.inputNames.iterator().next()
3031

31-
ortEnv.use {
32-
val floatBuffer = (preprocessor.preprocessBatch(bitmaps)).array()!!
33-
val buffers = splitFloatBuffer(FloatBuffer.wrap(floatBuffer), bitmaps.size)
32+
val floatBuffer = (preprocessor.preprocessBatch(bitmaps)).array()!!
33+
val buffers = splitFloatBuffer(FloatBuffer.wrap(floatBuffer), bitmaps.size)
3434

35-
// Correct shape calculation
36-
val shape: LongArray = longArrayOf(1, 3, INPUT.toLong(), INPUT.toLong())
37-
val res = mutableListOf<FloatArray>()
38-
for (i in bitmaps.indices) {
39-
val tensor = OnnxTensor.createTensor(ortEnv, buffers[i], shape)
40-
val output = ortSession?.run(Collections.singletonMap(inputName, tensor))
35+
// Correct shape calculation
36+
val shape: LongArray = longArrayOf(1, 3, INPUT.toLong(), INPUT.toLong())
37+
val res = mutableListOf<FloatArray>()
38+
for (i in bitmaps.indices) {
39+
OnnxTensor.createTensor(ortEnv, buffers[i], shape).use { tensor ->
40+
session.run(Collections.singletonMap(inputName, tensor)).use { output ->
4141

42-
@Suppress("UNCHECKED_CAST")
43-
val rawOutput =
44-
((output?.get(0)?.value) as Array<FloatArray>)[0]
45-
res.add(rawOutput)
42+
@Suppress("UNCHECKED_CAST")
43+
val rawOutput =
44+
((output.get(0).value) as Array<FloatArray>)[0]
45+
res.add(rawOutput)
46+
}
4647
}
47-
48-
return res
4948
}
49+
50+
return res
5051
}
5152

5253
private fun splitFloatBuffer(buffer: FloatBuffer, parts: Int): List<FloatBuffer> {
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
package me.grey.picquery.feature
2+
3+
import org.junit.Assert.assertEquals
4+
import org.junit.Test
5+
6+
class BPETokenizerTest {
7+
@Test
8+
fun utf8BytesAreTreatedAsUnsignedValues() {
9+
val bytes = utf8ByteValues("")
10+
11+
assertEquals(listOf(231, 177, 179), bytes.toList())
12+
}
13+
}

0 commit comments

Comments
 (0)