Skip to content

Commit dc51d41

Browse files
authored
Promote scalar float32 to float64 in AOT signature (#153)
* Promote scalar `fp32` to `fp64` in AOT signature * Add a test case for scalar `fp32` arguments in `test_aot.py`
1 parent b1c554b commit dc51d41

2 files changed

Lines changed: 43 additions & 0 deletions

File tree

src/ninetoothed/aot.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,9 @@ def _build_variant(
235235
param_types.append(f"{tensor.value}")
236236
constexpr_param_indices.append(len(param_types) - 1)
237237
else:
238+
if dtype == ninetoothed.dtype.float32:
239+
dtype = ninetoothed.dtype.float64
240+
238241
param_types.append(dtype)
239242

240243
signature = ", ".join(param_types)

tests/test_aot.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,46 @@ def test_conv2d(
302302
assert torch.allclose(output, expected, rtol=rtol, atol=atol)
303303

304304

305+
@pytest.mark.parametrize("device", get_available_devices())
306+
def test_fp32_scalar(device):
307+
def _arrangement(input, scale, output):
308+
return input.tile((256,)), scale, output.tile((256,))
309+
310+
def _application(input, scale, output):
311+
output = input * scale # noqa: F841
312+
313+
tensors = (
314+
Tensor(1, dtype=ninetoothed.float32),
315+
Tensor(0, dtype=ninetoothed.float32),
316+
Tensor(1, dtype=ninetoothed.float32),
317+
)
318+
319+
caller = device
320+
kernel_name = f"fp32_scalar{_generate_kernel_name_suffix()}"
321+
output_dir = ninetoothed.generation.CACHE_DIR
322+
323+
kernel = ninetoothed.make(
324+
_arrangement,
325+
_application,
326+
tensors,
327+
caller=caller,
328+
kernel_name=kernel_name,
329+
output_dir=output_dir,
330+
)
331+
332+
size = 256
333+
334+
input = torch.randn(size, dtype=torch.float32, device=device)
335+
scale = 0.125
336+
output = torch.empty_like(input)
337+
338+
kernel(input, scale, output)
339+
340+
expected = input * scale
341+
342+
assert torch.allclose(output, expected)
343+
344+
305345
def _generate_kernel_name_suffix():
306346
count = _generate_kernel_name_suffix._kernel_count
307347
_generate_kernel_name_suffix._kernel_count += 1

0 commit comments

Comments
 (0)