From e4b363f8f5d75845e1161b532fd66e1bf6b82a37 Mon Sep 17 00:00:00 2001 From: Zijie Xia Date: Tue, 17 Mar 2026 22:11:24 -0700 Subject: [PATCH 1/9] feat: add EAGLE3 support and new template for step3.5 --- configs/step-3.5-flash-eagle3.json | 31 ++++++++++ patches/sglang_step3p5_eagle3.patch | 91 +++++++++++++++++++++++++++++ specforge/data/template.py | 10 ++++ 3 files changed, 132 insertions(+) create mode 100644 configs/step-3.5-flash-eagle3.json create mode 100644 patches/sglang_step3p5_eagle3.patch diff --git a/configs/step-3.5-flash-eagle3.json b/configs/step-3.5-flash-eagle3.json new file mode 100644 index 00000000..66008288 --- /dev/null +++ b/configs/step-3.5-flash-eagle3.json @@ -0,0 +1,31 @@ +{ + "architectures": [ + "LlamaForCausalLMEagle3" + ], + "eagle_config": { + "eagle_aux_hidden_state_layer_ids": [ + 1, + 22, + 44 + ], + "use_aux_hidden_state": true + }, + "hidden_size": 4096, + "num_attention_heads": 64, + "num_key_value_heads": 8, + "intermediate_size": 11264, + "hidden_act": "silu", + "max_position_embeddings": 262144, + "vocab_size": 128896, + "draft_vocab_size": 32000, + "num_hidden_layers": 1, + "bos_token_id": 0, + "eos_token_id": 1, + "pad_token_id": 2, + "rms_norm_eps": 1e-05, + "tie_word_embeddings": false, + "use_cache": true, + "model_type": "llama", + "torch_dtype": "bfloat16", + "transformers_version": "4.51.0" +} diff --git a/patches/sglang_step3p5_eagle3.patch b/patches/sglang_step3p5_eagle3.patch new file mode 100644 index 00000000..851046a5 --- /dev/null +++ b/patches/sglang_step3p5_eagle3.patch @@ -0,0 +1,91 @@ +diff --git a/python/sglang/srt/models/step3p5.py b/python/sglang/srt/models/step3p5.py +index b3f82b916..969b5252a 100644 +--- a/python/sglang/srt/models/step3p5.py ++++ b/python/sglang/srt/models/step3p5.py +@@ -708,6 +708,9 @@ class Step3p5Model(nn.Module): + else: + self.norm = PPMissingLayer(return_tuple=True) + ++ # For EAGLE3 support ++ self.layers_to_capture = [] ++ + def get_input_embedding(self, input_ids: torch.Tensor) -> torch.Tensor: + if hasattr(self.config, "scale_emb"): + return self.get_input_embeddings()(input_ids) * self.config.scale_emb +@@ -736,7 +739,12 @@ class Step3p5Model(nn.Module): + hidden_states = pp_proxy_tensors["hidden_states"] + residual = pp_proxy_tensors["residual"] + ++ aux_hidden_states = [] + for i in range(self.start_layer, self.end_layer): ++ if i in self.layers_to_capture: ++ aux_hidden_states.append( ++ hidden_states + residual if residual is not None else hidden_states ++ ) + layer = self.layers[i] + hidden_states, residual = layer( + positions, +@@ -771,6 +779,8 @@ class Step3p5Model(nn.Module): + hidden_states = self.norm(hidden_states) + else: + hidden_states, _ = self.norm(hidden_states, residual) ++ if len(aux_hidden_states) > 0: ++ return hidden_states, hidden_states_before_norm, aux_hidden_states + return hidden_states, hidden_states_before_norm + + +@@ -843,6 +853,9 @@ class Step3p5ForCausalLM(nn.Module): + + self.logits_processor = LogitsProcessor(config) + ++ # For EAGLE3 support ++ self.capture_aux_hidden_states = False ++ + def get_input_embeddings(self) -> nn.Embedding: + return self.model.get_input_embeddings() + +@@ -855,13 +868,18 @@ class Step3p5ForCausalLM(nn.Module): + input_embeds: torch.Tensor = None, + pp_proxy_tensors: Optional[PPProxyTensors] = None, + ) -> torch.Tensor: +- hidden_states, hidden_states_before_norm = self.model( ++ model_out = self.model( + input_ids, + positions, + forward_batch, + input_embeds, + pp_proxy_tensors=pp_proxy_tensors, + ) ++ aux_hidden_states = None ++ if self.capture_aux_hidden_states and isinstance(model_out, tuple) and len(model_out) == 3: ++ hidden_states, hidden_states_before_norm, aux_hidden_states = model_out ++ else: ++ hidden_states, hidden_states_before_norm = model_out + + if self.pp_group.is_last_rank: + return self.logits_processor( +@@ -869,6 +887,7 @@ class Step3p5ForCausalLM(nn.Module): + hidden_states, + self.lm_head, + forward_batch, ++ aux_hidden_states, + hidden_states_before_norm=hidden_states_before_norm, + ) + else: +@@ -882,6 +901,16 @@ class Step3p5ForCausalLM(nn.Module): + def end_layer(self): + return self.model.end_layer + ++ def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None) -> None: ++ if not self.pp_group.is_last_rank: ++ return ++ self.capture_aux_hidden_states = True ++ if layer_ids is None: ++ num_layers = self.config.num_hidden_layers ++ self.model.layers_to_capture = [2, num_layers // 2, num_layers - 3] ++ else: ++ self.model.layers_to_capture = [val + 1 for val in layer_ids] ++ + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=False): + # NOTE: + # Step3p5 HF checkpoints (e.g. MTP/nextn variants) may include an extra diff --git a/specforge/data/template.py b/specforge/data/template.py index bda8812f..80365d71 100644 --- a/specforge/data/template.py +++ b/specforge/data/template.py @@ -127,6 +127,16 @@ def get_all_template_names(self) -> List[str]: ), ) +TEMPLATE_REGISTRY.register( + name="step3.5", + template=ChatTemplate( + assistant_header="<|im_start|>assistant\n", + user_header="<|im_start|>user\n", + system_prompt="You are a helpful assistant.", + end_of_turn_token="<|im_end|>\n", + ), +) + TEMPLATE_REGISTRY.register( name="phi3", template=ChatTemplate( From 581cdfd820da0acb7c350ee4146d80e74c1a4890 Mon Sep 17 00:00:00 2001 From: Zijie Xia Date: Wed, 18 Mar 2026 21:52:59 -0700 Subject: [PATCH 2/9] add training script --- configs/step-3.5-flash-eagle3.json | 8 -- examples/run_step3p5_flash_eagle3_online.sh | 34 ++++++ patches/sglang_step3p5_eagle3.patch | 121 ++++++++++++++++---- 3 files changed, 130 insertions(+), 33 deletions(-) create mode 100644 examples/run_step3p5_flash_eagle3_online.sh diff --git a/configs/step-3.5-flash-eagle3.json b/configs/step-3.5-flash-eagle3.json index 66008288..72a0e9ee 100644 --- a/configs/step-3.5-flash-eagle3.json +++ b/configs/step-3.5-flash-eagle3.json @@ -2,14 +2,6 @@ "architectures": [ "LlamaForCausalLMEagle3" ], - "eagle_config": { - "eagle_aux_hidden_state_layer_ids": [ - 1, - 22, - 44 - ], - "use_aux_hidden_state": true - }, "hidden_size": 4096, "num_attention_heads": 64, "num_key_value_heads": 8, diff --git a/examples/run_step3p5_flash_eagle3_online.sh b/examples/run_step3p5_flash_eagle3_online.sh new file mode 100644 index 00000000..3dc34bef --- /dev/null +++ b/examples/run_step3p5_flash_eagle3_online.sh @@ -0,0 +1,34 @@ + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +ROOT_DIR=$(dirname $SCRIPT_DIR) + +# train eagle3 for step-3.5-flash +NUM_GPUS=${1:-4} +TP_SIZE=${2:-4} +BUILD_DATASET_NUM_PROC=${BUILD_DATASET_NUM_PROC:-64} + +# train eagle3 online +torchrun \ + --standalone \ + --nproc_per_node $NUM_GPUS \ + $ROOT_DIR/scripts/train_eagle3.py \ + --target-model-path stepfun-ai/Step-3.5-Flash \ + --draft-model-config configs/step-3.5-flash-eagle3.json \ + --train-data-path cache/dataset/perfectblend_train.jsonl \ + --build-dataset-num-proc $BUILD_DATASET_NUM_PROC \ + --output-dir $ROOT_DIR/outputs/step-3.5-flash-eagle3-perfectblend-online \ + --tp-size $TP_SIZE \ + --target-model-backend sglang \ + --trust-remote-code \ + --num-epochs 10 \ + --batch-size 1 \ + --learning-rate 1e-4 \ + --max-length 4096 \ + --chat-template step3.5 \ + --cache-dir $ROOT_DIR/cache \ + --embedding-key model.embed_tokens.weight \ + --dist-timeout 60 \ + --sglang-mem-fraction-static 0.8 \ + --report-to wandb \ + --wandb-project specforge-step3p5-flash \ + --wandb-name specforge-step3p5-flash-perfectblend diff --git a/patches/sglang_step3p5_eagle3.patch b/patches/sglang_step3p5_eagle3.patch index 851046a5..a37cd623 100644 --- a/patches/sglang_step3p5_eagle3.patch +++ b/patches/sglang_step3p5_eagle3.patch @@ -1,8 +1,27 @@ diff --git a/python/sglang/srt/models/step3p5.py b/python/sglang/srt/models/step3p5.py -index b3f82b916..969b5252a 100644 +index b3f82b916..eb91f0a9d 100644 --- a/python/sglang/srt/models/step3p5.py +++ b/python/sglang/srt/models/step3p5.py -@@ -708,6 +708,9 @@ class Step3p5Model(nn.Module): +@@ -1,6 +1,6 @@ + import logging + import os +-from typing import Any, Dict, Iterable, Optional, Tuple, Union ++from typing import Any, Dict, Iterable, Optional, Tuple, Union, List + + import torch + import torch.nn.functional as F +@@ -634,6 +634,10 @@ class Step3p5DecoderLayer(nn.Module): + ) + self._dump_tensor("attn_output", hidden_states, dump_step) + # Fully Connected ++ # NOTE: prepare_mlp is intentionally bypassed here — step3p5 uses a ++ # non-standard residual pattern where the attn output has already been ++ # summed with the residual in prepare_attn, so we perform the addition ++ # manually rather than going through the LayerCommunicator. + # hidden_states, residual = self.layer_communicator.prepare_mlp( + # hidden_states, + # residual, +@@ -708,6 +712,9 @@ class Step3p5Model(nn.Module): else: self.norm = PPMissingLayer(return_tuple=True) @@ -12,7 +31,7 @@ index b3f82b916..969b5252a 100644 def get_input_embedding(self, input_ids: torch.Tensor) -> torch.Tensor: if hasattr(self.config, "scale_emb"): return self.get_input_embeddings()(input_ids) * self.config.scale_emb -@@ -736,7 +739,12 @@ class Step3p5Model(nn.Module): +@@ -736,7 +743,12 @@ class Step3p5Model(nn.Module): hidden_states = pp_proxy_tensors["hidden_states"] residual = pp_proxy_tensors["residual"] @@ -25,16 +44,45 @@ index b3f82b916..969b5252a 100644 layer = self.layers[i] hidden_states, residual = layer( positions, -@@ -771,6 +779,8 @@ class Step3p5Model(nn.Module): - hidden_states = self.norm(hidden_states) - else: - hidden_states, _ = self.norm(hidden_states, residual) -+ if len(aux_hidden_states) > 0: -+ return hidden_states, hidden_states_before_norm, aux_hidden_states - return hidden_states, hidden_states_before_norm +@@ -752,26 +764,18 @@ class Step3p5Model(nn.Module): + "residual": residual, + } + ) +- else: +- hidden_states_before_norm = None +- if not self.pp_group.is_last_rank: +- return PPProxyTensors( +- { +- "hidden_states": hidden_states, +- "residual": residual, +- } +- ) ++ hidden_states_before_norm = None ++ if hidden_states.shape[0] > 0: ++ hidden_states_before_norm = ( ++ hidden_states if residual is None else hidden_states + residual ++ ) ++ if residual is None: ++ hidden_states = self.norm(hidden_states) + else: +- if hidden_states.shape[0] > 0: +- # if forward_batch.return_hidden_states_before_norm: +- hidden_states_before_norm = ( +- hidden_states if residual is None else hidden_states + residual +- ) +- if residual is None: +- hidden_states = self.norm(hidden_states) +- else: +- hidden_states, _ = self.norm(hidden_states, residual) +- return hidden_states, hidden_states_before_norm ++ hidden_states, _ = self.norm(hidden_states, residual) ++ if len(aux_hidden_states) > 0: ++ return hidden_states, hidden_states_before_norm, aux_hidden_states ++ return hidden_states, hidden_states_before_norm -@@ -843,6 +853,9 @@ class Step3p5ForCausalLM(nn.Module): + class Step3p5ForCausalLM(nn.Module): +@@ -843,6 +847,9 @@ class Step3p5ForCausalLM(nn.Module): self.logits_processor = LogitsProcessor(config) @@ -44,7 +92,7 @@ index b3f82b916..969b5252a 100644 def get_input_embeddings(self) -> nn.Embedding: return self.model.get_input_embeddings() -@@ -855,13 +868,18 @@ class Step3p5ForCausalLM(nn.Module): +@@ -855,24 +862,40 @@ class Step3p5ForCausalLM(nn.Module): input_embeds: torch.Tensor = None, pp_proxy_tensors: Optional[PPProxyTensors] = None, ) -> torch.Tensor: @@ -56,23 +104,46 @@ index b3f82b916..969b5252a 100644 input_embeds, pp_proxy_tensors=pp_proxy_tensors, ) ++ if not self.pp_group.is_last_rank: ++ return model_out + +- if self.pp_group.is_last_rank: +- return self.logits_processor( +- input_ids, +- hidden_states, +- self.lm_head, +- forward_batch, +- hidden_states_before_norm=hidden_states_before_norm, +- ) + aux_hidden_states = None -+ if self.capture_aux_hidden_states and isinstance(model_out, tuple) and len(model_out) == 3: ++ if ( ++ self.capture_aux_hidden_states ++ and isinstance(model_out, tuple) ++ and len(model_out) == 3 ++ ): + hidden_states, hidden_states_before_norm, aux_hidden_states = model_out -+ else: + else: +- return hidden_states + hidden_states, hidden_states_before_norm = model_out ++ ++ if aux_hidden_states is not None: ++ # Null out hidden_states_before_norm so LogitsProcessor uses the EAGLE3 ++ # aux captures instead (LogitsProcessor prefers hidden_states_before_norm ++ # when both are provided, which would incorrectly discard aux captures). ++ hidden_states_before_norm = None ++ ++ return self.logits_processor( ++ input_ids, ++ hidden_states, ++ self.lm_head, ++ forward_batch, ++ aux_hidden_states, ++ hidden_states_before_norm=hidden_states_before_norm, ++ ) - if self.pp_group.is_last_rank: - return self.logits_processor( -@@ -869,6 +887,7 @@ class Step3p5ForCausalLM(nn.Module): - hidden_states, - self.lm_head, - forward_batch, -+ aux_hidden_states, - hidden_states_before_norm=hidden_states_before_norm, - ) - else: -@@ -882,6 +901,16 @@ class Step3p5ForCausalLM(nn.Module): + @property + def start_layer(self): +@@ -882,6 +905,16 @@ class Step3p5ForCausalLM(nn.Module): def end_layer(self): return self.model.end_layer From f73907b2358be678ec0d7c46051bfd1c355d0014 Mon Sep 17 00:00:00 2001 From: Zijie Xia Date: Thu, 19 Mar 2026 01:38:28 -0700 Subject: [PATCH 3/9] reproduce 03 --- configs/step-3.5-flash-eagle3.json | 8 ++++++++ examples/run_step3p5_flash_eagle3_online.sh | 11 ++++++----- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/configs/step-3.5-flash-eagle3.json b/configs/step-3.5-flash-eagle3.json index 72a0e9ee..b237162d 100644 --- a/configs/step-3.5-flash-eagle3.json +++ b/configs/step-3.5-flash-eagle3.json @@ -2,6 +2,14 @@ "architectures": [ "LlamaForCausalLMEagle3" ], + "eagle_config": { + "eagle_aux_hidden_state_layer_ids": [ + 1, + 22, + 42 + ], + "use_aux_hidden_state": true + }, "hidden_size": 4096, "num_attention_heads": 64, "num_key_value_heads": 8, diff --git a/examples/run_step3p5_flash_eagle3_online.sh b/examples/run_step3p5_flash_eagle3_online.sh index 3dc34bef..45ec9175 100644 --- a/examples/run_step3p5_flash_eagle3_online.sh +++ b/examples/run_step3p5_flash_eagle3_online.sh @@ -1,6 +1,7 @@ SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) ROOT_DIR=$(dirname $SCRIPT_DIR) +export TORCHINDUCTOR_CACHE_DIR=$ROOT_DIR/cache/compiled_kernels # train eagle3 for step-3.5-flash NUM_GPUS=${1:-4} @@ -14,21 +15,21 @@ torchrun \ $ROOT_DIR/scripts/train_eagle3.py \ --target-model-path stepfun-ai/Step-3.5-Flash \ --draft-model-config configs/step-3.5-flash-eagle3.json \ - --train-data-path cache/dataset/perfectblend_train.jsonl \ + --train-data-path cache/dataset/ultrachat_train.jsonl \ --build-dataset-num-proc $BUILD_DATASET_NUM_PROC \ - --output-dir $ROOT_DIR/outputs/step-3.5-flash-eagle3-perfectblend-online \ + --output-dir $ROOT_DIR/outputs/step-3.5-flash-eagle3-ultrachat-online \ --tp-size $TP_SIZE \ --target-model-backend sglang \ --trust-remote-code \ --num-epochs 10 \ --batch-size 1 \ --learning-rate 1e-4 \ - --max-length 4096 \ + --max-length 8196 \ --chat-template step3.5 \ --cache-dir $ROOT_DIR/cache \ --embedding-key model.embed_tokens.weight \ --dist-timeout 60 \ - --sglang-mem-fraction-static 0.8 \ + --sglang-mem-fraction-static 0.75 \ --report-to wandb \ --wandb-project specforge-step3p5-flash \ - --wandb-name specforge-step3p5-flash-perfectblend + --wandb-name specforge-step3p5-flash-ultrachat From 352a647daa684aa75d60436ed46ec66ee9ef5a32 Mon Sep 17 00:00:00 2001 From: Zijie Xia Date: Thu, 26 Mar 2026 22:58:55 +0000 Subject: [PATCH 4/9] use regen train the draft model --- configs/step-3.5-flash-eagle3.json | 6 +++--- examples/run_step3p5_flash_eagle3_online.sh | 15 ++++++++------- specforge/args.py | 2 +- specforge/data/template.py | 2 ++ 4 files changed, 14 insertions(+), 11 deletions(-) diff --git a/configs/step-3.5-flash-eagle3.json b/configs/step-3.5-flash-eagle3.json index b237162d..3bd33d5c 100644 --- a/configs/step-3.5-flash-eagle3.json +++ b/configs/step-3.5-flash-eagle3.json @@ -4,9 +4,9 @@ ], "eagle_config": { "eagle_aux_hidden_state_layer_ids": [ - 1, - 22, - 42 + 4, + 20, + 40 ], "use_aux_hidden_state": true }, diff --git a/examples/run_step3p5_flash_eagle3_online.sh b/examples/run_step3p5_flash_eagle3_online.sh index 45ec9175..c6385020 100644 --- a/examples/run_step3p5_flash_eagle3_online.sh +++ b/examples/run_step3p5_flash_eagle3_online.sh @@ -4,7 +4,7 @@ ROOT_DIR=$(dirname $SCRIPT_DIR) export TORCHINDUCTOR_CACHE_DIR=$ROOT_DIR/cache/compiled_kernels # train eagle3 for step-3.5-flash -NUM_GPUS=${1:-4} +NUM_GPUS=${1:-8} TP_SIZE=${2:-4} BUILD_DATASET_NUM_PROC=${BUILD_DATASET_NUM_PROC:-64} @@ -15,21 +15,22 @@ torchrun \ $ROOT_DIR/scripts/train_eagle3.py \ --target-model-path stepfun-ai/Step-3.5-Flash \ --draft-model-config configs/step-3.5-flash-eagle3.json \ - --train-data-path cache/dataset/ultrachat_train.jsonl \ + --train-data-path cache/dataset/ultrachat_train_regen.jsonl \ --build-dataset-num-proc $BUILD_DATASET_NUM_PROC \ - --output-dir $ROOT_DIR/outputs/step-3.5-flash-eagle3-ultrachat-online \ + --output-dir $ROOT_DIR/outputs/step-3.5-flash-eagle3-ultrachat-regen-online \ --tp-size $TP_SIZE \ + --sglang-ep-size $TP_SIZE \ --target-model-backend sglang \ --trust-remote-code \ --num-epochs 10 \ --batch-size 1 \ - --learning-rate 1e-4 \ - --max-length 8196 \ + --learning-rate 5e-5 \ + --max-length 4096 \ + --sglang-attention-backend fa3 \ --chat-template step3.5 \ --cache-dir $ROOT_DIR/cache \ - --embedding-key model.embed_tokens.weight \ --dist-timeout 60 \ --sglang-mem-fraction-static 0.75 \ --report-to wandb \ --wandb-project specforge-step3p5-flash \ - --wandb-name specforge-step3p5-flash-ultrachat + --wandb-name specforge-step3p5-flash-ultrachat-regen diff --git a/specforge/args.py b/specforge/args.py index fd6de14c..0c7d94e0 100644 --- a/specforge/args.py +++ b/specforge/args.py @@ -181,7 +181,7 @@ def from_args(args: argparse.Namespace) -> "SGLangBackendArgs": args.target_batch_size if hasattr(args, "target_batch_size") else None ), sglang_max_total_tokens=( - args.target_batch_size * args.max_length + int(args.target_batch_size * args.max_length * 1.2) if hasattr(args, "target_batch_size") and hasattr(args, "max_length") else None ), diff --git a/specforge/data/template.py b/specforge/data/template.py index 80365d71..9b3b426b 100644 --- a/specforge/data/template.py +++ b/specforge/data/template.py @@ -134,6 +134,8 @@ def get_all_template_names(self) -> List[str]: user_header="<|im_start|>user\n", system_prompt="You are a helpful assistant.", end_of_turn_token="<|im_end|>\n", + parser_type="thinking", + enable_thinking=True, ), ) From 2d28645ca33a7335648bea836d157ee42c7b79a8 Mon Sep 17 00:00:00 2001 From: Zijie Xia Date: Tue, 31 Mar 2026 09:25:40 -0700 Subject: [PATCH 5/9] feat: add support for smoltalk-chinese dataset processing --- scripts/prepare_data.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/scripts/prepare_data.py b/scripts/prepare_data.py index 4e63658f..61444e85 100644 --- a/scripts/prepare_data.py +++ b/scripts/prepare_data.py @@ -47,6 +47,7 @@ def parse_args(): "perfectblend-llama4-scout-instruct", "perfectblend-llama4-maverick-instruct", "magpie-qwen2.5-pro-1m-v0.1", + "smoltalk-chinese", "sharegpt4v", "allava4v", "opc", @@ -167,6 +168,32 @@ def process_ultrachat_row(row: Dict, dataset_name: str = None) -> Tuple[Dict, in return row, 0 +def process_smoltalk_row(row: Dict, dataset_name: str = None) -> Tuple[Dict, int]: + """Process a row from the opencsg/smoltalk-chinese dataset. + + The function expects a row with the following schema: + { + "conversations": [ + { + "role": "user" | "assistant", + "content": str + } + ] + } + """ + conversations = row["conversations"] # smoltalk uses "conversations", not "messages" + formatted_conversations = [] + for message in conversations: + role = message["role"] # already "user" or "assistant" — no mapping needed + content = message["content"] + assert role in ["user", "assistant"] + formatted_conversations.append({"role": role, "content": content}) + row_id = hashlib.md5( + "".join(m["content"] for m in conversations).encode() + ).hexdigest() + return {"id": row_id, "conversations": formatted_conversations}, 0 + + def process_sharegpt_row(row: Dict, dataset_name: str = None) -> Tuple[Dict, int]: """ sharegpt dataset schema: @@ -575,6 +602,9 @@ def main(): ds = load_dataset("Magpie-Align/Magpie-Qwen2.5-Pro-1M-v0.1")["train"] ds = ds.rename_column("uuid", "id") proc_fn = process_sharegpt_row + elif args.dataset == "smoltalk-chinese": + ds = load_dataset("opencsg/smoltalk-chinese")["train"] + proc_fn = process_smoltalk_row elif args.dataset == "sharegpt4v": ds = load_dataset("Lin-Chen/ShareGPT4V", "ShareGPT4V")["train"] raise Exception("Not supported sharegpt4v now") From 8f7f391bb3fbb46660e5a333308517f4f386f0c3 Mon Sep 17 00:00:00 2001 From: Zijie Xia Date: Mon, 6 Apr 2026 10:31:49 -0700 Subject: [PATCH 6/9] update smoltalk-chinese --- scripts/prepare_data.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/scripts/prepare_data.py b/scripts/prepare_data.py index 528be6cf..5ac3c851 100644 --- a/scripts/prepare_data.py +++ b/scripts/prepare_data.py @@ -188,10 +188,7 @@ def process_smoltalk_row(row: Dict, dataset_name: str = None) -> Tuple[Dict, int content = message["content"] assert role in ["user", "assistant"] formatted_conversations.append({"role": role, "content": content}) - row_id = hashlib.md5( - "".join(m["content"] for m in conversations).encode() - ).hexdigest() - return {"id": row_id, "conversations": formatted_conversations}, 0 + return {"id": row["id"], "conversations": formatted_conversations}, 0 def process_sharegpt_row(row: Dict, dataset_name: str = None) -> Tuple[Dict, int]: @@ -603,7 +600,7 @@ def main(): ds = ds.rename_column("uuid", "id") proc_fn = process_sharegpt_row elif args.dataset == "smoltalk-chinese": - ds = load_dataset("opencsg/smoltalk-chinese")["train"] + ds = load_dataset("zjxia/smoltalk-chinese")["train"] proc_fn = process_smoltalk_row elif args.dataset == "sharegpt4v": ds = load_dataset("Lin-Chen/ShareGPT4V", "ShareGPT4V")["train"] From 119ec6ab28578903bf04d5009e81d91ee7ce3b54 Mon Sep 17 00:00:00 2001 From: Zijie Xia Date: Mon, 13 Apr 2026 09:09:51 -0700 Subject: [PATCH 7/9] fix: correct train data path in run_step3p5_flash_eagle3_online.sh --- examples/run_step3p5_flash_eagle3_online.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/run_step3p5_flash_eagle3_online.sh b/examples/run_step3p5_flash_eagle3_online.sh index c6385020..58093d2f 100644 --- a/examples/run_step3p5_flash_eagle3_online.sh +++ b/examples/run_step3p5_flash_eagle3_online.sh @@ -15,7 +15,7 @@ torchrun \ $ROOT_DIR/scripts/train_eagle3.py \ --target-model-path stepfun-ai/Step-3.5-Flash \ --draft-model-config configs/step-3.5-flash-eagle3.json \ - --train-data-path cache/dataset/ultrachat_train_regen.jsonl \ + --train-data-path $ROOT_DIR/cache/dataset/ultrachat_train_regen.jsonl \ --build-dataset-num-proc $BUILD_DATASET_NUM_PROC \ --output-dir $ROOT_DIR/outputs/step-3.5-flash-eagle3-ultrachat-regen-online \ --tp-size $TP_SIZE \ From de9b1d8372565e00261f11ccf1ff8af50ac03943 Mon Sep 17 00:00:00 2001 From: Zijie Xia Date: Mon, 13 Apr 2026 09:12:18 -0700 Subject: [PATCH 8/9] lint --- scripts/prepare_data.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/scripts/prepare_data.py b/scripts/prepare_data.py index 5ac3c851..4b9fbdf0 100644 --- a/scripts/prepare_data.py +++ b/scripts/prepare_data.py @@ -181,7 +181,9 @@ def process_smoltalk_row(row: Dict, dataset_name: str = None) -> Tuple[Dict, int ] } """ - conversations = row["conversations"] # smoltalk uses "conversations", not "messages" + conversations = row[ + "conversations" + ] # smoltalk uses "conversations", not "messages" formatted_conversations = [] for message in conversations: role = message["role"] # already "user" or "assistant" — no mapping needed From c271055a3232602c933e11e38da298bd37c230f7 Mon Sep 17 00:00:00 2001 From: Zijie Xia Date: Mon, 13 Apr 2026 09:13:51 -0700 Subject: [PATCH 9/9] remove unused file --- patches/sglang_step3p5_eagle3.patch | 162 ---------------------------- 1 file changed, 162 deletions(-) delete mode 100644 patches/sglang_step3p5_eagle3.patch diff --git a/patches/sglang_step3p5_eagle3.patch b/patches/sglang_step3p5_eagle3.patch deleted file mode 100644 index a37cd623..00000000 --- a/patches/sglang_step3p5_eagle3.patch +++ /dev/null @@ -1,162 +0,0 @@ -diff --git a/python/sglang/srt/models/step3p5.py b/python/sglang/srt/models/step3p5.py -index b3f82b916..eb91f0a9d 100644 ---- a/python/sglang/srt/models/step3p5.py -+++ b/python/sglang/srt/models/step3p5.py -@@ -1,6 +1,6 @@ - import logging - import os --from typing import Any, Dict, Iterable, Optional, Tuple, Union -+from typing import Any, Dict, Iterable, Optional, Tuple, Union, List - - import torch - import torch.nn.functional as F -@@ -634,6 +634,10 @@ class Step3p5DecoderLayer(nn.Module): - ) - self._dump_tensor("attn_output", hidden_states, dump_step) - # Fully Connected -+ # NOTE: prepare_mlp is intentionally bypassed here — step3p5 uses a -+ # non-standard residual pattern where the attn output has already been -+ # summed with the residual in prepare_attn, so we perform the addition -+ # manually rather than going through the LayerCommunicator. - # hidden_states, residual = self.layer_communicator.prepare_mlp( - # hidden_states, - # residual, -@@ -708,6 +712,9 @@ class Step3p5Model(nn.Module): - else: - self.norm = PPMissingLayer(return_tuple=True) - -+ # For EAGLE3 support -+ self.layers_to_capture = [] -+ - def get_input_embedding(self, input_ids: torch.Tensor) -> torch.Tensor: - if hasattr(self.config, "scale_emb"): - return self.get_input_embeddings()(input_ids) * self.config.scale_emb -@@ -736,7 +743,12 @@ class Step3p5Model(nn.Module): - hidden_states = pp_proxy_tensors["hidden_states"] - residual = pp_proxy_tensors["residual"] - -+ aux_hidden_states = [] - for i in range(self.start_layer, self.end_layer): -+ if i in self.layers_to_capture: -+ aux_hidden_states.append( -+ hidden_states + residual if residual is not None else hidden_states -+ ) - layer = self.layers[i] - hidden_states, residual = layer( - positions, -@@ -752,26 +764,18 @@ class Step3p5Model(nn.Module): - "residual": residual, - } - ) -- else: -- hidden_states_before_norm = None -- if not self.pp_group.is_last_rank: -- return PPProxyTensors( -- { -- "hidden_states": hidden_states, -- "residual": residual, -- } -- ) -+ hidden_states_before_norm = None -+ if hidden_states.shape[0] > 0: -+ hidden_states_before_norm = ( -+ hidden_states if residual is None else hidden_states + residual -+ ) -+ if residual is None: -+ hidden_states = self.norm(hidden_states) - else: -- if hidden_states.shape[0] > 0: -- # if forward_batch.return_hidden_states_before_norm: -- hidden_states_before_norm = ( -- hidden_states if residual is None else hidden_states + residual -- ) -- if residual is None: -- hidden_states = self.norm(hidden_states) -- else: -- hidden_states, _ = self.norm(hidden_states, residual) -- return hidden_states, hidden_states_before_norm -+ hidden_states, _ = self.norm(hidden_states, residual) -+ if len(aux_hidden_states) > 0: -+ return hidden_states, hidden_states_before_norm, aux_hidden_states -+ return hidden_states, hidden_states_before_norm - - - class Step3p5ForCausalLM(nn.Module): -@@ -843,6 +847,9 @@ class Step3p5ForCausalLM(nn.Module): - - self.logits_processor = LogitsProcessor(config) - -+ # For EAGLE3 support -+ self.capture_aux_hidden_states = False -+ - def get_input_embeddings(self) -> nn.Embedding: - return self.model.get_input_embeddings() - -@@ -855,24 +862,40 @@ class Step3p5ForCausalLM(nn.Module): - input_embeds: torch.Tensor = None, - pp_proxy_tensors: Optional[PPProxyTensors] = None, - ) -> torch.Tensor: -- hidden_states, hidden_states_before_norm = self.model( -+ model_out = self.model( - input_ids, - positions, - forward_batch, - input_embeds, - pp_proxy_tensors=pp_proxy_tensors, - ) -+ if not self.pp_group.is_last_rank: -+ return model_out - -- if self.pp_group.is_last_rank: -- return self.logits_processor( -- input_ids, -- hidden_states, -- self.lm_head, -- forward_batch, -- hidden_states_before_norm=hidden_states_before_norm, -- ) -+ aux_hidden_states = None -+ if ( -+ self.capture_aux_hidden_states -+ and isinstance(model_out, tuple) -+ and len(model_out) == 3 -+ ): -+ hidden_states, hidden_states_before_norm, aux_hidden_states = model_out - else: -- return hidden_states -+ hidden_states, hidden_states_before_norm = model_out -+ -+ if aux_hidden_states is not None: -+ # Null out hidden_states_before_norm so LogitsProcessor uses the EAGLE3 -+ # aux captures instead (LogitsProcessor prefers hidden_states_before_norm -+ # when both are provided, which would incorrectly discard aux captures). -+ hidden_states_before_norm = None -+ -+ return self.logits_processor( -+ input_ids, -+ hidden_states, -+ self.lm_head, -+ forward_batch, -+ aux_hidden_states, -+ hidden_states_before_norm=hidden_states_before_norm, -+ ) - - @property - def start_layer(self): -@@ -882,6 +905,16 @@ class Step3p5ForCausalLM(nn.Module): - def end_layer(self): - return self.model.end_layer - -+ def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None) -> None: -+ if not self.pp_group.is_last_rank: -+ return -+ self.capture_aux_hidden_states = True -+ if layer_ids is None: -+ num_layers = self.config.num_hidden_layers -+ self.model.layers_to_capture = [2, num_layers // 2, num_layers - 3] -+ else: -+ self.model.layers_to_capture = [val + 1 for val in layer_ids] -+ - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=False): - # NOTE: - # Step3p5 HF checkpoints (e.g. MTP/nextn variants) may include an extra