Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions configs/platforms/iluvatar_cuda/z_image_turbo_t2i.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
{
"aspect_ratio": "16:9",
"num_channels_latents": 16,
"infer_steps": 9,
"attn_type": "iluvatar_flash_attn",
"enable_cfg": false,
"sample_guide_scale": 0.0,
"patch_size": 2,
"rope_type":"iluvatar_wan_rope",
"rms_norm_type":"iluvatar_rms_norm"
}
14 changes: 14 additions & 0 deletions configs/platforms/iluvatar_cuda/z_image_turbo_t2i_int8.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
{
"aspect_ratio": "16:9",
"num_channels_latents": 16,
"infer_steps": 9,
"attn_type": "iluvatar_flash_attn",
"enable_cfg": false,
"sample_guide_scale": 0.0,
"patch_size": 2,
"rope_type":"iluvatar_wan_rope",
"rms_norm_type":"iluvatar_rms_norm",
"dit_quantized": true,
"dit_quant_scheme": "int8-iluvatar",
"dit_quantized_ckpt": ""
}
10 changes: 8 additions & 2 deletions lightx2v/models/input_encoders/hf/wan/t5/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
)
from lightx2v_platform.ops.mm.cambricon_mlu.q_linear import MluQuantLinearInt8 # noqa E402
from lightx2v_platform.ops.mm.ascend_npu.npu_q_linear import NpuQuantLinearInt8 # noqa E402
from lightx2v_platform.ops.mm.iluvatar_cuda.q_linear import IluvatarQuantLinearInt8 # noqa E402

from lightx2v.models.input_encoders.hf.wan.t5.tokenizer import HuggingfaceTokenizer # noqa E402
from lightx2v.utils.envs import * # noqa E402
from lightx2v.utils.registry_factory import ( # noqa E402
Expand Down Expand Up @@ -226,8 +228,10 @@ def __init__(
linear_cls = MluQuantLinearInt8
elif quant_scheme == "int8-npu":
linear_cls = NpuQuantLinearInt8
elif quant_scheme == "int8-iluvatar":
linear_cls = IluvatarQuantLinearInt8
else:
NotImplementedError(f"Unsupported T5 quant scheme: {quant_scheme}")
raise NotImplementedError(f"Unsupported T5 quant scheme: {quant_scheme}")
else:
linear_cls = nn.Linear

Expand Down Expand Up @@ -309,8 +313,10 @@ def __init__(
linear_cls = MluQuantLinearInt8
elif quant_scheme == "int8-npu":
linear_cls = NpuQuantLinearInt8
elif quant_scheme == "int8-iluvatar":
linear_cls = IluvatarQuantLinearInt8
else:
NotImplementedError(f"Unsupported T5 quant scheme: {quant_scheme}")
raise NotImplementedError(f"Unsupported T5 quant scheme: {quant_scheme}")
else:
linear_cls = nn.Linear
# layers
Expand Down
9 changes: 7 additions & 2 deletions lightx2v/models/input_encoders/hf/wan/xlm_roberta/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from lightx2v.utils.utils import load_weights
from lightx2v_platform.base.global_var import AI_DEVICE
from lightx2v_platform.ops.mm.cambricon_mlu.q_linear import MluQuantLinearInt8
from lightx2v_platform.ops.mm.iluvatar_cuda.q_linear import IluvatarQuantLinearInt8

__all__ = [
"XLMRobertaCLIP",
Expand Down Expand Up @@ -91,8 +92,10 @@ def __init__(self, dim, num_heads, causal=False, attn_dropout=0.0, proj_dropout=
linear_cls = TritonQuantLinearFp8
elif quant_scheme == "int8-tmo":
linear_cls = MluQuantLinearInt8
elif quant_scheme == "int8-iluvatar":
linear_cls = IluvatarQuantLinearInt8
else:
NotImplementedError(f"Unsupported CLip quant scheme: {quant_scheme}")
raise NotImplementedError(f"Unsupported CLip quant scheme: {quant_scheme}")
else:
linear_cls = nn.Linear

Expand Down Expand Up @@ -181,8 +184,10 @@ def __init__(
linear_cls = TritonQuantLinearFp8
elif quant_scheme == "int8-tmo":
linear_cls = MluQuantLinearInt8
elif quant_scheme == "int8-iluvatar":
linear_cls = IluvatarQuantLinearInt8
else:
NotImplementedError(f"Unsupported T5 quant scheme: {quant_scheme}")
raise NotImplementedError(f"Unsupported T5 quant scheme: {quant_scheme}")
else:
linear_cls = nn.Linear

Expand Down
1 change: 1 addition & 0 deletions lightx2v/models/networks/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def _check_dit_quantized(self):
"gguf-Q3_K_M",
"int8-npu",
"fp8-intel-xpu",
"int8-iluvatar",
]

@abstractmethod
Expand Down
2 changes: 1 addition & 1 deletion lightx2v/models/networks/wan/audio_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def _load_adapter_ckpt(self):
if self.config.get("adapter_quantized", False):
if self.config.get("adapter_quant_scheme", None) in ["fp8", "fp8-q8f", "fp8-vllm", "fp8-sgl", "fp8-torchao", "fp8-triton"]:
adapter_model_name = "audio_adapter_model_fp8.safetensors"
elif self.config.get("adapter_quant_scheme", None) in ["int8", "int8-q8f", "int8-vllm", "int8-torchao", "int8-sgl", "int8-triton", "int8-tmo", "int8-npu"]:
elif self.config.get("adapter_quant_scheme", None) in ["int8", "int8-q8f", "int8-vllm", "int8-torchao", "int8-sgl", "int8-triton", "int8-tmo", "int8-npu", "int8-iluvatar"]:
adapter_model_name = "audio_adapter_model_int8.safetensors"
elif self.config.get("adapter_quant_scheme", None) in ["mxfp4"]:
adapter_model_name = "audio_adapter_model_mxfp4.safetensors"
Expand Down
23 changes: 20 additions & 3 deletions lightx2v/models/networks/z_image/infer/transformer_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch.nn.functional as F

from lightx2v.common.transformer_infer.transformer_infer import BaseTransformerInfer
from lightx2v.utils.registry_factory import ROPE_REGISTER

from .utils import apply_rotary_emb_qwen, apply_wan_rope_with_flashinfer

Expand All @@ -20,10 +21,26 @@ def __init__(self, config):
self.seq_p_group = None
self.seq_p_fp8_comm = False
self.seq_p_fp4_comm = False
if self.config.get("rope_type", "flashinfer") == "flashinfer":
self.apply_rope_func = apply_wan_rope_with_flashinfer

rope_funcs = {
"flashinfer": apply_wan_rope_with_flashinfer,
"torch_naive": apply_rotary_emb_qwen,
}

rope_type = self.config.get("rope_type", "flashinfer")
if rope_type in ROPE_REGISTER:
rope_class = ROPE_REGISTER[rope_type]
self.rope_instance = rope_class()

# Create a wrapper function that matches the expected signature
def rope_wrapper(xq, xk, cos_sin_cache):
return self.rope_instance.apply(xq, xk, cos_sin_cache)

rope_func = rope_wrapper
else:
self.apply_rope_func = apply_rotary_emb_qwen
# Fallback to hardcoded functions
rope_func = rope_funcs.get(rope_type, apply_rotary_emb_qwen)
self.apply_rope_func = rope_func

def set_scheduler(self, scheduler):
self.scheduler = scheduler
Expand Down
2 changes: 2 additions & 0 deletions lightx2v_platform/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from lightx2v_platform.base.nvidia import CudaDevice
from lightx2v_platform.base.enflame_gcu import EnflameGcuDevice
from lightx2v_platform.base.intel_xpu import IntelXpuDevice
from lightx2v_platform.base.iluvatar_cuda import IluvatarDevice

__all__ = [
"init_ai_device",
Expand All @@ -21,4 +22,5 @@
"MusaDevice",
"EnflameGcuDevice",
"IntelXpuDevice",
"IluvatarDevice",
]
40 changes: 40 additions & 0 deletions lightx2v_platform/base/iluvatar_cuda.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import torch
import torch.distributed as dist

from lightx2v_platform.registry_factory import PLATFORM_DEVICE_REGISTER

try:
from torch.distributed import ProcessGroupNCCL
except ImportError:
ProcessGroupNCCL = None


@PLATFORM_DEVICE_REGISTER("iluvatar_cuda")
class IluvatarDevice:
name = "iluvatar_cuda"

@staticmethod
def init_device_env():
pass

@staticmethod
def is_available() -> bool:
try:
import torch

return torch.cuda.is_available()
except ImportError:
return False

@staticmethod
def get_device() -> str:
return "cuda"

@staticmethod
def init_parallel_env():
if ProcessGroupNCCL is None:
raise RuntimeError("ProcessGroupNCCL is not available. Please check your runtime environment.")
pg_options = ProcessGroupNCCL.Options()
pg_options.is_high_priority_stream = True
dist.init_process_group(backend="nccl", pg_options=pg_options)
torch.cuda.set_device(dist.get_rank())
5 changes: 5 additions & 0 deletions lightx2v_platform/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,8 @@
elif PLATFORM == "intel_xpu":
from .attn.intel_xpu import *
from .mm.intel_xpu import *
elif PLATFORM == "iluvatar_cuda":
from .attn.iluvatar_cuda import *
from .mm.iluvatar_cuda import *
from .norm.iluvatar_cuda import *
from .rope.iluvatar_cuda import *
1 change: 1 addition & 0 deletions lightx2v_platform/ops/attn/iluvatar_cuda/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .flash_attn import *
64 changes: 64 additions & 0 deletions lightx2v_platform/ops/attn/iluvatar_cuda/flash_attn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import math

import torch

from lightx2v_platform.ops.attn.template import AttnWeightTemplate
from lightx2v_platform.registry_factory import PLATFORM_ATTN_WEIGHT_REGISTER

try:
from ixformer.contrib.vllm_flash_attn import flash_attn_varlen_func
except ImportError:
flash_attn_varlen_func = None


@PLATFORM_ATTN_WEIGHT_REGISTER("iluvatar_flash_attn")
class IluvatarFlashAttnWeight(AttnWeightTemplate):
def __init__(self):
self.config = {}
assert flash_attn_varlen_func is not None, "iluvatar ixformer is not installed."

def apply(self, q, k, v, cu_seqlens_q=None, cu_seqlens_kv=None, max_seqlen_q=None, max_seqlen_kv=None, **kwds):
half_dtypes = (torch.float16, torch.bfloat16)
device = q.device
dtype = q.dtype

def half(x):
return x if x.dtype in half_dtypes else x.to(dtype)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There are two issues here:

  1. dtype is undefined if len(q.shape) == 3 because it is only assigned inside the elif len(q.shape) == 4 block (line 30). This will cause a NameError if half() is ever called for 3D inputs.
  2. The logic x if x.dtype in half_dtypes else x.to(dtype) does not actually ensure half precision if the input q is float32 (as dtype would be float32). Flash attention kernels typically require float16 or bfloat16.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已经修正


if len(q.shape) == 3:
bs = 1
elif len(q.shape) == 4:
bs, lq, lk = q.size(0), q.size(1), k.size(1)
# preprocess query
if cu_seqlens_q is None:
q = half(q.flatten(0, 1))
cu_seqlens_q = torch.tensor([lq] * bs, dtype=torch.int32).to(device=q.device, non_blocking=True)
cu_seqlens_q = torch.cat([cu_seqlens_q.new_zeros([1]), cu_seqlens_q]).cumsum(0, dtype=torch.int32)
else:
q = half(torch.cat([u[:v] for u, v in zip(q, cu_seqlens_q)]))
# preprocess key, value
if cu_seqlens_kv is None:
k = half(k.flatten(0, 1))
v = half(v.flatten(0, 1))
cu_seqlens_kv = torch.tensor([lk] * bs, dtype=torch.int32).to(device=k.device, non_blocking=True)
cu_seqlens_kv = torch.cat([cu_seqlens_kv.new_zeros([1]), cu_seqlens_kv]).cumsum(0, dtype=torch.int32)
else:
k = half(torch.cat([u[:v] for u, v in zip(k, cu_seqlens_kv)]))
v = half(torch.cat([u[:v] for u, v in zip(v, cu_seqlens_kv)]))

q = q.to(v.dtype)
k = k.to(v.dtype)
softmax_scale = 1 / math.sqrt(q.shape[-1])
x = flash_attn_varlen_func(
q=q,
k=k,
v=v,
cu_seqlens_q=cu_seqlens_q.to(device), # cu_seqlens_q,
cu_seqlens_k=cu_seqlens_kv.to(device), # cu_seqlens_kv,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_kv,
softmax_scale=softmax_scale,
return_softmax_lse=False,
causal=False,
)
return x.reshape(bs * max_seqlen_q, -1)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The reshape operation will fail for variable-length 4D inputs. When cu_seqlens_q is provided, the output x from flash_attn_varlen_func is a packed tensor with shape [sum(seqlens), nheads, head_dim]. Reshaping it to [bs * max_seqlen_q, -1] (the padded size) will raise a RuntimeError because the total number of elements won't match if any padding was present in the original 4D input.

1 change: 1 addition & 0 deletions lightx2v_platform/ops/mm/iluvatar_cuda/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .mm_weight import *
100 changes: 100 additions & 0 deletions lightx2v_platform/ops/mm/iluvatar_cuda/mm_weight.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import torch

from lightx2v.utils.quant_utils import IntegerQuantizer
from lightx2v_platform.ops.mm.template import MMWeightQuantTemplate
from lightx2v_platform.registry_factory import PLATFORM_MM_WEIGHT_REGISTER

try:
import ixformer.inference.functions as ixf
except ImportError:
ixf = None


@PLATFORM_MM_WEIGHT_REGISTER("int8-iluvatar")
class MMWeightWint8channelAint8channeldynamicIluvatar(MMWeightQuantTemplate):
"""
Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-iluvatar

Quant MM:
Weight: int8 perchannel sym
Act: int8 perchannel dynamic sym
Kernel: iluvatar
"""

def __init__(
self,
weight_name,
bias_name,
create_cuda_buffer=False,
create_cpu_buffer=False,
lazy_load=False,
lazy_load_file=None,
is_post_adapter=False,
lora_prefix="diffusion_model.blocks",
lora_path="",
):
super().__init__(
weight_name,
bias_name,
create_cuda_buffer,
create_cpu_buffer,
lazy_load,
lazy_load_file,
is_post_adapter,
lora_prefix,
lora_path,
)
assert ixf is not None, "iluvatar ixformer is not installed."
self.load_func = self.load_int8_perchannel_sym
self.weight_need_transpose = False
self.act_quant_func = self.act_quant_int8_perchannel_sym_iluvatar

def _ensure_int8_weight_and_scale(self, weight_dict):
"""Fill missing weight_scale (or int8 weight) so load_quantized can run.

Some quantized checkpoints omit per-layer scales (e.g. adaLN) or use alternate
key names; others keep a few layers in float — per-channel int8 + scale is
then derived to match ixformer w8a8.
"""
if self.lazy_load:
return
if self.weight_name not in weight_dict:
return
if self.weight_scale_name in weight_dict:
return
base = self.weight_name.removesuffix(".weight")
for alt in (f"{base}.scale", f"{self.weight_name}_scale"):
if alt in weight_dict:
weight_dict[self.weight_scale_name] = weight_dict[alt].float()
return
w = weight_dict[self.weight_name]
if w.dtype in (torch.float16, torch.bfloat16, torch.float32):
w_float = w.to(torch.float32)
w_quantizer = IntegerQuantizer(8, True, "per_channel")
qw, scale, _ = w_quantizer.real_quant_tensor(w_float)
dev = w.device
weight_dict[self.weight_name] = qw.to(torch.int8).to(dev)
weight_dict[self.weight_scale_name] = scale.to(torch.float32).to(dev)

def load(self, weight_dict):
self._ensure_int8_weight_and_scale(weight_dict)
super().load(weight_dict)

def act_quant_int8_perchannel_sym_iluvatar(self, x):
device = x.device
input_tensor_quant = torch.empty(x.shape, dtype=torch.int8, device=device)
input_tensor_scale = torch.empty(x.shape[:-1], dtype=torch.float32, device=device)
ixf.dynamic_scaled_int8_quant(output=input_tensor_quant, input=x, scale=input_tensor_scale)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The API usage of ixf.dynamic_scaled_int8_quant is inconsistent with its usage in lightx2v_platform/ops/mm/iluvatar_cuda/q_linear.py. Here it is used as an in-place function with keyword arguments (output=, input=, scale=), while in q_linear.py it is used as a function returning two values with a single positional argument. Please verify the correct ixformer API and ensure consistency.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

iluvatar output使用先分配tensor

return input_tensor_quant, input_tensor_scale

def apply(self, input_tensor):
squeeze_output = False
dtype = input_tensor.dtype
if input_tensor.dim() == 3 and input_tensor.shape[0] == 1:
input_tensor = input_tensor.squeeze(0)
squeeze_output = True
input_tensor_quant, input_tensor_scale = self.act_quant_int8_perchannel_sym_iluvatar(input_tensor)
output = ixf.w8a8(input=input_tensor_quant, weight=self.weight, i_scales=input_tensor_scale, w_scales=self.weight_scale.reshape(-1), bias=self.bias, out_dtype=dtype)
if squeeze_output:
output = output.unsqueeze(0)
return output
Loading