Skip to content

Commit 3d12f3b

Browse files
authored
update a general solution for supporting both NPU and GPU
1 parent 1a22018 commit 3d12f3b

1 file changed

Lines changed: 18 additions & 54 deletions

File tree

src/diffusers/models/transformers/transformer_z_image.py

Lines changed: 18 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
from ...models.attention_processor import Attention
2525
from ...models.modeling_utils import ModelMixin
2626
from ...models.normalization import RMSNorm
27-
from ...utils import is_torch_npu_available
2827
from ...utils.torch_utils import maybe_allow_in_graph
2928
from ..attention_dispatch import dispatch_attention_fn
3029
from ..modeling_outputs import Transformer2DModelOutput
@@ -323,72 +322,37 @@ def __init__(
323322
self.axes_lens = axes_lens
324323
assert len(axes_dims) == len(axes_lens), "axes_dims and axes_lens must have the same length"
325324
self.freqs_cis = None
326-
self.freqs_real = None
327-
self.freqs_imag = None
328325

329326
@staticmethod
330327
def precompute_freqs_cis(dim: list[int], end: list[int], theta: float = 256.0):
331328
with torch.device("cpu"):
332-
if is_torch_npu_available:
333-
freqs_real_list = []
334-
freqs_imag_list = []
335-
for i, (d, e) in enumerate(zip(dim, end)):
336-
freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d))
337-
timestep = torch.arange(e, device=freqs.device, dtype=torch.float64)
338-
freqs = torch.outer(timestep, freqs).float()
339-
freqs_real = torch.cos(freqs)
340-
freqs_imag = torch.sin(freqs)
341-
freqs_real_list.append(freqs_real.to(torch.float32))
342-
freqs_imag_list.append(freqs_imag.to(torch.float32))
343-
344-
return freqs_real_list, freqs_imag_list
345-
else:
346-
freqs_cis = []
347-
for i, (d, e) in enumerate(zip(dim, end)):
348-
freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d))
349-
timestep = torch.arange(e, device=freqs.device, dtype=torch.float64)
350-
freqs = torch.outer(timestep, freqs).float()
351-
freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs).to(torch.complex64) # complex64
352-
freqs_cis.append(freqs_cis_i)
353-
return freqs_cis
329+
freqs_cis = []
330+
for i, (d, e) in enumerate(zip(dim, end)):
331+
freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d))
332+
timestep = torch.arange(e, device=freqs.device, dtype=torch.float64)
333+
freqs = torch.outer(timestep, freqs).float()
334+
freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs).to(torch.complex64) # complex64
335+
freqs_cis.append(freqs_cis_i)
336+
337+
return freqs_cis
354338

355339
def __call__(self, ids: torch.Tensor):
356340
assert ids.ndim == 2
357341
assert ids.shape[-1] == len(self.axes_dims)
358342
device = ids.device
359343

360-
if is_torch_npu_available:
361-
if self.freqs_real is None or self.freqs_imag is None:
362-
freqs_real, freqs_imag = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta)
363-
self.freqs_real = [fr.to(device) for fr in freqs_real]
364-
self.freqs_imag = [fi.to(device) for fi in freqs_imag]
365-
else:
366-
# Ensure freqs_cis are on the same device as ids
367-
if self.freqs_real[0].device != device:
368-
self.freqs_real = [fr.to(device) for fr in freqs_real]
369-
self.freqs_imag = [fi.to(device) for fi in freqs_imag]
370-
371-
result = []
372-
for i in range(len(self.axes_dims)):
373-
index = ids[:, i]
374-
real_part = self.freqs_real[i][index]
375-
imag_part = self.freqs_imag[i][index]
376-
complex_part = torch.complex(real_part, imag_part)
377-
result.append(complex_part)
344+
if self.freqs_cis is None:
345+
self.freqs_cis = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta)
346+
self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis]
378347
else:
379-
if self.freqs_cis is None:
380-
self.freqs_cis = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta)
348+
# Ensure freqs_cis are on the same device as ids
349+
if self.freqs_cis[0].device != device:
381350
self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis]
382-
else:
383-
# Ensure freqs_cis are on the same device as ids
384-
if self.freqs_cis[0].device != device:
385-
self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis]
386-
387-
result = []
388-
for i in range(len(self.axes_dims)):
389-
index = ids[:, i]
390-
result.append(self.freqs_cis[i][index])
391351

352+
result = []
353+
for i in range(len(self.axes_dims)):
354+
index = ids[:, i]
355+
result.append(torch.complex(self.freqs_cis[i].real[index], self.freqs_cis[i].imag[index]))
392356
return torch.cat(result, dim=-1)
393357

394358

0 commit comments

Comments
 (0)