Skip to content

Commit 6d52905

Browse files
authored
[API Compatibility] add torch.nn.ParameterDict -part (#831)
* add torch.nn.ParameterDict * add tests
1 parent 60e235d commit 6d52905

2 files changed

Lines changed: 162 additions & 9 deletions

File tree

paconvert/api_mapping.json

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8181,15 +8181,7 @@
81818181
"Matcher": "ChangePrefixMatcher"
81828182
},
81838183
"torch.nn.ParameterDict": {
8184-
"Matcher": "GenericMatcher",
8185-
"paddle_api": "paddle.nn.ParameterDict",
8186-
"min_input_args": 0,
8187-
"args_list": [
8188-
"values"
8189-
],
8190-
"kwargs_change": {
8191-
"values": "parameters"
8192-
}
8184+
"Matcher": "ChangePrefixMatcher"
81938185
},
81948186
"torch.nn.ParameterList": {
81958187
"Matcher": "GenericMatcher",

tests/test_nn_ParameterDict.py

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,3 +92,164 @@ def test_case_5():
9292
"""
9393
)
9494
obj.run(pytorch_code, ["result"])
95+
96+
97+
def test_case_6():
98+
pytorch_code = textwrap.dedent(
99+
"""
100+
import torch.nn as nn
101+
import torch
102+
choices = nn.ParameterDict(parameters={
103+
'a': nn.Parameter(torch.ones(2, 3)),
104+
'b': nn.Parameter(torch.zeros(4)),
105+
})
106+
result = list(choices)
107+
"""
108+
)
109+
obj.run(pytorch_code, ["result"])
110+
111+
112+
def test_case_7():
113+
pytorch_code = textwrap.dedent(
114+
"""
115+
import torch.nn as nn
116+
import torch
117+
choices = nn.ParameterDict([
118+
('a', nn.Parameter(torch.ones(2, 3))),
119+
('b', nn.Parameter(torch.zeros(4))),
120+
])
121+
result = list(choices)
122+
"""
123+
)
124+
obj.run(pytorch_code, ["result"])
125+
126+
127+
def test_case_8():
128+
pytorch_code = textwrap.dedent(
129+
"""
130+
import torch.nn as nn
131+
import torch
132+
choices = nn.ParameterDict({'w': nn.Parameter(torch.ones(2, 3))})
133+
result = choices['w']
134+
"""
135+
)
136+
obj.run(pytorch_code, ["result"], check_stop_gradient=False)
137+
138+
139+
def test_case_9():
140+
pytorch_code = textwrap.dedent(
141+
"""
142+
import torch.nn as nn
143+
import torch
144+
choices = nn.ParameterDict({
145+
'a': nn.Parameter(torch.ones(1)),
146+
'b': nn.Parameter(torch.ones(2)),
147+
'c': nn.Parameter(torch.ones(3)),
148+
})
149+
result = len(choices)
150+
"""
151+
)
152+
obj.run(pytorch_code, ["result"])
153+
154+
155+
def test_case_10():
156+
pytorch_code = textwrap.dedent(
157+
"""
158+
import torch.nn as nn
159+
import torch
160+
choices = nn.ParameterDict({
161+
'a': nn.Parameter(torch.ones(1)),
162+
'b': nn.Parameter(torch.ones(2)),
163+
})
164+
result = list(choices.keys())
165+
"""
166+
)
167+
obj.run(pytorch_code, ["result"])
168+
169+
170+
def test_case_11():
171+
pytorch_code = textwrap.dedent(
172+
"""
173+
import torch.nn as nn
174+
import torch
175+
choices = nn.ParameterDict({
176+
'a': nn.Parameter(torch.ones(2, 3)),
177+
'b': nn.Parameter(torch.zeros(4)),
178+
})
179+
result = list(choices.values())
180+
"""
181+
)
182+
obj.run(pytorch_code, ["result"], check_stop_gradient=False)
183+
184+
185+
def test_case_12():
186+
pytorch_code = textwrap.dedent(
187+
"""
188+
import torch.nn as nn
189+
import torch
190+
choices = nn.ParameterDict({'a': nn.Parameter(torch.ones(2))})
191+
result = 'a' in choices
192+
"""
193+
)
194+
obj.run(pytorch_code, ["result"])
195+
196+
197+
def test_case_14():
198+
pytorch_code = textwrap.dedent(
199+
"""
200+
import torch.nn as nn
201+
import torch
202+
choices = nn.ParameterDict(parameters={
203+
'a': nn.Parameter(torch.ones(2, 3)),
204+
'b': nn.Parameter(torch.zeros(4)),
205+
})
206+
result = choices['a']
207+
"""
208+
)
209+
obj.run(pytorch_code, ["result"], check_stop_gradient=False)
210+
211+
212+
def test_case_15():
213+
pytorch_code = textwrap.dedent(
214+
"""
215+
import torch.nn as nn
216+
import torch
217+
choices = nn.ParameterDict(parameters={
218+
'a': nn.Parameter(torch.ones(2, 3)),
219+
'b': nn.Parameter(torch.zeros(4)),
220+
})
221+
result = choices['b']
222+
"""
223+
)
224+
obj.run(pytorch_code, ["result"], check_stop_gradient=False)
225+
226+
227+
def test_case_13():
228+
pytorch_code = textwrap.dedent(
229+
"""
230+
import torch.nn as nn
231+
import torch
232+
choices = nn.ParameterDict({
233+
'a': nn.Parameter(torch.ones(2, 3)),
234+
'b': nn.Parameter(torch.zeros(4)),
235+
})
236+
result = choices.pop('a')
237+
"""
238+
)
239+
obj.run(pytorch_code, ["result"], check_stop_gradient=False)
240+
241+
242+
def test_case_16():
243+
pytorch_code = textwrap.dedent(
244+
"""
245+
import torch.nn as nn
246+
import torch
247+
pd1 = nn.ParameterDict({
248+
'a': nn.Parameter(torch.ones(2, 3)),
249+
'b': nn.Parameter(torch.zeros(4)),
250+
})
251+
pd2 = nn.ParameterDict(pd1)
252+
result = list(pd2)
253+
"""
254+
)
255+
obj.run(pytorch_code, ["result"])

0 commit comments

Comments
 (0)