Skip to content

Commit 3f573ae

Browse files
committed
feat: Add support for AncientGreekBert in the EmbeddingProvider
1 parent f6ee0ed commit 3f573ae

3 files changed

Lines changed: 108 additions & 28 deletions

File tree

core/src/main/kotlin/dev/paulee/core/data/analysis/Indexer.kt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ internal class Indexer(path: Path, dataInfo: DataInfo) : Closeable {
180180
val searcher = IndexSearcher(this.reader)
181181

182182
val embedding =
183-
EmbeddingProvider.createEmbeddings(model, true, listOf(query)).firstOrNull() ?: return emptyList()
183+
EmbeddingProvider.createEmbeddings(model, listOf(query), true).firstOrNull() ?: return emptyList()
184184

185185
val query = FloatVectorSimilarityQuery("$field.vec", embedding, similarity)
186186

@@ -214,7 +214,7 @@ internal class Indexer(path: Path, dataInfo: DataInfo) : Closeable {
214214
}
215215

216216
embeddingFields[id]?.let { model ->
217-
val embedding = EmbeddingProvider.createEmbeddings(model, false, listOf(value)).first()
217+
val embedding = EmbeddingProvider.createEmbeddings(model, listOf(value)).first()
218218

219219
add(KnnFloatVectorField("$id.vec", embedding, VectorSimilarityFunction.COSINE))
220220
}

core/src/main/kotlin/dev/paulee/core/data/provider/EmbeddingProvider.kt

Lines changed: 105 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,12 @@ import java.net.http.HttpClient
1616
import java.net.http.HttpRequest
1717
import java.net.http.HttpResponse
1818
import java.nio.file.Path
19+
import java.text.Normalizer
1920
import java.time.Duration
2021
import kotlin.coroutines.cancellation.CancellationException
2122
import kotlin.io.DEFAULT_BUFFER_SIZE
2223
import kotlin.io.path.*
24+
import kotlin.math.sqrt
2325
import kotlin.use
2426

2527
internal object EmbeddingProvider {
@@ -28,7 +30,9 @@ internal object EmbeddingProvider {
2830

2931
private const val HF_URL = "https://huggingface.co/%s/resolve/main"
3032

31-
private val env = OrtEnvironment.getEnvironment()
33+
private val env = OrtEnvironment.getEnvironment().apply {
34+
setTelemetry(false)
35+
}
3236

3337
private val tokenizer = mutableMapOf<Embedding.Model, HuggingFaceTokenizer>()
3438

@@ -49,9 +53,7 @@ internal object EmbeddingProvider {
4953
HuggingFaceTokenizer.builder()
5054
.optTokenizerConfigPath(modelPath.resolve(model.modelData.tokenizerConfig).toString())
5155
.optTokenizerPath(modelPath.resolve(model.modelData.tokenizer))
52-
53-
// TODO store/read values instead of fixed ones
54-
.optMaxLength(2048)
56+
.optMaxLength(model.modelData.maxLength)
5557
.optTruncation(true)
5658
.optPadding(true)
5759
.build()
@@ -63,11 +65,16 @@ internal object EmbeddingProvider {
6365
}
6466
}
6567

66-
fun createEmbeddings(model: Embedding.Model, query: Boolean, values: List<String>): Array<FloatArray> {
68+
fun createEmbeddings(model: Embedding.Model, values: List<String>, query: Boolean = false): Array<FloatArray> {
6769
val embeddings = when (model) {
6870
Embedding.Model.EmbeddingGemma -> {
69-
val texts =
70-
values.map { if (query) "task: search result | query: $it" else "title: none | text: $it" }
71+
val texts = values.map { if (query) "task: search result | query: $it" else "title: none | text: $it" }
72+
73+
createRawEmbeddings(model, texts)
74+
}
75+
76+
Embedding.Model.AncientGreekBert -> {
77+
val texts = values.map { it.stripAccentsAndLowercase() }
7178

7279
createRawEmbeddings(model, texts)
7380
}
@@ -237,33 +244,93 @@ internal object EmbeddingProvider {
237244

238245
if (session == null) return emptyArray()
239246

240-
return OnnxTensor.createTensor(env, inputIds).use { idsTensor ->
241-
OnnxTensor.createTensor(env, attentionMask).use { maskTensor ->
242-
val inputs = mapOf(
243-
"input_ids" to idsTensor,
244-
"attention_mask" to maskTensor
245-
)
246-
session.run(inputs).use { result ->
247-
val outName = session.outputNames.firstOrNull { it.contains("sentence_embedding") }
248-
249-
@Suppress("UNCHECKED_CAST")
250-
val embeddings: Array<FloatArray> = when {
251-
outName != null -> {
252-
val ov = result.get(outName)
253-
.orElseThrow { IllegalStateException("No output named $outName") }
254-
(ov as OnnxTensor).value as Array<FloatArray>
247+
fun runSession(sessionInputs: Map<String, OnnxTensor>): Array<FloatArray> {
248+
session.run(sessionInputs).use { result ->
249+
return when (model) {
250+
Embedding.Model.AncientGreekBert -> {
251+
val ov = result.get("last_hidden_state")
252+
.orElseThrow { IllegalStateException("No output named last_hidden_state") }
253+
254+
@Suppress("UNCHECKED_CAST")
255+
val lastHidden = (ov as OnnxTensor).value as Array<Array<FloatArray>>
256+
257+
val batchSize = lastHidden.size
258+
val seqLen = lastHidden[0].size
259+
val dim = lastHidden[0][0].size
260+
261+
val embeddings = Array(batchSize) { FloatArray(dim) }
262+
263+
for (i in 0 until batchSize) {
264+
var validTokens = 0f
265+
266+
for (j in 0 until seqLen) {
267+
if (attentionMask[i][j] != 0L) {
268+
val tok = lastHidden[i][j]
269+
270+
for (k in 0 until dim)
271+
embeddings[i][k] += tok[k]
272+
273+
validTokens += 1f
274+
}
275+
}
276+
277+
if (validTokens == 0f)
278+
validTokens = 1f
279+
280+
for (k in 0 until dim)
281+
embeddings[i][k] /= validTokens
282+
283+
var normSq = 0.0
284+
for (k in 0 until dim) {
285+
val v = embeddings[i][k]
286+
287+
normSq += (v * v).toDouble()
288+
}
289+
290+
val norm = sqrt(normSq).coerceAtLeast(1e-12)
291+
292+
for (k in 0 until dim)
293+
embeddings[i][k] = (embeddings[i][k] / norm).toFloat()
255294
}
256295

257-
else -> {
296+
embeddings
297+
}
298+
299+
else -> {
300+
val outName = session.outputNames.firstOrNull { it.contains("sentence_embedding") }
301+
302+
@Suppress("UNCHECKED_CAST") if (outName != null) {
303+
val ov =
304+
result.get(outName).orElseThrow { IllegalStateException("No output named $outName") }
305+
(ov as OnnxTensor).value as Array<FloatArray>
306+
} else {
258307
val ov = result.get(1)
259308
(ov as OnnxTensor).value as Array<FloatArray>
260309
}
261310
}
262-
263-
embeddings
264311
}
265312
}
266313
}
314+
315+
return OnnxTensor.createTensor(env, inputIds).use { idsTensor ->
316+
OnnxTensor.createTensor(env, attentionMask).use { maskTensor ->
317+
val inputs = mutableMapOf(
318+
"input_ids" to idsTensor, "attention_mask" to maskTensor
319+
)
320+
321+
val expectsTokenTypes = session.inputNames.any { it.contains("token_type_ids") }
322+
323+
if (expectsTokenTypes) {
324+
val tokenTypes = Array(inputIds.size) { LongArray(inputIds[0].size) { 0L } }
325+
326+
OnnxTensor.createTensor(env, tokenTypes).use { typeTensor ->
327+
inputs["token_type_ids"] = typeTensor
328+
329+
runSession(inputs)
330+
}
331+
} else runSession(inputs)
332+
}
333+
}
267334
}
268335

269336
private fun createSession(model: Embedding.Model): OrtSession? {
@@ -279,4 +346,17 @@ internal object EmbeddingProvider {
279346
null
280347
}
281348
}
349+
350+
private fun String.stripAccentsAndLowercase(): String {
351+
val normalized = Normalizer.normalize(this, Normalizer.Form.NFD)
352+
val nonSpacingMark = Character.NON_SPACING_MARK.toInt()
353+
354+
val withoutAccents = buildString {
355+
for (c in normalized) {
356+
if (Character.getType(c) != nonSpacingMark) append(c)
357+
}
358+
}
359+
360+
return withoutAccents.lowercase()
361+
}
282362
}

gradle.properties

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,6 @@ tokenizers.version=0.34.0
1313
onnx.version=1.22.0
1414

1515
api.version=1.12.3
16-
core.version=1.15.5
16+
core.version=1.15.6
1717
ui.version=1.16.4
1818
app.version=1.5.0

0 commit comments

Comments
 (0)