Skip to content

Commit d5601f2

Browse files
committed
feat: add greedy tokenizer, update general-sam to 0.5.3
1 parent 05ceb49 commit d5601f2

9 files changed

Lines changed: 219 additions & 56 deletions

File tree

Cargo.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "general-sam-py"
3-
version = "0.5.2-dev"
3+
version = "0.5.3"
44
edition = "2021"
55
license = "MIT OR Apache-2.0"
66
description = "Python bindings for general-sam and some utilities"
@@ -14,10 +14,11 @@ name = "general_sam"
1414
crate-type = ["cdylib"]
1515

1616
[dependencies]
17-
general-sam = { version = "0.5.2", features = ["all"] }
17+
general-sam = { version = "0.5.3", features = ["all"] }
1818
pyo3 = { version = "0.20.0", features = [
1919
"extension-module",
2020
"abi3-py38",
2121
"generate-import-lib",
2222
] }
2323
either = "1.9.0"
24+
ouroboros = "0.18.0"

general_sam/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from .general_sam import (
22
GeneralSAM,
33
GeneralSAMState,
4+
GreedyTokenizer,
45
Trie,
56
TrieNode,
67
)
@@ -21,6 +22,7 @@
2122
__all__ = [
2223
'GeneralSAM',
2324
'GeneralSAMState',
25+
'GreedyTokenizer',
2426
'Trie',
2527
'TrieNode',
2628
'CountInfo',

general_sam/general_sam.pyi

Lines changed: 32 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
1-
from typing import Callable, Mapping, Optional, Sequence, Union
1+
from typing import Callable, Mapping, Optional, Sequence, Tuple, Union
22

33
ByteOrChar = Union[str, int]
4+
TrieNodeID = int
5+
GeneralSAMNodeID = int
46

57
class TrieNode:
68
def is_in_chars(self) -> bool: ...
79
def is_in_bytes(self) -> bool: ...
8-
def get_node_id(self) -> int: ...
10+
def get_node_id(self) -> TrieNodeID: ...
911
def is_accepting(self) -> bool: ...
10-
def get_trans(self) -> Mapping[ByteOrChar, int]: ...
11-
def get_parent(self) -> int: ...
12+
def get_trans(self) -> Mapping[ByteOrChar, TrieNodeID]: ...
13+
def get_parent(self) -> TrieNodeID: ...
1214

1315
class Trie:
1416
@staticmethod
@@ -18,33 +20,33 @@ class Trie:
1820
def is_in_chars(self) -> bool: ...
1921
def is_in_bytes(self) -> bool: ...
2022
def num_of_nodes(self) -> int: ...
21-
def insert_chars(self, s: str) -> int: ...
22-
def insert_bytes(self, s: bytes) -> int: ...
23-
def get_bfs_order(self) -> Sequence[int]: ...
23+
def insert_chars(self, s: str) -> TrieNodeID: ...
24+
def insert_bytes(self, s: bytes) -> TrieNodeID: ...
25+
def get_bfs_order(self) -> Sequence[TrieNodeID]: ...
2426
def get_root(self) -> TrieNode: ...
25-
def get_node(self, node_id: int) -> Optional[TrieNode]: ...
27+
def get_node(self, node_id: TrieNodeID) -> Optional[TrieNode]: ...
2628
def dfs_travel(
2729
self,
28-
in_stack_callback: Callable[[int, Optional[ByteOrChar]], None],
29-
out_stack_callback: Callable[[int], None],
30-
root_node_id: Optional[int] = None,
30+
in_stack_callback: Callable[[TrieNodeID, Optional[ByteOrChar]], None],
31+
out_stack_callback: Callable[[TrieNodeID], None],
32+
root_node_id: Optional[TrieNodeID] = None,
3133
) -> TrieNode: ...
3234
def bfs_travel(
3335
self,
34-
in_queue_callback: Callable[[int, Optional[ByteOrChar]], None],
35-
out_queue_callback: Callable[[int], None],
36-
root_node_id: Optional[int] = None,
36+
in_queue_callback: Callable[[TrieNodeID, Optional[ByteOrChar]], None],
37+
out_queue_callback: Callable[[TrieNodeID], None],
38+
root_node_id: Optional[TrieNodeID] = None,
3739
) -> TrieNode: ...
3840

3941
class GeneralSAMState:
4042
def is_in_str(self) -> bool: ...
4143
def is_in_bytes(self) -> bool: ...
42-
def get_node_id(self) -> int: ...
44+
def get_node_id(self) -> GeneralSAMNodeID: ...
4345
def is_nil(self) -> bool: ...
4446
def is_root(self) -> bool: ...
4547
def is_accepting(self) -> bool: ...
46-
def get_trans(self) -> Mapping[ByteOrChar, int]: ...
47-
def get_suffix_parent_id(self) -> int: ...
48+
def get_trans(self) -> Mapping[ByteOrChar, GeneralSAMNodeID]: ...
49+
def get_suffix_parent_id(self) -> GeneralSAMNodeID: ...
4850
def copy(self) -> 'GeneralSAMState': ...
4951
def goto_suffix_parent(self) -> None: ...
5052
def goto_char(self, t: str) -> None: ...
@@ -55,19 +57,19 @@ class GeneralSAMState:
5557
self,
5658
trie: Trie,
5759
in_stack_callback: Callable[
58-
['GeneralSAMState', int, Optional[ByteOrChar]], None
60+
['GeneralSAMState', TrieNodeID, Optional[ByteOrChar]], None
5961
],
60-
out_stack_callback: Callable[['GeneralSAMState', int], None],
61-
trie_node_id: Optional[int] = None,
62+
out_stack_callback: Callable[['GeneralSAMState', TrieNodeID], None],
63+
trie_node_id: Optional[TrieNodeID] = None,
6264
) -> TrieNode: ...
6365
def bfs_along(
6466
self,
6567
trie: Trie,
6668
in_queue_callback: Callable[
67-
['GeneralSAMState', int, Optional[ByteOrChar]], None
69+
['GeneralSAMState', TrieNodeID, Optional[ByteOrChar]], None
6870
],
69-
out_queue_callback: Callable[['GeneralSAMState', int], None],
70-
trie_node_id: Optional[int] = None,
71+
out_queue_callback: Callable[['GeneralSAMState', TrieNodeID], None],
72+
trie_node_id: Optional[TrieNodeID] = None,
7173
) -> TrieNode: ...
7274

7375
class GeneralSAM:
@@ -81,5 +83,11 @@ class GeneralSAM:
8183
def is_in_bytes(self) -> bool: ...
8284
def num_of_nodes(self) -> int: ...
8385
def get_root_state(self) -> GeneralSAMState: ...
84-
def get_state(self, node_id: int) -> GeneralSAMState: ...
86+
def get_state(self, node_id: GeneralSAMNodeID) -> GeneralSAMState: ...
8587
def get_topo_and_suf_len_sorted_states(self) -> Sequence[GeneralSAMState]: ...
88+
89+
class GreedyTokenizer:
90+
@staticmethod
91+
def from_sam_and_trie(sam: GeneralSAM, trie: Trie) -> 'GreedyTokenizer': ...
92+
def tokenize_str(self, s: str) -> Sequence[Tuple[TrieNodeID, int]]: ...
93+
def tokenize_bytes(self, s: bytes) -> Sequence[Tuple[TrieNodeID, int]]: ...

src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
pub mod sam;
2+
pub mod tokenizer;
23
pub mod trie;
34
pub mod utils;
45

@@ -10,5 +11,6 @@ fn general_sam(_py: Python, m: &PyModule) -> PyResult<()> {
1011
m.add_class::<trie::Trie>()?;
1112
m.add_class::<sam::GeneralSAMState>()?;
1213
m.add_class::<sam::GeneralSAM>()?;
14+
m.add_class::<tokenizer::GreedyTokenizer>()?;
1315
Ok(())
1416
}

src/sam.rs

Lines changed: 35 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,24 @@
11
extern crate general_sam as general_sam_rs;
22

3-
use crate::trie::Trie;
4-
use crate::utils::{char_or_byte_type, for_both, ByteSide, CharSide};
3+
use std::{str::from_utf8, sync::Arc};
54

65
use general_sam_rs::{
76
sam as sam_rs, trie as trie_rs, BTreeTransTable, BoxBisectTable, TransitionTable, TravelEvent,
87
};
8+
use pyo3::exceptions::PyTypeError;
99
use pyo3::{prelude::*, types::PyDict};
10-
use std::{str::from_utf8, sync::Arc};
1110

12-
type RustBoxBisectGeneralSAM<T> = sam_rs::GeneralSAM<BoxBisectTable<T>>;
13-
type RustBoxBisectGeneralSAMState<'s, T> = sam_rs::GeneralSAMState<'s, BoxBisectTable<T>>;
14-
type RustGeneralSAM = char_or_byte_type!(RustBoxBisectGeneralSAM);
15-
type RustGeneralSAMState<'s> = char_or_byte_type!(RustBoxBisectGeneralSAMState; 's);
11+
use crate::for_both_and_wrap;
12+
use crate::trie::Trie;
13+
use crate::utils::{
14+
char_or_byte_type, for_both, get_char_or_byte_variant_name, ByteSide, CharSide,
15+
};
16+
17+
pub(crate) type RustBoxBisectGeneralSAM<T> = sam_rs::GeneralSAM<BoxBisectTable<T>>;
18+
pub(crate) type RustBoxBisectGeneralSAMState<'s, T> =
19+
sam_rs::GeneralSAMState<'s, BoxBisectTable<T>>;
20+
pub(crate) type RustGeneralSAM = char_or_byte_type!(RustBoxBisectGeneralSAM);
21+
pub(crate) type RustGeneralSAMState<'s> = char_or_byte_type!(RustBoxBisectGeneralSAMState; 's);
1622

1723
#[pyclass]
1824
pub struct GeneralSAM(pub Arc<RustGeneralSAM>);
@@ -112,17 +118,18 @@ impl GeneralSAMState {
112118
}
113119
}
114120

115-
pub fn feed_bytes(&mut self, s: &[u8]) {
121+
pub fn feed_bytes(&mut self, s: &[u8]) -> PyResult<()> {
116122
match self.get_state() {
117123
CharSide(state_chars) => {
118-
let state_chars = state_chars.feed_iter(from_utf8(s).unwrap().chars());
124+
let state_chars = state_chars.feed_iter(from_utf8(s)?.chars());
119125
self.1 = state_chars.node_id;
120126
}
121127
ByteSide(state_bytes) => {
122128
let state_bytes = state_bytes.feed_ref_iter(s.iter());
123129
self.1 = state_bytes.node_id;
124130
}
125131
}
132+
Ok(())
126133
}
127134

128135
#[pyo3(signature = (trie, in_stack_callback, out_stack_callback, trie_node_id=None))]
@@ -132,12 +139,16 @@ impl GeneralSAMState {
132139
in_stack_callback: PyObject,
133140
out_stack_callback: PyObject,
134141
trie_node_id: Option<usize>,
135-
) -> Result<(), PyErr> {
136-
assert!(trie.is_in_chars() == self.is_in_chars());
137-
let sam_state_and_trie = self.get_state().map_either(
138-
|x| (x, trie.0.as_ref().left().unwrap()),
139-
|x| (x, trie.0.as_ref().right().unwrap()),
140-
);
142+
) -> PyResult<()> {
143+
let sam_state_and_trie = for_both_and_wrap!(self.get_state(), &trie.0; (s, t) => (s, t))
144+
.map_err(|e| {
145+
PyTypeError::new_err(format!(
146+
"{}, {} vs {}",
147+
e,
148+
get_char_or_byte_variant_name(self.0.as_ref()),
149+
get_char_or_byte_variant_name(&trie.0)
150+
))
151+
})?;
141152
for_both!(sam_state_and_trie, (sam_state, trie) => {
142153
let tn = trie.get_state(trie_node_id.unwrap_or(trie_rs::TRIE_ROOT_NODE_ID));
143154
sam_state.dfs_along(tn, |event| {
@@ -178,11 +189,15 @@ impl GeneralSAMState {
178189
out_stack_callback: PyObject,
179190
trie_node_id: Option<usize>,
180191
) -> Result<(), PyErr> {
181-
assert!(trie.is_in_chars() == self.is_in_chars());
182-
let sam_state_and_trie = self.get_state().map_either(
183-
|x| (x, trie.0.as_ref().left().unwrap()),
184-
|x| (x, trie.0.as_ref().right().unwrap()),
185-
);
192+
let sam_state_and_trie = for_both_and_wrap!(self.get_state(), &trie.0; (s, t) => (s, t))
193+
.map_err(|e| {
194+
PyTypeError::new_err(format!(
195+
"{}, {} vs {}",
196+
e,
197+
get_char_or_byte_variant_name(self.0.as_ref()),
198+
get_char_or_byte_variant_name(&trie.0)
199+
))
200+
})?;
186201
for_both!(sam_state_and_trie, (sam_state, trie) => {
187202
let tn = trie.get_state(trie_node_id.unwrap_or(trie_rs::TRIE_ROOT_NODE_ID));
188203
sam_state.bfs_along(tn, |event| {

src/tokenizer.rs

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
use std::{str::from_utf8, sync::Arc};
2+
3+
use general_sam::{utils::tokenize as tokenize_rs, BoxBisectTable, TrieNodeID, TRIE_NIL_NODE_ID};
4+
use ouroboros::self_referencing;
5+
use pyo3::{exceptions::PyTypeError, prelude::*};
6+
7+
use crate::{
8+
char_or_byte_type, for_both_and_wrap,
9+
sam::GeneralSAM,
10+
trie::Trie,
11+
utils::{get_char_or_byte_variant_name, ByteSide, CharSide, InconsistentCharOrByte},
12+
};
13+
14+
pub(crate) type RustBoxBisectGreedyTokenizer<'s, T> =
15+
tokenize_rs::GreedyTokenizer<'s, BoxBisectTable<T>, TrieNodeID>;
16+
pub(crate) type RustGreedyTokenizer<'s> = char_or_byte_type!(RustBoxBisectGreedyTokenizer; 's);
17+
18+
#[self_referencing]
19+
pub struct SharedGreedyTokenizer {
20+
sam: GeneralSAM,
21+
#[borrows(sam)]
22+
#[covariant]
23+
inner: RustGreedyTokenizer<'this>,
24+
}
25+
26+
impl SharedGreedyTokenizer {
27+
fn from_sam_and_trie(sam: &GeneralSAM, trie: &Trie) -> Result<Self, InconsistentCharOrByte> {
28+
Self::try_new(GeneralSAM(sam.0.clone()), |sam: &GeneralSAM| {
29+
for_both_and_wrap!(sam.0.as_ref(), trie.0.as_ref(); (sam, trie) => {
30+
RustBoxBisectGreedyTokenizer::build_from_trie(sam, trie.get_root_state())
31+
})
32+
})
33+
}
34+
}
35+
36+
#[pyclass]
37+
pub struct GreedyTokenizer(pub Arc<SharedGreedyTokenizer>);
38+
39+
#[pymethods]
40+
impl GreedyTokenizer {
41+
#[staticmethod]
42+
pub fn from_sam_and_trie(sam: &GeneralSAM, trie: &Trie) -> PyResult<Self> {
43+
SharedGreedyTokenizer::from_sam_and_trie(sam, trie)
44+
.map(|x| Self(Arc::new(x)))
45+
.map_err(|e| {
46+
PyTypeError::new_err(format!(
47+
"{}, {} vs {}",
48+
e,
49+
get_char_or_byte_variant_name(sam.0.as_ref()),
50+
get_char_or_byte_variant_name(&trie.0)
51+
))
52+
})
53+
}
54+
55+
#[pyo3(signature = (s, unk_token_id=None))]
56+
pub fn tokenize_str(
57+
&mut self,
58+
s: &str,
59+
unk_token_id: Option<TrieNodeID>,
60+
) -> Vec<(TrieNodeID, usize)> {
61+
let unk_token_id = unk_token_id.unwrap_or(TRIE_NIL_NODE_ID);
62+
match self.0.borrow_inner() {
63+
CharSide(inner) => inner.tokenize(s.chars(), &unk_token_id),
64+
ByteSide(inner) => inner.tokenize(s.bytes(), &unk_token_id),
65+
}
66+
}
67+
68+
#[pyo3(signature = (s, unk_token_id=None))]
69+
pub fn tokenize_bytes(
70+
&mut self,
71+
s: &[u8],
72+
unk_token_id: Option<TrieNodeID>,
73+
) -> PyResult<Vec<(TrieNodeID, usize)>> {
74+
let unk_token_id = unk_token_id.unwrap_or(TRIE_NIL_NODE_ID);
75+
Ok(match self.0.borrow_inner() {
76+
CharSide(inner) => inner.tokenize(from_utf8(s)?.chars(), &unk_token_id),
77+
ByteSide(inner) => inner.tokenize(s.iter().copied(), &unk_token_id),
78+
})
79+
}
80+
}

src/trie.rs

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
extern crate general_sam as general_sam_rs;
22

3-
use crate::utils::{char_or_byte_type, for_both, ByteSide, CharSide};
3+
use std::{convert::Infallible, str::from_utf8};
44

55
use general_sam_rs::{trie as trie_rs, BTreeTransTable, TravelEvent, TrieNodeAlike};
66
use pyo3::prelude::*;
7-
use std::{convert::Infallible, str::from_utf8};
87

9-
type RustBTreeTrie<T> = trie_rs::Trie<BTreeTransTable<T>>;
10-
type RustBTreeTrieNode<T> = trie_rs::TrieNode<BTreeTransTable<T>>;
11-
type RustTrie = char_or_byte_type!(RustBTreeTrie);
12-
type RustTrieNode = char_or_byte_type!(RustBTreeTrieNode);
8+
use crate::utils::{char_or_byte_type, for_both, ByteSide, CharSide};
9+
10+
pub(crate) type RustBTreeTrie<T> = trie_rs::Trie<BTreeTransTable<T>>;
11+
pub(crate) type RustBTreeTrieNode<T> = trie_rs::TrieNode<BTreeTransTable<T>>;
12+
pub(crate) type RustTrie = char_or_byte_type!(RustBTreeTrie);
13+
pub(crate) type RustTrieNode = char_or_byte_type!(RustBTreeTrieNode);
1314

1415
#[pyclass]
1516
pub struct Trie(pub RustTrie);
@@ -79,11 +80,11 @@ impl Trie {
7980
}
8081
}
8182

82-
pub fn insert_bytes(&mut self, b: &[u8]) -> usize {
83-
match self.0.as_mut() {
84-
CharSide(trie_chars) => trie_chars.insert_iter(from_utf8(b).unwrap().chars()),
83+
pub fn insert_bytes(&mut self, b: &[u8]) -> PyResult<usize> {
84+
Ok(match self.0.as_mut() {
85+
CharSide(trie_chars) => trie_chars.insert_iter(from_utf8(b)?.chars()),
8586
ByteSide(trie_bytes) => trie_bytes.insert_ref_iter(b.iter()),
86-
}
87+
})
8788
}
8889

8990
pub fn get_bfs_order(&self) -> Vec<usize> {

0 commit comments

Comments
 (0)