Skip to content

Commit f12b8eb

Browse files
committed
format
1 parent 67aa0e7 commit f12b8eb

16 files changed

Lines changed: 550 additions & 230 deletions

examples/run_qwen3.5_27b_dflash_online.sh

100644100755
File mode changed.

scripts/launch_target_server.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
import sys
2727
import threading
2828

29-
import torch
3029
import torch.distributed as dist
3130

3231
# Add parent dir so 'specforge' is importable when running as a standalone script.
@@ -125,8 +124,11 @@ def main():
125124
# The SGLang target model wrapper calls specforge.distributed.get_tp_group()
126125
# which requires init_process_group. Always run via torchrun.
127126
from specforge.distributed import init_distributed
127+
128128
init_distributed(timeout=30, tp_size=args.tp_size)
129-
logger.info("Distributed initialized (tp_size=%d, rank=%d)", args.tp_size, dist.get_rank())
129+
logger.info(
130+
"Distributed initialized (tp_size=%d, rank=%d)", args.tp_size, dist.get_rank()
131+
)
130132

131133
# ---- Load model via TargetModelServer ----
132134
from specforge.modeling.target.remote_target_server import (
@@ -162,7 +164,12 @@ def main():
162164
if rank == 0:
163165
# ---- Start HTTP server (rank 0 only) ----
164166
httpd = create_http_server(server_app, args.host, args.port)
165-
logger.info("Target model server listening on %s:%d (mode=%s)", args.host, args.port, args.mode)
167+
logger.info(
168+
"Target model server listening on %s:%d (mode=%s)",
169+
args.host,
170+
args.port,
171+
args.mode,
172+
)
166173

167174
def shutdown(signum, frame):
168175
logger.info("Received signal %d, shutting down...", signum)
@@ -182,7 +189,9 @@ def shutdown(signum, frame):
182189
server_app.close()
183190
# Signal worker ranks to exit
184191
if dist.get_world_size() > 1:
185-
from specforge.modeling.target.remote_target_server import _SENTINEL_EXIT
192+
from specforge.modeling.target.remote_target_server import (
193+
_SENTINEL_EXIT,
194+
)
186195

187196
dist.broadcast_object_list([_SENTINEL_EXIT], src=0)
188197
logger.info("Server stopped.")

scripts/train_dflash.py

Lines changed: 40 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,9 @@
2828
from specforge.distributed import destroy_distributed, get_dp_group, init_distributed
2929
from specforge.modeling.draft.dflash import DFlashDraftModel
3030
from specforge.modeling.target.dflash_target_model import (
31+
VLM_MODEL_TYPES,
3132
DFlashTargetModel,
3233
HFDFlashTargetModel,
33-
VLM_MODEL_TYPES,
3434
get_dflash_target_model,
3535
)
3636
from specforge.modeling.target.target_utils import TargetEmbeddingsAndHead
@@ -82,9 +82,9 @@ def _build_target_layer_ids(
8282

8383
if layer_types is not None:
8484
eligible = [
85-
i for i, lt in enumerate(layer_types)
86-
if lt in ("full_attention", "attention")
87-
and start_layer <= i <= end_layer
85+
i
86+
for i, lt in enumerate(layer_types)
87+
if lt in ("full_attention", "attention") and start_layer <= i <= end_layer
8888
]
8989
if len(eligible) == 0:
9090
raise ValueError(
@@ -128,9 +128,10 @@ def _resolve_draft_config(target_config):
128128
):
129129
if not hasattr(target_config, attr_name):
130130
continue
131-
if not hasattr(draft_config, attr_name) or getattr(
132-
draft_config, attr_name
133-
) is None:
131+
if (
132+
not hasattr(draft_config, attr_name)
133+
or getattr(draft_config, attr_name) is None
134+
):
134135
setattr(draft_config, attr_name, getattr(target_config, attr_name))
135136
return draft_config
136137
return copy.deepcopy(target_config)
@@ -156,9 +157,11 @@ def _ensure_layer_types(draft_config) -> None:
156157
if max_window_layers is None:
157158
max_window_layers = num_hidden_layers
158159
draft_config.layer_types = [
159-
"sliding_attention"
160-
if sliding_window is not None and layer_idx >= max_window_layers
161-
else "full_attention"
160+
(
161+
"sliding_attention"
162+
if sliding_window is not None and layer_idx >= max_window_layers
163+
else "full_attention"
164+
)
162165
for layer_idx in range(num_hidden_layers)
163166
]
164167

@@ -259,7 +262,9 @@ def parse_args():
259262

260263
training_group = parser.add_argument_group("training")
261264
training_group.add_argument("--num-epochs", type=int, default=6)
262-
training_group.add_argument("--max-num-steps", type=int, default=None, help="Max steps (None = full epoch)")
265+
training_group.add_argument(
266+
"--max-num-steps", type=int, default=None, help="Max steps (None = full epoch)"
267+
)
263268
training_group.add_argument("--batch-size", type=int, default=1)
264269
training_group.add_argument("--learning-rate", type=float, default=6e-4)
265270
training_group.add_argument("--max-length", type=int, default=3072)
@@ -301,10 +306,8 @@ def parse_args():
301306
return parser.parse_args()
302307

303308

304-
from specforge.utils import (
305-
maybe_fetch_remote_config as _maybe_fetch_remote_config,
306-
resolve_local_model_path as _resolve_local_model_path,
307-
)
309+
from specforge.utils import maybe_fetch_remote_config as _maybe_fetch_remote_config
310+
from specforge.utils import resolve_local_model_path as _resolve_local_model_path
308311

309312

310313
def build_models(
@@ -670,13 +673,14 @@ def main():
670673
if args.target_prefetch_depth < 0:
671674
raise ValueError("--target-prefetch-depth must be non-negative.")
672675
if args.target_prefetch_depth > 0 and args.target_model_backend != "remote":
673-
raise ValueError("--target-prefetch-depth is only supported with --target-model-backend remote.")
676+
raise ValueError(
677+
"--target-prefetch-depth is only supported with --target-model-backend remote."
678+
)
674679
set_seed(args.seed)
675680

676681
init_distributed(timeout=args.dist_timeout, tp_size=args.tp_size)
677682
print_with_rank("Initialized distributed")
678683

679-
680684
# Fetch remote config early so _server_model_path is cached before any
681685
# _resolve_local_model_path() call.
682686
_maybe_fetch_remote_config(args)
@@ -790,9 +794,15 @@ def main():
790794
print_on_rank0(f"Total training steps: {total_steps}")
791795

792796
print_on_rank0("Loading target embeddings and head...")
793-
resolved_embed_key, resolved_lm_head_key = _resolve_target_weight_keys(target_config)
794-
embed_key = args.embedding_key if args.embedding_key is not None else resolved_embed_key
795-
lm_head_key = args.lm_head_key if args.lm_head_key is not None else resolved_lm_head_key
797+
resolved_embed_key, resolved_lm_head_key = _resolve_target_weight_keys(
798+
target_config
799+
)
800+
embed_key = (
801+
args.embedding_key if args.embedding_key is not None else resolved_embed_key
802+
)
803+
lm_head_key = (
804+
args.lm_head_key if args.lm_head_key is not None else resolved_lm_head_key
805+
)
796806
print_on_rank0(
797807
f"Loading target embeddings/head with keys: embed='{embed_key}', head='{lm_head_key}'"
798808
)
@@ -870,7 +880,9 @@ def next_prefetch_boundary() -> int | float:
870880
boundary = min(boundary, args.max_num_steps)
871881
return boundary
872882

873-
def train_one_dflash_batch(epoch: int, progress_bar, data: dict, target_output=None) -> bool:
883+
def train_one_dflash_batch(
884+
epoch: int, progress_bar, data: dict, target_output=None
885+
) -> bool:
874886
nonlocal global_step, last_time
875887
global_step += 1
876888
if target_output is None:
@@ -959,14 +971,19 @@ def fill_prefetch_queue(pending_current: int) -> None:
959971
not data_exhausted
960972
and len(prefetch_queue) < args.target_prefetch_depth
961973
):
962-
assigned_step = global_step + pending_current + len(prefetch_queue) + 1
974+
assigned_step = (
975+
global_step + pending_current + len(prefetch_queue) + 1
976+
)
963977
if assigned_step > boundary:
964978
return
965979
batch = next_data()
966980
if batch is None:
967981
return
968982
prefetch_queue.append(
969-
(batch, submit_dflash_target_async(target_model, batch, is_vlm))
983+
(
984+
batch,
985+
submit_dflash_target_async(target_model, batch, is_vlm),
986+
)
970987
)
971988

972989
fill_prefetch_queue(pending_current=0)

0 commit comments

Comments
 (0)