Skip to content

Commit ce5630f

Browse files
committed
rename torch_dtype to dtype
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
1 parent 4d327de commit ce5630f

5 files changed

Lines changed: 11 additions & 13 deletions

File tree

models/amplify/export.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
# Smoke test that the model can be loaded.
3737
model_te = AutoModelForMaskedLM.from_pretrained(
3838
f"./checkpoint_export/{tag}",
39-
torch_dtype=torch.bfloat16,
39+
dtype=torch.bfloat16,
4040
trust_remote_code=True,
4141
)
4242
del model_te

models/amplify/src/amplify/amplify_te.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -147,17 +147,15 @@ def __init__(self, config: AMPLIFYConfig, **kwargs):
147147
config.padded_vocab_size,
148148
config.hidden_size,
149149
padding_idx=config.pad_token_id,
150-
dtype=config.torch_dtype,
150+
dtype=config.dtype,
151151
)
152152

153153
if config.layer_norm_after_embedding:
154154
self.layer_norm_1 = (
155-
transformer_engine.pytorch.RMSNorm(
156-
config.hidden_size, config.norm_eps, params_dtype=config.torch_dtype
157-
)
155+
transformer_engine.pytorch.RMSNorm(config.hidden_size, config.norm_eps, params_dtype=config.dtype)
158156
if config.rms_norm
159157
else transformer_engine.pytorch.LayerNorm(
160-
config.hidden_size, config.norm_eps, params_dtype=config.torch_dtype
158+
config.hidden_size, config.norm_eps, params_dtype=config.dtype
161159
)
162160
)
163161

@@ -197,7 +195,7 @@ def __init__(self, config: AMPLIFYConfig, **kwargs):
197195
window_size=(-1, -1),
198196
rotary_pos_interleaved=True,
199197
seq_length=config.max_length,
200-
params_dtype=config.torch_dtype,
198+
params_dtype=config.dtype,
201199
)
202200
)
203201

@@ -278,7 +276,7 @@ def __init__(self, config: AMPLIFYConfig, **kwargs):
278276
config.hidden_size,
279277
config.padded_vocab_size,
280278
config.norm_eps,
281-
params_dtype=config.torch_dtype,
279+
params_dtype=config.dtype,
282280
normalization="RMSNorm" if config.rms_norm else "LayerNorm",
283281
init_method=lambda x: torch.nn.init.uniform_(
284282
x, -self.config.decoder_init_range, self.config.decoder_init_range
@@ -287,7 +285,7 @@ def __init__(self, config: AMPLIFYConfig, **kwargs):
287285

288286
else:
289287
self.decoder = transformer_engine.pytorch.Linear(
290-
config.hidden_size, config.vocab_size, params_dtype=config.torch_dtype
288+
config.hidden_size, config.vocab_size, params_dtype=config.dtype
291289
)
292290

293291
def forward(

models/amplify/src/amplify/state_dict_convert.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def convert_amplify_hf_to_te(model_hf: nn.Module, **config_kwargs) -> nn.Module:
4646
"""
4747
te_config = AMPLIFYConfig(**model_hf.config.to_dict(), **config_kwargs)
4848
with init_empty_weights():
49-
model_te = AMPLIFYForMaskedLM(te_config, torch_dtype=te_config.torch_dtype)
49+
model_te = AMPLIFYForMaskedLM(te_config, dtype=te_config.dtype)
5050

5151
output_model = io.apply_transforms(
5252
model_hf,

models/amplify/tests/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def tokenizer():
3636
@pytest.fixture
3737
def config():
3838
config = AutoConfig.from_pretrained("chandar-lab/AMPLIFY_120M", trust_remote_code=True)
39-
config.torch_dtype = torch.bfloat16
39+
config.dtype = torch.bfloat16
4040
return config
4141

4242

models/amplify/tests/test_encoder_block.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def data(self) -> torch.Tensor:
5757
@pytest.fixture
5858
def config():
5959
config = AutoConfig.from_pretrained("chandar-lab/AMPLIFY_120M", trust_remote_code=True)
60-
config.torch_dtype = torch.bfloat16
60+
config.dtype = torch.bfloat16
6161
return config
6262

6363

@@ -169,7 +169,7 @@ def test_encoder_block_forward(inputs, config):
169169
window_size=(-1, -1),
170170
rotary_pos_interleaved=True,
171171
seq_length=config.max_length,
172-
params_dtype=config.torch_dtype,
172+
params_dtype=config.dtype,
173173
).to("cuda", dtype=torch.bfloat16)
174174

175175
state_dict_mapping = {

0 commit comments

Comments
 (0)