3636from tqdm import tqdm
3737from transformers import AutoConfig , AutoProcessor , AutoTokenizer
3838
39+ from specforge .args import SGLangBackendArgs , TrackerArgs
3940from specforge .data import build_eagle3_dataset , prepare_dp_dataloaders
4041from specforge .distributed import (
4142 destroy_distributed ,
@@ -100,6 +101,7 @@ def parse_args():
100101 default = 2000 ,
101102 help = "Number of files per subdirectory." ,
102103 )
104+ SGLangBackendArgs .add_args (parser )
103105 return parser .parse_args ()
104106
105107
@@ -119,20 +121,29 @@ def build_target_model(
119121 target_model = (
120122 Qwen2_5_VLForConditionalGeneration .from_pretrained (
121123 pretrained_model_name_or_path = args .target_model_path ,
122- torch_dtype = torch .bfloat16 ,
124+ torch_dtype = (
125+ model_config .dtype
126+ if hasattr (model_config , "dtype" )
127+ else model_config .torch_dtype
128+ ),
123129 )
124130 .eval ()
125131 .cuda ()
126132 )
127133 else :
134+ target_model_kwargs = SGLangBackendArgs .from_args (args ).to_kwargs ()
128135 target_model = get_eagle3_target_model (
129136 pretrained_model_name_or_path = args .target_model_path ,
130137 backend = "sglang" , # we set this as the default backend to minimize precision mismatch in training and serving
131- torch_dtype = torch .bfloat16 ,
138+ torch_dtype = (
139+ model_config .dtype
140+ if hasattr (model_config , "dtype" )
141+ else model_config .torch_dtype
142+ ),
132143 device = "cuda" ,
133144 cache_dir = args .cache_dir ,
145+ ** target_model_kwargs ,
134146 )
135-
136147 # Set auxiliary hidden states layers if specified
137148 target_model .set_aux_hidden_states_layers (args .aux_hidden_states_layers )
138149
0 commit comments