Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
118 commits
Select commit Hold shift + click to select a range
c5c5f7e
flux_2
Gnav3852 Jan 26, 2026
28f2c75
flux2-klien
Gnav3852 Jan 29, 2026
c135016
param fix
Gnav3852 Jan 29, 2026
cf1a8d6
more fixes
Gnav3852 Jan 29, 2026
683854b
qwen3
Gnav3852 Jan 30, 2026
02cb4ae
cuda change
Gnav3852 Jan 30, 2026
c902650
atten
Gnav3852 Jan 30, 2026
e251689
imageen
Gnav3852 Jan 30, 2026
def1363
verify
Gnav3852 Jan 30, 2026
a80becb
mu
Gnav3852 Feb 2, 2026
844d0af
isntance
Gnav3852 Feb 2, 2026
0ee9329
reshape
Gnav3852 Feb 2, 2026
904eb18
reshape
Gnav3852 Feb 2, 2026
851e48a
decode
Gnav3852 Feb 2, 2026
56a0b39
NaN
Gnav3852 Feb 2, 2026
726378c
NaN
Gnav3852 Feb 2, 2026
bbfa354
NaN checks
Gnav3852 Feb 4, 2026
5bf5d6d
nan fix
Gnav3852 Feb 4, 2026
bbf733f
test
Gnav3852 Feb 4, 2026
c7c7703
flux2
Gnav3852 Feb 4, 2026
dcd74b9
base
Gnav3852 Feb 5, 2026
75b51a6
image fix
Gnav3852 Feb 5, 2026
3fdee04
compare_dit
Gnav3852 Feb 5, 2026
fc989cb
almost there
Gnav3852 Feb 5, 2026
b65332b
training
Gnav3852 Feb 5, 2026
5e8089d
core init
Gnav3852 Feb 5, 2026
6af28bd
port
Gnav3852 Feb 5, 2026
2ed8a24
dtype
Gnav3852 Feb 5, 2026
9791ded
eee
Gnav3852 Feb 5, 2026
f32785d
dit blocks
Gnav3852 Feb 5, 2026
65bbf3e
block fix
Gnav3852 Feb 5, 2026
f00791f
hook fix
Gnav3852 Feb 5, 2026
5ea9795
comp diff
Gnav3852 Feb 5, 2026
5de73a7
RoPE
Gnav3852 Feb 5, 2026
76da032
cpu
Gnav3852 Feb 5, 2026
44a1497
forward
Gnav3852 Feb 5, 2026
a48d4d4
safe
Gnav3852 Feb 5, 2026
236b9ad
safe more
Gnav3852 Feb 5, 2026
2c22864
double0
Gnav3852 Feb 5, 2026
260e4ae
qk calc
Gnav3852 Feb 5, 2026
ea7da25
roatry changes
Gnav3852 Feb 5, 2026
bc6df8a
dim
Gnav3852 Feb 5, 2026
b3efe07
fix thetha
Gnav3852 Feb 5, 2026
688c306
timestp
Gnav3852 Feb 5, 2026
e146c1e
more tests
Gnav3852 Feb 5, 2026
77b0588
test
Gnav3852 Feb 5, 2026
abbb9c7
me
Gnav3852 Feb 5, 2026
7b250e4
roate
Gnav3852 Feb 5, 2026
532f158
sglang compare files
Gnav3852 Feb 8, 2026
7e431c5
test
Gnav3852 Feb 8, 2026
fa9b35d
repo
Gnav3852 Feb 8, 2026
c6f8b8e
q fix
Gnav3852 Feb 8, 2026
48f28b3
e
Gnav3852 Feb 8, 2026
e0923c6
vae
Gnav3852 Feb 8, 2026
7424d46
dump
Gnav3852 Feb 8, 2026
55ee4a2
one mo
Gnav3852 Feb 8, 2026
3ebb45e
e
Gnav3852 Feb 8, 2026
aa26784
sglang
Gnav3852 Feb 8, 2026
5378be8
text en
Gnav3852 Feb 8, 2026
8a1363f
more changes
Gnav3852 Feb 8, 2026
8b0faa5
change
Gnav3852 Feb 8, 2026
a377c04
sglang
Gnav3852 Feb 8, 2026
3453c10
init
Gnav3852 Feb 8, 2026
c3bf98f
todev
Gnav3852 Feb 8, 2026
6fe37e5
init part
Gnav3852 Feb 8, 2026
fc42836
server args
Gnav3852 Feb 8, 2026
67a4fef
FSDP
Gnav3852 Feb 8, 2026
1290c66
e
Gnav3852 Feb 8, 2026
46ac05d
channels
Gnav3852 Feb 8, 2026
daeb55f
channels
Gnav3852 Feb 8, 2026
c67e93c
forward context
Gnav3852 Feb 8, 2026
8151c22
comploader
Gnav3852 Feb 12, 2026
c8c6426
debug text
Gnav3852 Feb 12, 2026
502127d
e
Gnav3852 Feb 12, 2026
6053390
fix
Gnav3852 Feb 12, 2026
a4c48f9
em
Gnav3852 Feb 12, 2026
37dcdcc
e
Gnav3852 Feb 12, 2026
9412aa5
foreard
Gnav3852 Feb 12, 2026
bcdc920
more debug
Gnav3852 Feb 12, 2026
3d4241a
attention vals
Gnav3852 Feb 12, 2026
f17f669
fix
Gnav3852 Feb 12, 2026
fb008e3
sdpa
Gnav3852 Feb 16, 2026
d119420
fix
Gnav3852 Feb 16, 2026
b573ba1
oproj
Gnav3852 Feb 16, 2026
7f563fd
fixes
Gnav3852 Feb 16, 2026
e6f74f4
fullT
Gnav3852 Feb 16, 2026
1a5f099
sanity chck
Gnav3852 Feb 16, 2026
9a68a33
cpu offlaod
Gnav3852 Feb 16, 2026
f06b4e8
bf16
Gnav3852 Feb 16, 2026
040b20c
3way
Gnav3852 Feb 19, 2026
7eb10a8
test
Gnav3852 Feb 19, 2026
0e859da
registry
Gnav3852 Feb 23, 2026
b262c58
cahn ge back
Gnav3852 Feb 23, 2026
45ed49a
ssim
Gnav3852 Feb 23, 2026
316b576
sdpa
Gnav3852 Feb 23, 2026
ca5d38b
e
Gnav3852 Feb 23, 2026
6138354
allow text encode from preomcouted
Gnav3852 Feb 23, 2026
ce73c1b
run sep
Gnav3852 Feb 23, 2026
ff0b48e
eh
Gnav3852 Feb 24, 2026
4d7732c
block compare
Gnav3852 Feb 24, 2026
6f2ff4b
changes
Gnav3852 Feb 24, 2026
5de038d
e
Gnav3852 Feb 24, 2026
45f6f13
compare blocks
Gnav3852 Feb 24, 2026
8fc53b2
ff
Gnav3852 Feb 24, 2026
8154ffa
e
Gnav3852 Feb 24, 2026
68fe651
e
Gnav3852 Feb 24, 2026
d5c5430
registry changes
Gnav3852 Feb 26, 2026
828f27b
flux2 file change
Gnav3852 Feb 26, 2026
058fe03
stepvid
Mar 28, 2026
b89bc96
RoPE
Mar 28, 2026
3465088
double0
Mar 28, 2026
15fd12d
fixDIm
Mar 28, 2026
f34f295
e
Mar 28, 2026
a9ac0a5
more fixes
Mar 28, 2026
e03ac86
e
Mar 28, 2026
cc63ebc
sdpa
Mar 28, 2026
26a8694
more
Mar 28, 2026
c0bafd2
linear
Mar 28, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
295 changes: 295 additions & 0 deletions compare_flux2_context_branch_double0.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,295 @@
#!/usr/bin/env python3
"""
Compare Diffusers vs FastVideo **context (text) branch linears** in double block 0:
add_q_proj, add_k_proj, add_v_proj, to_add_out

Hooks the same modules on both sides after a step-0 forward. Inputs to block 0
should already match if the dump is consistent; this isolates ColumnParallelLinear
(add_*) and to_add_out vs nn.Linear.

ColumnParallelLinear may report .weight shape (0, 0) in named_parameters while
hooks still see correct matmul I/O at tp_size==1; treat empty weight as a
storage/inspection quirk, not missing parameters.

python compare_flux2_context_branch_double0.py [--dump PATH] [--device cuda]
Use --no-math-sdp to allow non-math SDPA backends.

Requires: flux2_step0_dump.pt (with text_ids, latent_ids), diffusers Flux2KleinPipeline,
editable FastVideo.
"""
from __future__ import annotations

import argparse
import os
import sys

import torch
import torch.nn.functional as F

DUMP_PATH = "flux2_step0_dump.pt"
MODEL_ID = "black-forest-labs/FLUX.2-klein-4B"


def _enable_torch_math_sdp() -> None:
if not torch.cuda.is_available():
return
try:
torch.backends.cuda.enable_flash_sdp(False)
torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_math_sdp(True)
except Exception:
pass


def _get_transformer_path(model_id: str) -> str:
try:
from huggingface_hub import snapshot_download

root = snapshot_download(repo_id=model_id)
path = os.path.join(root, "transformer")
if os.path.isdir(path):
return path
except Exception:
pass
if os.path.isdir(model_id):
if os.path.exists(os.path.join(model_id, "transformer", "config.json")):
return os.path.join(model_id, "transformer")
if os.path.exists(os.path.join(model_id, "config.json")):
return model_id
raise FileNotFoundError(f"Could not find transformer for {model_id}")


def _register_io_hooks(attn_module, storage: dict, prefix: str) -> list:
handles = []

def wire(name: str, child: torch.nn.Module):
def pre(_m, args, _kwargs=None):
if args and args[0] is not None:
storage[f"{prefix}_{name}_in"] = args[0].detach().clone()

def post(_m, _args, out):
o = out[0] if isinstance(out, tuple) else out
storage[f"{prefix}_{name}_out"] = o.detach().clone()

handles.append(child.register_forward_pre_hook(pre, with_kwargs=True))
handles.append(child.register_forward_hook(post))

for name in ("add_q_proj", "add_k_proj", "add_v_proj", "to_add_out"):
if hasattr(attn_module, name):
wire(name, getattr(attn_module, name))
return handles


def _max_mean(a: torch.Tensor, b: torch.Tensor) -> tuple[float, float]:
a = a.float()
b = b.float()
if a.shape != b.shape:
return float("nan"), float("nan")
d = (a - b).abs()
return d.max().item(), d.mean().item()


def _linear_match_captured(
x_cap: torch.Tensor,
y_cap: torch.Tensor,
module: torch.nn.Module,
label: str,
) -> None:
"""Compare F.linear(x, weight, bias) to hook-captured output (no second forward)."""
w = getattr(module, "weight", None)
if w is None:
print(f" {label}: module has no .weight")
return
if w.numel() == 0:
print(f" {label}: .weight is empty; named_parameters:")
for n, p in module.named_parameters():
print(f" {n}: {tuple(p.shape)}")
return
b = getattr(module, "bias", None)
dev = w.device
dt = w.dtype
x2 = x_cap.to(device=dev, dtype=dt).reshape(-1, x_cap.shape[-1])
y2 = y_cap.to(device=dev, dtype=dt).reshape(-1, y_cap.shape[-1])
if w.shape[1] != x2.shape[1]:
print(
f" {label}: skip F.linear check (w.shape={tuple(w.shape)} "
f"vs x_in_features={x2.shape[1]})"
)
return
with torch.no_grad():
y_man = F.linear(
x2.float(),
w.float(),
None if b is None else b.float(),
)
d = (y_man - y2.float()).abs()
print(
f" {label}: F.linear vs captured out "
f"max={d.max().item():.6e} mean={d.mean().item():.6e} "
f"(x={tuple(x_cap.shape)} w={tuple(w.shape)} out={tuple(y_cap.shape)})"
)


def main() -> None:
parser = argparse.ArgumentParser(
description="Compare context-branch linears (double block 0) FV vs Diffusers."
)
parser.add_argument("--dump", default=DUMP_PATH)
parser.add_argument("--model-path", default=None)
parser.add_argument("--device", default="cuda")
parser.add_argument(
"--no-math-sdp",
action="store_true",
help="Allow flash/mem-efficient SDPA (default: force math SDPA for parity).",
)
args = parser.parse_args()

if not os.path.isfile(args.dump):
print(f"Missing dump: {args.dump}")
sys.exit(1)

os.environ.setdefault("FASTVIDEO_ATTENTION_BACKEND", "TORCH_SDPA")
if not args.no_math_sdp:
_enable_torch_math_sdp()
os.environ.setdefault("LOCAL_RANK", "0")
os.environ.setdefault("RANK", "0")
os.environ.setdefault("WORLD_SIZE", "1")
os.environ.setdefault("MASTER_ADDR", "127.0.0.1")
os.environ.setdefault("MASTER_PORT", "29500")
from fastvideo.distributed import maybe_init_distributed_environment_and_model_parallel

maybe_init_distributed_environment_and_model_parallel(tp_size=1, sp_size=1)

data = torch.load(args.dump, map_location="cpu", weights_only=True)
latent = data["latent_model_input"]
timestep_scaled = data["timestep_scaled"]
prompt_embeds = data["prompt_embeds"]
text_ids = data.get("text_ids")
latent_ids = data.get("latent_ids")
if text_ids is None or latent_ids is None:
print("Dump needs text_ids and latent_ids.")
sys.exit(1)

device = args.device
dtype = torch.bfloat16

try:
from diffusers import Flux2KleinPipeline
except ImportError:
from diffusers.pipelines.flux2 import Flux2KleinPipeline

print("Loading Diffusers pipeline ...")
pipe = Flux2KleinPipeline.from_pretrained(MODEL_ID, torch_dtype=dtype)
pipe = pipe.to(device)
off_attn = pipe.transformer.transformer_blocks[0].attn
off_store: dict = {}
off_handles = _register_io_hooks(off_attn, off_store, "off")

latent_d = latent.to(device, dtype=dtype)
te_d = data["timestep_scaled"].to(device, dtype=dtype)
if te_d.dim() == 1:
te_d = te_d.view(1).expand(latent_d.shape[0])
pe_d = prompt_embeds.to(device, dtype=dtype)

with torch.no_grad():
with pipe.transformer.cache_context("cond"):
pipe.transformer(
hidden_states=latent_d,
timestep=te_d,
guidance=None,
encoder_hidden_states=pe_d,
txt_ids=text_ids.to(device),
img_ids=latent_ids.to(device),
return_dict=False,
)
for h in off_handles:
h.remove()

from fastvideo.fastvideo_args import FastVideoArgs
from fastvideo.forward_context import set_forward_context
from fastvideo.models.dits.flux_2 import compute_flux2_freqs_cis_from_ids
from fastvideo.models.loader.component_loader import TransformerLoader

model_path = args.model_path or _get_transformer_path(MODEL_ID)
print("Loading FastVideo transformer ...")
fv_args = FastVideoArgs.from_kwargs(
model_path=MODEL_ID,
hsdp_shard_dim=1,
hsdp_replicate_dim=1,
num_gpus=1,
inference_mode=True,
use_fsdp_inference=False,
dit_cpu_offload=False,
pin_cpu_memory=False,
dit_precision="bf16",
)
fv = TransformerLoader(device=device).load(model_path, fv_args).to(device)
model_dtype = next(fv.parameters()).dtype
latent_f = latent.to(device, dtype=model_dtype)
te_f = timestep_scaled.to(device)
pe_f = prompt_embeds.to(device, dtype=model_dtype)
freqs = compute_flux2_freqs_cis_from_ids(
fv.rotary_emb, text_ids, latent_ids, device, dtype=model_dtype
)

fv_attn = fv.transformer_blocks[0].attn
fv_store: dict = {}
fv_handles = _register_io_hooks(fv_attn, fv_store, "fv")

with torch.no_grad(), set_forward_context(current_timestep=0, attn_metadata=None):
fv(latent_f, pe_f, te_f, guidance=None, freqs_cis=freqs)
for h in fv_handles:
h.remove()

print("\n--- Context branch linear I/O (double block 0) ---")
for name in ("add_q_proj", "add_k_proj", "add_v_proj", "to_add_out"):
ki, ko = f"off_{name}_in", f"off_{name}_out"
fi, fo = f"fv_{name}_in", f"fv_{name}_out"
if ki not in off_store or fi not in fv_store:
print(f" {name}: missing hook data (off_in={ki in off_store}, fv_in={fi in fv_store})")
continue
mi, ai = _max_mean(off_store[ki], fv_store[fi])
mo, ao = _max_mean(off_store[ko], fv_store[fo])
print(
f" {name}: input max|diff|={mi:.6f} mean={ai:.6f} | "
f"output max|diff|={mo:.6f} mean={ao:.6f}"
)

print("\n--- F.linear sanity (weights vs hook-captured outputs, no second forward) ---")
if fv_store.get("fv_add_q_proj_in") is not None and fv_store.get(
"fv_add_q_proj_out"
) is not None:
_linear_match_captured(
fv_store["fv_add_q_proj_in"],
fv_store["fv_add_q_proj_out"],
fv_attn.add_q_proj,
"add_q_proj",
)
if fv_store.get("fv_to_add_out_in") is not None and fv_store.get(
"fv_to_add_out_out"
) is not None:
_linear_match_captured(
fv_store["fv_to_add_out_in"],
fv_store["fv_to_add_out_out"],
fv_attn.to_add_out,
"to_add_out",
)

print("\n--- Weight vs Diffusers (add_q_proj) ---")
if hasattr(off_attn, "add_q_proj") and hasattr(fv_attn, "add_q_proj"):
ow = off_attn.add_q_proj.weight.data
fw = fv_attn.add_q_proj.weight.data
if ow.numel() == 0 or fw.numel() == 0:
print(f" skip (off weight {tuple(ow.shape)}, fv {tuple(fw.shape)})")
elif ow.shape != fw.shape:
print(f" skip shape mismatch off={tuple(ow.shape)} fv={tuple(fw.shape)}")
else:
d = (ow.float() - fw.float()).abs()
print(
f" weight max|diff|={d.max().item():.6e} "
f"mean={d.mean().item():.6e}"
)


if __name__ == "__main__":
main()
Loading