Skip to content

Commit 54ff24b

Browse files
committed
feat: add Jina v4 candle support
1 parent 6bc848e commit 54ff24b

5 files changed

Lines changed: 387 additions & 36 deletions

File tree

backends/candle/src/lib.rs

Lines changed: 156 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,10 @@ use crate::models::{
3131
use crate::models::{
3232
FlashBertModel, FlashDistilBertModel, FlashGTEModel, FlashJinaBertModel,
3333
FlashJinaCodeBertModel, FlashMistralModel, FlashModernBertModel, FlashNomicBertModel,
34-
FlashQwen2Model, FlashQwen3Model,
34+
FlashQwen2Model, FlashQwen3Model, LoraWeights,
3535
};
36+
#[cfg(feature = "cuda")]
37+
use std::{env, fs};
3638

3739
/// This enum is needed to be able to differentiate between jina models that also use
3840
/// the `bert` model type and valid Bert models.
@@ -88,6 +90,106 @@ impl<'de> Deserialize<'de> for BertConfigWrapper {
8890
}
8991
}
9092

93+
#[derive(Debug, Clone, Deserialize)]
94+
struct JinaV4Config {
95+
#[serde(default)]
96+
task_names: Vec<String>,
97+
text_config: Qwen2Config,
98+
}
99+
100+
fn is_jina_v4_config(value: &serde_json::Value) -> bool {
101+
value
102+
.get("architectures")
103+
.and_then(|v| v.as_array())
104+
.map(|items| {
105+
items
106+
.iter()
107+
.any(|item| item.as_str() == Some("JinaEmbeddingsV4Model"))
108+
})
109+
.unwrap_or(false)
110+
}
111+
112+
#[cfg(feature = "cuda")]
113+
fn load_jina_v4_lora(
114+
model_path: &Path,
115+
device: &Device,
116+
dtype: DType,
117+
config: &JinaV4Config,
118+
) -> Option<LoraWeights> {
119+
#[derive(Deserialize)]
120+
struct AdapterConfig {
121+
r: usize,
122+
lora_alpha: f32,
123+
}
124+
125+
let adapter_dir = model_path.join("adapters");
126+
let adapter_config_path = adapter_dir.join("adapter_config.json");
127+
let adapter_model_path = adapter_dir.join("adapter_model.safetensors");
128+
129+
if !adapter_config_path.exists() || !adapter_model_path.exists() {
130+
tracing::warn!("Jina v4 adapters not found; LoRA will be skipped.");
131+
return None;
132+
}
133+
134+
let adapter_config = match fs::read_to_string(&adapter_config_path) {
135+
Ok(content) => match serde_json::from_str::<AdapterConfig>(&content) {
136+
Ok(config) => config,
137+
Err(err) => {
138+
tracing::warn!("Failed to parse Jina v4 adapter_config.json: {err}");
139+
return None;
140+
}
141+
},
142+
Err(err) => {
143+
tracing::warn!("Failed to read Jina v4 adapter_config.json: {err}");
144+
return None;
145+
}
146+
};
147+
148+
let mut task = env::var("JINA_V4_TASK").unwrap_or_default();
149+
if task.is_empty() {
150+
task = config
151+
.task_names
152+
.first()
153+
.cloned()
154+
.unwrap_or_else(|| "retrieval".to_string());
155+
} else if !config.task_names.is_empty() && !config.task_names.contains(&task) {
156+
tracing::warn!(
157+
"JINA_V4_TASK={task} is not in config.task_names; defaulting to the first entry."
158+
);
159+
task = config
160+
.task_names
161+
.first()
162+
.cloned()
163+
.unwrap_or_else(|| "retrieval".to_string());
164+
}
165+
166+
let adapter_vb = unsafe { VarBuilder::from_mmaped_safetensors(&[adapter_model_path], dtype, device) };
167+
let adapter_vb = match adapter_vb.s() {
168+
Ok(vb) => vb,
169+
Err(err) => {
170+
tracing::warn!("Failed to load Jina v4 adapter weights: {err}");
171+
return None;
172+
}
173+
};
174+
175+
let lora_prefix = "base_model.model.model.language_model".to_string();
176+
let lora_check = format!(
177+
"{lora_prefix}.layers.0.self_attn.q_proj.lora_A.{task}.weight"
178+
);
179+
if !adapter_vb.contains_tensor(&lora_check) {
180+
tracing::warn!("Jina v4 adapter weights missing expected keys; LoRA will be skipped.");
181+
return None;
182+
}
183+
184+
Some(LoraWeights::new(
185+
adapter_vb,
186+
task,
187+
adapter_config.r,
188+
adapter_config.lora_alpha,
189+
lora_prefix,
190+
))
191+
}
192+
91193
#[derive(Deserialize)]
92194
#[serde(tag = "model_type", rename_all = "kebab-case")]
93195
enum Config {
@@ -111,6 +213,7 @@ enum Config {
111213
Qwen2(Qwen2Config),
112214
#[allow(dead_code)]
113215
Qwen3(Qwen3Config),
216+
JinaV4(JinaV4Config),
114217
Roberta(BertConfig),
115218
XlmRoberta(BertConfig),
116219
}
@@ -180,9 +283,28 @@ impl CandleBackend {
180283
let config: String = std::fs::read_to_string(model_path.join("config.json"))
181284
.context("Unable to read config file")
182285
.map_err(|err| BackendError::Start(format!("{err:?}")))?;
183-
let config: Config = serde_json::from_str(&config)
286+
let config_value: serde_json::Value = serde_json::from_str(&config)
184287
.context("Model is not supported")
185288
.map_err(|err| BackendError::Start(format!("{err:?}")))?;
289+
let config: Config = if is_jina_v4_config(&config_value) {
290+
if config_value
291+
.get("text_config")
292+
.and_then(|text| text.get("rope_scaling"))
293+
.is_some()
294+
{
295+
tracing::warn!(
296+
"Jina v4 rope_scaling is not supported in Candle; using base rope instead."
297+
);
298+
}
299+
let jina_config: JinaV4Config = serde_json::from_value(config_value.clone())
300+
.context("Model is not supported")
301+
.map_err(|err| BackendError::Start(format!("{err:?}")))?;
302+
Config::JinaV4(jina_config)
303+
} else {
304+
serde_json::from_value(config_value)
305+
.context("Model is not supported")
306+
.map_err(|err| BackendError::Start(format!("{err:?}")))?
307+
};
186308

187309
// Get candle device
188310
let device = if candle::utils::cuda_is_available() {
@@ -301,6 +423,10 @@ impl CandleBackend {
301423
"Qwen2 is only supported on Cuda devices in fp16 with flash attention enabled"
302424
.to_string(),
303425
)),
426+
(Config::JinaV4(_), Device::Cpu | Device::Metal(_)) => Err(BackendError::Start(
427+
"Jina v4 is only supported on Cuda devices in fp16 with flash attention enabled"
428+
.to_string(),
429+
)),
304430
(Config::Qwen3(config), Device::Cpu | Device::Metal(_)) => {
305431
tracing::info!("Starting Qwen3 model on {:?}", device);
306432
Ok(Box::new(Qwen3Model::load(vb, &config, model_type).s()?))
@@ -488,7 +614,34 @@ impl CandleBackend {
488614
}
489615
tracing::info!("Starting FlashQwen2 model on {:?}", device);
490616
Ok(Box::new(
491-
FlashQwen2Model::load(vb, &config, model_type).s()?,
617+
FlashQwen2Model::load(vb, &config, model_type, None, false).s()?,
618+
))
619+
}
620+
#[cfg(feature = "cuda")]
621+
(Config::JinaV4(config), Device::Cuda(_)) => {
622+
if dtype != DType::F16
623+
|| !cfg!(any(feature = "flash-attn", feature = "flash-attn-v1"))
624+
|| &std::env::var("USE_FLASH_ATTENTION")
625+
.unwrap_or("True".to_string())
626+
.to_lowercase()
627+
!= "true"
628+
{
629+
return Err(BackendError::Start(
630+
"Jina v4 is only supported on Cuda devices in fp16 with flash attention v2 enabled".to_string(),
631+
));
632+
}
633+
634+
let lora = load_jina_v4_lora(model_path, device, dtype, &config);
635+
tracing::info!("Starting Jina v4 model on {:?}", device);
636+
Ok(Box::new(
637+
FlashQwen2Model::load(
638+
vb,
639+
&config.text_config,
640+
model_type,
641+
lora,
642+
true,
643+
)
644+
.s()?,
492645
))
493646
}
494647
#[cfg(feature = "cuda")]

0 commit comments

Comments
 (0)