Skip to content

Commit 1a22018

Browse files
authored
z-image support npu
1 parent 7b107d3 commit 1a22018

1 file changed

Lines changed: 54 additions & 18 deletions

File tree

src/diffusers/models/transformers/transformer_z_image.py

Lines changed: 54 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from ...models.attention_processor import Attention
2525
from ...models.modeling_utils import ModelMixin
2626
from ...models.normalization import RMSNorm
27+
from ...utils import is_torch_npu_available
2728
from ...utils.torch_utils import maybe_allow_in_graph
2829
from ..attention_dispatch import dispatch_attention_fn
2930
from ..modeling_outputs import Transformer2DModelOutput
@@ -322,37 +323,72 @@ def __init__(
322323
self.axes_lens = axes_lens
323324
assert len(axes_dims) == len(axes_lens), "axes_dims and axes_lens must have the same length"
324325
self.freqs_cis = None
326+
self.freqs_real = None
327+
self.freqs_imag = None
325328

326329
@staticmethod
327330
def precompute_freqs_cis(dim: list[int], end: list[int], theta: float = 256.0):
328331
with torch.device("cpu"):
329-
freqs_cis = []
330-
for i, (d, e) in enumerate(zip(dim, end)):
331-
freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d))
332-
timestep = torch.arange(e, device=freqs.device, dtype=torch.float64)
333-
freqs = torch.outer(timestep, freqs).float()
334-
freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs).to(torch.complex64) # complex64
335-
freqs_cis.append(freqs_cis_i)
336-
337-
return freqs_cis
332+
if is_torch_npu_available:
333+
freqs_real_list = []
334+
freqs_imag_list = []
335+
for i, (d, e) in enumerate(zip(dim, end)):
336+
freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d))
337+
timestep = torch.arange(e, device=freqs.device, dtype=torch.float64)
338+
freqs = torch.outer(timestep, freqs).float()
339+
freqs_real = torch.cos(freqs)
340+
freqs_imag = torch.sin(freqs)
341+
freqs_real_list.append(freqs_real.to(torch.float32))
342+
freqs_imag_list.append(freqs_imag.to(torch.float32))
343+
344+
return freqs_real_list, freqs_imag_list
345+
else:
346+
freqs_cis = []
347+
for i, (d, e) in enumerate(zip(dim, end)):
348+
freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d))
349+
timestep = torch.arange(e, device=freqs.device, dtype=torch.float64)
350+
freqs = torch.outer(timestep, freqs).float()
351+
freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs).to(torch.complex64) # complex64
352+
freqs_cis.append(freqs_cis_i)
353+
return freqs_cis
338354

339355
def __call__(self, ids: torch.Tensor):
340356
assert ids.ndim == 2
341357
assert ids.shape[-1] == len(self.axes_dims)
342358
device = ids.device
343359

344-
if self.freqs_cis is None:
345-
self.freqs_cis = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta)
346-
self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis]
360+
if is_torch_npu_available:
361+
if self.freqs_real is None or self.freqs_imag is None:
362+
freqs_real, freqs_imag = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta)
363+
self.freqs_real = [fr.to(device) for fr in freqs_real]
364+
self.freqs_imag = [fi.to(device) for fi in freqs_imag]
365+
else:
366+
# Ensure freqs_cis are on the same device as ids
367+
if self.freqs_real[0].device != device:
368+
self.freqs_real = [fr.to(device) for fr in freqs_real]
369+
self.freqs_imag = [fi.to(device) for fi in freqs_imag]
370+
371+
result = []
372+
for i in range(len(self.axes_dims)):
373+
index = ids[:, i]
374+
real_part = self.freqs_real[i][index]
375+
imag_part = self.freqs_imag[i][index]
376+
complex_part = torch.complex(real_part, imag_part)
377+
result.append(complex_part)
347378
else:
348-
# Ensure freqs_cis are on the same device as ids
349-
if self.freqs_cis[0].device != device:
379+
if self.freqs_cis is None:
380+
self.freqs_cis = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta)
350381
self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis]
382+
else:
383+
# Ensure freqs_cis are on the same device as ids
384+
if self.freqs_cis[0].device != device:
385+
self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis]
386+
387+
result = []
388+
for i in range(len(self.axes_dims)):
389+
index = ids[:, i]
390+
result.append(self.freqs_cis[i][index])
351391

352-
result = []
353-
for i in range(len(self.axes_dims)):
354-
index = ids[:, i]
355-
result.append(self.freqs_cis[i][index])
356392
return torch.cat(result, dim=-1)
357393

358394

0 commit comments

Comments
 (0)