Skip to content

Commit 9ff7692

Browse files
committed
Add inplace parameter for relu
1 parent 046583c commit 9ff7692

3 files changed

Lines changed: 26 additions & 12 deletions

File tree

src/ntops/kernels/relu.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,22 @@
66
from ntops.kernels.element_wise import arrangement
77

88

9-
def application(input, output):
9+
def default_application(input, output):
1010
output = max(0.0, input) # noqa: F841
1111

1212

13+
def True_application(input, output):
14+
input = max(0.0, input)
15+
output = input # noqa: F841
16+
17+
1318
@functools.cache
14-
def make(ndim):
15-
return ninetoothed.make(arrangement, application,(Tensor(ndim), Tensor(ndim)))
19+
def make(ndim, replace):
20+
if replace == True:
21+
application = True_application
22+
else:
23+
application = default_application
24+
25+
tensors = (Tensor(ndim), Tensor(ndim))
26+
27+
return ninetoothed.make(arrangement, application, tensors)

src/ntops/torch.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -120,15 +120,14 @@ def mul(input, other, *, out=None):
120120
return out
121121

122122

123-
def relu(input, out=None):
124-
if out is None:
125-
out = torch.empty_like(input)
123+
def relu(input, inplace=False):
124+
output = torch.empty_like(input)
126125

127-
kernel = ntops.kernels.relu.make(input.ndim)
126+
kernel = ntops.kernels.relu.make(input.ndim, inplace)
128127

129-
kernel(input, out)
128+
kernel(input, output)
130129

131-
return out
130+
return output
132131

133132

134133
def rsqrt(input, *, out=None):

tests/test_relu.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,10 @@ def test_cuda(shape, dtype, atol, rtol):
1414

1515
input = torch.randn(shape, dtype=dtype, device=device)
1616

17-
ninetoothed_output = ntops.torch.relu(input)
18-
reference_output = F.relu(input)
17+
for replace in (False, True):
18+
ninetoothed_output = ntops.torch.relu(input, replace)
19+
reference_output = F.relu(input, replace)
1920

20-
assert torch.allclose(ninetoothed_output, reference_output, atol=atol, rtol=rtol)
21+
assert torch.allclose(
22+
ninetoothed_output, reference_output, atol=atol, rtol=rtol
23+
)

0 commit comments

Comments
 (0)