@@ -135,3 +135,133 @@ def kernel():
135135
136136 with pytest .raises (TileTypeError , match = re .escape ("Expected a tuple after *" )):
137137 ct .launch (torch .cuda .current_stream (), (1 ,), kernel , ())
138+
139+
140+ def test_tuple_compare_empty_eq ():
141+ @ct .kernel
142+ def kernel (x ):
143+ if () == ():
144+ ct .scatter (x , (), 1 )
145+ else :
146+ ct .scatter (x , (), 0 )
147+
148+ x = torch .zeros ((), dtype = torch .int32 , device = "cuda" )
149+ ct .launch (torch .cuda .current_stream (), (1 ,), kernel , (x ,))
150+ assert x .item () == 1
151+
152+
153+ def test_tuple_compare_constants_eq ():
154+ @ct .kernel
155+ def kernel (x ):
156+ if (1 , 2 , 3 ) == (1 , 2 , 3 ):
157+ ct .scatter (x , (), 1 )
158+ else :
159+ ct .scatter (x , (), 0 )
160+
161+ x = torch .zeros ((), dtype = torch .int32 , device = "cuda" )
162+ ct .launch (torch .cuda .current_stream (), (1 ,), kernel , (x ,))
163+ assert x .item () == 1
164+
165+
166+ def test_tuple_compare_constants_ne ():
167+ @ct .kernel
168+ def kernel (x ):
169+ if (1 , 2 ) != (1 , 3 ):
170+ ct .scatter (x , (), 1 )
171+ else :
172+ ct .scatter (x , (), 0 )
173+
174+ x = torch .zeros ((), dtype = torch .int32 , device = "cuda" )
175+ ct .launch (torch .cuda .current_stream (), (1 ,), kernel , (x ,))
176+ assert x .item () == 1
177+
178+
179+ def test_tuple_compare_different_lengths ():
180+ @ct .kernel
181+ def kernel (x ):
182+ a = ct .bid (0 )
183+ if (a , 1 ) != (a , 1 , 2 ):
184+ ct .scatter (x , (), 1 )
185+ else :
186+ ct .scatter (x , (), 0 )
187+
188+ x = torch .zeros ((), dtype = torch .int32 , device = "cuda" )
189+ ct .launch (torch .cuda .current_stream (), (1 ,), kernel , (x ,))
190+ assert x .item () == 1
191+
192+
193+ def test_tuple_compare_0d_tiles_eq ():
194+ @ct .kernel
195+ def kernel (x ):
196+ a = ct .bid (0 )
197+ b = ct .bid (1 )
198+ if (a , b ) == (0 , 0 ):
199+ ct .scatter (x , (a , b ), 1 )
200+ else :
201+ ct .scatter (x , (a , b ), - 1 )
202+
203+ x = torch .zeros ((2 , 2 ), dtype = torch .int32 , device = "cuda" )
204+ ct .launch (torch .cuda .current_stream (), (2 , 2 ), kernel , (x ,))
205+ assert x .tolist () == [[1 , - 1 ], [- 1 , - 1 ]]
206+
207+
208+ def test_tuple_compare_nd_tile_error ():
209+ @ct .kernel
210+ def kernel ():
211+ t = ct .ones ((4 ,), dtype = ct .int32 )
212+ if (t ,) == (t ,):
213+ pass
214+
215+ with pytest .raises (TileTypeError , match = "not supported for N-D tile elements" ):
216+ ct .launch (torch .cuda .current_stream (), (1 ,), kernel , ())
217+
218+
219+ def test_tuple_compare_unsupported_op ():
220+ @ct .kernel
221+ def kernel ():
222+ if (1 , 2 ) < (3 , 4 ):
223+ pass
224+
225+ with pytest .raises (TileTypeError , match = "not supported for tuples" ):
226+ ct .launch (torch .cuda .current_stream (), (1 ,), kernel , ())
227+
228+
229+ def test_tuple_compare_nested ():
230+ @ct .kernel
231+ def kernel (x ):
232+ a = ct .bid (0 )
233+ if ((a , 1 ), 2 ) == ((0 , 1 ), 2 ):
234+ ct .scatter (x , (a , ), 1 )
235+ else :
236+ ct .scatter (x , (a , ), - 1 )
237+
238+ x = torch .zeros ((2 , ), dtype = torch .int32 , device = "cuda" )
239+ ct .launch (torch .cuda .current_stream (), (2 , ), kernel , (x ,))
240+ assert x .tolist () == [1 , - 1 ]
241+
242+
243+ def test_tuple_compare_array_element_error ():
244+ @ct .kernel
245+ def kernel (x , y ):
246+ if (x ,) == (y ,):
247+ pass
248+
249+ with pytest .raises (TileTypeError , match = "not supported for elements of type" ):
250+ ct .launch (torch .cuda .current_stream (), (1 ,), kernel ,
251+ (torch .zeros (4 , dtype = torch .int32 , device = "cuda" ),
252+ torch .zeros (4 , dtype = torch .int32 , device = "cuda" )))
253+
254+
255+ def test_tuple_compare_constant_args ():
256+ @ct .kernel
257+ def kernel (x , M : ct .Constant [int ], N : ct .Constant [int ]):
258+ if (M , N ) == (4 , 8 ):
259+ ct .scatter (x , (), 1 )
260+ else :
261+ ct .scatter (x , (), - 1 )
262+
263+ x = torch .zeros ((), dtype = torch .int32 , device = "cuda" )
264+ ct .launch (torch .cuda .current_stream (), (1 ,), kernel , (x , 4 , 8 ))
265+ assert x .item () == 1
266+ ct .launch (torch .cuda .current_stream (), (1 ,), kernel , (x , 4 , 9 ))
267+ assert x .item () == - 1
0 commit comments