@@ -67,6 +67,10 @@ defmodule Bumblebee.Text.Bert do
6767 default: 0.02 ,
6868 doc:
6969 "the standard deviation of the normal initializer used for initializing kernel parameters"
70+ ] ,
71+ tie_word_embeddings: [
72+ default: true ,
73+ doc: "whether to tie input and output embedding weights"
7074 ]
7175 ] ++ Shared . common_options ( [ :use_cross_attention , :num_labels , :id_to_label ] )
7276
@@ -606,15 +610,16 @@ defmodule Bumblebee.Text.Bert do
606610 attention_dropout_rate: { "attention_probs_dropout_prob" , number ( ) } ,
607611 classifier_dropout_rate: { "classifier_dropout" , optional ( number ( ) ) } ,
608612 layer_norm_epsilon: { "layer_norm_eps" , number ( ) } ,
609- initializer_scale: { "initializer_range" , number ( ) }
613+ initializer_scale: { "initializer_range" , number ( ) } ,
614+ tie_word_embeddings: { "tie_word_embeddings" , boolean ( ) }
610615 ) ++ Shared . common_options_from_transformers ( data , spec )
611616
612617 @ for . config ( spec , opts )
613618 end
614619 end
615620
616621 defimpl Bumblebee.HuggingFace.Transformers.Model do
617- def params_mapping ( _spec ) do
622+ def params_mapping ( spec ) do
618623 % {
619624 "embedder.token_embedding" => "bert.embeddings.word_embeddings" ,
620625 "embedder.position_embedding" => "bert.embeddings.position_embeddings" ,
@@ -645,7 +650,11 @@ defmodule Bumblebee.Text.Bert do
645650 "pooler.output" => "bert.pooler.dense" ,
646651 "language_modeling_head.dense" => "cls.predictions.transform.dense" ,
647652 "language_modeling_head.norm" => "cls.predictions.transform.LayerNorm" ,
648- "language_modeling_head.output" => "cls.predictions.decoder" ,
653+ "language_modeling_head.output" =>
654+ if ( spec . tie_word_embeddings ,
655+ do: "bert.embeddings.word_embeddings" ,
656+ else: "cls.predictions.decoder"
657+ ) ,
649658 "language_modeling_head.bias" => "cls.predictions" ,
650659 "next_sentence_prediction_head.output" => "cls.seq_relationship" ,
651660 "sequence_classification_head.output" => "classifier" ,
0 commit comments