diff --git a/configs/platforms/iluvatar_cuda/z_image_turbo_t2i.json b/configs/platforms/iluvatar_cuda/z_image_turbo_t2i.json new file mode 100755 index 000000000..8be358d58 --- /dev/null +++ b/configs/platforms/iluvatar_cuda/z_image_turbo_t2i.json @@ -0,0 +1,11 @@ +{ + "aspect_ratio": "16:9", + "num_channels_latents": 16, + "infer_steps": 9, + "attn_type": "iluvatar_flash_attn", + "enable_cfg": false, + "sample_guide_scale": 0.0, + "patch_size": 2, + "rope_type":"iluvatar_wan_rope", + "rms_norm_type":"iluvatar_rms_norm" +} diff --git a/configs/platforms/iluvatar_cuda/z_image_turbo_t2i_int8.json b/configs/platforms/iluvatar_cuda/z_image_turbo_t2i_int8.json new file mode 100755 index 000000000..b4cab2fad --- /dev/null +++ b/configs/platforms/iluvatar_cuda/z_image_turbo_t2i_int8.json @@ -0,0 +1,14 @@ +{ + "aspect_ratio": "16:9", + "num_channels_latents": 16, + "infer_steps": 9, + "attn_type": "iluvatar_flash_attn", + "enable_cfg": false, + "sample_guide_scale": 0.0, + "patch_size": 2, + "rope_type":"iluvatar_wan_rope", + "rms_norm_type":"iluvatar_rms_norm", + "dit_quantized": true, + "dit_quant_scheme": "int8-iluvatar", + "dit_quantized_ckpt": "" +} diff --git a/lightx2v/models/input_encoders/hf/wan/t5/model.py b/lightx2v/models/input_encoders/hf/wan/t5/model.py index 88fb3b453..a8f842204 100755 --- a/lightx2v/models/input_encoders/hf/wan/t5/model.py +++ b/lightx2v/models/input_encoders/hf/wan/t5/model.py @@ -31,6 +31,8 @@ ) from lightx2v_platform.ops.mm.cambricon_mlu.q_linear import MluQuantLinearInt8 # noqa E402 from lightx2v_platform.ops.mm.ascend_npu.npu_q_linear import NpuQuantLinearInt8 # noqa E402 +from lightx2v_platform.ops.mm.iluvatar_cuda.q_linear import IluvatarQuantLinearInt8 # noqa E402 + from lightx2v.models.input_encoders.hf.wan.t5.tokenizer import HuggingfaceTokenizer # noqa E402 from lightx2v.utils.envs import * # noqa E402 from lightx2v.utils.registry_factory import ( # noqa E402 @@ -226,8 +228,10 @@ def __init__( linear_cls = MluQuantLinearInt8 elif quant_scheme == "int8-npu": linear_cls = NpuQuantLinearInt8 + elif quant_scheme == "int8-iluvatar": + linear_cls = IluvatarQuantLinearInt8 else: - NotImplementedError(f"Unsupported T5 quant scheme: {quant_scheme}") + raise NotImplementedError(f"Unsupported T5 quant scheme: {quant_scheme}") else: linear_cls = nn.Linear @@ -309,8 +313,10 @@ def __init__( linear_cls = MluQuantLinearInt8 elif quant_scheme == "int8-npu": linear_cls = NpuQuantLinearInt8 + elif quant_scheme == "int8-iluvatar": + linear_cls = IluvatarQuantLinearInt8 else: - NotImplementedError(f"Unsupported T5 quant scheme: {quant_scheme}") + raise NotImplementedError(f"Unsupported T5 quant scheme: {quant_scheme}") else: linear_cls = nn.Linear # layers diff --git a/lightx2v/models/input_encoders/hf/wan/xlm_roberta/model.py b/lightx2v/models/input_encoders/hf/wan/xlm_roberta/model.py index 697219c02..1de88ca61 100755 --- a/lightx2v/models/input_encoders/hf/wan/xlm_roberta/model.py +++ b/lightx2v/models/input_encoders/hf/wan/xlm_roberta/model.py @@ -24,6 +24,7 @@ from lightx2v.utils.utils import load_weights from lightx2v_platform.base.global_var import AI_DEVICE from lightx2v_platform.ops.mm.cambricon_mlu.q_linear import MluQuantLinearInt8 +from lightx2v_platform.ops.mm.iluvatar_cuda.q_linear import IluvatarQuantLinearInt8 __all__ = [ "XLMRobertaCLIP", @@ -91,8 +92,10 @@ def __init__(self, dim, num_heads, causal=False, attn_dropout=0.0, proj_dropout= linear_cls = TritonQuantLinearFp8 elif quant_scheme == "int8-tmo": linear_cls = MluQuantLinearInt8 + elif quant_scheme == "int8-iluvatar": + linear_cls = IluvatarQuantLinearInt8 else: - NotImplementedError(f"Unsupported CLip quant scheme: {quant_scheme}") + raise NotImplementedError(f"Unsupported CLip quant scheme: {quant_scheme}") else: linear_cls = nn.Linear @@ -181,8 +184,10 @@ def __init__( linear_cls = TritonQuantLinearFp8 elif quant_scheme == "int8-tmo": linear_cls = MluQuantLinearInt8 + elif quant_scheme == "int8-iluvatar": + linear_cls = IluvatarQuantLinearInt8 else: - NotImplementedError(f"Unsupported T5 quant scheme: {quant_scheme}") + raise NotImplementedError(f"Unsupported T5 quant scheme: {quant_scheme}") else: linear_cls = nn.Linear diff --git a/lightx2v/models/networks/base_model.py b/lightx2v/models/networks/base_model.py index 69da23c02..a5a5a11d8 100644 --- a/lightx2v/models/networks/base_model.py +++ b/lightx2v/models/networks/base_model.py @@ -131,6 +131,7 @@ def _check_dit_quantized(self): "gguf-Q3_K_M", "int8-npu", "fp8-intel-xpu", + "int8-iluvatar", ] @abstractmethod diff --git a/lightx2v/models/networks/wan/audio_model.py b/lightx2v/models/networks/wan/audio_model.py index 47036fd8d..879c18ef0 100755 --- a/lightx2v/models/networks/wan/audio_model.py +++ b/lightx2v/models/networks/wan/audio_model.py @@ -29,7 +29,7 @@ def _load_adapter_ckpt(self): if self.config.get("adapter_quantized", False): if self.config.get("adapter_quant_scheme", None) in ["fp8", "fp8-q8f", "fp8-vllm", "fp8-sgl", "fp8-torchao", "fp8-triton"]: adapter_model_name = "audio_adapter_model_fp8.safetensors" - elif self.config.get("adapter_quant_scheme", None) in ["int8", "int8-q8f", "int8-vllm", "int8-torchao", "int8-sgl", "int8-triton", "int8-tmo", "int8-npu"]: + elif self.config.get("adapter_quant_scheme", None) in ["int8", "int8-q8f", "int8-vllm", "int8-torchao", "int8-sgl", "int8-triton", "int8-tmo", "int8-npu", "int8-iluvatar"]: adapter_model_name = "audio_adapter_model_int8.safetensors" elif self.config.get("adapter_quant_scheme", None) in ["mxfp4"]: adapter_model_name = "audio_adapter_model_mxfp4.safetensors" diff --git a/lightx2v/models/networks/z_image/infer/transformer_infer.py b/lightx2v/models/networks/z_image/infer/transformer_infer.py index 41145f3f7..7e883d74b 100755 --- a/lightx2v/models/networks/z_image/infer/transformer_infer.py +++ b/lightx2v/models/networks/z_image/infer/transformer_infer.py @@ -2,6 +2,7 @@ import torch.nn.functional as F from lightx2v.common.transformer_infer.transformer_infer import BaseTransformerInfer +from lightx2v.utils.registry_factory import ROPE_REGISTER from .utils import apply_rotary_emb_qwen, apply_wan_rope_with_flashinfer @@ -20,10 +21,26 @@ def __init__(self, config): self.seq_p_group = None self.seq_p_fp8_comm = False self.seq_p_fp4_comm = False - if self.config.get("rope_type", "flashinfer") == "flashinfer": - self.apply_rope_func = apply_wan_rope_with_flashinfer + + rope_funcs = { + "flashinfer": apply_wan_rope_with_flashinfer, + "torch_naive": apply_rotary_emb_qwen, + } + + rope_type = self.config.get("rope_type", "flashinfer") + if rope_type in ROPE_REGISTER: + rope_class = ROPE_REGISTER[rope_type] + self.rope_instance = rope_class() + + # Create a wrapper function that matches the expected signature + def rope_wrapper(xq, xk, cos_sin_cache): + return self.rope_instance.apply(xq, xk, cos_sin_cache) + + rope_func = rope_wrapper else: - self.apply_rope_func = apply_rotary_emb_qwen + # Fallback to hardcoded functions + rope_func = rope_funcs.get(rope_type, apply_rotary_emb_qwen) + self.apply_rope_func = rope_func def set_scheduler(self, scheduler): self.scheduler = scheduler diff --git a/lightx2v_platform/base/__init__.py b/lightx2v_platform/base/__init__.py index 7544e897d..87b4d8c7f 100755 --- a/lightx2v_platform/base/__init__.py +++ b/lightx2v_platform/base/__init__.py @@ -8,6 +8,7 @@ from lightx2v_platform.base.nvidia import CudaDevice from lightx2v_platform.base.enflame_gcu import EnflameGcuDevice from lightx2v_platform.base.intel_xpu import IntelXpuDevice +from lightx2v_platform.base.iluvatar_cuda import IluvatarDevice __all__ = [ "init_ai_device", @@ -21,4 +22,5 @@ "MusaDevice", "EnflameGcuDevice", "IntelXpuDevice", + "IluvatarDevice", ] diff --git a/lightx2v_platform/base/iluvatar_cuda.py b/lightx2v_platform/base/iluvatar_cuda.py new file mode 100644 index 000000000..1ac255b82 --- /dev/null +++ b/lightx2v_platform/base/iluvatar_cuda.py @@ -0,0 +1,40 @@ +import torch +import torch.distributed as dist + +from lightx2v_platform.registry_factory import PLATFORM_DEVICE_REGISTER + +try: + from torch.distributed import ProcessGroupNCCL +except ImportError: + ProcessGroupNCCL = None + + +@PLATFORM_DEVICE_REGISTER("iluvatar_cuda") +class IluvatarDevice: + name = "iluvatar_cuda" + + @staticmethod + def init_device_env(): + pass + + @staticmethod + def is_available() -> bool: + try: + import torch + + return torch.cuda.is_available() + except ImportError: + return False + + @staticmethod + def get_device() -> str: + return "cuda" + + @staticmethod + def init_parallel_env(): + if ProcessGroupNCCL is None: + raise RuntimeError("ProcessGroupNCCL is not available. Please check your runtime environment.") + pg_options = ProcessGroupNCCL.Options() + pg_options.is_high_priority_stream = True + dist.init_process_group(backend="nccl", pg_options=pg_options) + torch.cuda.set_device(dist.get_rank()) diff --git a/lightx2v_platform/ops/__init__.py b/lightx2v_platform/ops/__init__.py index b0326841c..b56ad794a 100755 --- a/lightx2v_platform/ops/__init__.py +++ b/lightx2v_platform/ops/__init__.py @@ -24,3 +24,8 @@ elif PLATFORM == "intel_xpu": from .attn.intel_xpu import * from .mm.intel_xpu import * +elif PLATFORM == "iluvatar_cuda": + from .attn.iluvatar_cuda import * + from .mm.iluvatar_cuda import * + from .norm.iluvatar_cuda import * + from .rope.iluvatar_cuda import * diff --git a/lightx2v_platform/ops/attn/iluvatar_cuda/__init__.py b/lightx2v_platform/ops/attn/iluvatar_cuda/__init__.py new file mode 100644 index 000000000..9ecd99131 --- /dev/null +++ b/lightx2v_platform/ops/attn/iluvatar_cuda/__init__.py @@ -0,0 +1 @@ +from .flash_attn import * diff --git a/lightx2v_platform/ops/attn/iluvatar_cuda/flash_attn.py b/lightx2v_platform/ops/attn/iluvatar_cuda/flash_attn.py new file mode 100644 index 000000000..ff3f634b7 --- /dev/null +++ b/lightx2v_platform/ops/attn/iluvatar_cuda/flash_attn.py @@ -0,0 +1,64 @@ +import math + +import torch + +from lightx2v_platform.ops.attn.template import AttnWeightTemplate +from lightx2v_platform.registry_factory import PLATFORM_ATTN_WEIGHT_REGISTER + +try: + from ixformer.contrib.vllm_flash_attn import flash_attn_varlen_func +except ImportError: + flash_attn_varlen_func = None + + +@PLATFORM_ATTN_WEIGHT_REGISTER("iluvatar_flash_attn") +class IluvatarFlashAttnWeight(AttnWeightTemplate): + def __init__(self): + self.config = {} + assert flash_attn_varlen_func is not None, "iluvatar ixformer is not installed." + + def apply(self, q, k, v, cu_seqlens_q=None, cu_seqlens_kv=None, max_seqlen_q=None, max_seqlen_kv=None, **kwds): + half_dtypes = (torch.float16, torch.bfloat16) + device = q.device + dtype = q.dtype + + def half(x): + return x if x.dtype in half_dtypes else x.to(dtype) + + if len(q.shape) == 3: + bs = 1 + elif len(q.shape) == 4: + bs, lq, lk = q.size(0), q.size(1), k.size(1) + # preprocess query + if cu_seqlens_q is None: + q = half(q.flatten(0, 1)) + cu_seqlens_q = torch.tensor([lq] * bs, dtype=torch.int32).to(device=q.device, non_blocking=True) + cu_seqlens_q = torch.cat([cu_seqlens_q.new_zeros([1]), cu_seqlens_q]).cumsum(0, dtype=torch.int32) + else: + q = half(torch.cat([u[:v] for u, v in zip(q, cu_seqlens_q)])) + # preprocess key, value + if cu_seqlens_kv is None: + k = half(k.flatten(0, 1)) + v = half(v.flatten(0, 1)) + cu_seqlens_kv = torch.tensor([lk] * bs, dtype=torch.int32).to(device=k.device, non_blocking=True) + cu_seqlens_kv = torch.cat([cu_seqlens_kv.new_zeros([1]), cu_seqlens_kv]).cumsum(0, dtype=torch.int32) + else: + k = half(torch.cat([u[:v] for u, v in zip(k, cu_seqlens_kv)])) + v = half(torch.cat([u[:v] for u, v in zip(v, cu_seqlens_kv)])) + + q = q.to(v.dtype) + k = k.to(v.dtype) + softmax_scale = 1 / math.sqrt(q.shape[-1]) + x = flash_attn_varlen_func( + q=q, + k=k, + v=v, + cu_seqlens_q=cu_seqlens_q.to(device), # cu_seqlens_q, + cu_seqlens_k=cu_seqlens_kv.to(device), # cu_seqlens_kv, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_kv, + softmax_scale=softmax_scale, + return_softmax_lse=False, + causal=False, + ) + return x.reshape(bs * max_seqlen_q, -1) diff --git a/lightx2v_platform/ops/mm/iluvatar_cuda/__init__.py b/lightx2v_platform/ops/mm/iluvatar_cuda/__init__.py new file mode 100644 index 000000000..3f9898101 --- /dev/null +++ b/lightx2v_platform/ops/mm/iluvatar_cuda/__init__.py @@ -0,0 +1 @@ +from .mm_weight import * diff --git a/lightx2v_platform/ops/mm/iluvatar_cuda/mm_weight.py b/lightx2v_platform/ops/mm/iluvatar_cuda/mm_weight.py new file mode 100644 index 000000000..be81fc5e0 --- /dev/null +++ b/lightx2v_platform/ops/mm/iluvatar_cuda/mm_weight.py @@ -0,0 +1,100 @@ +import torch + +from lightx2v.utils.quant_utils import IntegerQuantizer +from lightx2v_platform.ops.mm.template import MMWeightQuantTemplate +from lightx2v_platform.registry_factory import PLATFORM_MM_WEIGHT_REGISTER + +try: + import ixformer.inference.functions as ixf +except ImportError: + ixf = None + + +@PLATFORM_MM_WEIGHT_REGISTER("int8-iluvatar") +class MMWeightWint8channelAint8channeldynamicIluvatar(MMWeightQuantTemplate): + """ + Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-iluvatar + + Quant MM: + Weight: int8 perchannel sym + Act: int8 perchannel dynamic sym + Kernel: iluvatar + """ + + def __init__( + self, + weight_name, + bias_name, + create_cuda_buffer=False, + create_cpu_buffer=False, + lazy_load=False, + lazy_load_file=None, + is_post_adapter=False, + lora_prefix="diffusion_model.blocks", + lora_path="", + ): + super().__init__( + weight_name, + bias_name, + create_cuda_buffer, + create_cpu_buffer, + lazy_load, + lazy_load_file, + is_post_adapter, + lora_prefix, + lora_path, + ) + assert ixf is not None, "iluvatar ixformer is not installed." + self.load_func = self.load_int8_perchannel_sym + self.weight_need_transpose = False + self.act_quant_func = self.act_quant_int8_perchannel_sym_iluvatar + + def _ensure_int8_weight_and_scale(self, weight_dict): + """Fill missing weight_scale (or int8 weight) so load_quantized can run. + + Some quantized checkpoints omit per-layer scales (e.g. adaLN) or use alternate + key names; others keep a few layers in float — per-channel int8 + scale is + then derived to match ixformer w8a8. + """ + if self.lazy_load: + return + if self.weight_name not in weight_dict: + return + if self.weight_scale_name in weight_dict: + return + base = self.weight_name.removesuffix(".weight") + for alt in (f"{base}.scale", f"{self.weight_name}_scale"): + if alt in weight_dict: + weight_dict[self.weight_scale_name] = weight_dict[alt].float() + return + w = weight_dict[self.weight_name] + if w.dtype in (torch.float16, torch.bfloat16, torch.float32): + w_float = w.to(torch.float32) + w_quantizer = IntegerQuantizer(8, True, "per_channel") + qw, scale, _ = w_quantizer.real_quant_tensor(w_float) + dev = w.device + weight_dict[self.weight_name] = qw.to(torch.int8).to(dev) + weight_dict[self.weight_scale_name] = scale.to(torch.float32).to(dev) + + def load(self, weight_dict): + self._ensure_int8_weight_and_scale(weight_dict) + super().load(weight_dict) + + def act_quant_int8_perchannel_sym_iluvatar(self, x): + device = x.device + input_tensor_quant = torch.empty(x.shape, dtype=torch.int8, device=device) + input_tensor_scale = torch.empty(x.shape[:-1], dtype=torch.float32, device=device) + ixf.dynamic_scaled_int8_quant(output=input_tensor_quant, input=x, scale=input_tensor_scale) + return input_tensor_quant, input_tensor_scale + + def apply(self, input_tensor): + squeeze_output = False + dtype = input_tensor.dtype + if input_tensor.dim() == 3 and input_tensor.shape[0] == 1: + input_tensor = input_tensor.squeeze(0) + squeeze_output = True + input_tensor_quant, input_tensor_scale = self.act_quant_int8_perchannel_sym_iluvatar(input_tensor) + output = ixf.w8a8(input=input_tensor_quant, weight=self.weight, i_scales=input_tensor_scale, w_scales=self.weight_scale.reshape(-1), bias=self.bias, out_dtype=dtype) + if squeeze_output: + output = output.unsqueeze(0) + return output diff --git a/lightx2v_platform/ops/mm/iluvatar_cuda/q_linear.py b/lightx2v_platform/ops/mm/iluvatar_cuda/q_linear.py new file mode 100644 index 000000000..96488a48f --- /dev/null +++ b/lightx2v_platform/ops/mm/iluvatar_cuda/q_linear.py @@ -0,0 +1,89 @@ +""" +Iluvatar GPU quantized linear layers for text encoders (T5, CLIP, etc.) + +These are nn.Module-based quantized linear layers optimized for Iluvatar GPU +""" + +import torch +import torch.nn as nn + +try: + import ixformer.inference.functions as ixf +except ImportError: + ixf = None + + +class IluvatarQuantLinearInt8(nn.Module): + """ + Iluvatar GPU INT8 quantized linear layer for text encoders. + + Strategy: + - Storage: INT8 - saves 50% memory + - Computation: FP16 using PyTorch native ops + - Dynamically dequantize INT8 → FP16 during forward pass + + Usage: + Used in T5 text encoder when config has: + { + "t5_quantized": true, + "t5_quant_scheme": "int8-iluvatar-cuda" + } + """ + + def __init__(self, in_features, out_features, bias=True, dtype=torch.float16): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.dtype = dtype + assert ixf is not None, "iluvatar ixformer is not installed." + # Register INT8 weight buffer + self.register_buffer("weight", torch.empty((out_features, in_features), dtype=torch.int8)) + + # Register FP32 scale buffer (per-channel) + self.register_buffer("weight_scale", torch.empty((out_features, 1), dtype=torch.float32)) + + # Register bias buffer + if bias: + self.register_buffer("bias", torch.empty(out_features, dtype=dtype)) + else: + self.register_buffer("bias", None) + + def act_quant_func(self, x): + input_tensor_quant, input_tensor_scale = ixf.dynamic_scaled_int8_quant(x) + return input_tensor_quant, input_tensor_scale + + def forward(self, input_tensor): + """ + Forward pass with INT8 → FP16 dequantization + """ + # Handle T5-style input + squeeze_output = False + dtype = input_tensor.dtype + if input_tensor.dim() == 3 and input_tensor.shape[0] == 1: + input_tensor = input_tensor.squeeze(0) + squeeze_output = True + + input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor) + output = ixf.w8a8(input=input_tensor_quant, weight=self.weight, i_scales=input_tensor_scale, w_scales=self.weight_scale.reshape(-1), bias=self.bias, out_dtype=dtype) + + if squeeze_output: + output = output.unsqueeze(0) + return output + + def _apply(self, fn): + for module in self.children(): + module._apply(fn) + + def maybe_cast(t): + if t is not None and t.device != fn(t).device: + return fn(t) + return t + + self.weight = maybe_cast(self.weight) + self.weight_scale = maybe_cast(self.weight_scale) + self.bias = maybe_cast(self.bias) + + return self + + def __repr__(self): + return f"IluvatarQuantLinearInt8(in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}, dtype={self.dtype})" diff --git a/lightx2v_platform/ops/norm/iluvatar_cuda/__init__.py b/lightx2v_platform/ops/norm/iluvatar_cuda/__init__.py new file mode 100644 index 000000000..c801cfc17 --- /dev/null +++ b/lightx2v_platform/ops/norm/iluvatar_cuda/__init__.py @@ -0,0 +1,3 @@ +from .iluvatar_rms_norm import IluvatarRmsNormWeight + +__all__ = ["IluvatarRmsNormWeight"] diff --git a/lightx2v_platform/ops/norm/iluvatar_cuda/iluvatar_rms_norm.py b/lightx2v_platform/ops/norm/iluvatar_cuda/iluvatar_rms_norm.py new file mode 100644 index 000000000..7e07c4e8b --- /dev/null +++ b/lightx2v_platform/ops/norm/iluvatar_cuda/iluvatar_rms_norm.py @@ -0,0 +1,18 @@ +from lightx2v_platform.ops.norm.norm_template import RMSWeightTemplate +from lightx2v_platform.registry_factory import PLATFORM_RMS_WEIGHT_REGISTER + +try: + import ixformer.inference.functions as ixf +except ImportError: + ixf = None + + +@PLATFORM_RMS_WEIGHT_REGISTER("iluvatar_rms_norm") +class IluvatarRmsNormWeight(RMSWeightTemplate): + def __init__(self, weight_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False, eps=0.000001): + super().__init__(weight_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter, eps) + assert ixf is not None, "iluvatar ixformer is not installed." + + def apply(self, input_tensor): + output = ixf.rms_norm(input_tensor.contiguous(), self.weight, self.eps) + return output diff --git a/lightx2v_platform/ops/rope/iluvatar_cuda/__init__.py b/lightx2v_platform/ops/rope/iluvatar_cuda/__init__.py new file mode 100644 index 000000000..4e478ba24 --- /dev/null +++ b/lightx2v_platform/ops/rope/iluvatar_cuda/__init__.py @@ -0,0 +1,3 @@ +from .wan_rope import IluvatarWanRope + +__all__ = ["IluvatarWanRope"] diff --git a/lightx2v_platform/ops/rope/iluvatar_cuda/wan_rope.py b/lightx2v_platform/ops/rope/iluvatar_cuda/wan_rope.py new file mode 100644 index 000000000..7a8d70b1c --- /dev/null +++ b/lightx2v_platform/ops/rope/iluvatar_cuda/wan_rope.py @@ -0,0 +1,217 @@ +from typing import Optional, Union + +import torch +import triton +import triton.language as tl + +from lightx2v_platform.ops.rope.rope_template import RopeTemplate +from lightx2v_platform.registry_factory import PLATFORM_ROPE_REGISTER + + +@PLATFORM_ROPE_REGISTER("iluvatar_wan_rope") +class IluvatarWanRope(RopeTemplate): + def __init__(self): + super().__init__() + + def apply(self, xq: torch.Tensor, xk: torch.Tensor, cos_sin_cache: torch.Tensor): + """ + Apply WAN RoPE using triton operations, optimized for Iluvatar cuda. + Args: + xq: Query tensor + xk: Key tensor + cos_sin_cache: Cosine and sine cache for rotary embedding + + Returns: + Tuple of (xq, xk) with rotary embedding applied + """ + assert torch.is_complex(cos_sin_cache), "cos_sin_cache must be complex tensor" + cos, sin = cos_sin_cache.real.contiguous(), cos_sin_cache.imag.contiguous() + if xq.dim() == 3: + xq = xq.unsqueeze(0) + xk = xk.unsqueeze(0) + xq = apply_rotary(xq, cos, sin, interleaved=True) + xk = apply_rotary(xk, cos, sin, interleaved=True) + if xq.dim() == 4: + xq = xq.squeeze(0) + xk = xk.squeeze(0) + return xq.to(self.infer_dtype), xk.to(self.infer_dtype) + + +@triton.jit +def rotary_kernel( + OUT, # Pointers to matrices + X, + COS, + SIN, + CU_SEQLENS, + SEQLEN_OFFSETS, # this could be int or a pointer + # Matrix dimensions + seqlen, + nheads, + seqlen_ro, + # strides + stride_out_batch, + stride_out_seqlen, + stride_out_nheads, + stride_out_headdim, + stride_x_batch, + stride_x_seqlen, + stride_x_nheads, + stride_x_headdim, + # Meta-parameters + # We want ROTARY_DIM to be constexpr, otherwise the triton compiler doesn't know that + # the mask is constant every 8 elements, and it will generate LDG.16 instead of LDG.128 + ROTARY_DIM: tl.constexpr, + IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr, + IS_VARLEN: tl.constexpr, + INTERLEAVED: tl.constexpr, + CONJUGATE: tl.constexpr, + BLOCK_H: tl.constexpr, + BLOCK_M: tl.constexpr, +): + BLOCK_K: tl.constexpr = ROTARY_DIM + ROTARY_DIM_HALF = ROTARY_DIM // 2 + pid_head = tl.program_id(axis=0) + pid_m = tl.program_id(axis=1) + pid_batch = tl.program_id(axis=2) + + if not IS_VARLEN: + X = X + pid_batch * stride_x_batch + OUT = OUT + pid_batch * stride_out_batch + else: + start_idx = tl.load(CU_SEQLENS + pid_batch) + seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx + X = X + start_idx * stride_x_seqlen + OUT = OUT + start_idx * stride_out_seqlen + + if pid_m * BLOCK_M >= seqlen: + return + + rh = pid_head * BLOCK_H + tl.arange(0, BLOCK_H) + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + if not IS_SEQLEN_OFFSETS_TENSOR: + rm_cs = rm + SEQLEN_OFFSETS + else: + rm_cs = rm + tl.load(SEQLEN_OFFSETS + pid_batch) + + rk_half = tl.arange(0, BLOCK_K // 2) + COS = COS + (rm_cs[:, None] * ROTARY_DIM_HALF + rk_half[None, :]) + SIN = SIN + (rm_cs[:, None] * ROTARY_DIM_HALF + rk_half[None, :]) + mask_cs = (rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < ROTARY_DIM_HALF) + cos = tl.load(COS, mask=mask_cs, other=1.0).to(tl.float32) + sin = tl.load(SIN, mask=mask_cs, other=0.0).to(tl.float32) + if CONJUGATE: + sin = -sin + + if not INTERLEAVED: + # Load the 1st and 2nd halves of X, do calculation, then store to 1st and 2nd halves of OUT + X = X + (rh[:, None, None] * stride_x_nheads + rm[None, :, None] * stride_x_seqlen + rk_half[None, None, :] * stride_x_headdim) + OUT = OUT + (rh[:, None, None] * stride_out_nheads + rm[None, :, None] * stride_out_seqlen + rk_half[None, None, :] * stride_out_headdim) + mask = (rh[:, None, None] < nheads) & (rm[None, :, None] < seqlen) & (rk_half[None, None, :] < ROTARY_DIM_HALF) + x0 = tl.load(X, mask=mask, other=0.0).to(tl.float32) + x1 = tl.load( + X + ROTARY_DIM_HALF * stride_x_headdim, + mask=mask, + other=0.0, + ).to(tl.float32) + o0 = x0 * cos - x1 * sin + o1 = x0 * sin + x1 * cos + tl.store(OUT, o0, mask=mask) + tl.store(OUT + ROTARY_DIM_HALF * stride_out_headdim, o1, mask=mask) + else: + rk = tl.arange(0, BLOCK_K) + X = X + (rh[:, None, None] * stride_x_nheads + rm[None, :, None] * stride_x_seqlen + rk[None, None, :] * stride_x_headdim) + OUT = OUT + (rh[:, None, None] * stride_out_nheads + rm[None, :, None] * stride_out_seqlen + rk[None, None, :] * stride_out_headdim) + mask = (rh[:, None, None] < nheads) & (rm[None, :, None] < seqlen) & (rk[None, None, :] < ROTARY_DIM) + x = tl.load(X, mask=mask, other=0.0).to(tl.float32) + x0, x1 = tl.split(tl.reshape(x, [BLOCK_H, BLOCK_M, BLOCK_K // 2, 2])) + o0 = x0 * cos - x1 * sin + o1 = x0 * sin + x1 * cos + o = tl.reshape(tl.join(o0, o1), [BLOCK_H, BLOCK_M, BLOCK_K]) + tl.store(OUT, o, mask=mask) + + +def apply_rotary( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + seqlen_offsets: Union[int, torch.Tensor] = 0, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + interleaved=False, + inplace=False, + conjugate=False, +) -> torch.Tensor: + """ + Arguments: + x: (batch, seqlen, nheads, headdim) if cu_seqlens is None + else (total_seqlen, nheads, headdim). + cos: (seqlen_ro, rotary_dim / 2) + sin: (seqlen_ro, rotary_dim / 2) + seqlen_offsets: integer or integer tensor of size (batch,) + cu_seqlens: (batch + 1,) or None + max_seqlen: int + Returns: + y: (batch, seqlen, nheads, headdim) + """ + is_varlen = cu_seqlens is not None + if not is_varlen: + batch, seqlen, nheads, headdim = x.shape + else: + assert max_seqlen is not None, "If cu_seqlens is passed in, then max_seqlen must be passed" + total_seqlen, nheads, headdim = x.shape + batch_p_1 = cu_seqlens.shape[0] + batch = batch_p_1 - 1 + seqlen = max_seqlen + seqlen_ro, rotary_dim = cos.shape + assert sin.shape == cos.shape + rotary_dim *= 2 + assert rotary_dim <= headdim, "rotary_dim must be <= headdim" + assert headdim <= 256, "Only support headdim <= 256" + assert seqlen_ro >= seqlen, "seqlen_ro must be >= seqlen" + + cos, sin = cos.contiguous(), sin.contiguous() + if isinstance(seqlen_offsets, torch.Tensor): + assert seqlen_offsets.shape == (batch,) + assert seqlen_offsets.dtype in [torch.int32, torch.int64] + seqlen_offsets = seqlen_offsets.contiguous() + else: + assert seqlen_offsets + seqlen <= seqlen_ro + + output = torch.empty_like(x) if not inplace else x + if rotary_dim < headdim and not inplace: + output[..., rotary_dim:].copy_(x[..., rotary_dim:]) + + grid = lambda META: (triton.cdiv(nheads, META["BLOCK_H"]), triton.cdiv(seqlen, META["BLOCK_M"]), batch) # noqa + BLOCK_M = 8 if rotary_dim <= 128 else 4 + + # Need this, otherwise Triton tries to launch from cuda:0 and we get + # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?) + + rotary_kernel[grid]( + output, # data ptrs + x, + cos, + sin, + cu_seqlens, + seqlen_offsets, + seqlen, # shapes + nheads, + seqlen_ro, + output.stride(0) if not is_varlen else 0, # batch_strides if not varlen else 0 + output.stride(-3), # seqlen_stride or total_seqlen_stride + output.stride(-2), # nheads_stride + output.stride(-1), # headdim_stride + x.stride(0) if not is_varlen else 0, # batch_strides if not varlen else 0 + x.stride(-3), # seqlen stride or total_seqlen_stride + x.stride(-2), # nheads stride + x.stride(-1), # headdim stride + rotary_dim, + isinstance(seqlen_offsets, torch.Tensor), + is_varlen, + interleaved, + conjugate, + BLOCK_M=BLOCK_M, + BLOCK_H=2, + ) + return output diff --git a/scripts/platforms/iluvatar_cuda/z_image_turbo_t2i.sh b/scripts/platforms/iluvatar_cuda/z_image_turbo_t2i.sh new file mode 100755 index 000000000..830ad9476 --- /dev/null +++ b/scripts/platforms/iluvatar_cuda/z_image_turbo_t2i.sh @@ -0,0 +1,22 @@ +#!/bin/bash + +# set path firstly +lightx2v_path=$(pwd)/LightX2V +model_path=Tongyi-MAI/Z-Image-Turbo/ + +export CUDA_VISIBLE_DEVICES=0 +export PLATFORM="iluvatar_cuda" + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +python3 -m lightx2v.infer \ +--model_cls z_image \ +--task t2i \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/platforms/iluvatar_cuda/z_image_turbo_t2i.json \ +--prompt 'Young Chinese woman in red Hanfu, intricate embroidery. Impeccable makeup, red floral forehead pattern. Elaborate high bun, golden phoenix headdress, red flowers, beads. Holds round folding fan with lady, trees, bird. Neon lightning-bolt lamp (⚡️), bright yellow glow, above extended left palm. Soft-lit outdoor night background, silhouetted tiered pagoda (西安大雁塔), blurred colorful distant lights.' \ +--negative_prompt " " \ +--save_result_path ${lightx2v_path}/save_results/z_image_turbo.png \ +--seed 42 \ +--aspect_ratio "16:9"