@@ -24,6 +24,12 @@ pub struct Qwen3Config {
2424 pub sliding_window : Option < usize > ,
2525 pub use_sliding_window : bool ,
2626 pub eos_token_id : usize ,
27+ // TODO(alvarobartt): Migrate to `is_causal` instead
28+ // https://github.com/huggingface/transformers/pull/43705
29+ #[ serde( default ) ]
30+ pub use_bidirectional_attention : Option < bool > ,
31+ #[ serde( default ) ]
32+ pub num_labels : Option < usize > ,
2733}
2834
2935struct Qwen3Attention {
@@ -379,11 +385,14 @@ pub struct Qwen3Model {
379385 embeddings : Embedding ,
380386 layers : Vec < Qwen3Layer > ,
381387 norm : RMSNorm ,
388+ // TODO(alvarobartt): Eventually extend Qwen3 for Voyage instead of adding `projection` here
389+ projection : Option < Linear > ,
382390 rotary_cache : ( Tensor , Tensor ) ,
383391 rotary_dim : usize ,
384392 pool : Pool ,
385393 num_attention_heads : usize ,
386394 pad_token_id : u32 ,
395+ use_bidirectional_attention : bool ,
387396
388397 dtype : DType ,
389398 device : Device ,
@@ -402,23 +411,44 @@ impl Qwen3Model {
402411
403412 // The Qwen3-Reranker models contain the `model` key
404413 // https://huggingface.co/collections/Qwen/qwen3-reranker-6841b22d0192d7ade9cdefea
405- let vb = if vb. contains_tensor ( "model.embed_tokens.weight" ) {
406- vb . pp ( "model" )
414+ let model_prefix = if vb. contains_tensor ( "model.embed_tokens.weight" ) {
415+ "model."
407416 } else {
408- vb
417+ ""
409418 } ;
410419
411420 let embeddings = Embedding :: new (
412- vb. pp ( " embed_tokens")
421+ vb. pp ( format ! ( "{model_prefix} embed_tokens") )
413422 . get ( ( config. vocab_size , config. hidden_size ) , "weight" ) ?,
414423 config. hidden_size ,
415424 ) ;
416425
417426 let layers = ( 0 ..config. num_hidden_layers )
418- . map ( |index| Qwen3Layer :: load ( vb. pp ( format ! ( "layers.{index}" ) ) , config) )
427+ . map ( |index| Qwen3Layer :: load ( vb. pp ( format ! ( "{model_prefix} layers.{index}" ) ) , config) )
419428 . collect :: < Result < Vec < _ > > > ( ) ?;
420429
421- let norm = RMSNorm :: load ( vb. pp ( "norm" ) , config. hidden_size , config. rms_norm_eps ) ?;
430+ let norm = RMSNorm :: load (
431+ vb. pp ( format ! ( "{model_prefix}norm" ) ) ,
432+ config. hidden_size ,
433+ config. rms_norm_eps ,
434+ ) ?;
435+
436+ let projection = if let Some ( num_labels) = config. num_labels {
437+ if vb. contains_tensor ( "linear.weight" ) {
438+ let projection_weight =
439+ vb. get ( ( num_labels, config. hidden_size ) , "linear.weight" ) ?;
440+ Some ( Linear :: new ( projection_weight, None , None ) )
441+ } else {
442+ tracing:: warn!(
443+ "num_labels is set but linear.weight not found, skipping projection layer"
444+ ) ;
445+ None
446+ }
447+ } else {
448+ None
449+ } ;
450+
451+ let use_bidirectional_attention = config. use_bidirectional_attention . unwrap_or ( false ) ;
422452
423453 let rotary_dim = config
424454 . head_dim
@@ -433,11 +463,13 @@ impl Qwen3Model {
433463 embeddings,
434464 layers,
435465 norm,
466+ projection,
436467 rotary_cache,
437468 rotary_dim,
438469 pool,
439470 pad_token_id : config. eos_token_id as u32 ,
440471 num_attention_heads : config. num_attention_heads ,
472+ use_bidirectional_attention,
441473 dtype : vb. dtype ( ) ,
442474 device : vb. device ( ) . clone ( ) ,
443475 span : tracing:: span!( tracing:: Level :: TRACE , "model" ) ,
@@ -555,7 +587,9 @@ impl Qwen3Model {
555587 ( input_ids, position_ids, input_lengths, Some ( attention_bias) )
556588 } ;
557589
558- let attention_bias = if let Some ( attn_bias) = attention_bias {
590+ let attention_bias = if self . use_bidirectional_attention {
591+ attention_bias
592+ } else if let Some ( attn_bias) = attention_bias {
559593 Some ( self . get_causal_attention_bias ( attn_bias) ?)
560594 } else {
561595 None
@@ -581,6 +615,12 @@ impl Qwen3Model {
581615
582616 let ( outputs, _) = self . norm . forward ( & hidden_states, None ) ?;
583617
618+ let outputs = if let Some ( ref projection) = self . projection {
619+ projection. forward ( & outputs) ?
620+ } else {
621+ outputs
622+ } ;
623+
584624 let has_pooling_requests = !batch. pooled_indices . is_empty ( ) ;
585625 let has_raw_requests = !batch. raw_indices . is_empty ( ) ;
586626
0 commit comments