2222input_t1 = Tuple [torch .Tensor ] # Input x
2323
2424test_data_suite = {
25- "4dim_last1dim" : lambda : (torch .rand (1 , 1 , 16 , 16 ), (1 , 1 , 0 , 0 , 0 , 0 , 0 , 0 ), 1 ),
26- "4dim_last2dim" : lambda : (torch .rand (1 , 1 , 16 , 16 ), (1 , 0 , 1 , 0 , 0 , 0 , 0 , 0 ), 2 ),
27- "4dim_last3dim" : lambda : (torch .rand (1 , 1 , 16 , 16 ), (1 , 1 , 0 , 2 , 0 , 2 , 0 , 0 ), 3 ),
28- "4dim_last4dim" : lambda : (torch .rand (1 , 1 , 16 , 16 ), (1 , 0 , 1 , 1 , 0 , 2 , 0 , 2 ), 4 ),
29- "3dim_last1dim" : lambda : (torch .rand (1 , 1 , 16 ), (1 , 1 , 0 , 0 , 0 , 0 ), 1 ),
30- "3dim_last2dim" : lambda : (torch .rand (1 , 1 , 16 ), (1 , 0 , 1 , 1 , 0 , 0 ), 2 ),
31- "3dim_last3dim" : lambda : (torch .rand (1 , 1 , 16 ), (1 , 0 , 1 , 0 , 1 , 1 ), 3 ),
32- "2dim_last1dim" : lambda : (torch .rand (1 , 1 , 16 ), (1 , 1 , 0 , 0 ), 1 ),
33- "2dim_last2dim" : lambda : (torch .rand (1 , 1 , 16 ), (1 , 0 , 1 , 1 ), 2 ),
25+ "4dim_last1dim" : lambda : (
26+ torch .rand (1 , 1 , 16 , 16 ),
27+ (1 , 1 , 0 , 0 , 0 , 0 , 0 , 0 ),
28+ 1 ,
29+ "constant" ,
30+ ),
31+ "4dim_last2dim" : lambda : (
32+ torch .rand (1 , 1 , 16 , 16 ),
33+ (1 , 0 , 1 , 0 , 0 , 0 , 0 , 0 ),
34+ 2 ,
35+ "constant" ,
36+ ),
37+ "4dim_last3dim" : lambda : (
38+ torch .rand (1 , 1 , 16 , 16 ),
39+ (1 , 1 , 0 , 2 , 0 , 2 , 0 , 0 ),
40+ 3 ,
41+ "constant" ,
42+ ),
43+ "4dim_last4dim" : lambda : (
44+ torch .rand (1 , 1 , 16 , 16 ),
45+ (1 , 0 , 1 , 1 , 0 , 2 , 0 , 2 ),
46+ 4 ,
47+ "constant" ,
48+ ),
49+ "3dim_last1dim" : lambda : (torch .rand (1 , 1 , 16 ), (1 , 1 , 0 , 0 , 0 , 0 ), 1 , "constant" ),
50+ "3dim_last2dim" : lambda : (torch .rand (1 , 1 , 16 ), (1 , 0 , 1 , 1 , 0 , 0 ), 2 , "constant" ),
51+ "3dim_last3dim" : lambda : (torch .rand (1 , 1 , 16 ), (1 , 0 , 1 , 0 , 1 , 1 ), 3 , "constant" ),
52+ "2dim_last1dim" : lambda : (torch .rand (1 , 1 , 16 ), (1 , 1 , 0 , 0 ), 1 , "constant" ),
53+ "2dim_last2dim" : lambda : (torch .rand (1 , 1 , 16 ), (1 , 0 , 1 , 1 ), 2 , "constant" ),
54+ "4dim_reflect" : lambda : (
55+ torch .rand (6 , 6 , 6 , 6 ),
56+ (3 , 3 , 3 , 3 , 3 , 3 ),
57+ None ,
58+ "reflect" ,
59+ ),
60+ "4dim_replicate" : lambda : (
61+ torch .rand (3 , 3 , 3 , 3 ),
62+ (3 , 3 , 3 , 3 , 3 , 3 ),
63+ None ,
64+ "replicate" ,
65+ ),
66+ "4dim_circular" : lambda : (
67+ torch .rand (3 , 3 , 3 , 3 ),
68+ (3 , 3 , 3 , 3 , 3 , 3 ),
69+ None ,
70+ "circular" ,
71+ ),
72+ "2dim_reflect" : lambda : (
73+ torch .rand (6 , 6 ),
74+ (3 , 3 ),
75+ None ,
76+ "reflect" ,
77+ ),
78+ "2dim_replicate" : lambda : (
79+ torch .rand (3 , 3 ),
80+ (3 , 3 ),
81+ None ,
82+ "replicate" ,
83+ ),
84+ "2dim_circular" : lambda : (
85+ torch .rand (3 , 3 ),
86+ (3 , 3 ),
87+ None ,
88+ "circular" ,
89+ ),
3490}
3591
3692test_data_suite_bf16 = {
3793 "4dim_last1dim_bf16" : lambda : (
3894 torch .rand (1 , 1 , 8 , 8 , dtype = torch .bfloat16 ),
3995 (1 , 1 , 0 , 0 , 0 , 0 , 0 , 0 ),
4096 1.0 ,
97+ "constant" ,
4198 ),
4299 "3dim_last1dim_bf16" : lambda : (
43100 torch .rand (1 , 1 , 8 , dtype = torch .bfloat16 ),
44101 (1 , 0 , 1 , 0 , 0 , 0 ),
45102 - 0.5 ,
103+ "constant" ,
46104 ),
47105}
48106test_data_suite_fp16 = {
49107 "4dim_last1dim_fp16" : lambda : (
50108 torch .rand (1 , 1 , 8 , 8 , dtype = torch .float16 ),
51109 (1 , 1 , 0 , 0 , 0 , 0 , 0 , 0 ),
52110 1.0 ,
111+ "constant" ,
53112 ),
54113 "3dim_last1dim_fp16" : lambda : (
55114 torch .rand (1 , 1 , 8 , dtype = torch .float16 ),
56115 (1 , 0 , 1 , 0 , 0 , 0 ),
57116 - 0.5 ,
117+ "constant" ,
58118 ),
59119}
60120
61121
62122class ConstantPadND (torch .nn .Module ):
63- def __init__ (self , pad : Tuple , value : float | None = None ):
123+ def __init__ (
124+ self ,
125+ pad : Tuple ,
126+ value : float | None = None ,
127+ mode : str = "constant" ,
128+ ):
64129 super ().__init__ ()
65130 self .value = value
131+ self .mode = mode
66132 nonzero_idx = len (pad )
67133 for i in range (0 , len (pad ), 2 ):
68134 if pad [i ] + pad [i + 1 ] == 0 :
@@ -71,18 +137,17 @@ def __init__(self, pad: Tuple, value: float | None = None):
71137 self .pad = pad [:nonzero_idx ]
72138
73139 def forward (self , x : torch .Tensor ):
74- x = F .pad (x , pad = self .pad , mode = "constant" , value = self .value )
75- return x
140+ return F .pad (x , pad = self .pad , mode = self .mode , value = self .value )
76141
77142
78143@common .parametrize (
79144 "test_data" ,
80145 test_data_suite | test_data_suite_bf16 | test_data_suite_fp16 ,
81146)
82147def test_constant_pad_nd_tosa_FP (test_data : Tuple ):
83- test_data , padding , value = test_data ()
148+ test_data , padding , value , mode = test_data ()
84149 pipeline = TosaPipelineFP [input_t1 ](
85- ConstantPadND (padding , value ),
150+ ConstantPadND (padding , value , mode ),
86151 (test_data ,),
87152 aten_op ,
88153 exir_op ,
@@ -93,9 +158,9 @@ def test_constant_pad_nd_tosa_FP(test_data: Tuple):
93158
94159@common .parametrize ("test_data" , test_data_suite )
95160def test_constant_pad_nd_tosa_INT (test_data : Tuple ):
96- test_data , padding , value = test_data ()
161+ test_data , padding , value , mode = test_data ()
97162 pipeline = TosaPipelineINT [input_t1 ](
98- ConstantPadND (padding , value ),
163+ ConstantPadND (padding , value , mode ),
99164 (test_data ,),
100165 aten_op ,
101166 exir_op ,
@@ -106,9 +171,9 @@ def test_constant_pad_nd_tosa_INT(test_data: Tuple):
106171@common .parametrize ("test_data" , test_data_suite )
107172def test_constant_pad_nd_tosa_INT_a16w8 (test_data : Tuple ):
108173 """Test constant_pad_nd op with int16 I/O quantization for TOSA INT."""
109- test_data , padding , value = test_data ()
174+ test_data , padding , value , mode = test_data ()
110175 pipeline = TosaPipelineINT [input_t1 ](
111- ConstantPadND (padding , value ),
176+ ConstantPadND (padding , value , mode ),
112177 (test_data ,),
113178 aten_op ,
114179 exir_op ,
@@ -120,9 +185,9 @@ def test_constant_pad_nd_tosa_INT_a16w8(test_data: Tuple):
120185@common .parametrize ("test_data" , test_data_suite | test_data_suite_fp16 )
121186@common .SkipIfNoModelConverter
122187def test_constant_pad_nd_vgf_no_quant (test_data : Tuple ):
123- inp , padding , value = test_data ()
188+ inp , padding , value , mode = test_data ()
124189 pipeline = VgfPipeline [input_t1 ](
125- ConstantPadND (padding , value ),
190+ ConstantPadND (padding , value , mode ),
126191 (inp ,),
127192 aten_op ,
128193 exir_op ,
@@ -134,9 +199,9 @@ def test_constant_pad_nd_vgf_no_quant(test_data: Tuple):
134199@common .parametrize ("test_data" , test_data_suite )
135200@common .SkipIfNoModelConverter
136201def test_constant_pad_nd_vgf_quant (test_data : Tuple ):
137- inp , padding , value = test_data ()
202+ inp , padding , value , mode = test_data ()
138203 pipeline = VgfPipeline [input_t1 ](
139- ConstantPadND (padding , value ),
204+ ConstantPadND (padding , value , mode ),
140205 (inp ,),
141206 aten_op ,
142207 exir_op ,
0 commit comments