77
88
99class MultiPeriodDiscriminator (torch .nn .Module ):
10- """
11- Multi-period discriminator.
10+ """Multi-period discriminator.
1211
1312 This class implements a multi-period discriminator, which is used to
1413 discriminate between real and fake audio signals. The discriminator
@@ -20,7 +19,9 @@ class MultiPeriodDiscriminator(torch.nn.Module):
2019 def __init__ (self , checkpointing : bool = False ):
2120 super ().__init__ ()
2221 self .checkpointing = checkpointing
23- self .discriminators = torch .nn .ModuleList ([DiscriminatorS ()] + [DiscriminatorP (period ) for period in [2 , 3 , 5 , 7 , 11 , 17 , 23 , 37 ]]) # periods
22+ self .discriminators = torch .nn .ModuleList (
23+ [DiscriminatorS ()] + [DiscriminatorP (period ) for period in [2 , 3 , 5 , 7 , 11 , 17 , 23 , 37 ]]
24+ ) # periods
2425
2526 def forward (self , y , y_hat ):
2627 y_d_rs , y_d_gs , fmap_rs , fmap_gs = [], [], [], []
@@ -40,8 +41,7 @@ def forward(self, y, y_hat):
4041
4142
4243class DiscriminatorS (torch .nn .Module ):
43- """
44- Discriminator for the short-term component.
44+ """Discriminator for the short-term component.
4545
4646 This class implements a discriminator for the short-term component
4747 of the audio signal. The discriminator is composed of a series of
@@ -58,7 +58,7 @@ def __init__(self):
5858 weight_norm (torch .nn .Conv1d (256 , 1024 , 41 , 4 , groups = 64 , padding = 20 )),
5959 weight_norm (torch .nn .Conv1d (1024 , 1024 , 41 , 4 , groups = 256 , padding = 20 )),
6060 weight_norm (torch .nn .Conv1d (1024 , 1024 , 5 , 1 , padding = 2 )),
61- ]
61+ ],
6262 )
6363 self .conv_post = weight_norm (torch .nn .Conv1d (1024 , 1 , 3 , 1 , padding = 1 ))
6464 self .lrelu = torch .nn .LeakyReLU (LRELU_SLOPE )
@@ -75,8 +75,7 @@ def forward(self, x):
7575
7676
7777class DiscriminatorP (torch .nn .Module ):
78- """
79- Discriminator for the long-term component.
78+ """Discriminator for the long-term component.
8079
8180 This class implements a discriminator for the long-term component
8281 of the audio signal. The discriminator is composed of a series of
@@ -86,6 +85,7 @@ class DiscriminatorP(torch.nn.Module):
8685 Args:
8786 period (int): Period of the discriminator.
8887 kernel_size (int): Kernel size of the convolutional layers. Defaults to 5.
88+
8989 """
9090
9191 def __init__ (self , period : int , kernel_size : int = 5 ):
@@ -100,14 +100,14 @@ def __init__(self, period: int, kernel_size: int = 5):
100100 (kernel_size , 1 ),
101101 (stride , 1 ),
102102 padding = (get_padding (kernel_size , 1 ), 0 ),
103- )
103+ ),
104104 )
105105 for input_channel , output_channel , stride in zip (
106106 [1 , 32 , 128 , 512 , 1024 ], # input_channels
107107 [32 , 128 , 512 , 1024 , 1024 ], # output_channels
108108 [3 , 3 , 3 , 3 , 1 ], # strides
109109 )
110- ]
110+ ],
111111 )
112112
113113 self .conv_post = weight_norm (torch .nn .Conv2d (1024 , 1 , (3 , 1 ), 1 , padding = (1 , 0 )))
0 commit comments