@@ -302,6 +302,46 @@ def test_conv2d(
302302 assert torch .allclose (output , expected , rtol = rtol , atol = atol )
303303
304304
305+ @pytest .mark .parametrize ("device" , get_available_devices ())
306+ def test_fp32_scalar (device ):
307+ def _arrangement (input , scale , output ):
308+ return input .tile ((256 ,)), scale , output .tile ((256 ,))
309+
310+ def _application (input , scale , output ):
311+ output = input * scale # noqa: F841
312+
313+ tensors = (
314+ Tensor (1 , dtype = ninetoothed .float32 ),
315+ Tensor (0 , dtype = ninetoothed .float32 ),
316+ Tensor (1 , dtype = ninetoothed .float32 ),
317+ )
318+
319+ caller = device
320+ kernel_name = f"fp32_scalar{ _generate_kernel_name_suffix ()} "
321+ output_dir = ninetoothed .generation .CACHE_DIR
322+
323+ kernel = ninetoothed .make (
324+ _arrangement ,
325+ _application ,
326+ tensors ,
327+ caller = caller ,
328+ kernel_name = kernel_name ,
329+ output_dir = output_dir ,
330+ )
331+
332+ size = 256
333+
334+ input = torch .randn (size , dtype = torch .float32 , device = device )
335+ scale = 0.125
336+ output = torch .empty_like (input )
337+
338+ kernel (input , scale , output )
339+
340+ expected = input * scale
341+
342+ assert torch .allclose (output , expected )
343+
344+
305345def _generate_kernel_name_suffix ():
306346 count = _generate_kernel_name_suffix ._kernel_count
307347 _generate_kernel_name_suffix ._kernel_count += 1
0 commit comments