@@ -278,6 +278,34 @@ def test_array_compare(shape, tile, dtype, op_symbol, op_func, tmp_path):
278278 assert_equal (z , ref )
279279
280280
281+ def make_is_operator_kernel (cmp ):
282+ @ct .kernel
283+ def is_operator (x ):
284+ bid = ct .bid (0 )
285+ a = 1 if cmp is None else - 1
286+ ct .store (x , index = (bid ,), tile = a )
287+ return is_operator
288+
289+
290+ def make_is_not_operator_kernel (cmp ):
291+ @ct .kernel
292+ def is_not_operator (x ):
293+ bid = ct .bid (0 )
294+ a = - 1 if cmp is not None else 1
295+ ct .store (x , index = (bid ,), tile = a )
296+ return is_not_operator
297+
298+
299+ @pytest .mark .parametrize ("make_kernel" , [make_is_operator_kernel , make_is_not_operator_kernel ])
300+ @pytest .mark .parametrize ("cmp" , [None , 1 ])
301+ def test_is_or_not_operator (make_kernel , cmp ):
302+ x = torch .zeros ((1 ,), dtype = torch .int32 , device = 'cuda' )
303+ kernel = make_kernel (cmp )
304+ ct .launch (torch .cuda .current_stream (), (1 , 1 , 1 ), kernel , (x , ))
305+ ref = 1 if cmp is None else - 1
306+ assert_equal (x , torch .tensor ([ref ], dtype = torch .int32 , device = 'cuda' ))
307+
308+
281309@pytest .mark .parametrize ("max_func" , ["max" , "ct.maximum" ])
282310@pytest .mark .parametrize ("dtype" , int_dtypes + float_dtypes , ids = dtype_id )
283311def test_array_max (shape , tile , dtype , tmp_path , max_func ):
0 commit comments