|
24 | 24 | from ...models.attention_processor import Attention |
25 | 25 | from ...models.modeling_utils import ModelMixin |
26 | 26 | from ...models.normalization import RMSNorm |
27 | | -from ...utils import is_torch_npu_available |
28 | 27 | from ...utils.torch_utils import maybe_allow_in_graph |
29 | 28 | from ..attention_dispatch import dispatch_attention_fn |
30 | 29 | from ..modeling_outputs import Transformer2DModelOutput |
@@ -323,72 +322,37 @@ def __init__( |
323 | 322 | self.axes_lens = axes_lens |
324 | 323 | assert len(axes_dims) == len(axes_lens), "axes_dims and axes_lens must have the same length" |
325 | 324 | self.freqs_cis = None |
326 | | - self.freqs_real = None |
327 | | - self.freqs_imag = None |
328 | 325 |
|
329 | 326 | @staticmethod |
330 | 327 | def precompute_freqs_cis(dim: list[int], end: list[int], theta: float = 256.0): |
331 | 328 | with torch.device("cpu"): |
332 | | - if is_torch_npu_available: |
333 | | - freqs_real_list = [] |
334 | | - freqs_imag_list = [] |
335 | | - for i, (d, e) in enumerate(zip(dim, end)): |
336 | | - freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d)) |
337 | | - timestep = torch.arange(e, device=freqs.device, dtype=torch.float64) |
338 | | - freqs = torch.outer(timestep, freqs).float() |
339 | | - freqs_real = torch.cos(freqs) |
340 | | - freqs_imag = torch.sin(freqs) |
341 | | - freqs_real_list.append(freqs_real.to(torch.float32)) |
342 | | - freqs_imag_list.append(freqs_imag.to(torch.float32)) |
343 | | - |
344 | | - return freqs_real_list, freqs_imag_list |
345 | | - else: |
346 | | - freqs_cis = [] |
347 | | - for i, (d, e) in enumerate(zip(dim, end)): |
348 | | - freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d)) |
349 | | - timestep = torch.arange(e, device=freqs.device, dtype=torch.float64) |
350 | | - freqs = torch.outer(timestep, freqs).float() |
351 | | - freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs).to(torch.complex64) # complex64 |
352 | | - freqs_cis.append(freqs_cis_i) |
353 | | - return freqs_cis |
| 329 | + freqs_cis = [] |
| 330 | + for i, (d, e) in enumerate(zip(dim, end)): |
| 331 | + freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d)) |
| 332 | + timestep = torch.arange(e, device=freqs.device, dtype=torch.float64) |
| 333 | + freqs = torch.outer(timestep, freqs).float() |
| 334 | + freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs).to(torch.complex64) # complex64 |
| 335 | + freqs_cis.append(freqs_cis_i) |
| 336 | + |
| 337 | + return freqs_cis |
354 | 338 |
|
355 | 339 | def __call__(self, ids: torch.Tensor): |
356 | 340 | assert ids.ndim == 2 |
357 | 341 | assert ids.shape[-1] == len(self.axes_dims) |
358 | 342 | device = ids.device |
359 | 343 |
|
360 | | - if is_torch_npu_available: |
361 | | - if self.freqs_real is None or self.freqs_imag is None: |
362 | | - freqs_real, freqs_imag = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta) |
363 | | - self.freqs_real = [fr.to(device) for fr in freqs_real] |
364 | | - self.freqs_imag = [fi.to(device) for fi in freqs_imag] |
365 | | - else: |
366 | | - # Ensure freqs_cis are on the same device as ids |
367 | | - if self.freqs_real[0].device != device: |
368 | | - self.freqs_real = [fr.to(device) for fr in freqs_real] |
369 | | - self.freqs_imag = [fi.to(device) for fi in freqs_imag] |
370 | | - |
371 | | - result = [] |
372 | | - for i in range(len(self.axes_dims)): |
373 | | - index = ids[:, i] |
374 | | - real_part = self.freqs_real[i][index] |
375 | | - imag_part = self.freqs_imag[i][index] |
376 | | - complex_part = torch.complex(real_part, imag_part) |
377 | | - result.append(complex_part) |
| 344 | + if self.freqs_cis is None: |
| 345 | + self.freqs_cis = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta) |
| 346 | + self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis] |
378 | 347 | else: |
379 | | - if self.freqs_cis is None: |
380 | | - self.freqs_cis = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta) |
| 348 | + # Ensure freqs_cis are on the same device as ids |
| 349 | + if self.freqs_cis[0].device != device: |
381 | 350 | self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis] |
382 | | - else: |
383 | | - # Ensure freqs_cis are on the same device as ids |
384 | | - if self.freqs_cis[0].device != device: |
385 | | - self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis] |
386 | | - |
387 | | - result = [] |
388 | | - for i in range(len(self.axes_dims)): |
389 | | - index = ids[:, i] |
390 | | - result.append(self.freqs_cis[i][index]) |
391 | 351 |
|
| 352 | + result = [] |
| 353 | + for i in range(len(self.axes_dims)): |
| 354 | + index = ids[:, i] |
| 355 | + result.append(torch.complex(self.freqs_cis[i].real[index], self.freqs_cis[i].imag[index])) |
392 | 356 | return torch.cat(result, dim=-1) |
393 | 357 |
|
394 | 358 |
|
|
0 commit comments