|
1 | 1 | import functools |
2 | 2 |
|
3 | | -import ninetoothed |
4 | 3 | import ninetoothed.language as ntl |
5 | | -from ninetoothed import Symbol, Tensor |
| 4 | +from ninetoothed import Tensor |
6 | 5 |
|
7 | | -BLOCK_SIZE = ninetoothed.block_size() |
8 | | - |
9 | | -KERNEL_SIZE_H = Symbol("kernel_size_h", constexpr=True, upper_bound=16) |
10 | | -KERNEL_SIZE_W = Symbol("kernel_size_w", constexpr=True, upper_bound=16) |
11 | | -STRIDE_H = Symbol("stride_h", constexpr=True) |
12 | | -STRIDE_W = Symbol("stride_w", constexpr=True) |
13 | | -PADDING_H = Symbol("padding_h", constexpr=True) |
14 | | -PADDING_W = Symbol("padding_w", constexpr=True) |
15 | | -DILATION_H = Symbol("dilation_h", constexpr=True) |
16 | | -DILATION_W = Symbol("dilation_w", constexpr=True) |
17 | | - |
18 | | - |
19 | | -def arrangement( |
20 | | - input, |
21 | | - output, |
22 | | - kernel_size_h=None, |
23 | | - kernel_size_w=None, |
24 | | - stride_h=None, |
25 | | - stride_w=None, |
26 | | - padding_h=None, |
27 | | - padding_w=None, |
28 | | - dilation_h=None, |
29 | | - dilation_w=None, |
30 | | - ceil_mode=None, |
31 | | - block_size=None, |
32 | | -): |
33 | | - if kernel_size_h is None: |
34 | | - kernel_size_h = KERNEL_SIZE_H |
35 | | - |
36 | | - if kernel_size_w is None: |
37 | | - kernel_size_w = KERNEL_SIZE_W |
38 | | - |
39 | | - if stride_h is None: |
40 | | - stride_h = STRIDE_H |
41 | | - |
42 | | - if stride_w is None: |
43 | | - stride_w = STRIDE_W |
44 | | - |
45 | | - if padding_h is None: |
46 | | - padding_h = PADDING_H |
47 | | - |
48 | | - if padding_w is None: |
49 | | - padding_w = PADDING_W |
50 | | - |
51 | | - if dilation_h is None: |
52 | | - dilation_h = DILATION_H |
53 | | - |
54 | | - if dilation_w is None: |
55 | | - dilation_w = DILATION_W |
56 | | - |
57 | | - if ceil_mode is None: |
58 | | - ceil_mode = False |
59 | | - |
60 | | - if block_size is None: |
61 | | - block_size = BLOCK_SIZE |
62 | | - |
63 | | - input_arranged = input.pad( |
64 | | - ((0, 0), (0, 0), (padding_h, padding_h), (padding_w, padding_w)) |
65 | | - ) |
66 | | - input_arranged = input_arranged.tile( |
67 | | - (1, 1, kernel_size_h, kernel_size_w), |
68 | | - strides=(-1, -1, stride_h, stride_w), |
69 | | - dilation=(1, 1, dilation_h, dilation_w), |
70 | | - floor_mode=not ceil_mode, |
71 | | - ) |
72 | | - input_arranged = input_arranged.ravel() |
73 | | - input_arranged = input_arranged.flatten(end_dim=4).flatten(start_dim=1) |
74 | | - input_arranged = input_arranged.tile((block_size, -1)) |
75 | | - |
76 | | - output_arranged = output.tile((1, 1, 1, 1)) |
77 | | - output_arranged = output_arranged.ravel() |
78 | | - output_arranged = output_arranged.flatten(end_dim=4).flatten(start_dim=1) |
79 | | - output_arranged = output_arranged.tile((block_size, -1)) |
80 | | - output_arranged.dtype = output_arranged.dtype.squeeze(1) |
81 | | - |
82 | | - return input_arranged, output_arranged |
| 6 | +from ntops.kernels.pooling import arrangement |
83 | 7 |
|
84 | 8 |
|
85 | 9 | def application(input, output): |
|
0 commit comments