Skip to content

Commit 1015917

Browse files
committed
Implement the inplace logic in src/ntops/torch.py
1 parent c37a940 commit 1015917

2 files changed

Lines changed: 7 additions & 14 deletions

File tree

src/ntops/kernels/relu.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,22 +6,12 @@
66
from ntops.kernels.element_wise import arrangement
77

88

9-
def default_application(input, output):
9+
def application(input, output):
1010
output = max(0.0, input) # noqa: F841
1111

1212

13-
def True_application(input, output):
14-
input = max(0.0, input)
15-
output = input # noqa: F841
16-
17-
1813
@functools.cache
19-
def make(ndim, replace):
20-
if replace == True:
21-
application = True_application
22-
else:
23-
application = default_application
24-
14+
def make(ndim):
2515
tensors = (Tensor(ndim), Tensor(ndim))
2616

2717
return ninetoothed.make(arrangement, application, tensors)

src/ntops/torch.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,9 +121,12 @@ def mul(input, other, *, out=None):
121121

122122

123123
def relu(input, inplace=False):
124-
output = torch.empty_like(input)
124+
if inplace:
125+
output = input
126+
else:
127+
output = torch.empty_like(input)
125128

126-
kernel = ntops.kernels.relu.make(input.ndim, inplace)
129+
kernel = ntops.kernels.relu.make(input.ndim)
127130

128131
kernel(input, output)
129132

0 commit comments

Comments
 (0)