-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsas_linear.py
More file actions
90 lines (68 loc) · 2.83 KB
/
sas_linear.py
File metadata and controls
90 lines (68 loc) · 2.83 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
# Copyright (C) 2025 Denso IT Laboratory, Inc.
# All Rights Reserved
import math
import torch
from torch import autograd, nn
import torch.nn.functional as F
import numpy as np
from timm.models.layers import trunc_normal_
########################################################################
# SASLinear
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
class SASLinear(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=[0,0], dilation=1, groups=1, bias=True, sparse_m=2, sparse_n=2, is_conv=False):
super(SASLinear, self).__init__()
self.kernel_size = kernel_size
self.padding = padding
self.dilation = dilation
self.groups = groups
self.in_channels = in_channels
self.out_channels = out_channels
self.padding = padding
self.stride = stride
self.sparse_m = sparse_m
self.sparse_n = sparse_n
self.is_conv = is_conv
if not isinstance(stride, (tuple, list)):
stride = (stride, stride)
if is_conv:
self.register_parameter("weight", nn.Parameter(torch.zeros(out_channels, sparse_m*in_channels//groups, kernel_size[0], kernel_size[1])))
else:
self.register_parameter("weight", nn.Parameter(torch.zeros(out_channels, sparse_m*in_channels//groups)))
if bias:
self.register_parameter("bias", nn.Parameter(torch.zeros(out_channels)))
else:
self.bias = None
self.reset_parameters()
def reset_parameters(self) -> None:
# Setting a=sqrt(5) in kaiming_uniform is the same as initializing with
# uniform(-1/sqrt(in_features), 1/sqrt(in_features)). For details, see
# https://github.com/pytorch/pytorch/issues/57109
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
if self.bias is not None:
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
fan_in = fan_in / self.sparse_m
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
nn.init.uniform_(self.bias, -bound, bound)
def SAS_proj(self, X):
if self.sparse_m == 1:
return X
M = np.abs(self.sparse_m)
S = X.shape
B,C, = X.shape
S = [B, C, M]
x_proj = torch.zeros(S, dtype=X.dtype, device=X.device)
Y = X.abs()
with torch.no_grad():
sgn = (X>0) # sgn = [Flase, True, False, True, ...]
yp = Y * sgn.float()
yn = Y * (1.0 - sgn.float())
x_proj[...,0] = yp
x_proj[...,1] = yn
x_proj = x_proj.flatten(1)
return x_proj
def forward(self, x):
# Sparse Projection
x_proj = self.SAS_proj(x)
output = F.linear(x_proj, self.weight, bias=self.bias)
return output