2828from specforge .distributed import destroy_distributed , get_dp_group , init_distributed
2929from specforge .modeling .draft .dflash import DFlashDraftModel
3030from specforge .modeling .target .dflash_target_model import (
31+ VLM_MODEL_TYPES ,
3132 DFlashTargetModel ,
3233 HFDFlashTargetModel ,
33- VLM_MODEL_TYPES ,
3434 get_dflash_target_model ,
3535)
3636from specforge .modeling .target .target_utils import TargetEmbeddingsAndHead
@@ -82,9 +82,9 @@ def _build_target_layer_ids(
8282
8383 if layer_types is not None :
8484 eligible = [
85- i for i , lt in enumerate ( layer_types )
86- if lt in ( "full_attention" , "attention" )
87- and start_layer <= i <= end_layer
85+ i
86+ for i , lt in enumerate ( layer_types )
87+ if lt in ( "full_attention" , "attention" ) and start_layer <= i <= end_layer
8888 ]
8989 if len (eligible ) == 0 :
9090 raise ValueError (
@@ -128,9 +128,10 @@ def _resolve_draft_config(target_config):
128128 ):
129129 if not hasattr (target_config , attr_name ):
130130 continue
131- if not hasattr (draft_config , attr_name ) or getattr (
132- draft_config , attr_name
133- ) is None :
131+ if (
132+ not hasattr (draft_config , attr_name )
133+ or getattr (draft_config , attr_name ) is None
134+ ):
134135 setattr (draft_config , attr_name , getattr (target_config , attr_name ))
135136 return draft_config
136137 return copy .deepcopy (target_config )
@@ -156,9 +157,11 @@ def _ensure_layer_types(draft_config) -> None:
156157 if max_window_layers is None :
157158 max_window_layers = num_hidden_layers
158159 draft_config .layer_types = [
159- "sliding_attention"
160- if sliding_window is not None and layer_idx >= max_window_layers
161- else "full_attention"
160+ (
161+ "sliding_attention"
162+ if sliding_window is not None and layer_idx >= max_window_layers
163+ else "full_attention"
164+ )
162165 for layer_idx in range (num_hidden_layers )
163166 ]
164167
@@ -259,7 +262,9 @@ def parse_args():
259262
260263 training_group = parser .add_argument_group ("training" )
261264 training_group .add_argument ("--num-epochs" , type = int , default = 6 )
262- training_group .add_argument ("--max-num-steps" , type = int , default = None , help = "Max steps (None = full epoch)" )
265+ training_group .add_argument (
266+ "--max-num-steps" , type = int , default = None , help = "Max steps (None = full epoch)"
267+ )
263268 training_group .add_argument ("--batch-size" , type = int , default = 1 )
264269 training_group .add_argument ("--learning-rate" , type = float , default = 6e-4 )
265270 training_group .add_argument ("--max-length" , type = int , default = 3072 )
@@ -301,10 +306,8 @@ def parse_args():
301306 return parser .parse_args ()
302307
303308
304- from specforge .utils import (
305- maybe_fetch_remote_config as _maybe_fetch_remote_config ,
306- resolve_local_model_path as _resolve_local_model_path ,
307- )
309+ from specforge .utils import maybe_fetch_remote_config as _maybe_fetch_remote_config
310+ from specforge .utils import resolve_local_model_path as _resolve_local_model_path
308311
309312
310313def build_models (
@@ -670,13 +673,14 @@ def main():
670673 if args .target_prefetch_depth < 0 :
671674 raise ValueError ("--target-prefetch-depth must be non-negative." )
672675 if args .target_prefetch_depth > 0 and args .target_model_backend != "remote" :
673- raise ValueError ("--target-prefetch-depth is only supported with --target-model-backend remote." )
676+ raise ValueError (
677+ "--target-prefetch-depth is only supported with --target-model-backend remote."
678+ )
674679 set_seed (args .seed )
675680
676681 init_distributed (timeout = args .dist_timeout , tp_size = args .tp_size )
677682 print_with_rank ("Initialized distributed" )
678683
679-
680684 # Fetch remote config early so _server_model_path is cached before any
681685 # _resolve_local_model_path() call.
682686 _maybe_fetch_remote_config (args )
@@ -790,9 +794,15 @@ def main():
790794 print_on_rank0 (f"Total training steps: { total_steps } " )
791795
792796 print_on_rank0 ("Loading target embeddings and head..." )
793- resolved_embed_key , resolved_lm_head_key = _resolve_target_weight_keys (target_config )
794- embed_key = args .embedding_key if args .embedding_key is not None else resolved_embed_key
795- lm_head_key = args .lm_head_key if args .lm_head_key is not None else resolved_lm_head_key
797+ resolved_embed_key , resolved_lm_head_key = _resolve_target_weight_keys (
798+ target_config
799+ )
800+ embed_key = (
801+ args .embedding_key if args .embedding_key is not None else resolved_embed_key
802+ )
803+ lm_head_key = (
804+ args .lm_head_key if args .lm_head_key is not None else resolved_lm_head_key
805+ )
796806 print_on_rank0 (
797807 f"Loading target embeddings/head with keys: embed='{ embed_key } ', head='{ lm_head_key } '"
798808 )
@@ -870,7 +880,9 @@ def next_prefetch_boundary() -> int | float:
870880 boundary = min (boundary , args .max_num_steps )
871881 return boundary
872882
873- def train_one_dflash_batch (epoch : int , progress_bar , data : dict , target_output = None ) -> bool :
883+ def train_one_dflash_batch (
884+ epoch : int , progress_bar , data : dict , target_output = None
885+ ) -> bool :
874886 nonlocal global_step , last_time
875887 global_step += 1
876888 if target_output is None :
@@ -959,14 +971,19 @@ def fill_prefetch_queue(pending_current: int) -> None:
959971 not data_exhausted
960972 and len (prefetch_queue ) < args .target_prefetch_depth
961973 ):
962- assigned_step = global_step + pending_current + len (prefetch_queue ) + 1
974+ assigned_step = (
975+ global_step + pending_current + len (prefetch_queue ) + 1
976+ )
963977 if assigned_step > boundary :
964978 return
965979 batch = next_data ()
966980 if batch is None :
967981 return
968982 prefetch_queue .append (
969- (batch , submit_dflash_target_async (target_model , batch , is_vlm ))
983+ (
984+ batch ,
985+ submit_dflash_target_async (target_model , batch , is_vlm ),
986+ )
970987 )
971988
972989 fill_prefetch_queue (pending_current = 0 )
0 commit comments