Skip to content

Commit 268341f

Browse files
committed
refactor: embedding 모델을 hugging face -> openai로 수정한다
1 parent c694067 commit 268341f

File tree

5 files changed

+54
-15
lines changed

5 files changed

+54
-15
lines changed

src/main/kotlin/org/gitanimals/quiz/infra/similarity/EsKnnTextSimilarityChecker.kt

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package org.gitanimals.quiz.infra.similarity
22

33
import org.gitanimals.quiz.app.SimilarityResponse
44
import org.gitanimals.quiz.app.TextSimilarityChecker
5+
import org.slf4j.LoggerFactory
56
import org.springframework.data.elasticsearch.client.elc.NativeQuery
67
import org.springframework.data.elasticsearch.core.ElasticsearchOperations
78
import org.springframework.stereotype.Component
@@ -12,13 +13,17 @@ class EsKnnTextSimilarityChecker(
1213
private val tokenizer: Tokenizer,
1314
) : TextSimilarityChecker {
1415

16+
private val logger = LoggerFactory.getLogger(this::class.simpleName)
17+
1518
override fun getSimilarity(text: String): SimilarityResponse {
16-
val tokenizedText = tokenizer.tokenize(Tokenizer.Request(text))
19+
val embeddingResponse = tokenizer.embed(Tokenizer.Request.from(text))
20+
21+
logger.info("[EsKnnTextSimilarityChecker] Embedding Success total token: ${embeddingResponse.usage.totalToken}, prompt token: ${embeddingResponse.usage.promptToken}")
1722

1823
val knnQuery = NativeQuery.builder()
1924
.withKnnSearches {
2025
it.field(QuizSimilarity::vector.name)
21-
it.queryVector(tokenizedText)
26+
it.queryVector(embeddingResponse.data.embedding)
2227
it.similarity(0.75F)
2328
it.k(MAX_RETURN_KNN_SIZE)
2429
it.numCandidates(MAX_RETURN_KNN_SIZE * 5)

src/main/kotlin/org/gitanimals/quiz/infra/similarity/NewQuizCreatedEventListener.kt

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,20 +18,25 @@ class NewQuizCreatedEventListener(
1818
fun addQuizVectorToEs(newQuizCreated: NewQuizCreated) {
1919
gracefulLaunch {
2020
val tokenizedQuizText = runCatching {
21-
tokenizer.tokenize(Tokenizer.Request(newQuizCreated.problem))
21+
tokenizer.embed(Tokenizer.Request.from(newQuizCreated.problem))
22+
}.onSuccess {
23+
logger.info("[NewQuizCreatedEventListener] Embedding Success total token: ${it.usage.totalToken}, prompt token: ${it.usage.promptToken}")
2224
}.getOrElse {
23-
logger.error("Tokenize fail. id: ${newQuizCreated.id}, cause: ${it.message}", it)
25+
logger.error(
26+
"[NewQuizCreatedEventListener] Tokenize fail. id: ${newQuizCreated.id}, cause: ${it.message}",
27+
it
28+
)
2429
throw it
25-
}
30+
}.data.embedding
2631

2732
runCatching {
2833
val quizSimilarity = QuizSimilarity.from(newQuizCreated.id, tokenizedQuizText)
2934
quizSimilarityRepository.save(quizSimilarity)
3035
}.onSuccess {
31-
logger.info("Tokenize and save success. quizId: ${it.quizId}")
36+
logger.info("[NewQuizCreatedEventListener] Tokenize and save success. quizId: ${it.quizId}")
3237
}.onFailure {
3338
logger.error(
34-
"Tokenize success but, Fail to save es. quizId: ${newQuizCreated.id}, cause: ${it.message}",
39+
"[NewQuizCreatedEventListener] Tokenize success but, Fail to save es. quizId: ${newQuizCreated.id}, cause: ${it.message}",
3540
it
3641
)
3742
}

src/main/kotlin/org/gitanimals/quiz/infra/similarity/QuizSimilarity.kt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ import org.springframework.data.elasticsearch.annotations.Document
66
import org.springframework.data.elasticsearch.annotations.Field
77
import org.springframework.data.elasticsearch.annotations.FieldType
88

9-
private const val BERT_LARGE_DIMS = 768
9+
private const val OPEN_AI_SMALL_DIMS = 1536
1010

1111
@Document(indexName = "quiz_similarity", createIndex = true)
1212
class QuizSimilarity(
@@ -16,7 +16,7 @@ class QuizSimilarity(
1616
@Field(name = "quiz_id")
1717
val quizId: Long,
1818

19-
@Field(name = "vector", type = FieldType.Dense_Vector, dims = BERT_LARGE_DIMS)
19+
@Field(name = "vector", type = FieldType.Dense_Vector, dims = OPEN_AI_SMALL_DIMS)
2020
val vector: List<Float>,
2121
) {
2222

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,43 @@
11
package org.gitanimals.quiz.infra.similarity
22

3+
import com.fasterxml.jackson.annotation.JsonProperty
34
import org.springframework.web.bind.annotation.RequestBody
45
import org.springframework.web.service.annotation.PostExchange
56

67
fun interface Tokenizer {
78

8-
@PostExchange("/pipeline/feature-extraction/sentence-transformers/all-mpnet-base-v2")
9-
fun tokenize(@RequestBody request: Request): List<Float>
9+
@PostExchange("/v1/embeddings")
10+
fun embed(@RequestBody request: Request): Response
1011

1112
data class Request(
12-
val inputs: String,
13-
)
13+
val input: String,
14+
val model: String = "text-embedding-3-small",
15+
@JsonProperty("encoding_format")
16+
val format: String = "float",
17+
) {
18+
19+
companion object {
20+
fun from(input: String): Request {
21+
return Request(input = input)
22+
}
23+
}
24+
}
25+
26+
data class Response(
27+
val usage: Usage,
28+
val model: String,
29+
val data: Data,
30+
) {
31+
data class Usage(
32+
@JsonProperty("prompt_token")
33+
val promptToken: Int,
34+
@JsonProperty("total_token")
35+
val totalToken: Int,
36+
)
37+
38+
data class Data(
39+
val `object`: String,
40+
val embedding: List<Float>,
41+
)
42+
}
1443
}

src/main/kotlin/org/gitanimals/quiz/infra/similarity/TokenizerHttpClientConfigurer.kt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ import org.springframework.web.service.invoker.HttpServiceProxyFactory
1212

1313
@Configuration
1414
class TokenizerHttpClientConfigurer(
15-
@Value("\${tokenizer.api.key}") private val apiKey: String,
15+
@Value("\${openai.key}") private val apiKey: String,
1616
) {
1717

1818
@Bean
@@ -23,7 +23,7 @@ class TokenizerHttpClientConfigurer(
2323
request.headers.add(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE)
2424
execution.execute(request, body)
2525
}
26-
.baseUrl("https://api-inference.huggingface.co")
26+
.baseUrl("https://api.openai.com")
2727
.defaultStatusHandler(tokenizerHttpClientErrorHandler())
2828
.build()
2929

0 commit comments

Comments
 (0)