@@ -92,3 +92,164 @@ def test_case_5():
9292 """
9393 )
9494 obj .run (pytorch_code , ["result" ])
95+
96+
97+ def test_case_6 ():
98+ pytorch_code = textwrap .dedent (
99+ """
100+ import torch.nn as nn
101+ import torch
102+ choices = nn.ParameterDict(parameters={
103+ 'a': nn.Parameter(torch.ones(2, 3)),
104+ 'b': nn.Parameter(torch.zeros(4)),
105+ })
106+ result = list(choices)
107+ """
108+ )
109+ obj .run (pytorch_code , ["result" ])
110+
111+
112+ def test_case_7 ():
113+ pytorch_code = textwrap .dedent (
114+ """
115+ import torch.nn as nn
116+ import torch
117+ choices = nn.ParameterDict([
118+ ('a', nn.Parameter(torch.ones(2, 3))),
119+ ('b', nn.Parameter(torch.zeros(4))),
120+ ])
121+ result = list(choices)
122+ """
123+ )
124+ obj .run (pytorch_code , ["result" ])
125+
126+
127+ def test_case_8 ():
128+ pytorch_code = textwrap .dedent (
129+ """
130+ import torch.nn as nn
131+ import torch
132+ choices = nn.ParameterDict({'w': nn.Parameter(torch.ones(2, 3))})
133+ result = choices['w']
134+ """
135+ )
136+ obj .run (pytorch_code , ["result" ], check_stop_gradient = False )
137+
138+
139+ def test_case_9 ():
140+ pytorch_code = textwrap .dedent (
141+ """
142+ import torch.nn as nn
143+ import torch
144+ choices = nn.ParameterDict({
145+ 'a': nn.Parameter(torch.ones(1)),
146+ 'b': nn.Parameter(torch.ones(2)),
147+ 'c': nn.Parameter(torch.ones(3)),
148+ })
149+ result = len(choices)
150+ """
151+ )
152+ obj .run (pytorch_code , ["result" ])
153+
154+
155+ def test_case_10 ():
156+ pytorch_code = textwrap .dedent (
157+ """
158+ import torch.nn as nn
159+ import torch
160+ choices = nn.ParameterDict({
161+ 'a': nn.Parameter(torch.ones(1)),
162+ 'b': nn.Parameter(torch.ones(2)),
163+ })
164+ result = list(choices.keys())
165+ """
166+ )
167+ obj .run (pytorch_code , ["result" ])
168+
169+
170+ def test_case_11 ():
171+ pytorch_code = textwrap .dedent (
172+ """
173+ import torch.nn as nn
174+ import torch
175+ choices = nn.ParameterDict({
176+ 'a': nn.Parameter(torch.ones(2, 3)),
177+ 'b': nn.Parameter(torch.zeros(4)),
178+ })
179+ result = list(choices.values())
180+ """
181+ )
182+ obj .run (pytorch_code , ["result" ], check_stop_gradient = False )
183+
184+
185+ def test_case_12 ():
186+ pytorch_code = textwrap .dedent (
187+ """
188+ import torch.nn as nn
189+ import torch
190+ choices = nn.ParameterDict({'a': nn.Parameter(torch.ones(2))})
191+ result = 'a' in choices
192+ """
193+ )
194+ obj .run (pytorch_code , ["result" ])
195+
196+
197+ def test_case_14 ():
198+ pytorch_code = textwrap .dedent (
199+ """
200+ import torch.nn as nn
201+ import torch
202+ choices = nn.ParameterDict(parameters={
203+ 'a': nn.Parameter(torch.ones(2, 3)),
204+ 'b': nn.Parameter(torch.zeros(4)),
205+ })
206+ result = choices['a']
207+ """
208+ )
209+ obj .run (pytorch_code , ["result" ], check_stop_gradient = False )
210+
211+
212+ def test_case_15 ():
213+ pytorch_code = textwrap .dedent (
214+ """
215+ import torch.nn as nn
216+ import torch
217+ choices = nn.ParameterDict(parameters={
218+ 'a': nn.Parameter(torch.ones(2, 3)),
219+ 'b': nn.Parameter(torch.zeros(4)),
220+ })
221+ result = choices['b']
222+ """
223+ )
224+ obj .run (pytorch_code , ["result" ], check_stop_gradient = False )
225+
226+
227+ def test_case_13 ():
228+ pytorch_code = textwrap .dedent (
229+ """
230+ import torch.nn as nn
231+ import torch
232+ choices = nn.ParameterDict({
233+ 'a': nn.Parameter(torch.ones(2, 3)),
234+ 'b': nn.Parameter(torch.zeros(4)),
235+ })
236+ result = choices.pop('a')
237+ """
238+ )
239+ obj .run (pytorch_code , ["result" ], check_stop_gradient = False )
240+
241+
242+ def test_case_16 ():
243+ pytorch_code = textwrap .dedent (
244+ """
245+ import torch.nn as nn
246+ import torch
247+ pd1 = nn.ParameterDict({
248+ 'a': nn.Parameter(torch.ones(2, 3)),
249+ 'b': nn.Parameter(torch.zeros(4)),
250+ })
251+ pd2 = nn.ParameterDict(pd1)
252+ result = list(pd2)
253+ """
254+ )
255+ obj .run (pytorch_code , ["result" ])
0 commit comments