@@ -56,144 +56,204 @@ def test_case_3():
5656
5757
5858def test_case_4 ():
59- """With dim keyword argument (using alias) """
59+ """With keepdim keyword argument"""
6060 pytorch_code = textwrap .dedent (
6161 """
6262 import torch
6363 x = torch.tensor([[1.0, 2.0, float('nan')], [4.0, 5.0, 6.0]])
64- result = x.nanquantile(0.5, dim=1 )
64+ result = x.nanquantile(0.25, 0, keepdim=True )
6565 """
6666 )
6767 obj .run (pytorch_code , ["result" ])
6868
6969
7070def test_case_5 ():
71- """With keepdim keyword argument """
71+ """Multiple quantile values """
7272 pytorch_code = textwrap .dedent (
7373 """
7474 import torch
7575 x = torch.tensor([[1.0, 2.0, float('nan')], [4.0, 5.0, 6.0]])
76- result = x.nanquantile(0.25, 0, keepdim=True )
76+ result = x.nanquantile(torch.tensor([ 0.25, 0.5, 0.75]) )
7777 """
7878 )
7979 obj .run (pytorch_code , ["result" ])
8080
8181
8282def test_case_6 ():
83- """Multiple quantile values """
83+ """3D tensor input """
8484 pytorch_code = textwrap .dedent (
8585 """
8686 import torch
87- x = torch.tensor([[1.0, 2.0, float('nan')], [4.0, 5.0, 6.0]])
88- result = x.nanquantile(torch.tensor([0.25, 0. 5, 0.75]) )
87+ x = torch.tensor([[[ 1.0, float('nan')], [3.0, 4.0]], [[ 5.0, 6.0], [float('nan'), 8.0] ]])
88+ result = x.nanquantile(0. 5, 2 )
8989 """
9090 )
9191 obj .run (pytorch_code , ["result" ])
9292
9393
9494def test_case_7 ():
95- """Keywords in different order ( dim first) """
95+ """Mixed parameter styles with dim and keepdim """
9696 pytorch_code = textwrap .dedent (
9797 """
9898 import torch
9999 x = torch.tensor([[1.0, 2.0, float('nan')], [4.0, 5.0, 6.0]])
100- result = x.nanquantile(q= 0.5, dim=1, keepdim=False )
100+ result = x.nanquantile(0.5, dim=1, keepdim=True )
101101 """
102102 )
103103 obj .run (pytorch_code , ["result" ])
104104
105105
106106def test_case_8 ():
107- """Keywords completely out of order """
107+ """Verify NaN handling - all NaN row """
108108 pytorch_code = textwrap .dedent (
109109 """
110110 import torch
111- x = torch.tensor([[1.0, 2.0 , float('nan')], [4 .0, 5 .0, 6 .0]])
112- result = x.nanquantile(dim=1, q= 0.5, keepdim=True )
111+ x = torch.tensor([[float('nan'), float('nan') , float('nan')], [1 .0, 2 .0, 3 .0]])
112+ result = x.nanquantile(0.5, 1 )
113113 """
114114 )
115115 obj .run (pytorch_code , ["result" ])
116116
117117
118118def test_case_9 ():
119- """1D tensor input """
119+ """Quantile at lower boundary """
120120 pytorch_code = textwrap .dedent (
121121 """
122122 import torch
123- x = torch.tensor([1.0, 2.0, float('nan'), 4.0, 5.0])
124- result = x.nanquantile(0.5 )
123+ x = torch.tensor([[ 1.0, 2.0, float('nan')], [ 4.0, 5.0, 6.0] ])
124+ result = x.nanquantile(0.0, 1 )
125125 """
126126 )
127127 obj .run (pytorch_code , ["result" ])
128128
129129
130130def test_case_10 ():
131- """3D tensor input """
131+ """Quantile at upper boundary """
132132 pytorch_code = textwrap .dedent (
133133 """
134134 import torch
135- x = torch.tensor([[[ 1.0, float('nan')], [3.0, 4.0]], [[ 5.0, 6.0], [float('nan'), 8.0] ]])
136- result = x.nanquantile(0.5, 2 )
135+ x = torch.tensor([[1.0, 2.0, float('nan')], [4.0, 5.0, 6.0]])
136+ result = x.nanquantile(1.0, 1 )
137137 """
138138 )
139139 obj .run (pytorch_code , ["result" ])
140140
141141
142142def test_case_11 ():
143- """Mixed parameter styles """
143+ """Chained method calls """
144144 pytorch_code = textwrap .dedent (
145145 """
146146 import torch
147147 x = torch.tensor([[1.0, 2.0, float('nan')], [4.0, 5.0, 6.0]])
148- result = x.nanquantile(0.5, dim=1, keepdim=True )
148+ result = x.nanquantile(0.5, 1).nanquantile(0.5 )
149149 """
150150 )
151151 obj .run (pytorch_code , ["result" ])
152152
153153
154154def test_case_12 ():
155- """Verify NaN handling - all NaN row """
155+ """With interpolation='lower' """
156156 pytorch_code = textwrap .dedent (
157157 """
158158 import torch
159- x = torch.tensor([[float('nan'), float('nan') , float('nan')], [1 .0, 2 .0, 3 .0]])
160- result = x.nanquantile(0.5, 1 )
159+ x = torch.tensor([[1.0, 2.0 , float('nan')], [4 .0, 5 .0, 6 .0]])
160+ result = x.nanquantile(0.5, dim=1, interpolation='lower' )
161161 """
162162 )
163163 obj .run (pytorch_code , ["result" ])
164164
165165
166166def test_case_13 ():
167- """Quantile at boundaries """
167+ """With interpolation='higher' """
168168 pytorch_code = textwrap .dedent (
169169 """
170170 import torch
171171 x = torch.tensor([[1.0, 2.0, float('nan')], [4.0, 5.0, 6.0]])
172- result = x.nanquantile(0.0, 1 )
172+ result = x.nanquantile(0.5, dim=1, interpolation='higher' )
173173 """
174174 )
175175 obj .run (pytorch_code , ["result" ])
176176
177177
178178def test_case_14 ():
179- """Quantile at upper boundary """
179+ """With interpolation='midpoint' """
180180 pytorch_code = textwrap .dedent (
181181 """
182182 import torch
183183 x = torch.tensor([[1.0, 2.0, float('nan')], [4.0, 5.0, 6.0]])
184- result = x.nanquantile(1.0, 1 )
184+ result = x.nanquantile(0.5, dim=1, interpolation='midpoint' )
185185 """
186186 )
187187 obj .run (pytorch_code , ["result" ])
188188
189189
190190def test_case_15 ():
191- """Chained method calls """
191+ """With interpolation='nearest' """
192192 pytorch_code = textwrap .dedent (
193193 """
194194 import torch
195195 x = torch.tensor([[1.0, 2.0, float('nan')], [4.0, 5.0, 6.0]])
196- result = x.nanquantile(0.5, 1).nanquantile(0.5)
196+ result = x.nanquantile(0.5, dim=1, interpolation='nearest')
197+ """
198+ )
199+ obj .run (pytorch_code , ["result" ])
200+
201+
202+ def test_case_16 ():
203+ """Multiple quantiles with interpolation parameter"""
204+ pytorch_code = textwrap .dedent (
205+ """
206+ import torch
207+ x = torch.tensor([[1.0, 2.0, float('nan')], [4.0, 5.0, 6.0]])
208+ result = x.nanquantile(torch.tensor([0.25, 0.5, 0.75]), dim=1, interpolation='midpoint')
209+ """
210+ )
211+ obj .run (pytorch_code , ["result" ])
212+
213+
214+ def test_case_17 ():
215+ """With dim=None explicitly (flatten behavior)"""
216+ pytorch_code = textwrap .dedent (
217+ """
218+ import torch
219+ x = torch.tensor([[1.0, 2.0, float('nan')], [4.0, 5.0, 6.0]])
220+ result = x.nanquantile(0.5, dim=None)
221+ """
222+ )
223+ obj .run (pytorch_code , ["result" ])
224+
225+
226+ def test_case_18 ():
227+ """With dim=None and keepdim=True"""
228+ pytorch_code = textwrap .dedent (
229+ """
230+ import torch
231+ x = torch.tensor([[1.0, 2.0, float('nan')], [4.0, 5.0, 6.0]])
232+ result = x.nanquantile(0.5, dim=None, keepdim=True)
233+ """
234+ )
235+ obj .run (pytorch_code , ["result" ])
236+
237+
238+ def test_case_19 ():
239+ """Multiple quantiles with interpolation on 3D tensor"""
240+ pytorch_code = textwrap .dedent (
241+ """
242+ import torch
243+ x = torch.tensor([[[1.0, float('nan')], [3.0, 4.0]], [[5.0, 6.0], [float('nan'), 8.0]]])
244+ result = x.nanquantile(torch.tensor([0.25, 0.75]), dim=0, interpolation='lower')
245+ """
246+ )
247+ obj .run (pytorch_code , ["result" ])
248+
249+
250+ def test_case_20 ():
251+ """All parameters including keywords in mixed order"""
252+ pytorch_code = textwrap .dedent (
253+ """
254+ import torch
255+ x = torch.tensor([[1.0, 2.0, float('nan')], [4.0, 5.0, 6.0]])
256+ result = x.nanquantile(q=0.5, interpolation='higher', dim=1, keepdim=False)
197257 """
198258 )
199259 obj .run (pytorch_code , ["result" ])
0 commit comments