Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 102 additions & 38 deletions python/tinyProp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch.nn as nn
import torch.nn.functional as F
from typing import Union
from torch.nn.common_types import _size_2_t
from torch.nn.common_types import _size_1_t, _size_2_t


# classes to hold TinyProp parameters on Net and Layer scope
Expand Down Expand Up @@ -55,7 +55,22 @@ def selectGradients(self, grad_output, params):
idx = torch.hstack(idx)
val = torch.cat(val)
return idx, val



#========== Helper functions ==========#

def _apply_tinyprop_mask(tp_info: "TinyPropLayer", grad_output: torch.Tensor, tp_params: TinyPropParams) -> torch.Tensor:
"""Apply the TinyProp gradient selection to the gradient tensor."""

flattened = torch.flatten(grad_output, start_dim=1)
indices, values = tp_info.selectGradients(flattened, tp_params)

masked_flat = torch.zeros_like(flattened)
if values.numel() > 0:
masked_flat[indices[0], indices[1]] = values

return masked_flat.view_as(grad_output)


#========== LINEAR ==========#

Expand Down Expand Up @@ -113,9 +128,79 @@ def forward(self, input):
return SparseLinear.apply(input, self.weight, self.tpParams, self, self.bias)


#========== CONVOLUTION ==========#
#========== CONVOLUTION 1D ==========#

class SparseConv1d(torch.autograd.Function):

@staticmethod
def forward(ctx, input, weight, bias, stride, padding, dilation, groups, padding_mode,
_reversed_padding_repeated_twice, tpParams: TinyPropParams, tpInfo: TinyPropLayer):
ctx.save_for_backward(input, weight, bias)

ctx.stride = stride
ctx.padding = padding
ctx.dilation = dilation
ctx.groups = groups
ctx.tpParams = tpParams
ctx.tpInfo = tpInfo

if padding_mode != 'zeros':
return F.conv1d(F.pad(input, _reversed_padding_repeated_twice, mode=padding_mode),
weight, bias, stride, 0, dilation, groups)
return F.conv1d(input, weight, bias, stride, padding, dilation, groups)

@staticmethod
def backward(ctx, grad_output):
input, weight, bias = ctx.saved_tensors

grad_input = grad_weight = grad_bias = None

masked_grad = _apply_tinyprop_mask(ctx.tpInfo, grad_output, ctx.tpParams)

if ctx.needs_input_grad[0]:
grad_input = torch.nn.grad.conv1d_input(input.shape, weight, masked_grad, ctx.stride,
ctx.padding, ctx.dilation, ctx.groups)
if ctx.needs_input_grad[1]:
grad_weight = torch.nn.grad.conv1d_weight(input, weight.shape, masked_grad, ctx.stride,
ctx.padding, ctx.dilation, ctx.groups)
if bias is not None and ctx.needs_input_grad[2]:
sum_dims = (0,) + tuple(range(2, masked_grad.dim()))
grad_bias = masked_grad.sum(dim=sum_dims)

return grad_input, grad_weight, grad_bias, None, None, None, None, None, None, None, None


class TinyPropConv1d(TinyPropLayer, nn.Conv1d):
def __init__(self,
in_channels: int,
out_channels: int,
kernel_size: _size_1_t,
tinyPropParams: TinyPropParams,
layer_number: int,
stride: _size_1_t = 1,
padding: Union[str, _size_1_t] = 0,
dilation: _size_1_t = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = 'zeros',
device=None,
dtype=None):
TinyPropLayer.__init__(self, tinyPropParams.number_of_layers - layer_number)
nn.Conv1d.__init__(self, in_channels, out_channels, kernel_size, stride=stride, padding=padding,
dilation=dilation, groups=groups, bias=bias, padding_mode=padding_mode,
device=device, dtype=dtype)

self.tpParams = tinyPropParams

class SparseConv2d(torch.autograd.Function):
def forward(self, input):
return SparseConv1d.apply(input, self.weight, self.bias, self.stride, self.padding, self.dilation,
self.groups, self.padding_mode, self._reversed_padding_repeated_twice,
self.tpParams, self)


#========== CONVOLUTION 2D ==========#

class SparseConv2d(torch.autograd.Function):
# keep in mind that convolution operations DO NOT reduce the batchSize (in contrast to matmul)!

@staticmethod
Expand Down Expand Up @@ -150,44 +235,19 @@ def backward(ctx, grad_output):
# Initialize all gradients w.r.t. inputs to None
grad_input = grad_weight = grad_bias = None

# This is the TinyProp part: conv can't handle sparse matrices so I have to build a masked version based on the selected gradients
out_ch = grad_output.shape[1]
out_width = grad_output.shape[2]
out_height = grad_output.shape[3]
# flatten elements to work with the gradient selection
flattened = torch.flatten(grad_output, start_dim=1)
indices, values = ctx.tpInfo.selectGradients(flattened, ctx.tpParams)
# mask grad_output by reinitializing with zeros
grad_output = torch.zeros(flattened.size())
# then loop over and set all selected gradient entries
for i in range(indices.size(1)):
grad_output[indices[0, i], indices[1, i]] = values[i]
# undo the flattening
grad_output = grad_output.view(-1, out_ch, out_width, out_height).to(weight.device)


# proceed with layer specific computations
masked_grad = _apply_tinyprop_mask(ctx.tpInfo, grad_output, ctx.tpParams)

if ctx.needs_input_grad[0]:
# can be solved by deconvolving grad_output with weight
grad_input = F.conv_transpose2d(grad_output, weight, None, ctx.stride, ctx.padding, groups=ctx.groups, dilation=ctx.dilation)
grad_input = torch.nn.grad.conv2d_input(input.shape, weight, masked_grad, ctx.stride,
ctx.padding, ctx.dilation, ctx.groups)

if ctx.needs_input_grad[1]:
# can be solved by convolving input with grad_output, but the resulting grad_weight is 5d which the conv function can't handle.
# I mitigate this problem by slicing the input by input channel. I can then do the convolution with this reduced dimension, where
# I can process the batch-dimension as input channel. Later grad_weight is constructed from these sub-convolutions

# use batch-dimension as in-channel [out, b, w, h] = [out, in, w, h]
permutated = grad_output.permute(1, 0, 2, 3)
# dismantle real input-channel
input_channels = torch.unbind(input, dim=1)
res = []
for channel in input_channels:
res.append(F.conv2d(channel, permutated, None, ctx.stride, ctx.padding, groups=ctx.groups, dilation=ctx.dilation))
grad_weight = torch.stack(res, dim=0).permute(1, 0, 2, 3)
grad_weight = torch.nn.grad.conv2d_weight(input, weight.shape, masked_grad, ctx.stride,
ctx.padding, ctx.dilation, ctx.groups)

if bias is not None and ctx.needs_input_grad[2]:
# simply sum up all elements over width, height
grad_bias = torch.sum(grad_output, dim=(2,3))
sum_dims = (0,) + tuple(range(2, masked_grad.dim()))
grad_bias = masked_grad.sum(dim=sum_dims)
return grad_input, grad_weight, grad_bias, None, None, None, None, None, None, None, None


Expand All @@ -202,11 +262,15 @@ def __init__(self,
stride: _size_2_t = 1,
padding: Union[str, _size_2_t] = 0,
dilation: _size_2_t = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = 'zeros',
device = None,
dtype = None):
TinyPropLayer.__init__(self, tinyPropParams.number_of_layers - layer_number)
nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, 1, bias, device=device, dtype=dtype)
nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size, stride=stride, padding=padding,
dilation=dilation, groups=groups, bias=bias, padding_mode=padding_mode,
device=device, dtype=dtype)

# Saving variables like this will pass it by REFERENCE, so changes
# made in backwards are reflected in layer
Expand Down
4 changes: 4 additions & 0 deletions src/aifes.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ extern "C" {

// Include the layer base implementations
#include "basic/base/ailayer/ailayer_dense.h"
#include "basic/base/ailayer/ailayer_conv1d.h"
#include "basic/base/ailayer/ailayer_conv2d.h"
#include "basic/base/ailayer/ailayer_input.h"
#include "basic/base/ailayer/ailayer_relu.h"
#include "basic/base/ailayer/ailayer_leaky_relu.h"
Expand All @@ -68,6 +70,8 @@ extern "C" {

// Include the layers in default implementation
#include "basic/default/ailayer/ailayer_dense_default.h"
#include "basic/default/ailayer/ailayer_conv1d_default.h"
#include "basic/default/ailayer/ailayer_conv2d_default.h"
#include "basic/default/ailayer/ailayer_input_default.h"
#include "basic/default/ailayer/ailayer_relu_default.h"
#include "basic/default/ailayer/ailayer_leaky_relu_default.h"
Expand Down
Loading