Skip to content

Commit cd74c2b

Browse files
authored
Update amplify model, add loss tests (#1135)
Not sure we support changing configs like this, but this would fail without handling the non-swiglu variant <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **Bug Fixes** - Fixes feed‑forward sizing for non‑SwiGLU activations and prevents initialization/runtime errors, improving stability across activation options. - Note: forward method signatures were simplified (removed variable kwargs), which changes the public forward API. - **Tests** - Adds deterministic masking seed and four loss/regression tests validating model behavior for pretrained and reinitialized variants. - **Chores** - Updates development container setup and install/run configuration for the project. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Peter St. John <pstjohn@nvidia.com>
1 parent 375a44b commit cd74c2b

1 file changed

Lines changed: 49 additions & 0 deletions

File tree

models/amplify/tests/test_amplify_model.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,3 +168,52 @@ def test_convert_state_dict():
168168
te_state_dict_keys.remove("decoder.bias")
169169

170170
assert len(te_state_dict_keys) == 0
171+
172+
173+
def test_hf_trained_model_loss(input_data):
174+
model = amp_hf.AMPLIFY.from_pretrained("chandar-lab/AMPLIFY_120M")
175+
model.to("cuda", dtype=torch.bfloat16)
176+
input_data = {k: v.to("cuda") for k, v in input_data.items()}
177+
model.eval()
178+
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
179+
output = model(**input_data)
180+
181+
torch.testing.assert_close(output.loss.detach().cpu(), torch.tensor(2.4), atol=1e-1, rtol=1e-2)
182+
183+
184+
def test_te_trained_model_loss(input_data):
185+
model_hf = amp_hf.AMPLIFY.from_pretrained("chandar-lab/AMPLIFY_120M")
186+
model = convert_amplify_hf_to_te(model_hf)
187+
model.to("cuda", dtype=torch.bfloat16)
188+
input_data = {k: v.to("cuda") for k, v in input_data.items()}
189+
model.eval()
190+
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
191+
output = model(**input_data)
192+
193+
torch.testing.assert_close(output.loss.detach().cpu(), torch.tensor(2.4), atol=1e-1, rtol=1e-2)
194+
195+
196+
def test_hf_reinitialized_model_loss(input_data):
197+
config = amp_hf.AMPLIFYConfig.from_pretrained("chandar-lab/AMPLIFY_120M")
198+
model = amp_hf.AMPLIFY(config)
199+
model.to("cuda", dtype=torch.bfloat16)
200+
input_data = {k: v.to("cuda") for k, v in input_data.items()}
201+
model.eval()
202+
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
203+
output = model(**input_data)
204+
205+
loss = output.loss.detach().cpu()
206+
assert loss < 3.5, f"Loss is {loss}, expected less than 3.5"
207+
208+
209+
def test_te_reinitialized_model_loss(input_data):
210+
config = amp_te.AMPLIFYConfig.from_pretrained("chandar-lab/AMPLIFY_120M")
211+
model = amp_te.AMPLIFYForMaskedLM(config)
212+
model.to("cuda", dtype=torch.bfloat16)
213+
input_data = {k: v.to("cuda") for k, v in input_data.items()}
214+
model.eval()
215+
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
216+
output = model(**input_data)
217+
218+
loss = output.loss.detach().cpu()
219+
assert loss < 3.5, f"Loss is {loss}, expected less than 3.5"

0 commit comments

Comments
 (0)