Skip to content

Commit f633e30

Browse files
aryan5valexzmsmergify[bot]
authored
[feat]: add LongCat bidirectional finetuning support (hao-ai-lab#1244)
Co-authored-by: Aryan Kumar <aryan5v@users.noreply.github.com> Co-authored-by: alexzms <3036648523@qq.com> Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
1 parent 5ce4947 commit f633e30

7 files changed

Lines changed: 281 additions & 12 deletions

File tree

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# LongCat-Video T2V 13.6B bidirectional finetune.
2+
3+
models:
4+
student:
5+
_target_: fastvideo.train.models.longcat.LongCatModel
6+
init_from: FastVideo/LongCat-Video-T2V-Diffusers
7+
trainable: true
8+
9+
method:
10+
_target_: fastvideo.train.methods.fine_tuning.finetune.FineTuneMethod
11+
12+
training:
13+
model_path: FastVideo/LongCat-Video-T2V-Diffusers
14+
15+
distributed:
16+
num_gpus: 8
17+
sp_size: 1
18+
tp_size: 1
19+
hsdp_replicate_dim: 1
20+
hsdp_shard_dim: 8
21+
22+
data:
23+
data_path: data/LongCat-Syn
24+
dataloader_num_workers: 4
25+
train_batch_size: 1
26+
training_cfg_rate: 0.0
27+
seed: 1000
28+
num_latent_t: 20
29+
num_height: 480
30+
num_width: 848
31+
num_frames: 77
32+
33+
optimizer:
34+
learning_rate: 1.0e-6
35+
betas: [0.9, 0.999]
36+
weight_decay: 0.01
37+
lr_scheduler: constant
38+
lr_warmup_steps: 0
39+
40+
loop:
41+
max_train_steps: 4000
42+
gradient_accumulation_steps: 1
43+
44+
checkpoint:
45+
output_dir: outputs/longcat_finetune
46+
training_state_checkpointing_steps: 1000
47+
checkpoints_total_limit: 3
48+
49+
tracker:
50+
project_name: fastvideo
51+
run_name: longcat_finetune
52+
53+
model:
54+
enable_gradient_checkpointing_type: full
55+
56+
callbacks:
57+
grad_clip:
58+
_target_: fastvideo.train.callbacks.grad_clip.GradNormClipCallback
59+
max_grad_norm: 1.0
60+
validation:
61+
_target_: fastvideo.train.callbacks.validation.ValidationCallback
62+
pipeline_target: fastvideo.pipelines.basic.longcat.longcat_pipeline.LongCatPipeline
63+
dataset_file: data/validation_prompts.json
64+
every_steps: 100
65+
sampling_steps: [50]
66+
guidance_scale: 5.0
67+
68+
pipeline:
69+
# Match the released LongCat scheduler config. flow_shift=0.0 collapses
70+
# FlowMatch training timesteps to zero in FastVideo's scheduler.
71+
flow_shift: 12.0

fastvideo/models/dits/longcat.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,17 @@ def forward(
206206
encoder_attention_mask = encoder_attention_mask.squeeze(1).squeeze(1)
207207
elif len(encoder_attention_mask.shape) == 3:
208208
encoder_attention_mask = encoder_attention_mask.squeeze(1)
209+
210+
seq_len = int(y.shape[1])
211+
mask_len = int(encoder_attention_mask.shape[1])
212+
if mask_len < seq_len:
213+
encoder_attention_mask = F.pad(
214+
encoder_attention_mask,
215+
(0, seq_len - mask_len),
216+
value=0,
217+
)
218+
elif mask_len > seq_len:
219+
encoder_attention_mask = encoder_attention_mask[:, :seq_len]
209220

210221
# Zero out padded tokens if requested
211222
if self.text_tokens_zero_pad:

fastvideo/pipelines/stages/longcat_denoising.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -70,13 +70,14 @@ def forward(
7070
pipeline.add_module("transformer", self.transformer)
7171
fastvideo_args.model_loaded["transformer"] = True
7272

73-
# Get transformer dtype
74-
if hasattr(self.transformer, 'module'):
75-
transformer_dtype = next(self.transformer.module.parameters()).dtype
76-
else:
77-
transformer_dtype = next(self.transformer.parameters()).dtype
78-
79-
target_dtype = transformer_dtype
73+
# Inference dtype. We hardcode bf16 (matching the WanDenoisingStage
74+
# pattern) rather than reading transformer.parameters().dtype: when
75+
# a model is loaded with default_dtype=fp32 but its FSDP-wrapped
76+
# submodules compute in bf16, the parameter-dtype heuristic
77+
# mismatches the Conv3d weight/bias dtype and the patch_embed
78+
# forward fails with "Input type (float) and bias type
79+
# (c10::BFloat16) should be the same".
80+
target_dtype = torch.bfloat16
8081
autocast_enabled = (target_dtype != torch.float32) and not fastvideo_args.disable_autocast
8182

8283
# Extract batch parameters

fastvideo/tests/transformers/test_cosmos.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,24 @@
2222
os.environ["MASTER_PORT"] = "29504"
2323

2424
BASE_MODEL_PATH = "nvidia/Cosmos-Predict2-2B-Video2World"
25-
MODEL_PATH = maybe_download_model(BASE_MODEL_PATH,
26-
local_dir=os.path.join(
27-
'data', BASE_MODEL_PATH))
25+
26+
27+
def _resolve_model_path() -> str:
28+
try:
29+
return maybe_download_model(
30+
BASE_MODEL_PATH,
31+
local_dir=os.path.join("data", BASE_MODEL_PATH),
32+
)
33+
except ValueError as exc:
34+
pytest.skip(
35+
"Skipping Cosmos transformer test because the configured "
36+
"HuggingFace token cannot access the gated Cosmos weights: "
37+
f"{exc}",
38+
allow_module_level=True,
39+
)
40+
41+
42+
MODEL_PATH = _resolve_model_path()
2843
TRANSFORMER_PATH = os.path.join(MODEL_PATH, "transformer")
2944

3045

@@ -131,4 +146,4 @@ def test_cosmos2_transformer():
131146
logger.info("Mean Diff: %s", mean_diff.item())
132147
assert max_diff < 1e-1, f"Maximum difference between outputs: {max_diff.item()}"
133148
# mean diff
134-
assert mean_diff < 1e-2, f"Mean difference between outputs: {mean_diff.item()}"
149+
assert mean_diff < 1e-2, f"Mean difference between outputs: {mean_diff.item()}"
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
"""LongCat model plugin package."""
3+
4+
from fastvideo.train.models.longcat.longcat import (
5+
LongCatModel as LongCatModel, )
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
"""LongCat model plugin (per-role instance)."""
3+
4+
from __future__ import annotations
5+
6+
from typing import Any, Literal, TYPE_CHECKING
7+
8+
import torch
9+
10+
from fastvideo.pipelines import TrainingBatch
11+
from fastvideo.train.models.wan.wan import WanModel
12+
13+
if TYPE_CHECKING:
14+
from fastvideo.train.utils.training_config import TrainingConfig
15+
16+
17+
class LongCatModel(WanModel):
18+
"""LongCat per-role model for training and distillation."""
19+
20+
_transformer_cls_name: str = "LongCatTransformer3DModel"
21+
22+
@staticmethod
23+
def _validate_flow_shift(flow_shift: float | None) -> float:
24+
if flow_shift is None:
25+
return 12.0
26+
27+
validated = float(flow_shift)
28+
if validated == 0.0:
29+
raise ValueError("LongCat training does not support flow_shift=0.0 because "
30+
"it collapses FlowMatch training timesteps. Use 12.0 to "
31+
"match the released LongCat scheduler config.")
32+
return validated
33+
34+
def __init__(
35+
self,
36+
*,
37+
init_from: str,
38+
training_config: TrainingConfig,
39+
trainable: bool = True,
40+
disable_custom_init_weights: bool = False,
41+
flow_shift: float = 12.0,
42+
enable_gradient_checkpointing_type: str | None = None,
43+
transformer_override_safetensor: str | None = None,
44+
) -> None:
45+
super().__init__(
46+
init_from=init_from,
47+
training_config=training_config,
48+
trainable=trainable,
49+
disable_custom_init_weights=disable_custom_init_weights,
50+
flow_shift=self._validate_flow_shift(flow_shift),
51+
enable_gradient_checkpointing_type=enable_gradient_checkpointing_type,
52+
transformer_override_safetensor=transformer_override_safetensor,
53+
)
54+
55+
def _init_timestep_mechanics(self) -> None:
56+
assert self.training_config is not None
57+
tc = self.training_config
58+
flow_shift = getattr(tc.pipeline_config, "flow_shift", None) # type: ignore[union-attr]
59+
self.timestep_shift = self._validate_flow_shift(flow_shift)
60+
self.num_train_timestep = int(self.noise_scheduler.num_train_timesteps)
61+
self.min_timestep = 0
62+
self.max_timestep = self.num_train_timestep
63+
64+
def _build_attention_metadata(self, training_batch: TrainingBatch) -> TrainingBatch:
65+
training_batch.attn_metadata = None
66+
return training_batch
67+
68+
def _build_distill_input_kwargs(
69+
self,
70+
noise_input: torch.Tensor,
71+
timestep: torch.Tensor,
72+
text_dict: dict[str, torch.Tensor] | None,
73+
) -> dict[str, Any]:
74+
if text_dict is None:
75+
raise ValueError("text_dict cannot be None for LongCat distillation")
76+
77+
batch_size = int(noise_input.shape[0])
78+
if timestep.ndim == 0:
79+
timestep = timestep.view(1).expand(batch_size)
80+
elif timestep.ndim == 1 and int(timestep.shape[0]) == 1 and batch_size > 1:
81+
timestep = timestep.expand(batch_size)
82+
83+
return {
84+
"hidden_states": noise_input.permute(0, 2, 1, 3, 4),
85+
"encoder_hidden_states": text_dict["encoder_hidden_states"],
86+
"encoder_attention_mask": text_dict["encoder_attention_mask"],
87+
"timestep": timestep,
88+
}
89+
90+
def predict_noise(
91+
self,
92+
noisy_latents: torch.Tensor,
93+
timestep: torch.Tensor,
94+
batch: TrainingBatch,
95+
*,
96+
conditional: bool,
97+
cfg_uncond: dict[str, Any] | None = None,
98+
attn_kind: Literal["dense", "vsa"] = "dense",
99+
) -> torch.Tensor:
100+
"""Adapt LongCat's sign convention to FineTuneMethod's target.
101+
102+
``LongCatTransformer3DModel`` is pretrained to output the
103+
``clean - noise`` direction; ``LongCatDenoisingStage`` (the
104+
bidirectional inference pipeline) explicitly negates the
105+
transformer output before handing it to
106+
``FlowMatchEulerDiscreteScheduler.step``. Training methods on
107+
the other hand (``FineTuneMethod``,
108+
``DiffusionForcingSFTMethod``) target ``noise - clean``
109+
directly (the standard rectified-flow velocity Wan uses).
110+
111+
Without the negation here, the loss MSE pushes the transformer
112+
toward ``noise - clean``, flipping its native output sign over
113+
training. Inference then applies its own negation on top, so
114+
the scheduler receives the wrong direction and produces noise
115+
even while the training loss is dropping. Verified empirically
116+
on a 100-step LongCat overfit run: step 0 generated meaningful
117+
video, step 100 was pure noise despite low loss.
118+
119+
Negating in ``predict_noise`` keeps the transformer's
120+
pretrained sign convention intact while presenting the
121+
training methods with a Wan-compatible
122+
``pred ≈ noise - clean`` for MSE.
123+
"""
124+
pred = super().predict_noise(
125+
noisy_latents,
126+
timestep,
127+
batch,
128+
conditional=conditional,
129+
cfg_uncond=cfg_uncond,
130+
attn_kind=attn_kind,
131+
)
132+
return -pred

fastvideo/training/trackers.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,45 @@
1717
from typing import Any
1818
from collections.abc import Iterable, Iterator
1919

20+
import torch
21+
2022
from fastvideo.logger import init_logger
2123

2224
logger = init_logger(__name__)
2325

2426

27+
def _sanitize_wandb_config(value: Any) -> Any:
28+
"""Best-effort conversion of nested config objects to W&B-safe values."""
29+
if value is None or isinstance(value, str | int | float | bool):
30+
return value
31+
if isinstance(value, Enum):
32+
return value.value
33+
if isinstance(value, pathlib.Path):
34+
return str(value)
35+
if isinstance(value, dict):
36+
return {str(k): _sanitize_wandb_config(v) for k, v in value.items()}
37+
if isinstance(value, list | tuple | set):
38+
return [_sanitize_wandb_config(v) for v in value]
39+
if isinstance(value, torch.dtype):
40+
return str(value)
41+
if isinstance(value, torch.Tensor):
42+
tensor = value.detach().cpu()
43+
if tensor.dtype == torch.bfloat16:
44+
tensor = tensor.to(dtype=torch.float32)
45+
if tensor.ndim == 0 or tensor.numel() == 1:
46+
return tensor.item()
47+
if tensor.numel() <= 256:
48+
return tensor.tolist()
49+
return {
50+
"_type": "tensor_summary",
51+
"shape": list(tensor.shape),
52+
"dtype": str(tensor.dtype),
53+
}
54+
if callable(value):
55+
return getattr(value, "__name__", repr(value))
56+
return repr(value)
57+
58+
2559
@dataclass
2660
class Timer:
2761
"""Simple timer utility used by the trackers."""
@@ -143,7 +177,7 @@ def __init__(
143177
self._run = wandb.init(
144178
project=experiment_name,
145179
dir=log_dir,
146-
config=config,
180+
config=(_sanitize_wandb_config(config) if config is not None else None),
147181
name=run_name,
148182
)
149183
logger.info("Initialized Weights & Biases tracker")

0 commit comments

Comments
 (0)