Skip to content

Commit ca2fc85

Browse files
authored
Fix FP32→BF16 redundant allocation during model init (#228)
* Fix FP32→BF16 redundant allocation during model init * Improve comments in from_config method
1 parent d562e50 commit ca2fc85

2 files changed

Lines changed: 12 additions & 6 deletions

File tree

scripts/train_eagle3_online.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -293,18 +293,19 @@ def main():
293293
if draft_model_last_checkpoint:
294294
draft_model = (
295295
AutoEagle3DraftModel.from_pretrained(
296-
draft_model_last_checkpoint, attention_backend=args.attention_backend
296+
draft_model_last_checkpoint, attention_backend=args.attention_backend,
297+
torch_dtype=torch.bfloat16
297298
)
298299
.cuda()
299-
.to(torch.bfloat16)
300+
300301
)
301302
else:
302303
draft_model = (
303304
AutoEagle3DraftModel.from_config(
304-
draft_model_config, attention_backend=args.attention_backend
305+
draft_model_config, attention_backend=args.attention_backend,
306+
torch_dtype=torch.bfloat16
305307
)
306308
.cuda()
307-
.to(torch.bfloat16)
308309
)
309310
draft_model.load_embedding(args.target_model_path, embedding_key=args.embedding_key)
310311
draft_model.freeze_embedding()

specforge/modeling/auto.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ class AutoEagle3DraftModel(AutoModelForCausalLMBase):
3939
}
4040

4141
@classmethod
42-
def from_config(cls, config: PretrainedConfig, **config_kwargs):
42+
def from_config(cls, config: PretrainedConfig, torch_dtype=None, **config_kwargs):
4343
"""
4444
This class method takes a configuration object and create its model based on the
4545
_model_mapping class variable.
@@ -52,7 +52,12 @@ def from_config(cls, config: PretrainedConfig, **config_kwargs):
5252
"""
5353
# get the model class from the
5454
_model_cls = cls._model_mapping[type(config)]
55-
return _model_cls(config, **config_kwargs)
55+
model = _model_cls(config, **config_kwargs)
56+
57+
# Convert model to specified dtype if provided
58+
if torch_dtype is not None:
59+
model = model.to(dtype=torch_dtype)
60+
return model
5661

5762
@classmethod
5863
def from_pretrained(

0 commit comments

Comments
 (0)