Skip to content

Commit 56b66ba

Browse files
committed
Add hardswish_backward_default
* It seems like we could factorize code to handle all activation function backwards easily, and maybe even all functions that can be partialled into a pointwise function.
1 parent c97612f commit 56b66ba

1 file changed

Lines changed: 9 additions & 0 deletions

File tree

src/torchjd/sparse/_diagonal_sparse_tensor.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -727,6 +727,15 @@ def hardtanh_backward_default(
727727
return DiagonalSparseTensor(new_physical, grad_output.v_to_ps)
728728

729729

730+
@DiagonalSparseTensor.implements(aten.hardswish_backward.default)
731+
def hardswish_backward_default(grad_output: DiagonalSparseTensor, self: Tensor):
732+
if isinstance(self, DiagonalSparseTensor):
733+
raise NotImplementedError()
734+
735+
new_physical = aten.hardswish_backward.default(grad_output.physical, self)
736+
return DiagonalSparseTensor(new_physical, grad_output.v_to_ps)
737+
738+
730739
@DiagonalSparseTensor.implements(aten.slice.Tensor)
731740
def slice_Tensor(
732741
t: DiagonalSparseTensor, dim: int, start: int | None, end: int | None, step: int = 1

0 commit comments

Comments
 (0)