Skip to content

Commit 9e02042

Browse files
committed
Simplified code
1 parent 0c9cd70 commit 9e02042

1 file changed

Lines changed: 11 additions & 25 deletions

File tree

src/model.rs

Lines changed: 11 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -87,11 +87,9 @@ fn match_hub_layout(
8787
}
8888

8989
fn resolve_local(folder: &Path) -> Option<ResolvedPaths> {
90-
// Native model2vec.
9190
if let r @ Some(_) = match_local_layout(folder, folder, "config.json", ModelLayout::Native) {
9291
return r;
9392
}
94-
// Sentence Transformers root layout.
9593
if let r @ Some(_) = match_local_layout(
9694
folder,
9795
folder,
@@ -121,11 +119,9 @@ fn resolve_local(folder: &Path) -> Option<ResolvedPaths> {
121119
}
122120

123121
fn resolve_hub(repo: &ApiRepo, prefix: &str) -> Result<ResolvedPaths> {
124-
// Native model2vec.
125122
if let Some(r) = match_hub_layout(repo, prefix, prefix, "config.json", ModelLayout::Native) {
126123
return r;
127124
}
128-
// Sentence Transformers root layout.
129125
if let Some(r) = match_hub_layout(
130126
repo,
131127
prefix,
@@ -291,6 +287,13 @@ impl StaticModel {
291287
/// * `normalize` - Whether to L2-normalize output embeddings
292288
/// * `weights` - Optional per-token weights for quantized models
293289
/// * `token_mapping` - Optional token ID mapping for quantized models
290+
fn check_shape(len: usize, rows: usize, cols: usize) -> Result<()> {
291+
if len != rows * cols {
292+
return Err(anyhow!("embeddings length {} != rows {} * cols {}", len, rows, cols));
293+
}
294+
Ok(())
295+
}
296+
294297
pub fn from_owned(
295298
tokenizer: Tokenizer,
296299
embeddings: Vec<f32>,
@@ -300,14 +303,7 @@ impl StaticModel {
300303
weights: Option<Vec<f32>>,
301304
token_mapping: Option<Vec<usize>>,
302305
) -> Result<Self> {
303-
if embeddings.len() != rows * cols {
304-
return Err(anyhow!(
305-
"embeddings length {} != rows {} * cols {}",
306-
embeddings.len(),
307-
rows,
308-
cols
309-
));
310-
}
306+
Self::check_shape(embeddings.len(), rows, cols)?;
311307

312308
let (median_token_length, unk_token_id) = Self::compute_metadata(&tokenizer)?;
313309

@@ -345,14 +341,7 @@ impl StaticModel {
345341
weights: Option<&'static [f32]>,
346342
token_mapping: Option<&'static [usize]>,
347343
) -> Result<Self> {
348-
if embeddings.len() != rows * cols {
349-
return Err(anyhow!(
350-
"embeddings length {} != rows {} * cols {}",
351-
embeddings.len(),
352-
rows,
353-
cols
354-
));
355-
}
344+
Self::check_shape(embeddings.len(), rows, cols)?;
356345

357346
let (median_token_length, unk_token_id) = Self::compute_metadata(&tokenizer)?;
358347

@@ -375,10 +364,7 @@ impl StaticModel {
375364
lens.sort_unstable();
376365
let median_token_length = lens.get(lens.len() / 2).copied().unwrap_or(1);
377366

378-
let spec_json = tokenizer
379-
.to_string(false)
380-
.map_err(|e| anyhow!("tokenizer -> JSON failed: {e}"))?;
381-
let spec: Value = serde_json::from_str(&spec_json)?;
367+
let spec: Value = serde_json::to_value(tokenizer).context("failed to serialize tokenizer")?;
382368
let unk_token = spec
383369
.get("model")
384370
.and_then(|m| m.get("unk_token"))
@@ -430,7 +416,7 @@ impl StaticModel {
430416
.tokenizer
431417
.encode_batch_fast::<String>(
432418
truncated.into_iter().map(Into::into).collect(),
433-
/* add_special_tokens = */ false,
419+
false,
434420
)
435421
.expect("tokenization failed");
436422
for encoding in encodings {

0 commit comments

Comments
 (0)