@@ -31,8 +31,10 @@ use crate::models::{
3131use crate :: models:: {
3232 FlashBertModel , FlashDistilBertModel , FlashGTEModel , FlashJinaBertModel ,
3333 FlashJinaCodeBertModel , FlashMistralModel , FlashModernBertModel , FlashNomicBertModel ,
34- FlashQwen2Model , FlashQwen3Model ,
34+ FlashQwen2Model , FlashQwen3Model , LoraWeights ,
3535} ;
36+ #[ cfg( feature = "cuda" ) ]
37+ use std:: { env, fs} ;
3638
3739/// This enum is needed to be able to differentiate between jina models that also use
3840/// the `bert` model type and valid Bert models.
@@ -88,6 +90,106 @@ impl<'de> Deserialize<'de> for BertConfigWrapper {
8890 }
8991}
9092
93+ #[ derive( Debug , Clone , Deserialize ) ]
94+ struct JinaV4Config {
95+ #[ serde( default ) ]
96+ task_names : Vec < String > ,
97+ text_config : Qwen2Config ,
98+ }
99+
100+ fn is_jina_v4_config ( value : & serde_json:: Value ) -> bool {
101+ value
102+ . get ( "architectures" )
103+ . and_then ( |v| v. as_array ( ) )
104+ . map ( |items| {
105+ items
106+ . iter ( )
107+ . any ( |item| item. as_str ( ) == Some ( "JinaEmbeddingsV4Model" ) )
108+ } )
109+ . unwrap_or ( false )
110+ }
111+
112+ #[ cfg( feature = "cuda" ) ]
113+ fn load_jina_v4_lora (
114+ model_path : & Path ,
115+ device : & Device ,
116+ dtype : DType ,
117+ config : & JinaV4Config ,
118+ ) -> Option < LoraWeights > {
119+ #[ derive( Deserialize ) ]
120+ struct AdapterConfig {
121+ r : usize ,
122+ lora_alpha : f32 ,
123+ }
124+
125+ let adapter_dir = model_path. join ( "adapters" ) ;
126+ let adapter_config_path = adapter_dir. join ( "adapter_config.json" ) ;
127+ let adapter_model_path = adapter_dir. join ( "adapter_model.safetensors" ) ;
128+
129+ if !adapter_config_path. exists ( ) || !adapter_model_path. exists ( ) {
130+ tracing:: warn!( "Jina v4 adapters not found; LoRA will be skipped." ) ;
131+ return None ;
132+ }
133+
134+ let adapter_config = match fs:: read_to_string ( & adapter_config_path) {
135+ Ok ( content) => match serde_json:: from_str :: < AdapterConfig > ( & content) {
136+ Ok ( config) => config,
137+ Err ( err) => {
138+ tracing:: warn!( "Failed to parse Jina v4 adapter_config.json: {err}" ) ;
139+ return None ;
140+ }
141+ } ,
142+ Err ( err) => {
143+ tracing:: warn!( "Failed to read Jina v4 adapter_config.json: {err}" ) ;
144+ return None ;
145+ }
146+ } ;
147+
148+ let mut task = env:: var ( "JINA_V4_TASK" ) . unwrap_or_default ( ) ;
149+ if task. is_empty ( ) {
150+ task = config
151+ . task_names
152+ . first ( )
153+ . cloned ( )
154+ . unwrap_or_else ( || "retrieval" . to_string ( ) ) ;
155+ } else if !config. task_names . is_empty ( ) && !config. task_names . contains ( & task) {
156+ tracing:: warn!(
157+ "JINA_V4_TASK={task} is not in config.task_names; defaulting to the first entry."
158+ ) ;
159+ task = config
160+ . task_names
161+ . first ( )
162+ . cloned ( )
163+ . unwrap_or_else ( || "retrieval" . to_string ( ) ) ;
164+ }
165+
166+ let adapter_vb = unsafe { VarBuilder :: from_mmaped_safetensors ( & [ adapter_model_path] , dtype, device) } ;
167+ let adapter_vb = match adapter_vb. s ( ) {
168+ Ok ( vb) => vb,
169+ Err ( err) => {
170+ tracing:: warn!( "Failed to load Jina v4 adapter weights: {err}" ) ;
171+ return None ;
172+ }
173+ } ;
174+
175+ let lora_prefix = "base_model.model.model.language_model" . to_string ( ) ;
176+ let lora_check = format ! (
177+ "{lora_prefix}.layers.0.self_attn.q_proj.lora_A.{task}.weight"
178+ ) ;
179+ if !adapter_vb. contains_tensor ( & lora_check) {
180+ tracing:: warn!( "Jina v4 adapter weights missing expected keys; LoRA will be skipped." ) ;
181+ return None ;
182+ }
183+
184+ Some ( LoraWeights :: new (
185+ adapter_vb,
186+ task,
187+ adapter_config. r ,
188+ adapter_config. lora_alpha ,
189+ lora_prefix,
190+ ) )
191+ }
192+
91193#[ derive( Deserialize ) ]
92194#[ serde( tag = "model_type" , rename_all = "kebab-case" ) ]
93195enum Config {
@@ -111,6 +213,7 @@ enum Config {
111213 Qwen2 ( Qwen2Config ) ,
112214 #[ allow( dead_code) ]
113215 Qwen3 ( Qwen3Config ) ,
216+ JinaV4 ( JinaV4Config ) ,
114217 Roberta ( BertConfig ) ,
115218 XlmRoberta ( BertConfig ) ,
116219}
@@ -180,9 +283,28 @@ impl CandleBackend {
180283 let config: String = std:: fs:: read_to_string ( model_path. join ( "config.json" ) )
181284 . context ( "Unable to read config file" )
182285 . map_err ( |err| BackendError :: Start ( format ! ( "{err:?}" ) ) ) ?;
183- let config : Config = serde_json:: from_str ( & config)
286+ let config_value : serde_json :: Value = serde_json:: from_str ( & config)
184287 . context ( "Model is not supported" )
185288 . map_err ( |err| BackendError :: Start ( format ! ( "{err:?}" ) ) ) ?;
289+ let config: Config = if is_jina_v4_config ( & config_value) {
290+ if config_value
291+ . get ( "text_config" )
292+ . and_then ( |text| text. get ( "rope_scaling" ) )
293+ . is_some ( )
294+ {
295+ tracing:: warn!(
296+ "Jina v4 rope_scaling is not supported in Candle; using base rope instead."
297+ ) ;
298+ }
299+ let jina_config: JinaV4Config = serde_json:: from_value ( config_value. clone ( ) )
300+ . context ( "Model is not supported" )
301+ . map_err ( |err| BackendError :: Start ( format ! ( "{err:?}" ) ) ) ?;
302+ Config :: JinaV4 ( jina_config)
303+ } else {
304+ serde_json:: from_value ( config_value)
305+ . context ( "Model is not supported" )
306+ . map_err ( |err| BackendError :: Start ( format ! ( "{err:?}" ) ) ) ?
307+ } ;
186308
187309 // Get candle device
188310 let device = if candle:: utils:: cuda_is_available ( ) {
@@ -301,6 +423,10 @@ impl CandleBackend {
301423 "Qwen2 is only supported on Cuda devices in fp16 with flash attention enabled"
302424 . to_string ( ) ,
303425 ) ) ,
426+ ( Config :: JinaV4 ( _) , Device :: Cpu | Device :: Metal ( _) ) => Err ( BackendError :: Start (
427+ "Jina v4 is only supported on Cuda devices in fp16 with flash attention enabled"
428+ . to_string ( ) ,
429+ ) ) ,
304430 ( Config :: Qwen3 ( config) , Device :: Cpu | Device :: Metal ( _) ) => {
305431 tracing:: info!( "Starting Qwen3 model on {:?}" , device) ;
306432 Ok ( Box :: new ( Qwen3Model :: load ( vb, & config, model_type) . s ( ) ?) )
@@ -488,7 +614,34 @@ impl CandleBackend {
488614 }
489615 tracing:: info!( "Starting FlashQwen2 model on {:?}" , device) ;
490616 Ok ( Box :: new (
491- FlashQwen2Model :: load ( vb, & config, model_type) . s ( ) ?,
617+ FlashQwen2Model :: load ( vb, & config, model_type, None , false ) . s ( ) ?,
618+ ) )
619+ }
620+ #[ cfg( feature = "cuda" ) ]
621+ ( Config :: JinaV4 ( config) , Device :: Cuda ( _) ) => {
622+ if dtype != DType :: F16
623+ || !cfg ! ( any( feature = "flash-attn" , feature = "flash-attn-v1" ) )
624+ || & std:: env:: var ( "USE_FLASH_ATTENTION" )
625+ . unwrap_or ( "True" . to_string ( ) )
626+ . to_lowercase ( )
627+ != "true"
628+ {
629+ return Err ( BackendError :: Start (
630+ "Jina v4 is only supported on Cuda devices in fp16 with flash attention v2 enabled" . to_string ( ) ,
631+ ) ) ;
632+ }
633+
634+ let lora = load_jina_v4_lora ( model_path, device, dtype, & config) ;
635+ tracing:: info!( "Starting Jina v4 model on {:?}" , device) ;
636+ Ok ( Box :: new (
637+ FlashQwen2Model :: load (
638+ vb,
639+ & config. text_config ,
640+ model_type,
641+ lora,
642+ true ,
643+ )
644+ . s ( ) ?,
492645 ) )
493646 }
494647 #[ cfg( feature = "cuda" ) ]
0 commit comments