Skip to content

Commit 0b1153f

Browse files
committed
remove amplify kwargs
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
1 parent 8a26723 commit 0b1153f

4 files changed

Lines changed: 52 additions & 10 deletions

File tree

models/amplify/.devcontainer/devcontainer.json

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,11 @@
22
// README at: https://github.com/devcontainers/templates/tree/main/src/docker-existing-dockerfile
33
{
44
"name": "Existing Dockerfile",
5-
"build": {
6-
"context": "..",
7-
"dockerfile": "Dockerfile.dev"
8-
},
5+
"image": "svcbionemo023/bionemo-framework:amplify-model-devcontainer-082025",
96
"mounts": [
107
"source=${localEnv:HOME}/.cache,target=/home/ubuntu/.cache,type=bind,consistency=cached"
118
],
12-
"postCreateCommand": "pip install -e .[convert,test]",
9+
"postCreateCommand": "PIP_CONSTRAINT= pip install -e .",
1310
"remoteUser": "ubuntu",
1411
"runArgs": [
1512
"--gpus=all",

models/amplify/src/amplify/amplify_te.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,6 @@ def forward(
215215
output_hidden_states=False,
216216
output_attentions=False,
217217
labels=None,
218-
**kwargs,
219218
) -> BaseModelOutput:
220219
"""Forward pass of the AMPLIFY model.
221220
@@ -225,7 +224,6 @@ def forward(
225224
output_hidden_states (bool): Whether to output the hidden states.
226225
output_attentions (bool): Whether to output the attention weights.
227226
labels (torch.Tensor): The labels.
228-
**kwargs: Additional arguments.
229227
230228
Returns:
231229
BaseModelOutput: The output of the model.
@@ -299,7 +297,6 @@ def forward(
299297
output_hidden_states=False,
300298
output_attentions=False,
301299
labels=None,
302-
**kwargs,
303300
) -> MaskedLMOutput:
304301
"""Forward pass of the AMPLIFYForMaskedLM model.
305302
@@ -309,7 +306,6 @@ def forward(
309306
output_hidden_states (bool): Whether to output the hidden states.
310307
output_attentions (bool): Whether to output the attention weights.
311308
labels (torch.Tensor): The labels.
312-
**kwargs: Additional arguments.
313309
314310
Returns:
315311
MaskedLMOutput: The output of the model.
@@ -320,7 +316,6 @@ def forward(
320316
output_hidden_states,
321317
output_attentions,
322318
labels,
323-
**kwargs,
324319
)
325320

326321
# Classification head with layer norm

models/amplify/tests/conftest.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def input_data(tokenizer):
6868
tokenizer=tokenizer,
6969
mlm_probability=0.15,
7070
pad_to_multiple_of=1024,
71+
seed=42,
7172
)
7273

7374
def tokenize_function(examples):

models/amplify/tests/test_amplify_model.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,3 +168,52 @@ def test_convert_state_dict():
168168
te_state_dict_keys.remove("decoder.bias")
169169

170170
assert len(te_state_dict_keys) == 0
171+
172+
173+
def test_hf_trained_model_loss(input_data):
174+
model = amp_hf.AMPLIFY.from_pretrained("chandar-lab/AMPLIFY_120M")
175+
model.to("cuda", dtype=torch.bfloat16)
176+
input_data = {k: v.to("cuda") for k, v in input_data.items()}
177+
model.eval()
178+
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
179+
output = model(**input_data)
180+
181+
torch.testing.assert_close(output.loss.detach().cpu(), torch.tensor(2.4), atol=1e-1, rtol=1e-2)
182+
183+
184+
def test_te_trained_model_loss(input_data):
185+
model_hf = amp_hf.AMPLIFY.from_pretrained("chandar-lab/AMPLIFY_120M")
186+
model = convert_amplify_hf_to_te(model_hf)
187+
model.to("cuda", dtype=torch.bfloat16)
188+
input_data = {k: v.to("cuda") for k, v in input_data.items()}
189+
model.eval()
190+
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
191+
output = model(**input_data)
192+
193+
torch.testing.assert_close(output.loss.detach().cpu(), torch.tensor(2.4), atol=1e-1, rtol=1e-2)
194+
195+
196+
def test_hf_reinitialized_model_loss(input_data):
197+
config = amp_hf.AMPLIFYConfig.from_pretrained("chandar-lab/AMPLIFY_120M")
198+
model = amp_hf.AMPLIFY(config)
199+
model.to("cuda", dtype=torch.bfloat16)
200+
input_data = {k: v.to("cuda") for k, v in input_data.items()}
201+
model.eval()
202+
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
203+
output = model(**input_data)
204+
205+
loss = output.loss.detach().cpu()
206+
assert loss < 3.5, f"Loss is {loss}, expected less than 3.5"
207+
208+
209+
def test_te_reinitialized_model_loss(input_data):
210+
config = amp_te.AMPLIFYConfig.from_pretrained("chandar-lab/AMPLIFY_120M")
211+
model = amp_te.AMPLIFYForMaskedLM(config)
212+
model.to("cuda", dtype=torch.bfloat16)
213+
input_data = {k: v.to("cuda") for k, v in input_data.items()}
214+
model.eval()
215+
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
216+
output = model(**input_data)
217+
218+
loss = output.loss.detach().cpu()
219+
assert loss < 3.5, f"Loss is {loss}, expected less than 3.5"

0 commit comments

Comments
 (0)