Skip to content

Commit 2187e00

Browse files
pwilkinCISC
andauthored
StepFun 3.5 MTP (#23274)
* StepFun 3.5 MTP * Simplify to single layer * Rollback core changes * fix flake8 errors * Remove scripts * modify to convention * Apply suggestions from code review Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * dos2unix --------- Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
1 parent 0b71540 commit 2187e00

5 files changed

Lines changed: 418 additions & 26 deletions

File tree

conversion/step3.py

Lines changed: 103 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,34 @@ class Step3VLTextModel(Qwen3Model):
9999
class Step35Model(TextModel):
100100
model_arch = gguf.MODEL_ARCH.STEP35
101101

102+
# The --mtp / --no-mtp toggles are ModelBase.mtp_only / no_mtp (set in
103+
# convert_hf_to_gguf.py main()). Unlike Qwen3.5, which stores MTP under a
104+
# `mtp.*` namespace, Step3.5 appends MTP layers at
105+
# `model.layers.{num_hidden_layers + i}`, so we filter them by layer index.
106+
# The trunk layer count is captured before indexing so the classmethod
107+
# filter_tensors can tell the appended MTP block(s) apart from the trunk.
108+
_n_main_layers: int | None = None
109+
110+
def __init__(self, *args, **kwargs):
111+
super().__init__(*args, **kwargs)
112+
# NextN/MTP layers are appended past num_hidden_layers; extend the
113+
# tensor map to cover them so the MTP block's tensors get correctly
114+
# indexed names. When --no-mtp drops the MTP blocks, fall back to the
115+
# base num_hidden_layers so we don't reserve unused slots.
116+
n_nextn = int(self.hparams.get("num_nextn_predict_layers", 0))
117+
if n_nextn > 0 and not self.no_mtp:
118+
self.block_count += n_nextn
119+
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
120+
121+
def index_tensors(self, remote_hf_model_id: str | None = None):
122+
# filter_tensors is a classmethod and can't reach self.hparams; stash
123+
# the trunk layer count here (before indexing runs) so it can detect
124+
# the appended MTP layers by index.
125+
hparams = {**self.hparams, **self.hparams.get("text_config", {})}
126+
key = next((k for k in ["n_layers", "num_hidden_layers", "n_layer", "num_layers"] if k in hparams), None)
127+
type(self)._n_main_layers = hparams.get(key)
128+
return super().index_tensors(remote_hf_model_id=remote_hf_model_id)
129+
102130
def set_gguf_parameters(self):
103131
rope_theta = self.hparams.get("rope_theta")
104132
if isinstance(rope_theta, list):
@@ -119,8 +147,25 @@ def set_gguf_parameters(self):
119147
n_head_swa = attn_other.get("num_attention_heads", n_head_base)
120148
n_kv_swa = attn_other.get("num_attention_groups", n_kv_base)
121149

122-
layer_types = layer_types[: self.block_count]
123-
partial_rotary_factors = partial_rotary_factors[: self.block_count]
150+
n_nextn = int(self.hparams.get("num_nextn_predict_layers", 0))
151+
152+
# The Step3p5 HF checkpoint stores layer_types/partial_rotary_factors
153+
# entries for the MTP blocks past num_hidden_layers; preserve them so
154+
# the MTP layer's attention shape, SWA flag, and partial RoPE dim are
155+
# set correctly. Pad with full-attention defaults if the checkpoint
156+
# truncated them.
157+
def _pad(arr, n, default):
158+
arr = list(arr)
159+
if len(arr) < n:
160+
arr = arr + [default] * (n - len(arr))
161+
return arr[:n]
162+
163+
layer_types = _pad(layer_types, self.block_count, "full_attention")
164+
partial_rotary_factors = _pad(
165+
partial_rotary_factors,
166+
self.block_count,
167+
0.5, # full_attention default for Step3p5
168+
)
124169
assert [1.0 if lt == "sliding_attention" else 0.5 for lt in layer_types] == partial_rotary_factors
125170
head_arr = [n_head_swa if lt == "sliding_attention" else n_head_base for lt in layer_types]
126171
kv_arr = [n_kv_swa if lt == "sliding_attention" else n_kv_base for lt in layer_types]
@@ -157,31 +202,61 @@ def set_gguf_parameters(self):
157202

158203
self.gguf_writer.add_layer_norm_rms_eps(self.hparams.get("rms_norm_eps", 1e-5))
159204

160-
# Optional per-layer SwiGLU clamps.
205+
# Optional per-layer SwiGLU clamps. MTP layers default to no clamping (0.0).
161206
if (limits := self.hparams.get("swiglu_limits")) is not None:
162-
limits_f = [0.0 if v is None else float(v) for v in limits[: self.block_count]]
207+
limits_f = _pad(
208+
[0.0 if v is None else float(v) for v in limits],
209+
self.block_count,
210+
0.0,
211+
)
163212
self.gguf_writer.add_swiglu_clamp_exp(limits_f)
164213
if (limits_shared := self.hparams.get("swiglu_limits_shared")) is not None:
165-
limits_shared_f = [0.0 if v is None else float(v) for v in limits_shared[: self.block_count]]
214+
limits_shared_f = _pad(
215+
[0.0 if v is None else float(v) for v in limits_shared],
216+
self.block_count,
217+
0.0,
218+
)
166219
self.gguf_writer.add_swiglu_clamp_shexp(limits_shared_f)
167220

221+
if n_nextn > 0 and not self.no_mtp:
222+
self.gguf_writer.add_nextn_predict_layers(n_nextn)
223+
168224
@classmethod
169225
def filter_tensors(cls, item: tuple[str, Callable[[], Tensor]]) -> tuple[str, Callable[[], Tensor]] | None:
170-
name, gen = item
226+
if (titem := super().filter_tensors(item)) is None:
227+
return None
228+
name, gen = titem
171229

172230
# Map router bias (expert selection bias) to a GGUF bias tensor
173231
if name.endswith(".moe.router_bias"):
174232
name += ".bias"
175233

176-
return super().filter_tensors((name, gen))
234+
# Step3.5 appends the MTP block(s) past num_hidden_layers.
235+
assert cls._n_main_layers is not None
236+
is_mtp = (m := re.match(r"model\.layers\.(\d+)\.", name)) is not None and int(m.group(1)) >= cls._n_main_layers
237+
238+
# --no-mtp: drop the appended MTP block(s) entirely.
239+
if is_mtp and cls.no_mtp:
240+
return None
241+
# --mtp: keep ONLY MTP-block tensors plus the shared embeddings/norm/
242+
# lm_head (so the resulting GGUF carries just the draft head).
243+
if cls.mtp_only and not is_mtp and name not in (
244+
"model.embed_tokens.weight", "model.norm.weight", "lm_head.weight",
245+
):
246+
return None
247+
248+
# The checkpoint nests the per-MTP-layer shared head under
249+
# `model.layers.{N+i}.transformer.shared_head.{norm,output}.weight`;
250+
# strip the `transformer.` infix and rename `output` → `head` so the
251+
# existing NEXTN_SHARED_HEAD_{NORM,HEAD} tensor mapping picks them up.
252+
# Mirrors vllm's `_rewrite_spec_layer_name` (step3p5_mtp.py).
253+
if is_mtp:
254+
name = name.replace(".transformer.", ".")
255+
name = name.replace("shared_head.output", "shared_head.head")
256+
257+
return name, gen
177258

178259
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
179-
# remove mtp layers
180-
if (m := re.match(r"model\.layers\.(\d+)\.", name)) is not None:
181-
il = int(m.group(1))
182-
n_main = int(self.hparams.get("num_hidden_layers", self.block_count))
183-
if il >= n_main:
184-
return
185260
if name.endswith("norm.weight"):
186261
data_torch += 1.0
187262

@@ -190,6 +265,21 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
190265

191266
yield from super().modify_tensors(data_torch, name, bid)
192267

268+
def prepare_metadata(self, vocab_only: bool):
269+
from_dir = self.fname_out.is_dir()
270+
super().prepare_metadata(vocab_only=vocab_only)
271+
272+
# Mirror Qwen3.5's behavior: when emitting a draft-only file into a
273+
# directory, prefix with "mtp-" so it doesn't collide with the trunk.
274+
if not self.mtp_only or not from_dir:
275+
return
276+
277+
output_type: str = self.ftype.name.partition("_")[2]
278+
fname_default: str = gguf.naming_convention(
279+
self.metadata.name, self.metadata.basename, self.metadata.finetune,
280+
self.metadata.version, size_label=None, output_type=output_type, model_type=None)
281+
self.fname_out = self.fname_out.parent / f"mtp-{fname_default}.gguf"
282+
193283
def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
194284
# Step35 can optionally use Llama-3 style RoPE scaling (HF: rope_scaling.rope_type == "llama3").
195285
# llama.cpp represents this via a single extra tensor: "rope_freqs.weight" (aka MODEL_TENSOR.ROPE_FREQS).

convert_hf_to_gguf.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -251,8 +251,9 @@ def main() -> None:
251251

252252
if args.mtp or args.no_mtp:
253253
from conversion.qwen import _Qwen35MtpMixin
254-
if not issubclass(model_class, _Qwen35MtpMixin):
255-
logger.error("--mtp / --no-mtp are only supported for Qwen3.5/3.6 text variants today")
254+
from conversion.step3 import Step35Model
255+
if not (issubclass(model_class, _Qwen35MtpMixin) or issubclass(model_class, Step35Model)):
256+
logger.error("--mtp / --no-mtp are only supported for Qwen3.5/3.6 and Step3.5 text variants today")
256257
sys.exit(1)
257258
if args.no_mtp:
258259
model_class.no_mtp = True

gguf-py/gguf/constants.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3994,6 +3994,13 @@ class MODEL_TENSOR(IntEnum):
39943994
MODEL_TENSOR.FFN_GATE_SHEXP,
39953995
MODEL_TENSOR.FFN_DOWN_SHEXP,
39963996
MODEL_TENSOR.FFN_EXP_PROBS_B,
3997+
# NextN/MTP tensors (Step3p5 draft head)
3998+
MODEL_TENSOR.NEXTN_EH_PROJ,
3999+
MODEL_TENSOR.NEXTN_EMBED_TOKENS,
4000+
MODEL_TENSOR.NEXTN_ENORM,
4001+
MODEL_TENSOR.NEXTN_HNORM,
4002+
MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD,
4003+
MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM,
39974004
],
39984005
MODEL_ARCH.LLAMA_EMBED: [
39994006
MODEL_TENSOR.TOKEN_EMBD,

src/models/models.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1913,5 +1913,9 @@ struct llama_model_step35 : public llama_model_base {
19131913
graph(const llama_model & model, const llm_graph_params & params);
19141914
};
19151915

1916+
struct graph_mtp : public llm_graph_context {
1917+
graph_mtp(const llama_model & model, const llm_graph_params & params);
1918+
};
1919+
19161920
std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override;
19171921
};

0 commit comments

Comments
 (0)