@@ -72,6 +72,18 @@ def test_convert_hf_to_te_with_bf16():
7272 convert_llama_hf_to_te (model_hf )
7373
7474
75+ def test_convert_hf_to_te_with_bf16_tied_weights ():
76+ config = AutoConfig .from_pretrained (
77+ "nvidia/Llama-3.1-8B-Instruct-FP8" ,
78+ dtype = torch .bfloat16 ,
79+ num_hidden_layers = 2 ,
80+ tie_word_embeddings = True ,
81+ )
82+ model_hf = LlamaForCausalLM (config )
83+ model_hf .to (dtype = torch .bfloat16 ) # I think the original llama3 model doesn't initialize in bf16.
84+ convert_llama_hf_to_te (model_hf )
85+
86+
7587def test_convert_te_to_hf_with_bf16 ():
7688 config = NVLlamaConfig .from_pretrained (
7789 "nvidia/Llama-3.1-8B-Instruct-FP8" , dtype = torch .bfloat16 , num_hidden_layers = 2
@@ -81,6 +93,18 @@ def test_convert_te_to_hf_with_bf16():
8193 convert_llama_te_to_hf (model_te )
8294
8395
96+ def test_convert_te_to_hf_with_bf16_tied_weights ():
97+ config = NVLlamaConfig .from_pretrained (
98+ "nvidia/Llama-3.1-8B-Instruct-FP8" ,
99+ dtype = torch .bfloat16 ,
100+ num_hidden_layers = 2 ,
101+ tie_word_embeddings = True ,
102+ )
103+ model_te = NVLlamaForCausalLM (config )
104+ model_te .to (dtype = torch .float32 ) # I think the original llama3 model doesn't initialize in bf16.
105+ convert_llama_te_to_hf (model_te )
106+
107+
84108@pytest .mark .skipif (os .getenv ("CI" , "false" ) == "true" , reason = "Skipping test in CI not download llama3 models." )
85109@pytest .mark .parametrize (
86110 "upstream_model_name" , ["meta-llama/Llama-3.2-1B-Instruct" , "meta-llama/Llama-3.1-8B-Instruct" ]
0 commit comments