Skip to content

Commit 3e56b8a

Browse files
authored
Remove kwargs in amplify forward pass (#1141)
Having a **kwargs in model.forward leads to some odd complications with accelerate, where it sums rather than averages loss across parallel processes. Also does some other fixes in the amplify model since we'll need to push a new version to the HF hub <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * New Features * Data collator now supports a seed option for deterministic masking. * Refactor * Standardized dtype handling to a single dtype setting across embeddings, norms, and layers. * Ensured intermediate size is always defined when activation is not swiglu. * Simplified model forward APIs by removing unused keyword passthroughs. * Tests * Added loss verification tests for pretrained and reinitialized models across implementations. * Chores * Updated development container to use a prebuilt image, increased shared memory, and simplified dependency installation. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Peter St. John <pstjohn@nvidia.com>
1 parent 664a9b9 commit 3e56b8a

9 files changed

Lines changed: 68 additions & 27 deletions

File tree

models/amplify/.devcontainer/devcontainer.json

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,11 @@
22
// README at: https://github.com/devcontainers/templates/tree/main/src/docker-existing-dockerfile
33
{
44
"name": "Existing Dockerfile",
5-
"build": {
6-
"context": "..",
7-
"dockerfile": "Dockerfile.dev"
8-
},
5+
"image": "svcbionemo023/bionemo-framework:amplify-model-devcontainer-082025",
96
"mounts": [
107
"source=${localEnv:HOME}/.cache,target=/home/ubuntu/.cache,type=bind,consistency=cached"
118
],
12-
"postCreateCommand": "pip install -e .[convert,test]",
9+
"postCreateCommand": "PIP_CONSTRAINT= pip install -e .",
1310
"remoteUser": "ubuntu",
1411
"runArgs": [
1512
"--gpus=all",

models/amplify/Dockerfile

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@ FROM nvcr.io/nvidia/pytorch:25.01-py3
44
RUN MAX_JOBS=4 pip --disable-pip-version-check --no-cache-dir install -v git+https://github.com/facebookresearch/xformers.git@v0.0.29.post1#egg=xformers
55
RUN PIP_CONSTRAINT= NVTE_FRAMEWORK=pytorch MAX_JOBS=4 pip --disable-pip-version-check --no-cache-dir install -v git+https://github.com/nvidia/TransformerEngine.git@v2.4
66

7-
COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
87
WORKDIR /workspace/bionemo
98
COPY . .
10-
RUN --mount=type=cache,target=/root/.cache/uv \
9+
RUN --mount=type=cache,target=/root/.cache/pip \
1110
PIP_CONSTRAINT= pip install -e .

models/amplify/export.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
# Smoke test that the model can be loaded.
3737
model_te = AutoModelForMaskedLM.from_pretrained(
3838
f"./checkpoint_export/{tag}",
39-
torch_dtype=torch.bfloat16,
39+
dtype=torch.bfloat16,
4040
trust_remote_code=True,
4141
)
4242
del model_te

models/amplify/src/amplify/amplify_te.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -147,17 +147,15 @@ def __init__(self, config: AMPLIFYConfig, **kwargs):
147147
config.padded_vocab_size,
148148
config.hidden_size,
149149
padding_idx=config.pad_token_id,
150-
dtype=config.torch_dtype,
150+
dtype=config.dtype,
151151
)
152152

153153
if config.layer_norm_after_embedding:
154154
self.layer_norm_1 = (
155-
transformer_engine.pytorch.RMSNorm(
156-
config.hidden_size, config.norm_eps, params_dtype=config.torch_dtype
157-
)
155+
transformer_engine.pytorch.RMSNorm(config.hidden_size, config.norm_eps, params_dtype=config.dtype)
158156
if config.rms_norm
159157
else transformer_engine.pytorch.LayerNorm(
160-
config.hidden_size, config.norm_eps, params_dtype=config.torch_dtype
158+
config.hidden_size, config.norm_eps, params_dtype=config.dtype
161159
)
162160
)
163161

@@ -169,6 +167,9 @@ def __init__(self, config: AMPLIFYConfig, **kwargs):
169167
intermediate_size = int(2 * config.intermediate_size / 3)
170168
intermediate_size = multiple_of * ((intermediate_size + multiple_of - 1) // multiple_of)
171169

170+
else:
171+
intermediate_size = config.intermediate_size
172+
172173
self.transformer_encoder = nn.ModuleList()
173174
for layer_num in range(config.num_hidden_layers):
174175
self.transformer_encoder.append(
@@ -194,7 +195,7 @@ def __init__(self, config: AMPLIFYConfig, **kwargs):
194195
window_size=(-1, -1),
195196
rotary_pos_interleaved=True,
196197
seq_length=config.max_length,
197-
params_dtype=config.torch_dtype,
198+
params_dtype=config.dtype,
198199
)
199200
)
200201

@@ -212,7 +213,6 @@ def forward(
212213
output_hidden_states=False,
213214
output_attentions=False,
214215
labels=None,
215-
**kwargs,
216216
) -> BaseModelOutput:
217217
"""Forward pass of the AMPLIFY model.
218218
@@ -222,7 +222,6 @@ def forward(
222222
output_hidden_states (bool): Whether to output the hidden states.
223223
output_attentions (bool): Whether to output the attention weights.
224224
labels (torch.Tensor): The labels.
225-
**kwargs: Additional arguments.
226225
227226
Returns:
228227
BaseModelOutput: The output of the model.
@@ -277,7 +276,7 @@ def __init__(self, config: AMPLIFYConfig, **kwargs):
277276
config.hidden_size,
278277
config.padded_vocab_size,
279278
config.norm_eps,
280-
params_dtype=config.torch_dtype,
279+
params_dtype=config.dtype,
281280
normalization="RMSNorm" if config.rms_norm else "LayerNorm",
282281
init_method=lambda x: torch.nn.init.uniform_(
283282
x, -self.config.decoder_init_range, self.config.decoder_init_range
@@ -286,7 +285,7 @@ def __init__(self, config: AMPLIFYConfig, **kwargs):
286285

287286
else:
288287
self.decoder = transformer_engine.pytorch.Linear(
289-
config.hidden_size, config.vocab_size, params_dtype=config.torch_dtype
288+
config.hidden_size, config.vocab_size, params_dtype=config.dtype
290289
)
291290

292291
def forward(
@@ -296,7 +295,6 @@ def forward(
296295
output_hidden_states=False,
297296
output_attentions=False,
298297
labels=None,
299-
**kwargs,
300298
) -> MaskedLMOutput:
301299
"""Forward pass of the AMPLIFYForMaskedLM model.
302300
@@ -306,7 +304,6 @@ def forward(
306304
output_hidden_states (bool): Whether to output the hidden states.
307305
output_attentions (bool): Whether to output the attention weights.
308306
labels (torch.Tensor): The labels.
309-
**kwargs: Additional arguments.
310307
311308
Returns:
312309
MaskedLMOutput: The output of the model.
@@ -317,7 +314,6 @@ def forward(
317314
output_hidden_states,
318315
output_attentions,
319316
labels,
320-
**kwargs,
321317
)
322318

323319
# Classification head with layer norm

models/amplify/src/amplify/state_dict_convert.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def convert_amplify_hf_to_te(model_hf: nn.Module, **config_kwargs) -> nn.Module:
4646
"""
4747
te_config = AMPLIFYConfig(**model_hf.config.to_dict(), **config_kwargs)
4848
with init_empty_weights():
49-
model_te = AMPLIFYForMaskedLM(te_config, torch_dtype=te_config.torch_dtype)
49+
model_te = AMPLIFYForMaskedLM(te_config, dtype=te_config.dtype)
5050

5151
output_model = io.apply_transforms(
5252
model_hf,

models/amplify/tests/conftest.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def tokenizer():
3636
@pytest.fixture
3737
def config():
3838
config = AutoConfig.from_pretrained("chandar-lab/AMPLIFY_120M", trust_remote_code=True)
39-
config.torch_dtype = torch.bfloat16
39+
config.dtype = torch.bfloat16
4040
return config
4141

4242

@@ -68,6 +68,7 @@ def input_data(tokenizer):
6868
tokenizer=tokenizer,
6969
mlm_probability=0.15,
7070
pad_to_multiple_of=1024,
71+
seed=42,
7172
)
7273

7374
def tokenize_function(examples):

models/amplify/tests/test_amplify_model.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,3 +168,52 @@ def test_convert_state_dict():
168168
te_state_dict_keys.remove("decoder.bias")
169169

170170
assert len(te_state_dict_keys) == 0
171+
172+
173+
def test_hf_trained_model_loss(input_data):
174+
model = amp_hf.AMPLIFY.from_pretrained("chandar-lab/AMPLIFY_120M")
175+
model.to("cuda", dtype=torch.bfloat16)
176+
input_data = {k: v.to("cuda") for k, v in input_data.items()}
177+
model.eval()
178+
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
179+
output = model(**input_data)
180+
181+
torch.testing.assert_close(output.loss.detach().cpu(), torch.tensor(2.4), atol=1e-1, rtol=1e-2)
182+
183+
184+
def test_te_trained_model_loss(input_data):
185+
model_hf = amp_hf.AMPLIFY.from_pretrained("chandar-lab/AMPLIFY_120M")
186+
model = convert_amplify_hf_to_te(model_hf)
187+
model.to("cuda", dtype=torch.bfloat16)
188+
input_data = {k: v.to("cuda") for k, v in input_data.items()}
189+
model.eval()
190+
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
191+
output = model(**input_data)
192+
193+
torch.testing.assert_close(output.loss.detach().cpu(), torch.tensor(2.4), atol=1e-1, rtol=1e-2)
194+
195+
196+
def test_hf_reinitialized_model_loss(input_data):
197+
config = amp_hf.AMPLIFYConfig.from_pretrained("chandar-lab/AMPLIFY_120M")
198+
model = amp_hf.AMPLIFY(config)
199+
model.to("cuda", dtype=torch.bfloat16)
200+
input_data = {k: v.to("cuda") for k, v in input_data.items()}
201+
model.eval()
202+
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
203+
output = model(**input_data)
204+
205+
loss = output.loss.detach().cpu()
206+
assert loss < 3.5, f"Loss is {loss}, expected less than 3.5"
207+
208+
209+
def test_te_reinitialized_model_loss(input_data):
210+
config = amp_te.AMPLIFYConfig.from_pretrained("chandar-lab/AMPLIFY_120M")
211+
model = amp_te.AMPLIFYForMaskedLM(config)
212+
model.to("cuda", dtype=torch.bfloat16)
213+
input_data = {k: v.to("cuda") for k, v in input_data.items()}
214+
model.eval()
215+
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
216+
output = model(**input_data)
217+
218+
loss = output.loss.detach().cpu()
219+
assert loss < 3.5, f"Loss is {loss}, expected less than 3.5"

models/amplify/tests/test_encoder_block.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def data(self) -> torch.Tensor:
5757
@pytest.fixture
5858
def config():
5959
config = AutoConfig.from_pretrained("chandar-lab/AMPLIFY_120M", trust_remote_code=True)
60-
config.torch_dtype = torch.bfloat16
60+
config.dtype = torch.bfloat16
6161
return config
6262

6363

@@ -169,7 +169,7 @@ def test_encoder_block_forward(inputs, config):
169169
window_size=(-1, -1),
170170
rotary_pos_interleaved=True,
171171
seq_length=config.max_length,
172-
params_dtype=config.torch_dtype,
172+
params_dtype=config.dtype,
173173
).to("cuda", dtype=torch.bfloat16)
174174

175175
state_dict_mapping = {

models/esm2/Dockerfile

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
FROM nvcr.io/nvidia/pytorch:25.06-py3
2-
COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
32
WORKDIR /workspace/bionemo
43
COPY . .
5-
RUN --mount=type=cache,target=/root/.cache/uv \
4+
RUN --mount=type=cache,target=/root/.cache/pip \
65
PIP_CONSTRAINT= pip install -e .

0 commit comments

Comments
 (0)