-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Expand file tree
/
Copy pathConvModule.cs
More file actions
77 lines (70 loc) · 2.96 KB
/
ConvModule.cs
File metadata and controls
77 lines (70 loc) · 2.96 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using TorchSharp;
using TorchSharp.Modules;
using static TorchSharp.torch;
using static TorchSharp.torch.nn;
namespace Microsoft.ML.TorchSharp.AutoFormerV2
{
/// <summary>
/// The convolution and activation module.
/// </summary>
public class ConvModule : Module<Tensor, Tensor>
{
#pragma warning disable MSML_PrivateFieldName // Need to match TorchSharp model names.
private readonly Conv2d conv;
private readonly ReLU activation;
private readonly bool useRelu;
private bool _disposedValue;
#pragma warning restore MSML_PrivateFieldName
/// <summary>
/// Initializes a new instance of the <see cref="ConvModule"/> class.
/// </summary>
/// <param name="inChannel">The input channels of convolution layer.</param>
/// <param name="outChannel">The output channels of convolution layer.</param>
/// <param name="kernelSize">The kernel size of convolution layer.</param>
/// <param name="stride">The stride of convolution layer.</param>
/// <param name="padding">The padding of convolution layer.</param>
/// <param name="dilation">The dilation of convolution layer.</param>
/// <param name="bias">The bias of convolution layer.</param>
/// <param name="useRelu">Whether use Relu activation function.</param>
public ConvModule(int inChannel, int outChannel, int kernelSize, int stride = 1, int padding = 0, int dilation = 1, bool bias = true, bool useRelu = true)
: base(nameof(ConvModule))
{
this.conv = nn.Conv2d(in_channels: inChannel, out_channels: outChannel, kernel_size: kernelSize, stride: stride, padding: padding, dilation: dilation, bias: bias);
this.useRelu = useRelu;
if (this.useRelu)
{
this.activation = nn.ReLU();
}
}
/// <inheritdoc/>
[System.Diagnostics.CodeAnalysis.SuppressMessage("Naming", "MSML_GeneralName:This name should be PascalCased", Justification = "Need to match TorchSharp.")]
public override Tensor forward(Tensor x)
{
using (var scope = torch.NewDisposeScope())
{
x = this.conv.forward(x);
if (this.useRelu)
{
x = this.activation.forward(x);
}
return x.MoveToOuterDisposeScope();
}
}
protected override void Dispose(bool disposing)
{
if (!_disposedValue)
{
if (disposing)
{
conv.Dispose();
activation?.Dispose();
_disposedValue = true;
}
}
base.Dispose(disposing);
}
}
}