@@ -16,10 +16,12 @@ import java.net.http.HttpClient
1616import java.net.http.HttpRequest
1717import java.net.http.HttpResponse
1818import java.nio.file.Path
19+ import java.text.Normalizer
1920import java.time.Duration
2021import kotlin.coroutines.cancellation.CancellationException
2122import kotlin.io.DEFAULT_BUFFER_SIZE
2223import kotlin.io.path.*
24+ import kotlin.math.sqrt
2325import kotlin.use
2426
2527internal object EmbeddingProvider {
@@ -28,7 +30,9 @@ internal object EmbeddingProvider {
2830
2931 private const val HF_URL = " https://huggingface.co/%s/resolve/main"
3032
31- private val env = OrtEnvironment .getEnvironment()
33+ private val env = OrtEnvironment .getEnvironment().apply {
34+ setTelemetry(false )
35+ }
3236
3337 private val tokenizer = mutableMapOf<Embedding .Model , HuggingFaceTokenizer >()
3438
@@ -49,9 +53,7 @@ internal object EmbeddingProvider {
4953 HuggingFaceTokenizer .builder()
5054 .optTokenizerConfigPath(modelPath.resolve(model.modelData.tokenizerConfig).toString())
5155 .optTokenizerPath(modelPath.resolve(model.modelData.tokenizer))
52-
53- // TODO store/read values instead of fixed ones
54- .optMaxLength(2048 )
56+ .optMaxLength(model.modelData.maxLength)
5557 .optTruncation(true )
5658 .optPadding(true )
5759 .build()
@@ -63,11 +65,16 @@ internal object EmbeddingProvider {
6365 }
6466 }
6567
66- fun createEmbeddings (model : Embedding .Model , query : Boolean , values : List <String >): Array <FloatArray > {
68+ fun createEmbeddings (model : Embedding .Model , values : List <String >, query : Boolean = false ): Array <FloatArray > {
6769 val embeddings = when (model) {
6870 Embedding .Model .EmbeddingGemma -> {
69- val texts =
70- values.map { if (query) " task: search result | query: $it " else " title: none | text: $it " }
71+ val texts = values.map { if (query) " task: search result | query: $it " else " title: none | text: $it " }
72+
73+ createRawEmbeddings(model, texts)
74+ }
75+
76+ Embedding .Model .AncientGreekBert -> {
77+ val texts = values.map { it.stripAccentsAndLowercase() }
7178
7279 createRawEmbeddings(model, texts)
7380 }
@@ -237,33 +244,93 @@ internal object EmbeddingProvider {
237244
238245 if (session == null ) return emptyArray()
239246
240- return OnnxTensor .createTensor(env, inputIds).use { idsTensor ->
241- OnnxTensor .createTensor(env, attentionMask).use { maskTensor ->
242- val inputs = mapOf (
243- " input_ids" to idsTensor,
244- " attention_mask" to maskTensor
245- )
246- session.run (inputs).use { result ->
247- val outName = session.outputNames.firstOrNull { it.contains(" sentence_embedding" ) }
248-
249- @Suppress(" UNCHECKED_CAST" )
250- val embeddings: Array <FloatArray > = when {
251- outName != null -> {
252- val ov = result.get(outName)
253- .orElseThrow { IllegalStateException (" No output named $outName " ) }
254- (ov as OnnxTensor ).value as Array <FloatArray >
247+ fun runSession (sessionInputs : Map <String , OnnxTensor >): Array <FloatArray > {
248+ session.run (sessionInputs).use { result ->
249+ return when (model) {
250+ Embedding .Model .AncientGreekBert -> {
251+ val ov = result.get(" last_hidden_state" )
252+ .orElseThrow { IllegalStateException (" No output named last_hidden_state" ) }
253+
254+ @Suppress(" UNCHECKED_CAST" )
255+ val lastHidden = (ov as OnnxTensor ).value as Array <Array <FloatArray >>
256+
257+ val batchSize = lastHidden.size
258+ val seqLen = lastHidden[0 ].size
259+ val dim = lastHidden[0 ][0 ].size
260+
261+ val embeddings = Array (batchSize) { FloatArray (dim) }
262+
263+ for (i in 0 until batchSize) {
264+ var validTokens = 0f
265+
266+ for (j in 0 until seqLen) {
267+ if (attentionMask[i][j] != 0L ) {
268+ val tok = lastHidden[i][j]
269+
270+ for (k in 0 until dim)
271+ embeddings[i][k] + = tok[k]
272+
273+ validTokens + = 1f
274+ }
275+ }
276+
277+ if (validTokens == 0f )
278+ validTokens = 1f
279+
280+ for (k in 0 until dim)
281+ embeddings[i][k] / = validTokens
282+
283+ var normSq = 0.0
284+ for (k in 0 until dim) {
285+ val v = embeddings[i][k]
286+
287+ normSq + = (v * v).toDouble()
288+ }
289+
290+ val norm = sqrt(normSq).coerceAtLeast(1e- 12 )
291+
292+ for (k in 0 until dim)
293+ embeddings[i][k] = (embeddings[i][k] / norm).toFloat()
255294 }
256295
257- else -> {
296+ embeddings
297+ }
298+
299+ else -> {
300+ val outName = session.outputNames.firstOrNull { it.contains(" sentence_embedding" ) }
301+
302+ @Suppress(" UNCHECKED_CAST" ) if (outName != null ) {
303+ val ov =
304+ result.get(outName).orElseThrow { IllegalStateException (" No output named $outName " ) }
305+ (ov as OnnxTensor ).value as Array <FloatArray >
306+ } else {
258307 val ov = result.get(1 )
259308 (ov as OnnxTensor ).value as Array <FloatArray >
260309 }
261310 }
262-
263- embeddings
264311 }
265312 }
266313 }
314+
315+ return OnnxTensor .createTensor(env, inputIds).use { idsTensor ->
316+ OnnxTensor .createTensor(env, attentionMask).use { maskTensor ->
317+ val inputs = mutableMapOf (
318+ " input_ids" to idsTensor, " attention_mask" to maskTensor
319+ )
320+
321+ val expectsTokenTypes = session.inputNames.any { it.contains(" token_type_ids" ) }
322+
323+ if (expectsTokenTypes) {
324+ val tokenTypes = Array (inputIds.size) { LongArray (inputIds[0 ].size) { 0L } }
325+
326+ OnnxTensor .createTensor(env, tokenTypes).use { typeTensor ->
327+ inputs[" token_type_ids" ] = typeTensor
328+
329+ runSession(inputs)
330+ }
331+ } else runSession(inputs)
332+ }
333+ }
267334 }
268335
269336 private fun createSession (model : Embedding .Model ): OrtSession ? {
@@ -279,4 +346,17 @@ internal object EmbeddingProvider {
279346 null
280347 }
281348 }
349+
350+ private fun String.stripAccentsAndLowercase (): String {
351+ val normalized = Normalizer .normalize(this , Normalizer .Form .NFD )
352+ val nonSpacingMark = Character .NON_SPACING_MARK .toInt()
353+
354+ val withoutAccents = buildString {
355+ for (c in normalized) {
356+ if (Character .getType(c) != nonSpacingMark) append(c)
357+ }
358+ }
359+
360+ return withoutAccents.lowercase()
361+ }
282362}
0 commit comments