Skip to content

Commit debe8a9

Browse files
committed
fix CI
1 parent b0c9ec5 commit debe8a9

9 files changed

Lines changed: 148 additions & 43 deletions

File tree

paconvert/api_mapping.json

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5328,9 +5328,6 @@
53285328
"axis": 0
53295329
}
53305330
},
5331-
"torch.float8_e4m3fn": {
5332-
"Matcher": "ChangePrefixMatcher"
5333-
},
53345331
"torch.float_power": {
53355332
"Matcher": "FloatPowerMatcher",
53365333
"min_input_args": 2,

paconvert/attribute_mapping.json

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@
6868
},
6969
"torch.autograd.function.FunctionCtx.saved_tensors": {
7070
"Matcher": "Attribute2Func",
71-
"paddle_api": "paddle.autograd.PyLayerContext.saved_tensor"
71+
"paddle_api": "paddle.autograd.function.FunctionCtx.saved_tensor"
7272
},
7373
"torch.autograd.profiler.profile.self_cpu_time_total": {},
7474
"torch.backends.cuda.matmul.allow_tf32": {
@@ -148,6 +148,9 @@
148148
"torch.float64": {
149149
"Matcher": "ChangePrefixMatcher"
150150
},
151+
"torch.float8_e4m3fn": {
152+
"Matcher": "ChangePrefixMatcher"
153+
},
151154
"torch.inf": {
152155
"Matcher": "ChangePrefixMatcher"
153156
},

paconvert/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -383,7 +383,7 @@ def parse_func(self, func):
383383
self.paddleClass = new_func[0 : new_func.rfind(".")]
384384
if self.get_paddle_api():
385385
new_paddle_api = re.sub(
386-
"paddle.Tensor|paddle.nn.Module|paddle.optimizer.Optimizer|paddle.distribution.Distribution|paddle.autograd.function.FunctionCtx|paddle.autograd.PyLayerContext|paddle.profiler.Profiler",
386+
"paddle.Tensor|paddle.nn.Module|paddle.optimizer.Optimizer|paddle.distribution.Distribution|paddle.autograd.function.FunctionCtx|paddle.profiler.Profiler",
387387
re.escape(self.paddleClass),
388388
self.get_paddle_api(),
389389
)

paconvert/transformer/basic_transformer.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -119,11 +119,14 @@ def visit_Attribute(self, node):
119119
node.lineno,
120120
)
121121

122-
matcher = self.get_api_matcher(torch_api)
123-
# can be api_matcher or attribute_matcher
124-
if matcher is None:
125-
matcher = self.get_attribute_mather(torch_api)
126-
122+
# can be attribute_matcher or attribute_matcher
123+
attribute_matcher = self.get_attribute_matcher(torch_api)
124+
api_matcher = self.get_api_matcher(torch_api)
125+
assert (
126+
attribute_matcher and api_matcher
127+
), f"{torch_api} can not be both in attribute_matcher and api_matcher"
128+
129+
matcher = attribute_matcher or api_matcher
127130
if matcher:
128131
paddle_api = matcher.get_paddle_api()
129132
if paddle_api == "delete":
@@ -275,7 +278,7 @@ def visit_Attribute(self, node):
275278
return node
276279

277280
def trans_class_attribute(self, node, torch_api):
278-
matcher = self.get_attribute_mather(torch_api)
281+
matcher = self.get_attribute_matcher(torch_api)
279282
if matcher:
280283
self.all_api_map[torch_api]["paddle_api"] = (
281284
matcher.get_paddle_api() if matcher.get_paddle_api() else ""
@@ -736,6 +739,9 @@ def in_api_mapping(self, torch_api):
736739
def get_api_matcher(self, torch_api):
737740
api_mapping_dict = {}
738741
if torch_api in GlobalManager.ALIAS_MAPPING:
742+
assert (
743+
torch_api not in GlobalManager.API_MAPPING
744+
), f"{torch_api} can not be both in alias mapping and api mapping"
739745
torch_api = GlobalManager.ALIAS_MAPPING[torch_api]
740746
if torch_api in GlobalManager.API_MAPPING:
741747
api_mapping_dict = GlobalManager.API_MAPPING[torch_api]
@@ -759,10 +765,12 @@ def in_attribute_mapping(self, torch_api):
759765
return True
760766
return False
761767

762-
def get_attribute_mather(self, torch_api):
768+
def get_attribute_matcher(self, torch_api):
763769
attr_mapping_dict = {}
764-
765770
if torch_api in GlobalManager.ALIAS_MAPPING:
771+
assert (
772+
torch_api not in GlobalManager.ATTRIBUTE_MAPPING
773+
), f"{torch_api} can not be both in alias mapping and attribute mapping"
766774
torch_api = GlobalManager.ALIAS_MAPPING[torch_api]
767775
if torch_api in GlobalManager.ATTRIBUTE_MAPPING:
768776
attr_mapping_dict = GlobalManager.ATTRIBUTE_MAPPING[torch_api]

tests/test_Tensor_nanquantile.py

Lines changed: 89 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -56,144 +56,204 @@ def test_case_3():
5656

5757

5858
def test_case_4():
59-
"""With dim keyword argument (using alias)"""
59+
"""With keepdim keyword argument"""
6060
pytorch_code = textwrap.dedent(
6161
"""
6262
import torch
6363
x = torch.tensor([[1.0, 2.0, float('nan')], [4.0, 5.0, 6.0]])
64-
result = x.nanquantile(0.5, dim=1)
64+
result = x.nanquantile(0.25, 0, keepdim=True)
6565
"""
6666
)
6767
obj.run(pytorch_code, ["result"])
6868

6969

7070
def test_case_5():
71-
"""With keepdim keyword argument"""
71+
"""Multiple quantile values"""
7272
pytorch_code = textwrap.dedent(
7373
"""
7474
import torch
7575
x = torch.tensor([[1.0, 2.0, float('nan')], [4.0, 5.0, 6.0]])
76-
result = x.nanquantile(0.25, 0, keepdim=True)
76+
result = x.nanquantile(torch.tensor([0.25, 0.5, 0.75]))
7777
"""
7878
)
7979
obj.run(pytorch_code, ["result"])
8080

8181

8282
def test_case_6():
83-
"""Multiple quantile values"""
83+
"""3D tensor input"""
8484
pytorch_code = textwrap.dedent(
8585
"""
8686
import torch
87-
x = torch.tensor([[1.0, 2.0, float('nan')], [4.0, 5.0, 6.0]])
88-
result = x.nanquantile(torch.tensor([0.25, 0.5, 0.75]))
87+
x = torch.tensor([[[1.0, float('nan')], [3.0, 4.0]], [[5.0, 6.0], [float('nan'), 8.0]]])
88+
result = x.nanquantile(0.5, 2)
8989
"""
9090
)
9191
obj.run(pytorch_code, ["result"])
9292

9393

9494
def test_case_7():
95-
"""Keywords in different order (dim first)"""
95+
"""Mixed parameter styles with dim and keepdim"""
9696
pytorch_code = textwrap.dedent(
9797
"""
9898
import torch
9999
x = torch.tensor([[1.0, 2.0, float('nan')], [4.0, 5.0, 6.0]])
100-
result = x.nanquantile(q=0.5, dim=1, keepdim=False)
100+
result = x.nanquantile(0.5, dim=1, keepdim=True)
101101
"""
102102
)
103103
obj.run(pytorch_code, ["result"])
104104

105105

106106
def test_case_8():
107-
"""Keywords completely out of order"""
107+
"""Verify NaN handling - all NaN row"""
108108
pytorch_code = textwrap.dedent(
109109
"""
110110
import torch
111-
x = torch.tensor([[1.0, 2.0, float('nan')], [4.0, 5.0, 6.0]])
112-
result = x.nanquantile(dim=1, q=0.5, keepdim=True)
111+
x = torch.tensor([[float('nan'), float('nan'), float('nan')], [1.0, 2.0, 3.0]])
112+
result = x.nanquantile(0.5, 1)
113113
"""
114114
)
115115
obj.run(pytorch_code, ["result"])
116116

117117

118118
def test_case_9():
119-
"""1D tensor input"""
119+
"""Quantile at lower boundary"""
120120
pytorch_code = textwrap.dedent(
121121
"""
122122
import torch
123-
x = torch.tensor([1.0, 2.0, float('nan'), 4.0, 5.0])
124-
result = x.nanquantile(0.5)
123+
x = torch.tensor([[1.0, 2.0, float('nan')], [4.0, 5.0, 6.0]])
124+
result = x.nanquantile(0.0, 1)
125125
"""
126126
)
127127
obj.run(pytorch_code, ["result"])
128128

129129

130130
def test_case_10():
131-
"""3D tensor input"""
131+
"""Quantile at upper boundary"""
132132
pytorch_code = textwrap.dedent(
133133
"""
134134
import torch
135-
x = torch.tensor([[[1.0, float('nan')], [3.0, 4.0]], [[5.0, 6.0], [float('nan'), 8.0]]])
136-
result = x.nanquantile(0.5, 2)
135+
x = torch.tensor([[1.0, 2.0, float('nan')], [4.0, 5.0, 6.0]])
136+
result = x.nanquantile(1.0, 1)
137137
"""
138138
)
139139
obj.run(pytorch_code, ["result"])
140140

141141

142142
def test_case_11():
143-
"""Mixed parameter styles"""
143+
"""Chained method calls"""
144144
pytorch_code = textwrap.dedent(
145145
"""
146146
import torch
147147
x = torch.tensor([[1.0, 2.0, float('nan')], [4.0, 5.0, 6.0]])
148-
result = x.nanquantile(0.5, dim=1, keepdim=True)
148+
result = x.nanquantile(0.5, 1).nanquantile(0.5)
149149
"""
150150
)
151151
obj.run(pytorch_code, ["result"])
152152

153153

154154
def test_case_12():
155-
"""Verify NaN handling - all NaN row"""
155+
"""With interpolation='lower'"""
156156
pytorch_code = textwrap.dedent(
157157
"""
158158
import torch
159-
x = torch.tensor([[float('nan'), float('nan'), float('nan')], [1.0, 2.0, 3.0]])
160-
result = x.nanquantile(0.5, 1)
159+
x = torch.tensor([[1.0, 2.0, float('nan')], [4.0, 5.0, 6.0]])
160+
result = x.nanquantile(0.5, dim=1, interpolation='lower')
161161
"""
162162
)
163163
obj.run(pytorch_code, ["result"])
164164

165165

166166
def test_case_13():
167-
"""Quantile at boundaries"""
167+
"""With interpolation='higher'"""
168168
pytorch_code = textwrap.dedent(
169169
"""
170170
import torch
171171
x = torch.tensor([[1.0, 2.0, float('nan')], [4.0, 5.0, 6.0]])
172-
result = x.nanquantile(0.0, 1)
172+
result = x.nanquantile(0.5, dim=1, interpolation='higher')
173173
"""
174174
)
175175
obj.run(pytorch_code, ["result"])
176176

177177

178178
def test_case_14():
179-
"""Quantile at upper boundary"""
179+
"""With interpolation='midpoint'"""
180180
pytorch_code = textwrap.dedent(
181181
"""
182182
import torch
183183
x = torch.tensor([[1.0, 2.0, float('nan')], [4.0, 5.0, 6.0]])
184-
result = x.nanquantile(1.0, 1)
184+
result = x.nanquantile(0.5, dim=1, interpolation='midpoint')
185185
"""
186186
)
187187
obj.run(pytorch_code, ["result"])
188188

189189

190190
def test_case_15():
191-
"""Chained method calls"""
191+
"""With interpolation='nearest'"""
192192
pytorch_code = textwrap.dedent(
193193
"""
194194
import torch
195195
x = torch.tensor([[1.0, 2.0, float('nan')], [4.0, 5.0, 6.0]])
196-
result = x.nanquantile(0.5, 1).nanquantile(0.5)
196+
result = x.nanquantile(0.5, dim=1, interpolation='nearest')
197+
"""
198+
)
199+
obj.run(pytorch_code, ["result"])
200+
201+
202+
def test_case_16():
203+
"""Multiple quantiles with interpolation parameter"""
204+
pytorch_code = textwrap.dedent(
205+
"""
206+
import torch
207+
x = torch.tensor([[1.0, 2.0, float('nan')], [4.0, 5.0, 6.0]])
208+
result = x.nanquantile(torch.tensor([0.25, 0.5, 0.75]), dim=1, interpolation='midpoint')
209+
"""
210+
)
211+
obj.run(pytorch_code, ["result"])
212+
213+
214+
def test_case_17():
215+
"""With dim=None explicitly (flatten behavior)"""
216+
pytorch_code = textwrap.dedent(
217+
"""
218+
import torch
219+
x = torch.tensor([[1.0, 2.0, float('nan')], [4.0, 5.0, 6.0]])
220+
result = x.nanquantile(0.5, dim=None)
221+
"""
222+
)
223+
obj.run(pytorch_code, ["result"])
224+
225+
226+
def test_case_18():
227+
"""With dim=None and keepdim=True"""
228+
pytorch_code = textwrap.dedent(
229+
"""
230+
import torch
231+
x = torch.tensor([[1.0, 2.0, float('nan')], [4.0, 5.0, 6.0]])
232+
result = x.nanquantile(0.5, dim=None, keepdim=True)
233+
"""
234+
)
235+
obj.run(pytorch_code, ["result"])
236+
237+
238+
def test_case_19():
239+
"""Multiple quantiles with interpolation on 3D tensor"""
240+
pytorch_code = textwrap.dedent(
241+
"""
242+
import torch
243+
x = torch.tensor([[[1.0, float('nan')], [3.0, 4.0]], [[5.0, 6.0], [float('nan'), 8.0]]])
244+
result = x.nanquantile(torch.tensor([0.25, 0.75]), dim=0, interpolation='lower')
245+
"""
246+
)
247+
obj.run(pytorch_code, ["result"])
248+
249+
250+
def test_case_20():
251+
"""All parameters including keywords in mixed order"""
252+
pytorch_code = textwrap.dedent(
253+
"""
254+
import torch
255+
x = torch.tensor([[1.0, 2.0, float('nan')], [4.0, 5.0, 6.0]])
256+
result = x.nanquantile(q=0.5, interpolation='higher', dim=1, keepdim=False)
197257
"""
198258
)
199259
obj.run(pytorch_code, ["result"])

tests/test_cummax.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,3 +106,17 @@ def test_case_6():
106106
"""
107107
)
108108
obj.run(pytorch_code, ["result", "values", "indices"])
109+
110+
111+
def test_case_7():
112+
pytorch_code = textwrap.dedent(
113+
"""
114+
import torch
115+
x = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
116+
values = torch.empty(2, 2)
117+
indices = torch.empty(2, 2, dtype=torch.int64)
118+
out = (values, indices)
119+
result = torch.cummax(out=out, dim=0, input=x)
120+
"""
121+
)
122+
obj.run(pytorch_code, ["result", "out"])

tests/test_cummin.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,3 +106,17 @@ def test_case_6():
106106
"""
107107
)
108108
obj.run(pytorch_code, ["result", "values", "indices"])
109+
110+
111+
def test_case_7():
112+
pytorch_code = textwrap.dedent(
113+
"""
114+
import torch
115+
x = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
116+
values = torch.empty(2, 2)
117+
indices = torch.empty(2, 2, dtype=torch.int64)
118+
out = (values, indices)
119+
result = torch.cummin(out=out, dim=0, input=x)
120+
"""
121+
)
122+
obj.run(pytorch_code, ["result", "out"])

tests/test_inner.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,6 @@ def test_case_4():
6969
obj.run(pytorch_code, ["result"])
7070

7171

72-
# The paddle input does not support integer type
7372
def test_case_5():
7473
pytorch_code = textwrap.dedent(
7574
"""

tests/test_rot90.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,3 +114,13 @@ def test_case_8():
114114
"""
115115
)
116116
obj.run(pytorch_code, ["y", "x_grad"], check_stop_gradient=False)
117+
118+
119+
def test_case_9():
120+
pytorch_code = textwrap.dedent(
121+
"""
122+
import torch
123+
result = torch.rot90(dims=[0, 1], k=1, input=torch.tensor([[1, 2, 3], [4, 5, 6]]))
124+
"""
125+
)
126+
obj.run(pytorch_code, ["result"])

0 commit comments

Comments
 (0)