|
21 | 21 | from .normalization import RMSNorm |
22 | 22 |
|
23 | 23 |
|
| 24 | +def _prepare_fir_kernel( |
| 25 | + kernel: torch.Tensor | None, |
| 26 | + *, |
| 27 | + factor: int, |
| 28 | + gain: float, |
| 29 | + device: torch.device, |
| 30 | + dtype: torch.dtype, |
| 31 | + upsample: bool = False, |
| 32 | +) -> torch.Tensor: |
| 33 | + if kernel is None: |
| 34 | + kernel = [1] * factor |
| 35 | + |
| 36 | + kernel = torch.as_tensor(kernel, device=device, dtype=torch.float32) |
| 37 | + if kernel.ndim == 1: |
| 38 | + kernel = torch.outer(kernel, kernel) |
| 39 | + kernel = kernel / torch.sum(kernel) |
| 40 | + |
| 41 | + scale = gain * (factor**2) if upsample else gain |
| 42 | + return (kernel * scale).to(device=device, dtype=dtype) |
| 43 | + |
| 44 | + |
24 | 45 | class Upsample1D(nn.Module): |
25 | 46 | """A 1D upsampling layer with an optional convolution. |
26 | 47 |
|
@@ -253,17 +274,14 @@ def _upsample_2d( |
253 | 274 |
|
254 | 275 | assert isinstance(factor, int) and factor >= 1 |
255 | 276 |
|
256 | | - # Setup filter kernel. |
257 | | - if kernel is None: |
258 | | - kernel = [1] * factor |
259 | | - |
260 | | - # setup kernel |
261 | | - kernel = torch.tensor(kernel, dtype=torch.float32) |
262 | | - if kernel.ndim == 1: |
263 | | - kernel = torch.outer(kernel, kernel) |
264 | | - kernel /= torch.sum(kernel) |
265 | | - |
266 | | - kernel = kernel * (gain * (factor**2)) |
| 277 | + kernel = _prepare_fir_kernel( |
| 278 | + kernel, |
| 279 | + factor=factor, |
| 280 | + gain=gain, |
| 281 | + device=hidden_states.device, |
| 282 | + dtype=hidden_states.dtype, |
| 283 | + upsample=True, |
| 284 | + ) |
267 | 285 |
|
268 | 286 | if self.use_conv: |
269 | 287 | convH = weight.shape[2] |
@@ -300,14 +318,14 @@ def _upsample_2d( |
300 | 318 |
|
301 | 319 | output = upfirdn2d_native( |
302 | 320 | inverse_conv, |
303 | | - torch.tensor(kernel, device=inverse_conv.device), |
| 321 | + kernel.to(device=inverse_conv.device, dtype=inverse_conv.dtype), |
304 | 322 | pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2 + 1), |
305 | 323 | ) |
306 | 324 | else: |
307 | 325 | pad_value = kernel.shape[0] - factor |
308 | 326 | output = upfirdn2d_native( |
309 | 327 | hidden_states, |
310 | | - torch.tensor(kernel, device=hidden_states.device), |
| 328 | + kernel, |
311 | 329 | up=factor, |
312 | 330 | pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2), |
313 | 331 | ) |
@@ -496,19 +514,18 @@ def upsample_2d( |
496 | 514 | Tensor of the shape `[N, C, H * factor, W * factor]` |
497 | 515 | """ |
498 | 516 | assert isinstance(factor, int) and factor >= 1 |
499 | | - if kernel is None: |
500 | | - kernel = [1] * factor |
501 | | - |
502 | | - kernel = torch.tensor(kernel, dtype=torch.float32) |
503 | | - if kernel.ndim == 1: |
504 | | - kernel = torch.outer(kernel, kernel) |
505 | | - kernel /= torch.sum(kernel) |
506 | | - |
507 | | - kernel = kernel * (gain * (factor**2)) |
| 517 | + kernel = _prepare_fir_kernel( |
| 518 | + kernel, |
| 519 | + factor=factor, |
| 520 | + gain=gain, |
| 521 | + device=hidden_states.device, |
| 522 | + dtype=hidden_states.dtype, |
| 523 | + upsample=True, |
| 524 | + ) |
508 | 525 | pad_value = kernel.shape[0] - factor |
509 | 526 | output = upfirdn2d_native( |
510 | 527 | hidden_states, |
511 | | - kernel.to(device=hidden_states.device), |
| 528 | + kernel, |
512 | 529 | up=factor, |
513 | 530 | pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2), |
514 | 531 | ) |
|
0 commit comments