File tree Expand file tree Collapse file tree
rust/src/embeddings/local Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -30,11 +30,29 @@ pub trait BertEmbed {
3030 late_chunking : Option < bool > ,
3131 ) -> Result < Vec < EmbeddingResult > , anyhow:: Error > ;
3232}
33+ // HuggingFace tokenizer configs represent special tokens either as a plain
34+ // string or as an AddedToken object. Both forms must be accepted.
35+ #[ derive( Debug , Deserialize , Clone ) ]
36+ #[ serde( untagged) ]
37+ pub enum SpecialToken {
38+ String ( String ) ,
39+ Added { content : String } ,
40+ }
41+
42+ impl SpecialToken {
43+ pub fn content ( & self ) -> & str {
44+ match self {
45+ SpecialToken :: String ( s) => s,
46+ SpecialToken :: Added { content } => content,
47+ }
48+ }
49+ }
50+
3351#[ derive( Debug , Deserialize , Clone ) ]
3452pub struct TokenizerConfig {
3553 pub max_length : Option < usize > ,
3654 pub model_max_length : Option < usize > ,
37- pub mask_token : Option < String > ,
55+ pub mask_token : Option < SpecialToken > ,
3856 pub added_tokens_decoder : Option < HashMap < String , AddedToken > > ,
3957}
4058
Original file line number Diff line number Diff line change @@ -100,9 +100,12 @@ impl OrtColbertEmbedder {
100100 } ;
101101
102102 let mut tokenizer = Tokenizer :: from_file ( tokenizer_filename) . map_err ( E :: msg) ?;
103- let mask_token = tokenizer_config. clone ( ) . mask_token ;
104- let pad_id = match mask_token. clone ( ) {
105- Some ( mask_token) => tokenizer_config. get_token_id_from_token ( & mask_token) ,
103+ let mask_token: Option < String > = tokenizer_config
104+ . mask_token
105+ . as_ref ( )
106+ . map ( |t| t. content ( ) . to_owned ( ) ) ;
107+ let pad_id = match mask_token. as_deref ( ) {
108+ Some ( mask_token) => tokenizer_config. get_token_id_from_token ( mask_token) ,
106109 None => None ,
107110 } ;
108111 let document_marker_token_id = tokenizer_config. get_token_id_from_token ( "[DocumentMarker]" ) ;
You can’t perform that action at this time.
0 commit comments