Skip to content

Commit 5194f30

Browse files
hufangjian2017fangjian.hu
andauthored
[Feature] iluvatar platforms support (#1045)
Co-authored-by: fangjian.hu <fangjian.hu@iluvatar.ai>
1 parent 0684d35 commit 5194f30

20 files changed

Lines changed: 627 additions & 8 deletions

File tree

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
{
2+
"aspect_ratio": "16:9",
3+
"num_channels_latents": 16,
4+
"infer_steps": 9,
5+
"attn_type": "iluvatar_flash_attn",
6+
"enable_cfg": false,
7+
"sample_guide_scale": 0.0,
8+
"patch_size": 2,
9+
"rope_type":"iluvatar_wan_rope",
10+
"rms_norm_type":"iluvatar_rms_norm"
11+
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
{
2+
"aspect_ratio": "16:9",
3+
"num_channels_latents": 16,
4+
"infer_steps": 9,
5+
"attn_type": "iluvatar_flash_attn",
6+
"enable_cfg": false,
7+
"sample_guide_scale": 0.0,
8+
"patch_size": 2,
9+
"rope_type":"iluvatar_wan_rope",
10+
"rms_norm_type":"iluvatar_rms_norm",
11+
"dit_quantized": true,
12+
"dit_quant_scheme": "int8-iluvatar",
13+
"dit_quantized_ckpt": ""
14+
}

lightx2v/models/input_encoders/hf/wan/t5/model.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131
)
3232
from lightx2v_platform.ops.mm.cambricon_mlu.q_linear import MluQuantLinearInt8 # noqa E402
3333
from lightx2v_platform.ops.mm.ascend_npu.npu_q_linear import NpuQuantLinearInt8 # noqa E402
34+
from lightx2v_platform.ops.mm.iluvatar_cuda.q_linear import IluvatarQuantLinearInt8 # noqa E402
35+
3436
from lightx2v.models.input_encoders.hf.wan.t5.tokenizer import HuggingfaceTokenizer # noqa E402
3537
from lightx2v.utils.envs import * # noqa E402
3638
from lightx2v.utils.registry_factory import ( # noqa E402
@@ -226,8 +228,10 @@ def __init__(
226228
linear_cls = MluQuantLinearInt8
227229
elif quant_scheme == "int8-npu":
228230
linear_cls = NpuQuantLinearInt8
231+
elif quant_scheme == "int8-iluvatar":
232+
linear_cls = IluvatarQuantLinearInt8
229233
else:
230-
NotImplementedError(f"Unsupported T5 quant scheme: {quant_scheme}")
234+
raise NotImplementedError(f"Unsupported T5 quant scheme: {quant_scheme}")
231235
else:
232236
linear_cls = nn.Linear
233237

@@ -309,8 +313,10 @@ def __init__(
309313
linear_cls = MluQuantLinearInt8
310314
elif quant_scheme == "int8-npu":
311315
linear_cls = NpuQuantLinearInt8
316+
elif quant_scheme == "int8-iluvatar":
317+
linear_cls = IluvatarQuantLinearInt8
312318
else:
313-
NotImplementedError(f"Unsupported T5 quant scheme: {quant_scheme}")
319+
raise NotImplementedError(f"Unsupported T5 quant scheme: {quant_scheme}")
314320
else:
315321
linear_cls = nn.Linear
316322
# layers

lightx2v/models/input_encoders/hf/wan/xlm_roberta/model.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from lightx2v.utils.utils import load_weights
2525
from lightx2v_platform.base.global_var import AI_DEVICE
2626
from lightx2v_platform.ops.mm.cambricon_mlu.q_linear import MluQuantLinearInt8
27+
from lightx2v_platform.ops.mm.iluvatar_cuda.q_linear import IluvatarQuantLinearInt8
2728

2829
__all__ = [
2930
"XLMRobertaCLIP",
@@ -91,8 +92,10 @@ def __init__(self, dim, num_heads, causal=False, attn_dropout=0.0, proj_dropout=
9192
linear_cls = TritonQuantLinearFp8
9293
elif quant_scheme == "int8-tmo":
9394
linear_cls = MluQuantLinearInt8
95+
elif quant_scheme == "int8-iluvatar":
96+
linear_cls = IluvatarQuantLinearInt8
9497
else:
95-
NotImplementedError(f"Unsupported CLip quant scheme: {quant_scheme}")
98+
raise NotImplementedError(f"Unsupported CLip quant scheme: {quant_scheme}")
9699
else:
97100
linear_cls = nn.Linear
98101

@@ -181,8 +184,10 @@ def __init__(
181184
linear_cls = TritonQuantLinearFp8
182185
elif quant_scheme == "int8-tmo":
183186
linear_cls = MluQuantLinearInt8
187+
elif quant_scheme == "int8-iluvatar":
188+
linear_cls = IluvatarQuantLinearInt8
184189
else:
185-
NotImplementedError(f"Unsupported T5 quant scheme: {quant_scheme}")
190+
raise NotImplementedError(f"Unsupported T5 quant scheme: {quant_scheme}")
186191
else:
187192
linear_cls = nn.Linear
188193

lightx2v/models/networks/base_model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ def _check_dit_quantized(self):
131131
"gguf-Q3_K_M",
132132
"int8-npu",
133133
"fp8-intel-xpu",
134+
"int8-iluvatar",
134135
]
135136

136137
@abstractmethod

lightx2v/models/networks/wan/audio_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def _load_adapter_ckpt(self):
2929
if self.config.get("adapter_quantized", False):
3030
if self.config.get("adapter_quant_scheme", None) in ["fp8", "fp8-q8f", "fp8-vllm", "fp8-sgl", "fp8-torchao", "fp8-triton"]:
3131
adapter_model_name = "audio_adapter_model_fp8.safetensors"
32-
elif self.config.get("adapter_quant_scheme", None) in ["int8", "int8-q8f", "int8-vllm", "int8-torchao", "int8-sgl", "int8-triton", "int8-tmo", "int8-npu"]:
32+
elif self.config.get("adapter_quant_scheme", None) in ["int8", "int8-q8f", "int8-vllm", "int8-torchao", "int8-sgl", "int8-triton", "int8-tmo", "int8-npu", "int8-iluvatar"]:
3333
adapter_model_name = "audio_adapter_model_int8.safetensors"
3434
elif self.config.get("adapter_quant_scheme", None) in ["mxfp4"]:
3535
adapter_model_name = "audio_adapter_model_mxfp4.safetensors"

lightx2v/models/networks/z_image/infer/transformer_infer.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import torch.nn.functional as F
33

44
from lightx2v.common.transformer_infer.transformer_infer import BaseTransformerInfer
5+
from lightx2v.utils.registry_factory import ROPE_REGISTER
56

67
from .utils import apply_rotary_emb_qwen, apply_wan_rope_with_flashinfer
78

@@ -20,10 +21,26 @@ def __init__(self, config):
2021
self.seq_p_group = None
2122
self.seq_p_fp8_comm = False
2223
self.seq_p_fp4_comm = False
23-
if self.config.get("rope_type", "flashinfer") == "flashinfer":
24-
self.apply_rope_func = apply_wan_rope_with_flashinfer
24+
25+
rope_funcs = {
26+
"flashinfer": apply_wan_rope_with_flashinfer,
27+
"torch_naive": apply_rotary_emb_qwen,
28+
}
29+
30+
rope_type = self.config.get("rope_type", "flashinfer")
31+
if rope_type in ROPE_REGISTER:
32+
rope_class = ROPE_REGISTER[rope_type]
33+
self.rope_instance = rope_class()
34+
35+
# Create a wrapper function that matches the expected signature
36+
def rope_wrapper(xq, xk, cos_sin_cache):
37+
return self.rope_instance.apply(xq, xk, cos_sin_cache)
38+
39+
rope_func = rope_wrapper
2540
else:
26-
self.apply_rope_func = apply_rotary_emb_qwen
41+
# Fallback to hardcoded functions
42+
rope_func = rope_funcs.get(rope_type, apply_rotary_emb_qwen)
43+
self.apply_rope_func = rope_func
2744

2845
def set_scheduler(self, scheduler):
2946
self.scheduler = scheduler

lightx2v_platform/base/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from lightx2v_platform.base.nvidia import CudaDevice
99
from lightx2v_platform.base.enflame_gcu import EnflameGcuDevice
1010
from lightx2v_platform.base.intel_xpu import IntelXpuDevice
11+
from lightx2v_platform.base.iluvatar_cuda import IluvatarDevice
1112

1213
__all__ = [
1314
"init_ai_device",
@@ -21,4 +22,5 @@
2122
"MusaDevice",
2223
"EnflameGcuDevice",
2324
"IntelXpuDevice",
25+
"IluvatarDevice",
2426
]
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import torch
2+
import torch.distributed as dist
3+
4+
from lightx2v_platform.registry_factory import PLATFORM_DEVICE_REGISTER
5+
6+
try:
7+
from torch.distributed import ProcessGroupNCCL
8+
except ImportError:
9+
ProcessGroupNCCL = None
10+
11+
12+
@PLATFORM_DEVICE_REGISTER("iluvatar_cuda")
13+
class IluvatarDevice:
14+
name = "iluvatar_cuda"
15+
16+
@staticmethod
17+
def init_device_env():
18+
pass
19+
20+
@staticmethod
21+
def is_available() -> bool:
22+
try:
23+
import torch
24+
25+
return torch.cuda.is_available()
26+
except ImportError:
27+
return False
28+
29+
@staticmethod
30+
def get_device() -> str:
31+
return "cuda"
32+
33+
@staticmethod
34+
def init_parallel_env():
35+
if ProcessGroupNCCL is None:
36+
raise RuntimeError("ProcessGroupNCCL is not available. Please check your runtime environment.")
37+
pg_options = ProcessGroupNCCL.Options()
38+
pg_options.is_high_priority_stream = True
39+
dist.init_process_group(backend="nccl", pg_options=pg_options)
40+
torch.cuda.set_device(dist.get_rank())

lightx2v_platform/ops/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,8 @@
2424
elif PLATFORM == "intel_xpu":
2525
from .attn.intel_xpu import *
2626
from .mm.intel_xpu import *
27+
elif PLATFORM == "iluvatar_cuda":
28+
from .attn.iluvatar_cuda import *
29+
from .mm.iluvatar_cuda import *
30+
from .norm.iluvatar_cuda import *
31+
from .rope.iluvatar_cuda import *

0 commit comments

Comments
 (0)