Skip to content

Commit d69a72d

Browse files
committed
StepFun 3.5 MTP
1 parent d14ce3d commit d69a72d

12 files changed

Lines changed: 706 additions & 34 deletions

File tree

common/speculative.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -635,6 +635,11 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
635635
std::memcpy(batch.embd + n_embd*(batch.n_tokens - 1), h_row, row_bytes);
636636
}
637637

638+
// First draft step uses the first MTP block (step 0). Archs with a
639+
// single MTP block ignore this; multi-block archs (Step-3.5-Flash) use
640+
// it to round-robin across their N MTP layers.
641+
llama_set_mtp_step(ctx_dft, 0);
642+
638643
int ret = llama_decode(ctx_dft, batch);
639644
if (ret != 0) {
640645
LOG_WRN("%s: llama_decode returned %d\n", __func__, ret);
@@ -699,6 +704,10 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
699704
break;
700705
}
701706

707+
// Step i+1: feed the i-th sampled draft token into the (i+1)-th
708+
// MTP block. Multi-block archs round-robin via mtp_step % N.
709+
llama_set_mtp_step(ctx_dft, (uint32_t)(i + 1));
710+
702711
// evaluate the drafted tokens on the draft model
703712
ret = llama_decode(ctx_dft, batch);
704713
if (ret != 0) {
@@ -709,6 +718,10 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
709718
++i;
710719
}
711720

721+
// Reset MTP step so a subsequent non-MTP decode on this context doesn't
722+
// inherit a stale offset.
723+
llama_set_mtp_step(ctx_dft, 0);
724+
712725
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
713726
auto & dp = dparams[seq_id];
714727
if (!dp.drafting) {

conversion/step3.py

Lines changed: 99 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,24 @@ class Step3VLTextModel(Qwen3Model):
9999
class Step35Model(TextModel):
100100
model_arch = gguf.MODEL_ARCH.STEP35
101101

102+
# --mtp / --no-mtp toggles (see convert_hf_to_gguf.py main()).
103+
# Unlike Qwen3.5 which stores MTP under a `mtp.*` namespace, Step3.5 just
104+
# appends MTP layers at `model.layers.{num_hidden_layers + i}`; these flags
105+
# filter by layer index instead of by name prefix.
106+
no_mtp: bool = False
107+
mtp_only: bool = False
108+
109+
def __init__(self, *args, **kwargs):
110+
super().__init__(*args, **kwargs)
111+
# NextN/MTP layers are appended past num_hidden_layers; extend the
112+
# tensor map to cover them so the MTP block's tensors get correctly
113+
# indexed names. When --no-mtp drops the MTP blocks, fall back to the
114+
# base num_hidden_layers so we don't reserve unused slots.
115+
n_nextn = int(self.hparams.get("num_nextn_predict_layers", 0))
116+
if n_nextn > 0 and not self.no_mtp:
117+
self.block_count = int(self.hparams["num_hidden_layers"]) + n_nextn
118+
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
119+
102120
def set_gguf_parameters(self):
103121
rope_theta = self.hparams.get("rope_theta")
104122
if isinstance(rope_theta, list):
@@ -119,8 +137,25 @@ def set_gguf_parameters(self):
119137
n_head_swa = attn_other.get("num_attention_heads", n_head_base)
120138
n_kv_swa = attn_other.get("num_attention_groups", n_kv_base)
121139

122-
layer_types = layer_types[: self.block_count]
123-
partial_rotary_factors = partial_rotary_factors[: self.block_count]
140+
n_nextn = int(self.hparams.get("num_nextn_predict_layers", 0))
141+
142+
# The Step3p5 HF checkpoint stores layer_types/partial_rotary_factors
143+
# entries for the MTP blocks past num_hidden_layers; preserve them so
144+
# the MTP layer's attention shape, SWA flag, and partial RoPE dim are
145+
# set correctly. Pad with full-attention defaults if the checkpoint
146+
# truncated them.
147+
def _pad(arr, n, default):
148+
arr = list(arr)
149+
if len(arr) < n:
150+
arr = arr + [default] * (n - len(arr))
151+
return arr[:n]
152+
153+
layer_types = _pad(layer_types, self.block_count, "full_attention")
154+
partial_rotary_factors = _pad(
155+
partial_rotary_factors,
156+
self.block_count,
157+
0.5, # full_attention default for Step3p5
158+
)
124159
assert [1.0 if lt == "sliding_attention" else 0.5 for lt in layer_types] == partial_rotary_factors
125160
head_arr = [n_head_swa if lt == "sliding_attention" else n_head_base for lt in layer_types]
126161
kv_arr = [n_kv_swa if lt == "sliding_attention" else n_kv_base for lt in layer_types]
@@ -157,14 +192,25 @@ def set_gguf_parameters(self):
157192

158193
self.gguf_writer.add_layer_norm_rms_eps(self.hparams.get("rms_norm_eps", 1e-5))
159194

160-
# Optional per-layer SwiGLU clamps.
195+
# Optional per-layer SwiGLU clamps. MTP layers default to no clamping (0.0).
161196
if (limits := self.hparams.get("swiglu_limits")) is not None:
162-
limits_f = [0.0 if v is None else float(v) for v in limits[: self.block_count]]
197+
limits_f = _pad(
198+
[0.0 if v is None else float(v) for v in limits],
199+
self.block_count,
200+
0.0,
201+
)
163202
self.gguf_writer.add_swiglu_clamp_exp(limits_f)
164203
if (limits_shared := self.hparams.get("swiglu_limits_shared")) is not None:
165-
limits_shared_f = [0.0 if v is None else float(v) for v in limits_shared[: self.block_count]]
204+
limits_shared_f = _pad(
205+
[0.0 if v is None else float(v) for v in limits_shared],
206+
self.block_count,
207+
0.0,
208+
)
166209
self.gguf_writer.add_swiglu_clamp_shexp(limits_shared_f)
167210

211+
if n_nextn > 0 and not self.no_mtp:
212+
self.gguf_writer.add_nextn_predict_layers(n_nextn)
213+
168214
@classmethod
169215
def filter_tensors(cls, item: tuple[str, Callable[[], Tensor]]) -> tuple[str, Callable[[], Tensor]] | None:
170216
name, gen = item
@@ -175,13 +221,41 @@ def filter_tensors(cls, item: tuple[str, Callable[[], Tensor]]) -> tuple[str, Ca
175221

176222
return super().filter_tensors((name, gen))
177223

224+
def _is_mtp_layer(self, bid: int | None) -> bool:
225+
if bid is None:
226+
return False
227+
n_main = int(self.hparams.get("num_hidden_layers", self.block_count))
228+
return bid >= n_main
229+
178230
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
179-
# remove mtp layers
180-
if (m := re.match(r"model\.layers\.(\d+)\.", name)) is not None:
181-
il = int(m.group(1))
182-
n_main = int(self.hparams.get("num_hidden_layers", self.block_count))
183-
if il >= n_main:
231+
is_mtp = self._is_mtp_layer(bid)
232+
233+
# --no-mtp: drop the appended MTP block(s) entirely.
234+
if is_mtp and self.no_mtp:
235+
return
236+
# --mtp: keep ONLY MTP-block tensors plus the shared embeddings/norm/lm_head
237+
# (so the resulting GGUF carries just the draft head).
238+
if self.mtp_only and not is_mtp and bid is not None:
239+
return
240+
if self.mtp_only and bid is None:
241+
# Top-level tensors: keep only shared embeddings/norm/lm_head.
242+
keep = name in (
243+
"model.embed_tokens.weight", "model.norm.weight", "lm_head.weight",
244+
)
245+
if not keep:
184246
return
247+
248+
# The checkpoint nests the per-MTP-layer shared head under
249+
# `model.layers.{N+i}.transformer.shared_head.{norm,output}.weight`;
250+
# strip the `transformer.` infix and rename `output` → `head` so the
251+
# existing NEXTN_SHARED_HEAD_{NORM,HEAD} tensor mapping picks them up.
252+
# Mirrors vllm's `_rewrite_spec_layer_name` (step3p5_mtp.py).
253+
if is_mtp:
254+
if ".transformer." in name:
255+
name = name.replace(".transformer.", ".")
256+
if "shared_head.output" in name:
257+
name = name.replace("shared_head.output", "shared_head.head")
258+
185259
if name.endswith("norm.weight"):
186260
data_torch += 1.0
187261

@@ -190,6 +264,21 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
190264

191265
yield from super().modify_tensors(data_torch, name, bid)
192266

267+
def prepare_metadata(self, vocab_only: bool):
268+
from_dir = self.fname_out.is_dir()
269+
super().prepare_metadata(vocab_only=vocab_only)
270+
271+
# Mirror Qwen3.5's behavior: when emitting a draft-only file into a
272+
# directory, prefix with "mtp-" so it doesn't collide with the trunk.
273+
if not self.mtp_only or not from_dir:
274+
return
275+
276+
output_type: str = self.ftype.name.partition("_")[2]
277+
fname_default: str = gguf.naming_convention(
278+
self.metadata.name, self.metadata.basename, self.metadata.finetune,
279+
self.metadata.version, size_label=None, output_type=output_type, model_type=None)
280+
self.fname_out = self.fname_out.parent / f"mtp-{fname_default}.gguf"
281+
193282
def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
194283
# Step35 can optionally use Llama-3 style RoPE scaling (HF: rope_scaling.rope_type == "llama3").
195284
# llama.cpp represents this via a single extra tensor: "rope_freqs.weight" (aka MODEL_TENSOR.ROPE_FREQS).

convert_hf_to_gguf.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -247,8 +247,9 @@ def main() -> None:
247247

248248
if args.mtp or args.no_mtp:
249249
from conversion.qwen import _Qwen35MtpMixin
250-
if not issubclass(model_class, _Qwen35MtpMixin):
251-
logger.error("--mtp / --no-mtp are only supported for Qwen3.5/3.6 text variants today")
250+
from conversion.step3 import Step35Model
251+
if not (issubclass(model_class, _Qwen35MtpMixin) or issubclass(model_class, Step35Model)):
252+
logger.error("--mtp / --no-mtp are only supported for Qwen3.5/3.6 and Step3.5 text variants today")
252253
sys.exit(1)
253254
if args.no_mtp:
254255
model_class.no_mtp = True

gguf-py/gguf/constants.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3935,6 +3935,13 @@ class MODEL_TENSOR(IntEnum):
39353935
MODEL_TENSOR.FFN_GATE_SHEXP,
39363936
MODEL_TENSOR.FFN_DOWN_SHEXP,
39373937
MODEL_TENSOR.FFN_EXP_PROBS_B,
3938+
# NextN/MTP tensors (Step3p5 draft head)
3939+
MODEL_TENSOR.NEXTN_EH_PROJ,
3940+
MODEL_TENSOR.NEXTN_EMBED_TOKENS,
3941+
MODEL_TENSOR.NEXTN_ENORM,
3942+
MODEL_TENSOR.NEXTN_HNORM,
3943+
MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD,
3944+
MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM,
39383945
],
39393946
MODEL_ARCH.LLAMA_EMBED: [
39403947
MODEL_TENSOR.TOKEN_EMBD,

0 commit comments

Comments
 (0)