Skip to content

Commit ba93c30

Browse files
jingyu-mlclaude
andcommitted
fastgen: delete unused wan22 plugin
wan22.py shipped the Wan 2.2 teacher feature-capture helpers, but with the Wan example config and recipe already removed the plugin is never exercised: the Qwen-Image plugin provides its own attach_feature_capture, and the DMD feature-capture path is duck-typed on ``_fastgen_captured``. Remove the module, drop its (only) import from plugins/__init__.py, and repoint the docstring / error-message references at the qwen_image plugin. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
1 parent 9754d5d commit ba93c30

7 files changed

Lines changed: 14 additions & 173 deletions

File tree

modelopt/torch/fastgen/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
3232
# If GAN is enabled, expose intermediate teacher features to the discriminator.
3333
if cfg.gan_loss_weight_gen > 0:
34-
mtf.plugins.wan22.attach_feature_capture(teacher, feature_indices=[15, 22, 29])
34+
mtf.plugins.qwen_image.attach_feature_capture(teacher, feature_indices=[30])
3535
3636
pipeline = mtf.DMDPipeline(student, teacher, fake_score, cfg, discriminator=disc)
3737
@@ -62,7 +62,7 @@
6262
from .pipeline import DistillationPipeline
6363

6464
# isort: off
65-
# Plugins must be imported after the core exports so the wan22 hooks can reference
65+
# Plugins must be imported after the core exports so the plugin hooks can reference
6666
# DMDPipeline if needed in the future; also matches the ordering used by
6767
# modelopt.torch.distill.
6868
from . import plugins

modelopt/torch/fastgen/discriminators.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
a list of spatial feature tensors ``[B, C, H, W]`` and returns concatenated
2323
logits ``[B, num_heads]``. The model-specific work of producing those tensors
2424
(installing forward hooks, reshaping packed-token streams into spatial maps)
25-
lives in the per-model plugins (``plugins/qwen_image.py``, ``plugins/wan22.py``).
25+
lives in the per-model plugins (``plugins/qwen_image.py``).
2626
"""
2727

2828
from __future__ import annotations

modelopt/torch/fastgen/loader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
"""YAML-driven configuration loading for fastgen distillation pipelines.
1717
18-
YAML is the first-class entry point for DMD-on-Wan configurations — the fastgen library
18+
YAML is the first-class entry point for DMD configurations — the fastgen library
1919
does not expect callers to hand-build Python dicts. Typical usage::
2020
2121
from modelopt.torch.fastgen import DMDConfig, load_dmd_config

modelopt/torch/fastgen/methods/dmd.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@
6464

6565

6666
# ---------------------------------------------------------------------------- #
67-
# Feature capture helper (duck-typed so tests can bypass the wan22 plugin) #
67+
# Feature capture helper (duck-typed so tests can bypass the capture plugin) #
6868
# ---------------------------------------------------------------------------- #
6969

7070

@@ -76,8 +76,7 @@ def _drain_if_hooked(module: nn.Module) -> list[torch.Tensor] | None:
7676
call sites can drain unconditionally after every teacher forward — this prevents
7777
the buffer from growing across steps when hooks are attached but the GAN branch is
7878
disabled (e.g. an ablation). Callers that need the strict "did you forget to attach
79-
hooks?" failure mode should call :func:`_require_hooked` on the result, or use
80-
:func:`modelopt.torch.fastgen.plugins.wan22.pop_captured_features` directly.
79+
hooks?" failure mode should call :func:`_require_hooked` on the result.
8180
"""
8281
captured = getattr(module, "_fastgen_captured", None)
8382
if captured is None:
@@ -106,7 +105,7 @@ def _require_hooked(
106105
raise RuntimeError(
107106
f"Feature-capture hooks are required on the teacher ({which} branch): "
108107
"teacher._fastgen_captured is missing. Call "
109-
"modelopt.torch.fastgen.plugins.wan22.attach_feature_capture(teacher, ...) "
108+
"modelopt.torch.fastgen.plugins.qwen_image.attach_feature_capture(teacher, ...) "
110109
"before running this loss."
111110
)
112111
return features
@@ -127,7 +126,7 @@ class DMDPipeline(DistillationPipeline):
127126
object with a ``.sample`` attribute.
128127
teacher: Frozen reference module with the same call signature. If ``discriminator``
129128
is provided, feature-capture hooks must be attached to ``teacher`` before
130-
calling ``compute_*_loss`` — see :func:`modelopt.torch.fastgen.plugins.wan22.attach_feature_capture`.
129+
calling ``compute_*_loss`` — see :func:`modelopt.torch.fastgen.plugins.qwen_image.attach_feature_capture`.
131130
fake_score: Trainable auxiliary module (same signature as teacher/student). Used to
132131
approximate the student's generated distribution for the VSD gradient.
133132
config: :class:`~modelopt.torch.fastgen.config.DMDConfig` with the hyperparameters.

modelopt/torch/fastgen/plugins/__init__.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,13 @@
1515

1616
"""Optional plugins for the fastgen subpackage (gated via ``import_plugin``).
1717
18-
``wan22`` holds the forward-hook helpers for exposing intermediate teacher activations
19-
to the DMD2 GAN discriminator on Wan 2.2 models. The module itself only depends on
20-
``torch`` at runtime, but we still gate the import so environments that choose not to
21-
install any optional fastgen dependencies see a clean package import.
18+
``qwen_image`` holds the Qwen-Image pipeline plus the forward-hook helpers that expose
19+
intermediate teacher activations to the DMD2 GAN discriminator. The import is gated so
20+
environments that choose not to install the optional fastgen dependencies still see a
21+
clean package import.
2222
"""
2323

2424
from modelopt.torch.utils import import_plugin
2525

26-
with import_plugin("wan22"):
27-
from .wan22 import *
28-
2926
with import_plugin("qwen_image"):
3027
from .qwen_image import *

modelopt/torch/fastgen/plugins/qwen_image.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -257,10 +257,9 @@ def _call_model(
257257
# GAN feature capture #
258258
# ---------------------------------------------------------------------------- #
259259

260-
# Attribute names match :mod:`modelopt.torch.fastgen.plugins.wan22` so the shared
260+
# These attribute names are what the shared
261261
# :func:`~modelopt.torch.fastgen.methods.dmd._drain_if_hooked` /
262-
# :func:`~modelopt.torch.fastgen.methods.dmd._require_hooked` helpers work
263-
# without modification.
262+
# :func:`~modelopt.torch.fastgen.methods.dmd._require_hooked` helpers look for.
264263
_CAPTURED_ATTR = "_fastgen_captured"
265264
_HANDLES_ATTR = "_fastgen_capture_handles"
266265
_INDICES_ATTR = "_fastgen_capture_indices"

modelopt/torch/fastgen/plugins/wan22.py

Lines changed: 0 additions & 154 deletions
This file was deleted.

0 commit comments

Comments
 (0)