|
24 | 24 | from ...models.attention_processor import Attention |
25 | 25 | from ...models.modeling_utils import ModelMixin |
26 | 26 | from ...models.normalization import RMSNorm |
| 27 | +from ...utils import is_torch_npu_available |
27 | 28 | from ...utils.torch_utils import maybe_allow_in_graph |
28 | 29 | from ..attention_dispatch import dispatch_attention_fn |
29 | 30 | from ..modeling_outputs import Transformer2DModelOutput |
@@ -322,37 +323,72 @@ def __init__( |
322 | 323 | self.axes_lens = axes_lens |
323 | 324 | assert len(axes_dims) == len(axes_lens), "axes_dims and axes_lens must have the same length" |
324 | 325 | self.freqs_cis = None |
| 326 | + self.freqs_real = None |
| 327 | + self.freqs_imag = None |
325 | 328 |
|
326 | 329 | @staticmethod |
327 | 330 | def precompute_freqs_cis(dim: list[int], end: list[int], theta: float = 256.0): |
328 | 331 | with torch.device("cpu"): |
329 | | - freqs_cis = [] |
330 | | - for i, (d, e) in enumerate(zip(dim, end)): |
331 | | - freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d)) |
332 | | - timestep = torch.arange(e, device=freqs.device, dtype=torch.float64) |
333 | | - freqs = torch.outer(timestep, freqs).float() |
334 | | - freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs).to(torch.complex64) # complex64 |
335 | | - freqs_cis.append(freqs_cis_i) |
336 | | - |
337 | | - return freqs_cis |
| 332 | + if is_torch_npu_available: |
| 333 | + freqs_real_list = [] |
| 334 | + freqs_imag_list = [] |
| 335 | + for i, (d, e) in enumerate(zip(dim, end)): |
| 336 | + freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d)) |
| 337 | + timestep = torch.arange(e, device=freqs.device, dtype=torch.float64) |
| 338 | + freqs = torch.outer(timestep, freqs).float() |
| 339 | + freqs_real = torch.cos(freqs) |
| 340 | + freqs_imag = torch.sin(freqs) |
| 341 | + freqs_real_list.append(freqs_real.to(torch.float32)) |
| 342 | + freqs_imag_list.append(freqs_imag.to(torch.float32)) |
| 343 | + |
| 344 | + return freqs_real_list, freqs_imag_list |
| 345 | + else: |
| 346 | + freqs_cis = [] |
| 347 | + for i, (d, e) in enumerate(zip(dim, end)): |
| 348 | + freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d)) |
| 349 | + timestep = torch.arange(e, device=freqs.device, dtype=torch.float64) |
| 350 | + freqs = torch.outer(timestep, freqs).float() |
| 351 | + freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs).to(torch.complex64) # complex64 |
| 352 | + freqs_cis.append(freqs_cis_i) |
| 353 | + return freqs_cis |
338 | 354 |
|
339 | 355 | def __call__(self, ids: torch.Tensor): |
340 | 356 | assert ids.ndim == 2 |
341 | 357 | assert ids.shape[-1] == len(self.axes_dims) |
342 | 358 | device = ids.device |
343 | 359 |
|
344 | | - if self.freqs_cis is None: |
345 | | - self.freqs_cis = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta) |
346 | | - self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis] |
| 360 | + if is_torch_npu_available: |
| 361 | + if self.freqs_real is None or self.freqs_imag is None: |
| 362 | + freqs_real, freqs_imag = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta) |
| 363 | + self.freqs_real = [fr.to(device) for fr in freqs_real] |
| 364 | + self.freqs_imag = [fi.to(device) for fi in freqs_imag] |
| 365 | + else: |
| 366 | + # Ensure freqs_cis are on the same device as ids |
| 367 | + if self.freqs_real[0].device != device: |
| 368 | + self.freqs_real = [fr.to(device) for fr in freqs_real] |
| 369 | + self.freqs_imag = [fi.to(device) for fi in freqs_imag] |
| 370 | + |
| 371 | + result = [] |
| 372 | + for i in range(len(self.axes_dims)): |
| 373 | + index = ids[:, i] |
| 374 | + real_part = self.freqs_real[i][index] |
| 375 | + imag_part = self.freqs_imag[i][index] |
| 376 | + complex_part = torch.complex(real_part, imag_part) |
| 377 | + result.append(complex_part) |
347 | 378 | else: |
348 | | - # Ensure freqs_cis are on the same device as ids |
349 | | - if self.freqs_cis[0].device != device: |
| 379 | + if self.freqs_cis is None: |
| 380 | + self.freqs_cis = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta) |
350 | 381 | self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis] |
| 382 | + else: |
| 383 | + # Ensure freqs_cis are on the same device as ids |
| 384 | + if self.freqs_cis[0].device != device: |
| 385 | + self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis] |
| 386 | + |
| 387 | + result = [] |
| 388 | + for i in range(len(self.axes_dims)): |
| 389 | + index = ids[:, i] |
| 390 | + result.append(self.freqs_cis[i][index]) |
351 | 391 |
|
352 | | - result = [] |
353 | | - for i in range(len(self.axes_dims)): |
354 | | - index = ids[:, i] |
355 | | - result.append(self.freqs_cis[i][index]) |
356 | 392 | return torch.cat(result, dim=-1) |
357 | 393 |
|
358 | 394 |
|
|
0 commit comments