Skip to content

Commit bd98dfb

Browse files
committed
Fix NomicBertConfig & FlashNomicBertConfig
Signed-off-by: Alvaro Bartolome <36760800+alvarobartt@users.noreply.github.com>
1 parent d2deb20 commit bd98dfb

2 files changed

Lines changed: 31 additions & 16 deletions

File tree

backends/candle/src/models/flash_nomic.rs

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ pub struct FlashNomicBertModel {
196196
pool: Pool,
197197
pub device: Device,
198198

199-
max_trained_positions: u32,
199+
max_position_embeddings: u32,
200200
rotary_cache: (Tensor, Tensor),
201201
scaled_rotary_cache: Option<(Tensor, Tensor)>,
202202

@@ -233,14 +233,21 @@ impl FlashNomicBertModel {
233233
let embeddings = NomicBertEmbeddings::load(vb.clone(), config)?;
234234
let encoder = NomicBertEncoder::load(vb.pp("encoder"), config)?;
235235

236+
let max_position_embeddings = match config.max_position_embeddings {
237+
Some(max_position_embeddings) => max_position_embeddings,
238+
None => match config.max_trained_positions {
239+
Some(max_trained_positions) => max_trained_positions,
240+
None => 2048,
241+
},
242+
};
243+
236244
let rotary_dim = encoder.layers[0].attention.attention_head_size;
237245
let inv_freqs = get_inv_freqs(rotary_dim, config.rotary_emb_base, vb.device(), None)?;
238246
let rotary_cache = get_cos_sin(config.n_positions, &inv_freqs, vb.dtype(), false)?;
239247

240248
let scaled_rotary_cache = if let Some(scaling_factor) = config.rotary_scaling_factor {
241249
let new_base = (config.rotary_emb_base
242-
* ((scaling_factor * config.n_positions as f32
243-
/ config.max_trained_positions as f32)
250+
* ((scaling_factor * config.n_positions as f32 / max_position_embeddings as f32)
244251
- (scaling_factor - 1.0)))
245252
.powi((rotary_dim as f32 / (rotary_dim as f32 - 2.0)) as i32);
246253
let inv_freqs = get_inv_freqs(rotary_dim, new_base, vb.device(), None)?;
@@ -258,7 +265,7 @@ impl FlashNomicBertModel {
258265
embeddings,
259266
encoder,
260267
pool,
261-
max_trained_positions: config.max_trained_positions as u32,
268+
max_position_embeddings: max_position_embeddings as u32,
262269
rotary_cache,
263270
scaled_rotary_cache,
264271
device: vb.device().clone(),
@@ -283,7 +290,7 @@ impl FlashNomicBertModel {
283290
)?;
284291

285292
let (cos, sin) = if self.scaled_rotary_cache.is_some()
286-
&& batch.max_length > self.max_trained_positions
293+
&& batch.max_length > self.max_position_embeddings
287294
{
288295
let cos = index_select(
289296
&self.scaled_rotary_cache.as_ref().unwrap().0,

backends/candle/src/models/nomic.rs

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,13 @@ pub struct NomicConfig {
1818
pub mlp_fc1_bias: bool,
1919
pub mlp_fc2_bias: bool,
2020
pub rotary_scaling_factor: Option<f32>,
21-
#[serde(default = "default_max_trained_positions")]
22-
pub max_trained_positions: usize,
21+
22+
// NOTE: `max_trained_positions` is specific for NomicBERT when it required custom code, but
23+
// since Transformers v5 it's no longer required, and it now defines `max_position_embeddings`
24+
// in the `config.json` instead. Not included as an `alias` since both can be present at the
25+
// same time, see https://huggingface.co/nomic-ai/nomic-embed-text-v1.5/blob/e9b6763023c676ca8431644204f50c2b100d9aab/config.json#L33-L34
26+
pub max_trained_positions: Option<usize>,
27+
pub max_position_embeddings: Option<usize>,
2328

2429
pub moe_every_n_layers: Option<usize>,
2530
pub moe_normalize_expert_weights: Option<bool>,
@@ -39,10 +44,6 @@ pub struct NomicConfig {
3944
pub layer_norm_epsilon: f32,
4045
}
4146

42-
fn default_max_trained_positions() -> usize {
43-
2048
44-
}
45-
4647
impl NomicConfig {
4748
// For now, we only support these parameters
4849
pub fn valid(&self) -> bool {
@@ -668,7 +669,7 @@ pub struct NomicBertModel {
668669
dtype: DType,
669670

670671
rotary_dim: usize,
671-
max_trained_positions: u32,
672+
max_position_embeddings: u32,
672673
rotary_cache: (Tensor, Tensor),
673674
scaled_rotary_cache: Option<(Tensor, Tensor)>,
674675

@@ -702,15 +703,22 @@ impl NomicBertModel {
702703
let embeddings = NomicBertEmbeddings::load(vb.clone(), config)?;
703704
let encoder = NomicBertEncoder::load(vb.pp("encoder"), config)?;
704705

706+
let max_position_embeddings = match config.max_position_embeddings {
707+
Some(max_position_embeddings) => max_position_embeddings,
708+
None => match config.max_trained_positions {
709+
Some(max_trained_positions) => max_trained_positions,
710+
None => 2048,
711+
},
712+
};
713+
705714
let rotary_dim = encoder.layers[0].attention.attention_head_size;
706715
let inv_freqs_tensor =
707716
get_inv_freqs(rotary_dim, config.rotary_emb_base, vb.device(), None)?;
708717
let rotary_cache = get_cos_sin(config.n_positions, &inv_freqs_tensor, vb.dtype(), true)?;
709718

710719
let scaled_rotary_cache = if let Some(scaling_factor) = config.rotary_scaling_factor {
711720
let new_base = (config.rotary_emb_base
712-
* ((scaling_factor * config.n_positions as f32
713-
/ config.max_trained_positions as f32)
721+
* ((scaling_factor * config.n_positions as f32 / max_position_embeddings as f32)
714722
- (scaling_factor - 1.0)))
715723
.powi((rotary_dim as f32 / (rotary_dim as f32 - 2.0)) as i32);
716724
let inv_freqs_tensor = get_inv_freqs(rotary_dim, new_base, vb.device(), None)?;
@@ -729,7 +737,7 @@ impl NomicBertModel {
729737
encoder,
730738
pool,
731739
rotary_dim,
732-
max_trained_positions: config.max_trained_positions as u32,
740+
max_position_embeddings: max_position_embeddings as u32,
733741
rotary_cache,
734742
scaled_rotary_cache,
735743
num_attention_heads: config.n_head,
@@ -855,7 +863,7 @@ impl NomicBertModel {
855863
Tensor::from_vec(input_lengths, (batch_size, 1), &self.device)?.to_dtype(self.dtype)?;
856864

857865
let (cos, sin) = if self.scaled_rotary_cache.is_some()
858-
&& batch.max_length > self.max_trained_positions
866+
&& batch.max_length > self.max_position_embeddings
859867
{
860868
let cos = self
861869
.scaled_rotary_cache

0 commit comments

Comments
 (0)