Skip to content

Commit 622b6ed

Browse files
committed
Support loading bert with tied weights
1 parent d0774e8 commit 622b6ed

1 file changed

Lines changed: 12 additions & 3 deletions

File tree

lib/bumblebee/text/bert.ex

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,10 @@ defmodule Bumblebee.Text.Bert do
6767
default: 0.02,
6868
doc:
6969
"the standard deviation of the normal initializer used for initializing kernel parameters"
70+
],
71+
tie_word_embeddings: [
72+
default: true,
73+
doc: "whether to tie input and output embedding weights"
7074
]
7175
] ++ Shared.common_options([:use_cross_attention, :num_labels, :id_to_label])
7276

@@ -606,15 +610,16 @@ defmodule Bumblebee.Text.Bert do
606610
attention_dropout_rate: {"attention_probs_dropout_prob", number()},
607611
classifier_dropout_rate: {"classifier_dropout", optional(number())},
608612
layer_norm_epsilon: {"layer_norm_eps", number()},
609-
initializer_scale: {"initializer_range", number()}
613+
initializer_scale: {"initializer_range", number()},
614+
tie_word_embeddings: {"tie_word_embeddings", boolean()}
610615
) ++ Shared.common_options_from_transformers(data, spec)
611616

612617
@for.config(spec, opts)
613618
end
614619
end
615620

616621
defimpl Bumblebee.HuggingFace.Transformers.Model do
617-
def params_mapping(_spec) do
622+
def params_mapping(spec) do
618623
%{
619624
"embedder.token_embedding" => "bert.embeddings.word_embeddings",
620625
"embedder.position_embedding" => "bert.embeddings.position_embeddings",
@@ -645,7 +650,11 @@ defmodule Bumblebee.Text.Bert do
645650
"pooler.output" => "bert.pooler.dense",
646651
"language_modeling_head.dense" => "cls.predictions.transform.dense",
647652
"language_modeling_head.norm" => "cls.predictions.transform.LayerNorm",
648-
"language_modeling_head.output" => "cls.predictions.decoder",
653+
"language_modeling_head.output" =>
654+
if(spec.tie_word_embeddings,
655+
do: "bert.embeddings.word_embeddings",
656+
else: "cls.predictions.decoder"
657+
),
649658
"language_modeling_head.bias" => "cls.predictions",
650659
"next_sentence_prediction_head.output" => "cls.seq_relationship",
651660
"sequence_classification_head.output" => "classifier",

0 commit comments

Comments
 (0)