Skip to content

Commit 436555c

Browse files
committed
Add initial scribles of JLama port for tokenizers
1 parent 22bacec commit 436555c

5 files changed

Lines changed: 773 additions & 2 deletions

File tree

build.gradle.kts

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@ plugins {
44
alias(libs.plugins.jetbrainsKotlinJvm) apply false
55
alias(libs.plugins.binaryCompatibility) apply false
66
alias(libs.plugins.modulegraph.souza) apply true
7-
8-
7+
alias(libs.plugins.spotless) apply false
98
}
109

1110
allprojects {
Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
package sk.ai.net.io.tokenizer
2+
3+
/**
4+
* Byte Pair Encoding tokenizer
5+
*/
6+
class BPETokenizer protected constructor(modelRoot: TokenizerModel) : Tokenizer {
7+
override val model: TokenizerModel = modelRoot
8+
//protected val decodeBuffer: ByteBuffer = ByteBuffer.allocate(4)
9+
10+
fun getModel(): TokenizerModel {
11+
return model
12+
}
13+
14+
fun tokenize(sentence: String): List<String?> {
15+
if (sentence.isEmpty()) return emptyList()
16+
17+
18+
19+
20+
if (model.preTokenizer() == null && model.addedTokenPattern() == null) Collections.singletonList(sentence)
21+
22+
val sentencePieces: List<String> = ArrayList()
23+
if (model.addedTokenPattern() != null) {
24+
// Split the sentence into pieces using the added token pattern
25+
// Any non-added token is split into pieces using the pre-tokenizer
26+
val pieces: Array<String> = TokenizerModel.Companion.split(model.addedTokenPattern(), sentence, 0, true)
27+
for (piece in pieces) {
28+
if (!piece.isEmpty()) {
29+
if (model.addedTokens().containsKey(piece)) sentencePieces.add(piece)
30+
else if (model.preTokenizer() != null) sentencePieces.addAll(
31+
model.preTokenizer().pretokenize(piece)
32+
)
33+
else sentencePieces.add(piece)
34+
}
35+
}
36+
} else if (model.preTokenizer() != null) {
37+
sentencePieces.addAll(model.preTokenizer().pretokenize(sentence))
38+
} else {
39+
sentencePieces.add(sentence)
40+
}
41+
42+
return sentencePieces
43+
44+
45+
return emptyList()
46+
}
47+
48+
protected fun preProcess(sentence: String?): String? {
49+
return sentence
50+
}
51+
52+
fun encode(rawSentence: String): LongArray {
53+
val sentencePieces = tokenize(rawSentence)
54+
val allTokens: List<Long?> = ArrayList()
55+
56+
for (sentence in sentencePieces) {
57+
var sentence: String = sentence!!
58+
if (model.addedTokens() != null && model.addedTokens().containsKey(sentence)) {
59+
allTokens.add(model.addedTokens().get(sentence))
60+
continue
61+
}
62+
val tokens: List<Long?> = ArrayList()
63+
sentence = preProcess(sentence)!!
64+
val codes: IntArray = sentence.codePoints().toArray()
65+
for (i in codes.indices) {
66+
val c: String? = Character.toString(codes[i])
67+
val id: Long? = model.vocabLookup.get(c)
68+
if (id != null) {
69+
// we found this codepoint in vocab, add it as a token
70+
// logger.debug("{} -> {}", c, id);
71+
tokens.add(id)
72+
} else {
73+
if (model.byteFallback) {
74+
// byte_fallback encoding: just encode each byte as a token
75+
val code = Character.toString(codes[i])
76+
val chars: ByteArray = code.getBytes(StandardCharsets.UTF_8)
77+
for (k in chars.indices) {
78+
val token = encodeCharacterAsToken(chars[k])
79+
// logger.debug("byte {} -> {}", Byte.toUnsignedInt(chars[k]), token);
80+
tokens.add(token)
81+
}
82+
} else {
83+
if (model.unkToken != null) {
84+
tokens.add(model.vocabLookup.get(model.unkToken))
85+
}
86+
}
87+
}
88+
}
89+
90+
// merge the best consecutive tuple each iteration,
91+
// until we can't find any more pairs to merge
92+
while (true) {
93+
var bestId: Long = -1
94+
var bestIdx: Long = -1
95+
var bestRank = Long.MAX_VALUE
96+
97+
for (i in 0..<tokens.size() - 1) {
98+
// check if we can merge the pair (tokens[i], tokens[i+1])
99+
val token1 = decodeInternal(tokens.get(i)!!)
100+
val token2 = decodeInternal(tokens.get(i + 1)!!)
101+
102+
val merge2: String? = String.format("%s %s", token1, token2)
103+
val merge3: String? = String.format("%s%s", token1, token2)
104+
105+
if (model.merges.containsKey(merge2)) {
106+
val id: Long? = model.vocabLookup.get(merge3)
107+
if (id != null) {
108+
// Check if this merge has a better rank (i.e., lower rank number)
109+
val rank: Long = model.merges.get(merge2)!!
110+
if (rank < bestRank) {
111+
// this merge pair exists in vocab! record its position
112+
bestId = id
113+
bestIdx = i.toLong()
114+
bestRank = rank
115+
}
116+
}
117+
}
118+
}
119+
120+
if (bestIdx == -1L) {
121+
break // we couldn't find any more pairs to merge, so we're done
122+
}
123+
124+
// merge the consecutive pair (best_idx, best_idx+1) into new token best_id
125+
tokens.set(bestIdx.toInt(), bestId)
126+
// delete token at position best_idx+1, shift the entire sequence back 1
127+
tokens.remove(bestIdx.toInt() + 1)
128+
}
129+
130+
allTokens.addAll(tokens)
131+
}
132+
133+
return allTokens.stream().mapToLong({ s -> s }).toArray()
134+
}
135+
136+
protected fun postProcessToken(decoded: String?): String? {
137+
var decoded = decoded
138+
if (decoded == null) decoded = model.unkToken
139+
140+
return decoded
141+
}
142+
143+
override fun tokenize(sentence: String?): List<String?>? {
144+
return if (sentence == null) null else tokenize(sentence)
145+
}
146+
147+
override fun encode(sentence: String?): LongArray? {
148+
TODO("Not yet implemented")
149+
}
150+
151+
override fun decode(id: Long): String {
152+
return maybeDecodeTokenAsCharacter(id).map({ c ->
153+
// We have a continuation byte or are buffering them
154+
if (Character.isUnicodeIdentifierPart(c) || decodeBuffer.remaining() < 4) {
155+
decodeBuffer.put(c.charValue() as Byte)
156+
157+
// Unicode symbol is ready
158+
if (decodeBuffer.remaining() === 0) {
159+
val s = String(decodeBuffer.array())
160+
decodeBuffer.rewind()
161+
return@map s
162+
}
163+
164+
return@map ""
165+
}
166+
Character.toString(c)
167+
}).orElseGet({ postProcessToken(model.vocabLookup.inverse().get(id)) })
168+
}
169+
170+
protected abstract fun encodeCharacterAsToken(c: Byte): Long
171+
172+
protected abstract fun maybeDecodeTokenAsCharacter(id: Long): Optional<Character?>?
173+
174+
// Only used for merging
175+
protected fun decodeInternal(id: Long): String {
176+
return maybeDecodeTokenAsCharacter(id).map(Object::toString).orElseGet({
177+
var s: String? = model.vocabLookup.inverse().get(id)
178+
if (s == null) s = model.unkToken
179+
s
180+
})
181+
}
182+
183+
protected fun postProcess(sentence: String?): String? {
184+
return sentence
185+
}
186+
187+
override fun decode(ids: LongArray?): String? {
188+
return ""
189+
// return postProcess(Arrays.stream(ids).mapToObj(this::decode).collect(Collectors.joining()))
190+
}
191+
192+
companion object {
193+
var alteredBytes: BiMap<Int?, Int?>? // Codepoint and Token mapping needed for legacy mode
194+
195+
init {
196+
// https://github.com/openai/gpt-2/blob/master/src/encoder.py#L19
197+
val tmpAlteredBytes: BiMap<Integer?, Integer?> = HashBiMap.create()
198+
var i = 0
199+
for (c in 0..255) {
200+
if ((c < '!'.code || c > '~'.code) && (c < '¡'.code || c > '¬'.code) && (c < '®'.code || c > 'ÿ'.code)) {
201+
val codepoint = (i++ + 256)
202+
tmpAlteredBytes.put(c, codepoint)
203+
}
204+
}
205+
206+
alteredBytes =
207+
ImmutableBiMap.copyOf(tmpAlteredBytes)
208+
}
209+
}
210+
}
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
package sk.ai.net.io.tokenizer
2+
3+
/**
4+
* Tokenizer interface
5+
*/
6+
interface Tokenizer {
7+
/**
8+
* Tokenize a sentence
9+
* @param sentence
10+
* @return list of token strings
11+
*/
12+
fun tokenize(sentence: String?): List<String?>?
13+
14+
/**
15+
* Encode a sentence into a list of token ids
16+
* @param sentence
17+
* @return list of token ids
18+
*/
19+
fun encode(sentence: String?): LongArray?
20+
21+
/**
22+
* Decode a token id into its string representation
23+
* @param id
24+
* @return token string
25+
*/
26+
fun decode(id: Long): String?
27+
28+
/**
29+
* Decode a list of token ids into their string representation
30+
* @param ids list of token ids
31+
* @return list of token strings
32+
*/
33+
fun decode(ids: LongArray?): String?
34+
35+
/**
36+
* Get the model for this tokenizer (expert mode)
37+
* @return tokenizer model
38+
*/
39+
val model: TokenizerModel?
40+
}

0 commit comments

Comments
 (0)