Skip to content

Commit e8da97d

Browse files
committed
fix: minor update
1 parent c47ec6e commit e8da97d

5 files changed

Lines changed: 87 additions & 10 deletions

File tree

README.md

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,9 @@ assert state.is_nil()
195195
### `GreedyTokenizer`
196196

197197
```python
198-
vocab = ['歌曲', '聆听歌曲', '播放歌曲', '歌词', '查看歌词', '听歌', '曲折']
198+
from general_sam import GeneralSAM, GreedyTokenizer, build_trie_from_chars
199+
200+
vocab = ['a', 'ab', 'b', 'bc', 'c', 'd', 'e', 'f', 'cd', 'abcde']
199201
trie, token_to_trie_node = build_trie_from_chars(vocab)
200202

201203
trie_node_to_token = [-1] * trie.num_of_nodes()
@@ -208,12 +210,9 @@ tokenizer = GreedyTokenizer.from_sam_and_trie(sam, trie)
208210
def tokenize(s: str):
209211
return [(trie_node_to_token[i], j) for i, j in tokenizer.tokenize_str(s)]
210212

211-
assert tokenize('歌曲折') == [(0, 2), (-1, 1)]
212-
assert tokenize('听歌曲') == [(5, 2), (-1, 1)]
213-
assert tokenize('听歌曲折') == [(5, 2), (6, 2)]
214-
assert tokenize('聆听歌曲折') == [(1, 4), (-1, 1)]
215-
assert tokenize('查看歌词歌曲') == [(4, 4), (0, 2)]
216-
assert tokenize('一起播放歌曲并共享歌词') == [(-1, 2), (2, 4), (-1, 3), (3, 2)]
213+
assert tokenize('abcde') == [(9, 5)]
214+
assert tokenize('abcdf') == [(1, 2), (8, 2), (7, 1)]
215+
assert tokenize('abca') == [(1, 2), (4, 1), (0, 1)]
217216
```
218217

219218
## License

general_sam/general_sam.pyi

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ class Trie:
3939
) -> TrieNode: ...
4040

4141
class GeneralSAMState:
42-
def is_in_str(self) -> bool: ...
42+
def is_in_chars(self) -> bool: ...
4343
def is_in_bytes(self) -> bool: ...
4444
def get_node_id(self) -> GeneralSAMNodeID: ...
4545
def is_nil(self) -> bool: ...
@@ -79,7 +79,7 @@ class GeneralSAM:
7979
def from_bytes(s: bytes) -> 'GeneralSAM': ...
8080
@staticmethod
8181
def from_trie(trie: Trie) -> 'GeneralSAM': ...
82-
def is_in_str(self) -> bool: ...
82+
def is_in_chars(self) -> bool: ...
8383
def is_in_bytes(self) -> bool: ...
8484
def num_of_nodes(self) -> int: ...
8585
def get_root_state(self) -> GeneralSAMState: ...
@@ -89,5 +89,8 @@ class GeneralSAM:
8989
class GreedyTokenizer:
9090
@staticmethod
9191
def from_sam_and_trie(sam: GeneralSAM, trie: Trie) -> 'GreedyTokenizer': ...
92+
def get_sam(self) -> GeneralSAM: ...
93+
def is_in_chars(self) -> bool: ...
94+
def is_in_bytes(self) -> bool: ...
9295
def tokenize_str(self, s: str) -> Sequence[Tuple[TrieNodeID, int]]: ...
9396
def tokenize_bytes(self, s: bytes) -> Sequence[Tuple[TrieNodeID, int]]: ...

src/tokenizer.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,18 @@ pub struct GreedyTokenizer(pub Arc<SharedGreedyTokenizer>);
3838

3939
#[pymethods]
4040
impl GreedyTokenizer {
41+
pub fn get_sam(&self) -> GeneralSAM {
42+
GeneralSAM(self.0.borrow_sam().0.clone())
43+
}
44+
45+
pub fn is_in_chars(&self) -> bool {
46+
self.0.borrow_sam().is_in_chars()
47+
}
48+
49+
pub fn is_in_bytes(&self) -> bool {
50+
self.0.borrow_sam().is_in_bytes()
51+
}
52+
4153
#[staticmethod]
4254
pub fn from_sam_and_trie(sam: &GeneralSAM, trie: &Trie) -> PyResult<Self> {
4355
SharedGreedyTokenizer::from_sam_and_trie(sam, trie)

tests/test_general_sam.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
def test_bytes_abcbc():
55
sam = GeneralSAM.from_bytes(b'abcbc')
6+
assert sam.is_in_bytes()
67

78
state = sam.get_root_state()
89
state.feed_bytes(b'cbc')
@@ -15,6 +16,8 @@ def test_bytes_abcbc():
1516

1617
def test_chars_abcbc():
1718
sam = GeneralSAM.from_chars('abcbc')
19+
assert sam.is_in_chars()
20+
1821
state = sam.get_root_state()
1922

2023
state.feed_chars('b')
@@ -30,6 +33,7 @@ def test_chars_abcbc():
3033
def test_simple_sam_from_trie():
3134
trie, _ = build_trie_from_chars(['hello', 'Chielo'])
3235
sam = GeneralSAM.from_trie(trie)
36+
assert trie.is_in_chars() and sam.is_in_chars()
3337

3438
def fetch_state(s: str) -> GeneralSAMState:
3539
state = sam.get_root_state()

tests/test_greedy_tokenizer.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,29 @@
1-
from general_sam import GeneralSAM, GreedyTokenizer, build_trie_from_chars
1+
from general_sam import (
2+
GeneralSAM,
3+
GreedyTokenizer,
4+
build_trie_from_bytes,
5+
build_trie_from_chars,
6+
)
7+
8+
9+
def test_english_chars_tokenize():
10+
vocab = ['a', 'ab', 'b', 'bc', 'c', 'd', 'e', 'f', 'cd', 'abcde']
11+
trie, token_to_trie_node = build_trie_from_chars(vocab)
12+
13+
trie_node_to_token = [-1] * trie.num_of_nodes()
14+
for i, j in enumerate(token_to_trie_node):
15+
trie_node_to_token[j] = i
16+
17+
sam = GeneralSAM.from_trie(trie)
18+
tokenizer = GreedyTokenizer.from_sam_and_trie(sam, trie)
19+
assert tokenizer.is_in_chars()
20+
21+
def tokenize(s: str):
22+
return [(trie_node_to_token[i], j) for i, j in tokenizer.tokenize_str(s)]
23+
24+
assert tokenize('abcde') == [(9, 5)]
25+
assert tokenize('abcdf') == [(1, 2), (8, 2), (7, 1)]
26+
assert tokenize('abca') == [(1, 2), (4, 1), (0, 1)]
227

328

429
def test_chinese_chars_tokenize():
@@ -11,6 +36,7 @@ def test_chinese_chars_tokenize():
1136

1237
sam = GeneralSAM.from_trie(trie)
1338
tokenizer = GreedyTokenizer.from_sam_and_trie(sam, trie)
39+
assert tokenizer.is_in_chars()
1440

1541
def tokenize(s: str):
1642
return [(trie_node_to_token[i], j) for i, j in tokenizer.tokenize_str(s)]
@@ -21,3 +47,36 @@ def tokenize(s: str):
2147
assert tokenize('聆听歌曲折') == [(1, 4), (-1, 1)]
2248
assert tokenize('查看歌词歌曲') == [(4, 4), (0, 2)]
2349
assert tokenize('一起播放歌曲并共享歌词') == [(-1, 2), (2, 4), (-1, 3), (3, 2)]
50+
51+
52+
def test_chinese_bytes_tokenize():
53+
vocab = ['歌曲', '聆听歌曲', '播放歌曲', '歌词', '查看歌词', '听歌', '曲折']
54+
vocab = [i.encode() for i in vocab]
55+
trie, token_to_trie_node = build_trie_from_bytes(vocab)
56+
57+
trie_node_to_token = [-1] * trie.num_of_nodes()
58+
for i, j in enumerate(token_to_trie_node):
59+
trie_node_to_token[j] = i
60+
61+
sam = GeneralSAM.from_trie(trie)
62+
tokenizer = GreedyTokenizer.from_sam_and_trie(sam, trie)
63+
assert tokenizer.is_in_bytes()
64+
65+
def tokenize_str(s: str):
66+
return [trie_node_to_token[i] for i, _ in tokenizer.tokenize_str(s)]
67+
68+
def tokenize_bytes(s: str):
69+
return [trie_node_to_token[i] for i, _ in tokenizer.tokenize_bytes(s.encode())]
70+
71+
def tokenize(s: str):
72+
a = tokenize_str(s)
73+
b = tokenize_bytes(s)
74+
assert a == b
75+
return a
76+
77+
assert tokenize('歌曲折') == [0, -1]
78+
assert tokenize('听歌曲') == [5, -1]
79+
assert tokenize('听歌曲折') == [5, 6]
80+
assert tokenize('聆听歌曲折') == [1, -1]
81+
assert tokenize('查看歌词歌曲') == [4, 0]
82+
assert tokenize('一起播放歌曲并共享歌词') == [-1, 2, -1, 3]

0 commit comments

Comments
 (0)