-
Notifications
You must be signed in to change notification settings - Fork 192
[Feature] iluvatar platforms support #1045
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,11 @@ | ||
| { | ||
| "aspect_ratio": "16:9", | ||
| "num_channels_latents": 16, | ||
| "infer_steps": 9, | ||
| "attn_type": "iluvatar_flash_attn", | ||
| "enable_cfg": false, | ||
| "sample_guide_scale": 0.0, | ||
| "patch_size": 2, | ||
| "rope_type":"iluvatar_wan_rope", | ||
| "rms_norm_type":"iluvatar_rms_norm" | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,14 @@ | ||
| { | ||
| "aspect_ratio": "16:9", | ||
| "num_channels_latents": 16, | ||
| "infer_steps": 9, | ||
| "attn_type": "iluvatar_flash_attn", | ||
| "enable_cfg": false, | ||
| "sample_guide_scale": 0.0, | ||
| "patch_size": 2, | ||
| "rope_type":"iluvatar_wan_rope", | ||
| "rms_norm_type":"iluvatar_rms_norm", | ||
| "dit_quantized": true, | ||
| "dit_quant_scheme": "int8-iluvatar", | ||
| "dit_quantized_ckpt": "" | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,40 @@ | ||
| import torch | ||
| import torch.distributed as dist | ||
|
|
||
| from lightx2v_platform.registry_factory import PLATFORM_DEVICE_REGISTER | ||
|
|
||
| try: | ||
| from torch.distributed import ProcessGroupNCCL | ||
| except ImportError: | ||
| ProcessGroupNCCL = None | ||
|
|
||
|
|
||
| @PLATFORM_DEVICE_REGISTER("iluvatar_cuda") | ||
| class IluvatarDevice: | ||
| name = "iluvatar_cuda" | ||
|
|
||
| @staticmethod | ||
| def init_device_env(): | ||
| pass | ||
|
|
||
| @staticmethod | ||
| def is_available() -> bool: | ||
| try: | ||
| import torch | ||
|
|
||
| return torch.cuda.is_available() | ||
| except ImportError: | ||
| return False | ||
|
|
||
| @staticmethod | ||
| def get_device() -> str: | ||
| return "cuda" | ||
|
|
||
| @staticmethod | ||
| def init_parallel_env(): | ||
| if ProcessGroupNCCL is None: | ||
| raise RuntimeError("ProcessGroupNCCL is not available. Please check your runtime environment.") | ||
| pg_options = ProcessGroupNCCL.Options() | ||
| pg_options.is_high_priority_stream = True | ||
| dist.init_process_group(backend="nccl", pg_options=pg_options) | ||
| torch.cuda.set_device(dist.get_rank()) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| from .flash_attn import * |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,64 @@ | ||
| import math | ||
|
|
||
| import torch | ||
|
|
||
| from lightx2v_platform.ops.attn.template import AttnWeightTemplate | ||
| from lightx2v_platform.registry_factory import PLATFORM_ATTN_WEIGHT_REGISTER | ||
|
|
||
| try: | ||
| from ixformer.contrib.vllm_flash_attn import flash_attn_varlen_func | ||
| except ImportError: | ||
| flash_attn_varlen_func = None | ||
|
|
||
|
|
||
| @PLATFORM_ATTN_WEIGHT_REGISTER("iluvatar_flash_attn") | ||
| class IluvatarFlashAttnWeight(AttnWeightTemplate): | ||
| def __init__(self): | ||
| self.config = {} | ||
| assert flash_attn_varlen_func is not None, "iluvatar ixformer is not installed." | ||
|
|
||
| def apply(self, q, k, v, cu_seqlens_q=None, cu_seqlens_kv=None, max_seqlen_q=None, max_seqlen_kv=None, **kwds): | ||
| half_dtypes = (torch.float16, torch.bfloat16) | ||
| device = q.device | ||
| dtype = q.dtype | ||
|
|
||
| def half(x): | ||
| return x if x.dtype in half_dtypes else x.to(dtype) | ||
|
|
||
| if len(q.shape) == 3: | ||
| bs = 1 | ||
| elif len(q.shape) == 4: | ||
| bs, lq, lk = q.size(0), q.size(1), k.size(1) | ||
| # preprocess query | ||
| if cu_seqlens_q is None: | ||
| q = half(q.flatten(0, 1)) | ||
| cu_seqlens_q = torch.tensor([lq] * bs, dtype=torch.int32).to(device=q.device, non_blocking=True) | ||
| cu_seqlens_q = torch.cat([cu_seqlens_q.new_zeros([1]), cu_seqlens_q]).cumsum(0, dtype=torch.int32) | ||
| else: | ||
| q = half(torch.cat([u[:v] for u, v in zip(q, cu_seqlens_q)])) | ||
| # preprocess key, value | ||
| if cu_seqlens_kv is None: | ||
| k = half(k.flatten(0, 1)) | ||
| v = half(v.flatten(0, 1)) | ||
| cu_seqlens_kv = torch.tensor([lk] * bs, dtype=torch.int32).to(device=k.device, non_blocking=True) | ||
| cu_seqlens_kv = torch.cat([cu_seqlens_kv.new_zeros([1]), cu_seqlens_kv]).cumsum(0, dtype=torch.int32) | ||
| else: | ||
| k = half(torch.cat([u[:v] for u, v in zip(k, cu_seqlens_kv)])) | ||
| v = half(torch.cat([u[:v] for u, v in zip(v, cu_seqlens_kv)])) | ||
|
|
||
| q = q.to(v.dtype) | ||
| k = k.to(v.dtype) | ||
| softmax_scale = 1 / math.sqrt(q.shape[-1]) | ||
| x = flash_attn_varlen_func( | ||
| q=q, | ||
| k=k, | ||
| v=v, | ||
| cu_seqlens_q=cu_seqlens_q.to(device), # cu_seqlens_q, | ||
| cu_seqlens_k=cu_seqlens_kv.to(device), # cu_seqlens_kv, | ||
| max_seqlen_q=max_seqlen_q, | ||
| max_seqlen_k=max_seqlen_kv, | ||
| softmax_scale=softmax_scale, | ||
| return_softmax_lse=False, | ||
| causal=False, | ||
| ) | ||
| return x.reshape(bs * max_seqlen_q, -1) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| from .mm_weight import * |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,100 @@ | ||
| import torch | ||
|
|
||
| from lightx2v.utils.quant_utils import IntegerQuantizer | ||
| from lightx2v_platform.ops.mm.template import MMWeightQuantTemplate | ||
| from lightx2v_platform.registry_factory import PLATFORM_MM_WEIGHT_REGISTER | ||
|
|
||
| try: | ||
| import ixformer.inference.functions as ixf | ||
| except ImportError: | ||
| ixf = None | ||
|
|
||
|
|
||
| @PLATFORM_MM_WEIGHT_REGISTER("int8-iluvatar") | ||
| class MMWeightWint8channelAint8channeldynamicIluvatar(MMWeightQuantTemplate): | ||
| """ | ||
| Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-iluvatar | ||
|
|
||
| Quant MM: | ||
| Weight: int8 perchannel sym | ||
| Act: int8 perchannel dynamic sym | ||
| Kernel: iluvatar | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| weight_name, | ||
| bias_name, | ||
| create_cuda_buffer=False, | ||
| create_cpu_buffer=False, | ||
| lazy_load=False, | ||
| lazy_load_file=None, | ||
| is_post_adapter=False, | ||
| lora_prefix="diffusion_model.blocks", | ||
| lora_path="", | ||
| ): | ||
| super().__init__( | ||
| weight_name, | ||
| bias_name, | ||
| create_cuda_buffer, | ||
| create_cpu_buffer, | ||
| lazy_load, | ||
| lazy_load_file, | ||
| is_post_adapter, | ||
| lora_prefix, | ||
| lora_path, | ||
| ) | ||
| assert ixf is not None, "iluvatar ixformer is not installed." | ||
| self.load_func = self.load_int8_perchannel_sym | ||
| self.weight_need_transpose = False | ||
| self.act_quant_func = self.act_quant_int8_perchannel_sym_iluvatar | ||
|
|
||
| def _ensure_int8_weight_and_scale(self, weight_dict): | ||
| """Fill missing weight_scale (or int8 weight) so load_quantized can run. | ||
|
|
||
| Some quantized checkpoints omit per-layer scales (e.g. adaLN) or use alternate | ||
| key names; others keep a few layers in float — per-channel int8 + scale is | ||
| then derived to match ixformer w8a8. | ||
| """ | ||
| if self.lazy_load: | ||
| return | ||
| if self.weight_name not in weight_dict: | ||
| return | ||
| if self.weight_scale_name in weight_dict: | ||
| return | ||
| base = self.weight_name.removesuffix(".weight") | ||
| for alt in (f"{base}.scale", f"{self.weight_name}_scale"): | ||
| if alt in weight_dict: | ||
| weight_dict[self.weight_scale_name] = weight_dict[alt].float() | ||
| return | ||
| w = weight_dict[self.weight_name] | ||
| if w.dtype in (torch.float16, torch.bfloat16, torch.float32): | ||
| w_float = w.to(torch.float32) | ||
| w_quantizer = IntegerQuantizer(8, True, "per_channel") | ||
| qw, scale, _ = w_quantizer.real_quant_tensor(w_float) | ||
| dev = w.device | ||
| weight_dict[self.weight_name] = qw.to(torch.int8).to(dev) | ||
| weight_dict[self.weight_scale_name] = scale.to(torch.float32).to(dev) | ||
|
|
||
| def load(self, weight_dict): | ||
| self._ensure_int8_weight_and_scale(weight_dict) | ||
| super().load(weight_dict) | ||
|
|
||
| def act_quant_int8_perchannel_sym_iluvatar(self, x): | ||
| device = x.device | ||
| input_tensor_quant = torch.empty(x.shape, dtype=torch.int8, device=device) | ||
| input_tensor_scale = torch.empty(x.shape[:-1], dtype=torch.float32, device=device) | ||
| ixf.dynamic_scaled_int8_quant(output=input_tensor_quant, input=x, scale=input_tensor_scale) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The API usage of
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. iluvatar output使用先分配tensor |
||
| return input_tensor_quant, input_tensor_scale | ||
|
|
||
| def apply(self, input_tensor): | ||
| squeeze_output = False | ||
| dtype = input_tensor.dtype | ||
| if input_tensor.dim() == 3 and input_tensor.shape[0] == 1: | ||
| input_tensor = input_tensor.squeeze(0) | ||
| squeeze_output = True | ||
| input_tensor_quant, input_tensor_scale = self.act_quant_int8_perchannel_sym_iluvatar(input_tensor) | ||
| output = ixf.w8a8(input=input_tensor_quant, weight=self.weight, i_scales=input_tensor_scale, w_scales=self.weight_scale.reshape(-1), bias=self.bias, out_dtype=dtype) | ||
| if squeeze_output: | ||
| output = output.unsqueeze(0) | ||
| return output | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are two issues here:
dtypeis undefined iflen(q.shape) == 3because it is only assigned inside theelif len(q.shape) == 4block (line 30). This will cause aNameErrorifhalf()is ever called for 3D inputs.x if x.dtype in half_dtypes else x.to(dtype)does not actually ensure half precision if the inputqisfloat32(asdtypewould befloat32). Flash attention kernels typically requirefloat16orbfloat16.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已经修正