Skip to content

Commit f40e4b0

Browse files
committed
update amplify model with some loss tests
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
1 parent f6e3139 commit f40e4b0

4 files changed

Lines changed: 68 additions & 9 deletions

File tree

models/amplify/.devcontainer/devcontainer.json

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,11 @@
22
// README at: https://github.com/devcontainers/templates/tree/main/src/docker-existing-dockerfile
33
{
44
"name": "Existing Dockerfile",
5-
"build": {
6-
"context": "..",
7-
"dockerfile": "Dockerfile.dev"
8-
},
5+
"image": "svcbionemo023/bionemo-framework:amplify-model-devcontainer-082025",
96
"mounts": [
107
"source=${localEnv:HOME}/.cache,target=/home/ubuntu/.cache,type=bind,consistency=cached"
118
],
12-
"postCreateCommand": "pip install -e .[convert,test]",
9+
"postCreateCommand": "PIP_CONSTRAINT= pip install -e .",
1310
"remoteUser": "ubuntu",
1411
"runArgs": [
1512
"--gpus=all",

models/amplify/tests/conftest.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def input_data(tokenizer):
6868
tokenizer=tokenizer,
6969
mlm_probability=0.15,
7070
pad_to_multiple_of=1024,
71+
seed=42,
7172
)
7273

7374
def tokenize_function(examples):

models/amplify/tests/test_amplify_model.py

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,22 @@
2323
from conftest import requires_fp8
2424
from transformer_engine.common.recipe import DelayedScaling, Format
2525

26-
import amplify.amplify_hf as amp_hf
2726
import amplify.amplify_te as amp_te
2827
from amplify.state_dict_convert import convert_amplify_hf_to_te
2928

3029

30+
try:
31+
import xformers
32+
except ImportError:
33+
xformers = None
34+
35+
if xformers is not None:
36+
import amplify.amplify_hf as amp_hf
37+
else:
38+
amp_hf = None
39+
40+
41+
@pytest.mark.skipif(amp_hf is None, reason="xformers is not installed")
3142
def test_amplify_hf_model(config, input_data):
3243
model = amp_hf.AMPLIFY(config)
3344
model.to("cuda")
@@ -67,6 +78,7 @@ def test_te_model_has_all_te_layers(config):
6778
assert not isinstance(module, nn.RMSNorm), f"Vanilla RMSNorm layer found in {name}"
6879

6980

81+
@pytest.mark.skipif(amp_hf is None, reason="xformers is not installed")
7082
def test_models_have_identical_outputs(input_data):
7183
model_hf = amp_hf.AMPLIFY.from_pretrained("chandar-lab/AMPLIFY_120M")
7284
model_te = convert_amplify_hf_to_te(model_hf)
@@ -84,6 +96,7 @@ def test_models_have_identical_outputs(input_data):
8496
torch.testing.assert_close(outputs_hf.loss, outputs_te.loss, atol=1e-2, rtol=1e-3)
8597

8698

99+
@pytest.mark.skipif(amp_hf is None, reason="xformers is not installed")
87100
def test_converted_model_roundtrip(input_data, tmp_path):
88101
model_hf = amp_hf.AMPLIFY.from_pretrained("chandar-lab/AMPLIFY_120M")
89102
model_te = convert_amplify_hf_to_te(model_hf)
@@ -107,6 +120,7 @@ def test_converted_model_roundtrip(input_data, tmp_path):
107120
torch.testing.assert_close(outputs_hf.loss, outputs_te.loss, atol=1e-2, rtol=1e-3)
108121

109122

123+
@pytest.mark.skipif(amp_hf is None, reason="xformers is not installed")
110124
def test_convert_state_dict():
111125
model_hf = amp_hf.AMPLIFY.from_pretrained("chandar-lab/AMPLIFY_120M")
112126
model_te = convert_amplify_hf_to_te(model_hf)
@@ -168,3 +182,52 @@ def test_convert_state_dict():
168182
te_state_dict_keys.remove("decoder.bias")
169183

170184
assert len(te_state_dict_keys) == 0
185+
186+
187+
def test_hf_trained_model_loss(input_data):
188+
model = amp_hf.AMPLIFY.from_pretrained("chandar-lab/AMPLIFY_120M")
189+
model.to("cuda", dtype=torch.bfloat16)
190+
input_data = {k: v.to("cuda") for k, v in input_data.items()}
191+
model.eval()
192+
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
193+
output = model(**input_data)
194+
195+
torch.testing.assert_close(output.loss.detach().cpu(), torch.tensor(2.4), atol=1e-1, rtol=1e-2)
196+
197+
198+
def test_te_trained_model_loss(input_data):
199+
model_hf = amp_hf.AMPLIFY.from_pretrained("chandar-lab/AMPLIFY_120M")
200+
model = convert_amplify_hf_to_te(model_hf)
201+
model.to("cuda", dtype=torch.bfloat16)
202+
input_data = {k: v.to("cuda") for k, v in input_data.items()}
203+
model.eval()
204+
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
205+
output = model(**input_data)
206+
207+
torch.testing.assert_close(output.loss.detach().cpu(), torch.tensor(2.4), atol=1e-1, rtol=1e-2)
208+
209+
210+
def test_hf_reinitialized_model_loss(input_data):
211+
config = amp_hf.AMPLIFYConfig.from_pretrained("chandar-lab/AMPLIFY_120M")
212+
model = amp_hf.AMPLIFY(config)
213+
model.to("cuda", dtype=torch.bfloat16)
214+
input_data = {k: v.to("cuda") for k, v in input_data.items()}
215+
model.eval()
216+
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
217+
output = model(**input_data)
218+
219+
loss = output.loss.detach().cpu()
220+
assert loss < 3.5, f"Loss is {loss}, expected less than 3.5"
221+
222+
223+
def test_te_reinitialized_model_loss(input_data):
224+
config = amp_te.AMPLIFYConfig.from_pretrained("chandar-lab/AMPLIFY_120M")
225+
model = amp_te.AMPLIFYForMaskedLM(config)
226+
model.to("cuda", dtype=torch.bfloat16)
227+
input_data = {k: v.to("cuda") for k, v in input_data.items()}
228+
model.eval()
229+
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
230+
output = model(**input_data)
231+
232+
loss = output.loss.detach().cpu()
233+
assert loss < 3.5, f"Loss is {loss}, expected less than 3.5"

recipes/esm2_accelerate/train.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,13 +46,11 @@ def main(args: DictConfig):
4646
)
4747

4848
config = AutoConfig.from_pretrained(args.model_tag, trust_remote_code=True)
49-
config.max_seq_length = args.max_seq_length
50-
config.micro_batch_size = args.trainer.per_device_train_batch_size
5149
model = AutoModelForMaskedLM.from_config(config, trust_remote_code=True, torch_dtype=torch.bfloat16)
5250

5351
train_dataset, eval_dataset, data_collator = create_datasets_and_collator(
5452
tokenizer_name=args.model_tag,
55-
max_length=config.max_seq_length,
53+
max_length=args.max_seq_length,
5654
)
5755

5856
training_args = TrainingArguments(**args.trainer)

0 commit comments

Comments
 (0)