Skip to content

Commit 54247e3

Browse files
Merge pull request #205 from arnaudbriche/fix-tokenizer-config-with-addedd-token
fix: support AddedToken objects in tokenizer config special tokens
2 parents 273bf9d + d2a5f9e commit 54247e3

2 files changed

Lines changed: 25 additions & 4 deletions

File tree

rust/src/embeddings/local/bert.rs

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,29 @@ pub trait BertEmbed {
3030
late_chunking: Option<bool>,
3131
) -> Result<Vec<EmbeddingResult>, anyhow::Error>;
3232
}
33+
// HuggingFace tokenizer configs represent special tokens either as a plain
34+
// string or as an AddedToken object. Both forms must be accepted.
35+
#[derive(Debug, Deserialize, Clone)]
36+
#[serde(untagged)]
37+
pub enum SpecialToken {
38+
String(String),
39+
Added { content: String },
40+
}
41+
42+
impl SpecialToken {
43+
pub fn content(&self) -> &str {
44+
match self {
45+
SpecialToken::String(s) => s,
46+
SpecialToken::Added { content } => content,
47+
}
48+
}
49+
}
50+
3351
#[derive(Debug, Deserialize, Clone)]
3452
pub struct TokenizerConfig {
3553
pub max_length: Option<usize>,
3654
pub model_max_length: Option<usize>,
37-
pub mask_token: Option<String>,
55+
pub mask_token: Option<SpecialToken>,
3856
pub added_tokens_decoder: Option<HashMap<String, AddedToken>>,
3957
}
4058

rust/src/embeddings/local/colbert.rs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -100,9 +100,12 @@ impl OrtColbertEmbedder {
100100
};
101101

102102
let mut tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
103-
let mask_token = tokenizer_config.clone().mask_token;
104-
let pad_id = match mask_token.clone() {
105-
Some(mask_token) => tokenizer_config.get_token_id_from_token(&mask_token),
103+
let mask_token: Option<String> = tokenizer_config
104+
.mask_token
105+
.as_ref()
106+
.map(|t| t.content().to_owned());
107+
let pad_id = match mask_token.as_deref() {
108+
Some(mask_token) => tokenizer_config.get_token_id_from_token(mask_token),
106109
None => None,
107110
};
108111
let document_marker_token_id = tokenizer_config.get_token_id_from_token("[DocumentMarker]");

0 commit comments

Comments
 (0)