Skip to content

Commit a55bb05

Browse files
committed
Add ntops.torch.pooling._calculate_output_size
1 parent d8235d6 commit a55bb05

3 files changed

Lines changed: 39 additions & 32 deletions

File tree

src/ntops/torch/avg_pool2d.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
1-
import math
2-
31
import torch
42

53
import ntops
4+
from ntops.torch.pooling import _calculate_output_size
65
from ntops.torch.utils import _cached_make
76

87

@@ -32,18 +31,12 @@ def avg_pool2d(
3231

3332
n, c, h, w = input.shape
3433

35-
def _calculate_output_size(input_size, kernel_size, stride, padding, ceil_mode):
36-
int_ = math.ceil if ceil_mode else math.floor
37-
38-
result = int_((input_size + 2 * padding - kernel_size) / stride + 1)
39-
40-
if ceil_mode and (result - 1) * stride >= input_size + padding:
41-
result -= 1
42-
43-
return result
44-
45-
h_ = _calculate_output_size(h, kernel_size[0], stride[0], padding[0], ceil_mode)
46-
w_ = _calculate_output_size(w, kernel_size[1], stride[1], padding[1], ceil_mode)
34+
h_ = _calculate_output_size(
35+
h, kernel_size[0], stride=stride[0], padding=padding[0], ceil_mode=ceil_mode
36+
)
37+
w_ = _calculate_output_size(
38+
w, kernel_size[1], stride=stride[1], padding=padding[1], ceil_mode=ceil_mode
39+
)
4740

4841
output = torch.empty((n, c, h_, w_), dtype=input.dtype, device=input.device)
4942

src/ntops/torch/max_pool2d.py

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
1-
import math
2-
31
import torch
42

53
import ntops
4+
from ntops.torch.pooling import _calculate_output_size
65
from ntops.torch.utils import _cached_make
76

87

@@ -33,25 +32,21 @@ def max_pool2d(
3332

3433
n, c, h, w = input.shape
3534

36-
def _calculate_output_size(
37-
input_size, kernel_size, stride, padding, dilation, ceil_mode
38-
):
39-
int_ = math.ceil if ceil_mode else math.floor
40-
41-
result = int_(
42-
(input_size + 2 * padding - dilation * (kernel_size - 1) - 1) / stride + 1
43-
)
44-
45-
if ceil_mode and (result - 1) * stride >= input_size + padding:
46-
result -= 1
47-
48-
return result
49-
5035
h_ = _calculate_output_size(
51-
h, kernel_size[0], stride[0], padding[0], dilation[0], ceil_mode
36+
h,
37+
kernel_size[0],
38+
stride=stride[0],
39+
padding=padding[0],
40+
dilation=dilation[0],
41+
ceil_mode=ceil_mode,
5242
)
5343
w_ = _calculate_output_size(
54-
w, kernel_size[1], stride[1], padding[1], dilation[1], ceil_mode
44+
w,
45+
kernel_size[1],
46+
stride=stride[1],
47+
padding=padding[1],
48+
dilation=dilation[1],
49+
ceil_mode=ceil_mode,
5550
)
5651

5752
output = torch.empty((n, c, h_, w_), dtype=input.dtype, device=input.device)

src/ntops/torch/pooling.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import math
2+
3+
4+
def _calculate_output_size(
5+
input_size, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False
6+
):
7+
if stride is None:
8+
stride = kernel_size
9+
10+
int_ = math.ceil if ceil_mode else math.floor
11+
12+
result = int_(
13+
(input_size + 2 * padding - dilation * (kernel_size - 1) - 1) / stride + 1
14+
)
15+
16+
if ceil_mode and (result - 1) * stride >= input_size + padding:
17+
result -= 1
18+
19+
return result

0 commit comments

Comments
 (0)