|
1 | | -from typing import Optional, List |
| 1 | +from typing import Optional, List, Union |
2 | 2 | import math |
3 | 3 |
|
4 | 4 | import torch |
@@ -83,7 +83,7 @@ def __init__( |
83 | 83 | self.conv_pre = Conv1d( |
84 | 84 | initial_channel, upsample_initial_channel, 7, 1, padding=3 |
85 | 85 | ) |
86 | | - resblock = ResBlock1 if resblock == "1" else ResBlock2 |
| 86 | + resblockcls = ResBlock1 if resblock == "1" else ResBlock2 |
87 | 87 |
|
88 | 88 | self.ups = nn.ModuleList() |
89 | 89 | for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): |
@@ -114,12 +114,13 @@ def __init__( |
114 | 114 | self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1)) |
115 | 115 |
|
116 | 116 | self.resblocks = nn.ModuleList() |
| 117 | + ch = 0 |
117 | 118 | for i in range(len(self.ups)): |
118 | | - ch: int = upsample_initial_channel // (2 ** (i + 1)) |
| 119 | + ch = upsample_initial_channel // (2 ** (i + 1)) |
119 | 120 | for j, (k, d) in enumerate( |
120 | 121 | zip(resblock_kernel_sizes, resblock_dilation_sizes) |
121 | 122 | ): |
122 | | - self.resblocks.append(resblock(ch, k, d)) |
| 123 | + self.resblocks.append(resblockcls(ch, k, d)) |
123 | 124 |
|
124 | 125 | self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False) |
125 | 126 | self.ups.apply(call_weight_data_normal_if_Conv) |
|
0 commit comments