|
| 1 | +import functools |
1 | 2 | import itertools |
2 | 3 |
|
3 | 4 | import ninetoothed |
|
9 | 10 |
|
10 | 11 | _DEFAULT_NDIMS = (2, 3, 4) |
11 | 12 |
|
| 13 | +_BATCH_DIM_AUTO_TUNE_SIZE = 1 |
| 14 | +_NORMALIZED_DIM_AUTO_TUNE_SIZE = 256 |
| 15 | + |
12 | 16 | _CONFIGS = tuple( |
13 | 17 | ( |
14 | 18 | (), |
|
30 | 34 | ) |
31 | 35 |
|
32 | 36 |
|
| 37 | +def _shape_options(ndim, num_normalized_dims): |
| 38 | + batch_dims = ndim - num_normalized_dims |
| 39 | + |
| 40 | + return ( |
| 41 | + *({"upper_bound": _BATCH_DIM_AUTO_TUNE_SIZE} for _ in range(batch_dims)), |
| 42 | + *( |
| 43 | + {"upper_bound": _NORMALIZED_DIM_AUTO_TUNE_SIZE} |
| 44 | + for _ in range(num_normalized_dims) |
| 45 | + ), |
| 46 | + ) |
| 47 | + |
| 48 | + |
| 49 | +def _premake( |
| 50 | + ndim, |
| 51 | + num_normalized_dims, |
| 52 | + input_dtype=None, |
| 53 | + weight_dtype=None, |
| 54 | + output_dtype=None, |
| 55 | + block_size=None, |
| 56 | +): |
| 57 | + dims = tuple(-(dim + 1) for dim in range(num_normalized_dims)) |
| 58 | + arrangement = functools.partial( |
| 59 | + ntops.kernels.reduction.arrangement, |
| 60 | + dim=dims, |
| 61 | + block_size=block_size, |
| 62 | + ) |
| 63 | + shape_options = _shape_options(ndim, num_normalized_dims) |
| 64 | + tensors = ( |
| 65 | + ninetoothed.Tensor( |
| 66 | + ndim, |
| 67 | + other=0, |
| 68 | + dtype=input_dtype, |
| 69 | + shape_options=shape_options, |
| 70 | + ), |
| 71 | + ninetoothed.Tensor(ndim, dtype=weight_dtype, shape_options=shape_options), |
| 72 | + ninetoothed.Tensor(0, dtype=ninetoothed.float32), |
| 73 | + ninetoothed.Tensor(ndim, dtype=output_dtype, shape_options=shape_options), |
| 74 | + ninetoothed.Tensor(0, dtype=ninetoothed.uint64), |
| 75 | + ) |
| 76 | + |
| 77 | + return arrangement, ntops.kernels.rms_norm.application, tensors |
| 78 | + |
| 79 | + |
33 | 80 | def build(output_dir): |
34 | 81 | variant_dir = output_dir / "rms_norm" |
35 | 82 | variant_dir.mkdir(parents=True, exist_ok=True) |
36 | 83 | ninetoothed.build( |
37 | | - ntops.kernels.rms_norm.premake, |
| 84 | + _premake, |
38 | 85 | _CONFIGS, |
39 | 86 | meta_parameters=("block_size",), |
40 | 87 | caller="cuda", |
|
0 commit comments