Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 114 additions & 39 deletions src/ninetoothed/aot.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,14 +79,14 @@ def _find_tensor_by_source_name(tensors, name):
grid = f"{ast.unparse(grid_extractor.grid[0])}, 1, 1"

launch_arg_names = tuple(arg.arg for arg in launch_func.args.args)
variant_specs = _enumerate_stride_specs(
variant_specs = _enumerate_variant_specs(
launch_arg_names, tensors, _find_tensor_by_source_name
)

output_contents = {}

for variant_suffix, stride_spec in variant_specs:
variant_outputs = _build_stride_variant(
for variant_suffix, divisibility_spec, contiguity_spec in variant_specs:
variant_outputs = _build_variant(
source_file,
kernel_func,
launch_func,
Expand All @@ -98,11 +98,12 @@ def _find_tensor_by_source_name(tensors, name):
grid=grid,
num_warps=num_warps,
num_stages=num_stages,
stride_spec=stride_spec,
divisibility_spec=divisibility_spec,
contiguity_spec=contiguity_spec,
)
output_contents.update(variant_outputs)

dispatcher_source, dispatcher_header = _generate_stride_dispatcher(
dispatcher_source, dispatcher_header = _generate_dispatcher(
kernel_name, launch_arg_names, variant_specs
)

Expand All @@ -112,7 +113,7 @@ def _find_tensor_by_source_name(tensors, name):
return output_contents


def _generate_stride_dispatcher(kernel_name, launch_arg_names, variant_specs):
def _generate_dispatcher(kernel_name, launch_arg_names, variant_specs):
tensor_params = ", ".join(f"NineToothedTensor {name}" for name in launch_arg_names)
signature_params = (
f"NineToothedStream stream, {tensor_params}"
Expand All @@ -138,19 +139,20 @@ def _generate_stride_dispatcher(kernel_name, launch_arg_names, variant_specs):
externs = []
branches = []

for variant_suffix, stride_spec in variant_specs:
for variant_suffix, divisibility_spec, contiguity_spec in variant_specs:
variant_name = f"launch_{kernel_name}_{variant_suffix}"
externs.append(
f'extern "C" NineToothedResult {variant_name}({signature_params});'
)

call = f"return {variant_name}({call_args});"

if stride_spec:
check = " && ".join(
f"{name}.strides[{dim}] == 1" for name, dim in stride_spec
)
branches.append(f"{_INDENTATION}if ({check}) {call}")
checks = tuple(
f"{name}.shape[{dim}] % 16 == 0" for name, dim in divisibility_spec
) + tuple(f"{name}.strides[{dim}] == 1" for name, dim in contiguity_spec)

if checks:
branches.append(f"{_INDENTATION}if ({' && '.join(checks)}) {call}")
else:
branches.append(f"{_INDENTATION}{call}")

Expand All @@ -165,7 +167,7 @@ def _generate_stride_dispatcher(kernel_name, launch_arg_names, variant_specs):
return source, header


def _build_stride_variant(
def _build_variant(
source_file,
kernel_func,
launch_func,
Expand All @@ -178,9 +180,15 @@ def _build_stride_variant(
grid,
num_warps,
num_stages,
stride_spec,
divisibility_spec,
contiguity_spec,
):
spec_set = {(naming.remove_prefixes(name), dim) for name, dim in stride_spec}
divisibility_set = {
(naming.remove_prefixes(name), dim) for name, dim in divisibility_spec
}
contiguity_set = {
(naming.remove_prefixes(name), dim) for name, dim in contiguity_spec
}

param_strings = ["stream"]
param_types = []
Expand All @@ -198,14 +206,21 @@ def _build_stride_variant(
dtype = tensor.source.dtype

param_types.append(f"*{dtype}:16")
elif Tensor.size_pattern().fullmatch(param):
param_types.append(ninetoothed.dtype.int64)
elif match := Tensor.size_pattern().fullmatch(param):
source_name = match.group(1)
dim_index = int(match.group(3))
bare_source_name = naming.remove_prefixes(source_name)

if (bare_source_name, dim_index) in divisibility_set:
param_types.append(f"{ninetoothed.dtype.int64}:16")
else:
param_types.append(ninetoothed.dtype.int64)
elif match := Tensor.stride_pattern().fullmatch(param):
source_name = match.group(1)
dim_index = int(match.group(3))
bare_source_name = naming.remove_prefixes(source_name)

if (bare_source_name, dim_index) in spec_set:
if (bare_source_name, dim_index) in contiguity_set:
param_types.append("1")
constexpr_param_indices.append(len(param_types) - 1)
constexpr_strides.append((source_name, dim_index))
Expand Down Expand Up @@ -282,50 +297,110 @@ def _build_stride_variant(
return output_contents


def _enumerate_stride_specs(launch_arg_names, tensors, find_tensor):
def _enumerate_variant_specs(launch_arg_names, tensors, find_tensor):
per_tensor_dims, tensor_ndims, innermost_dims = _per_tensor_dim_options(
launch_arg_names, tensors, find_tensor
)

def _spec_from_combo(combo):
return tuple(
(name, dim) for name, dim in zip(launch_arg_names, combo) if dim is not None
)

dim_specs = tuple(
_spec_from_combo(combo)
for combo in itertools.product(*per_tensor_dims)
if any(dim is not None for dim in combo)
) + ((),)

specs = []

for divisibility_spec in dim_specs:
for contiguity_spec in dim_specs:
suffix = _variant_suffix(
Comment thread
voltjia marked this conversation as resolved.
divisibility_spec, contiguity_spec, launch_arg_names, tensor_ndims
)
specs.append((suffix, divisibility_spec, contiguity_spec))

def _num_innermost(spec):
return sum(1 for name, dim in spec if innermost_dims.get(name) == dim)

def _specificity(entry):
_, divisibility_spec, contiguity_spec = entry

return (
-len(divisibility_spec),
-_num_innermost(divisibility_spec),
-len(contiguity_spec),
-_num_innermost(contiguity_spec),
)

specs.sort(key=_specificity)

return specs


def _per_tensor_dim_options(launch_arg_names, tensors, find_tensor):
per_tensor_dims = []
tensor_ndims = []

for name in launch_arg_names:
tensor = find_tensor(tensors, name)
ndim = tensor.source.ndim if tensor is not None else 0
tensor_ndims.append(ndim)

if tensor is None or tensor.source.ndim == 0:
if ndim == 0:
per_tensor_dims.append((None,))
elif tensor.source.ndim == 1:
elif ndim == 1:
per_tensor_dims.append((0,))
else:
ndim = tensor.source.ndim
per_tensor_dims.append((ndim - 1, ndim - 2))

innermost = {
per_tensor_dims = tuple(per_tensor_dims)
tensor_ndims = tuple(tensor_ndims)

innermost_dims = {
name: dims[0]
for name, dims in zip(launch_arg_names, per_tensor_dims)
if dims[0] is not None
}

specs = []
return per_tensor_dims, tensor_ndims, innermost_dims

for combo in itertools.product(*per_tensor_dims):
spec = tuple(
(name, dim) for name, dim in zip(launch_arg_names, combo) if dim is not None
)

if not spec:
continue
def _variant_suffix(divisibility_spec, contiguity_spec, launch_arg_names, tensor_ndims):
divisibility_part = _divisibility_suffix(
divisibility_spec, launch_arg_names, tensor_ndims
)
contiguity_part = _contiguity_suffix(
contiguity_spec, launch_arg_names, tensor_ndims
)

suffix = "s_" + "_".join(str(dim) if dim is not None else "n" for dim in combo)
specs.append((suffix, spec))
return f"{divisibility_part}_{contiguity_part}"

specs.append(("generic", ()))

def _specificity(entry):
_, spec = entry
num_innermost = sum(1 for name, dim in spec if innermost.get(name) == dim)
def _divisibility_suffix(divisibility_spec, launch_arg_names, tensor_ndims):
hinted = set(divisibility_spec)

return (-len(spec), -num_innermost)
parts = tuple(
"16" if (name, dim) in hinted else "1"
for name, ndim in zip(launch_arg_names, tensor_ndims)
for dim in range(ndim)
)

specs.sort(key=_specificity)
return "divisibility_" + "_".join(parts) if parts else "divisibility"

return specs

def _contiguity_suffix(contiguity_spec, launch_arg_names, tensor_ndims):
contiguous = set(contiguity_spec)

parts = tuple(
"1" if (name, dim) in contiguous else "0"
for name, ndim in zip(launch_arg_names, tensor_ndims)
for dim in range(ndim)
)

return "contiguity_" + "_".join(parts) if parts else "contiguity"


_INDENTATION = " "
Expand Down
Loading