Skip to content

Commit 48ab3e3

Browse files
committed
Make _get_final_ir easier to use without real arguments
- Don't require a TileContext, take a TileContextConfig instead; - Take a list of ir.KernelArgument descriptors instead of actual args. Signed-off-by: Greg Bonik <gbonik@nvidia.com>
1 parent 548f99c commit 48ab3e3

File tree

4 files changed

+36
-28
lines changed

4 files changed

+36
-28
lines changed

src/cuda/tile/_compile.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-FileCopyrightText: Copyright (c) <2025> NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
#
33
# SPDX-License-Identifier: Apache-2.0
4+
import inspect
45
import math
56
import re
67
from dataclasses import dataclass
@@ -16,19 +17,19 @@
1617
import tempfile
1718
import threading
1819
import traceback
19-
from typing import Callable, Optional, Any, Set
20+
from typing import Callable, Optional, Any, Set, Sequence
2021
import zipfile
2122

2223
from cuda.tile._cext import get_compute_capability, TileContext, default_tile_context
2324
from cuda.tile._compiler_options import CompilerOptions
2425
from cuda.tile._const_utils import get_constant_annotations
26+
from cuda.tile._context import TileContextConfig
2527
from cuda.tile._exception import (
2628
TileCompilerError,
2729
TileCompilerExecutionError,
2830
TileCompilerTimeoutError, TileValueError, TileTypeError
2931
)
3032
from cuda.tile._ir import ir, hir
31-
from cuda.tile._ir.ir import Argument
3233
from cuda.tile._ir.typing_support import typeof_pyval, get_constant_value
3334
from cuda.tile._passes.ast2hir import get_function_hir
3435
from cuda.tile._passes.code_motion import hoist_loop_invariants
@@ -73,13 +74,13 @@ def wrapper(*args, **kwargs):
7374
return wrapper
7475

7576

76-
def _get_final_ir(pyfunc, args, tile_context) -> ir.Function:
77-
ir_ctx = ir.IRContext(tile_context)
77+
def _get_final_ir(pyfunc,
78+
args: Sequence[ir.KernelArgument],
79+
config: TileContextConfig) -> ir.Function:
7880
func_hir: hir.Function = get_function_hir(pyfunc, entry_point=True)
7981

80-
ir_args = _bind_kernel_arguments(tuple(func_hir.signature.parameters),
81-
args, get_constant_annotations(pyfunc))
82-
func_body = hir2ir(func_hir, ir_args, ir_ctx)
82+
ir_ctx = ir.IRContext(config)
83+
func_body = hir2ir(func_hir, args, ir_ctx)
8384
eliminate_assign_ops(func_body)
8485
dead_code_elimination_pass(func_body)
8586

@@ -100,7 +101,7 @@ def _get_final_ir(pyfunc, args, tile_context) -> ir.Function:
100101

101102
def _bind_kernel_arguments(param_names: tuple[str, ...],
102103
args: tuple[Any, ...],
103-
constant_args: Set[str]) -> tuple[Argument, ...]:
104+
constant_args: Set[str]) -> tuple[ir.KernelArgument, ...]:
104105
# TODO: unify this logic with dispatcher from c extension
105106
# Refactor "extract_cuda_args" to return type descriptor
106107
# that can be wrapped as IR Type for type inference.
@@ -120,9 +121,7 @@ def _bind_kernel_arguments(param_names: tuple[str, ...],
120121
raise TileTypeError(
121122
f"Argument `{param_name}` is a constexpr, "
122123
f"but the value is not a supported constant.")
123-
ir_args.append(Argument(type=ty,
124-
is_const=is_const,
125-
const_value=const_val))
124+
ir_args.append(ir.KernelArgument(type=ty, is_const=is_const, const_value=const_val))
126125
return tuple(ir_args)
127126

128127

@@ -181,7 +180,9 @@ def compile_tile(pyfunc,
181180
args,
182181
compiler_options: CompilerOptions,
183182
context: TileContext = default_tile_context) -> TileLibrary:
184-
func_ir = _get_final_ir(pyfunc, args, context)
183+
param_names = tuple(inspect.signature(pyfunc).parameters.keys())
184+
ir_args = _bind_kernel_arguments(param_names, args, get_constant_annotations(pyfunc))
185+
func_ir = _get_final_ir(pyfunc, ir_args, context.config)
185186

186187
if 'CUTILEIR' in context.config.log_keys:
187188
code = (f"==== CuTile IR for {func_ir.name}==== \n\n"

src/cuda/tile/_ir/ir.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,20 +22,21 @@
2222
TileTypeError, Loc, TileInternalError
2323
)
2424
from .._cext import TileContext
25+
from .._context import TileContextConfig
2526

2627
if TYPE_CHECKING:
2728
from cuda.tile._ir2bytecode import BytecodeContext
2829

2930

3031
class IRContext:
31-
def __init__(self, tile_ctx: TileContext):
32+
def __init__(self, config: TileContextConfig):
3233
self._all_vars: Dict[str, str] = {}
3334
self._counter_by_name: Dict[str, Iterator[int]] = defaultdict(itertools.count)
3435
self._temp_counter = itertools.count()
3536
self.typemap: Dict[str, Type] = dict()
3637
self.constants: Dict[str, Any] = dict()
3738
self._loose_typemap: Dict[str, Type] = dict()
38-
self.tile_ctx: TileContext = tile_ctx
39+
self.config: TileContext = config
3940
self._aggregate_values: Dict[str, Any] = dict()
4041

4142
# Make a Var with a unique name based on `name`.
@@ -822,7 +823,7 @@ class Function:
822823

823824

824825
@dataclass
825-
class Argument:
826+
class KernelArgument:
826827
type: Type
827828
is_const: bool = False
828829
const_value: Any = None

src/cuda/tile/_passes/hir2ir.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,14 @@
44
import inspect
55
import sys
66
from contextlib import contextmanager
7-
from typing import Any
7+
from typing import Any, Sequence
88

99
from .ast2hir import get_function_hir
1010
from .. import TileTypeError
1111
from .._coroutine_util import resume_after, run_coroutine
1212
from .._exception import Loc, TileSyntaxError, TileInternalError, TileError, TileRecursionError
1313
from .._ir import hir, ir
14-
from .._ir.ir import Var, IRContext, Argument, BoundMethodValue, ClosureValue
14+
from .._ir.ir import Var, IRContext, BoundMethodValue, ClosureValue, KernelArgument
1515
from .._ir.op_impl import op_implementations
1616
from .._ir.ops import loosely_typed_const, end_branch, return_, continue_, \
1717
break_, flatten_block_parameters, store_var
@@ -23,14 +23,13 @@
2323
MAX_RECURSION_DEPTH = 1000
2424

2525

26-
def hir2ir(func_hir: hir.Function,
27-
args: tuple[Argument, ...],
28-
ir_ctx: IRContext) -> ir.Block:
26+
def hir2ir(func_hir: hir.Function, args: Sequence[KernelArgument], ir_ctx: IRContext) -> ir.Block:
2927
# Run as a coroutine using a software stack, so that we don't exceed Python's recursion limit.
3028
return run_coroutine(_hir2ir_coroutine(func_hir, args, ir_ctx))
3129

3230

33-
async def _hir2ir_coroutine(func_hir: hir.Function, args: tuple[Argument, ...], ir_ctx: IRContext):
31+
async def _hir2ir_coroutine(func_hir: hir.Function, args: Sequence[KernelArgument],
32+
ir_ctx: IRContext):
3433
scope = _create_scope(func_hir, ir_ctx, call_site=None, parent_scopes=())
3534
aggregate_params = [
3635
scope.local.redefine(local_idx, param_loc)
@@ -50,7 +49,7 @@ async def _hir2ir_coroutine(func_hir: hir.Function, args: tuple[Argument, ...],
5049

5150
await _dispatch_hir_block_inner(func_hir.body, ir_builder)
5251
except Exception as e:
53-
if 'CUTILEIR' in ir_ctx.tile_ctx.config.log_keys:
52+
if 'CUTILEIR' in ir_ctx.config.log_keys:
5453
highlight_loc = e.loc if hasattr(e, 'loc') else None
5554
ir_str = "\n".join(op.to_string(highlight_loc=highlight_loc)
5655
for op in ir_builder.ops)
@@ -93,7 +92,7 @@ async def _dispatch_hir_block_inner(block: hir.Block, builder: ir.Builder):
9392
with _wrap_exceptions(loc), builder.change_loc(loc):
9493
_dispatch_hir_jump(block, scope)
9594
except Exception:
96-
if 'CUTILEIR' in builder.ir_ctx.tile_ctx.config.log_keys:
95+
if 'CUTILEIR' in builder.ir_ctx.config.log_keys:
9796
hir_params = ", ".join(p.name for p in block.params)
9897
hir_lines = [str(c) for c in block.calls]
9998
hir_lines.append(block.jump_str())

test/test_dce.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,26 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5-
import torch
6-
75
import cuda.tile as ct
6+
from cuda.tile._ir.ir import KernelArgument
87
from cuda.tile._ir.ops import Loop, Continue, Break
98
from cuda.tile._ir import ir
109
from cuda.tile._compile import _get_final_ir
1110
from cuda.tile._cext import default_tile_context
11+
from cuda.tile._ir.type import ArrayTy, TupleTy, SizeTy
1212

1313

1414
def get_ir(func) -> ir.Block:
15-
x = torch.zeros(10, device="cuda")
16-
ir = _get_final_ir(func, (x,), default_tile_context)
17-
print(ir)
15+
x = KernelArgument(type=ArrayTy(ct.int32,
16+
shape=TupleTy((SizeTy(),)),
17+
strides=TupleTy((SizeTy(1,),)),
18+
elements_disjoint=True,
19+
base_ptr_div_by=None,
20+
stride_div_by=(None,),
21+
shape_div_by=(None,)),
22+
is_const=False,
23+
const_value=None)
24+
ir = _get_final_ir(func, (x,), default_tile_context.config)
1825
return ir.body
1926

2027

0 commit comments

Comments
 (0)