11# SPDX-FileCopyrightText: Copyright (c) <2025> NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22#
33# SPDX-License-Identifier: Apache-2.0
4+ import inspect
45import math
56import re
67from dataclasses import dataclass
1617import tempfile
1718import threading
1819import traceback
19- from typing import Callable , Optional , Any , Set
20+ from typing import Callable , Optional , Any , Set , Sequence
2021import zipfile
2122
2223from cuda .tile ._cext import get_compute_capability , TileContext , default_tile_context
2324from cuda .tile ._compiler_options import CompilerOptions
2425from cuda .tile ._const_utils import get_constant_annotations
26+ from cuda .tile ._context import TileContextConfig
2527from cuda .tile ._exception import (
2628 TileCompilerError ,
2729 TileCompilerExecutionError ,
2830 TileCompilerTimeoutError , TileValueError , TileTypeError
2931)
3032from cuda .tile ._ir import ir , hir
31- from cuda .tile ._ir .ir import Argument
3233from cuda .tile ._ir .typing_support import typeof_pyval , get_constant_value
3334from cuda .tile ._passes .ast2hir import get_function_hir
3435from cuda .tile ._passes .code_motion import hoist_loop_invariants
@@ -73,13 +74,13 @@ def wrapper(*args, **kwargs):
7374 return wrapper
7475
7576
76- def _get_final_ir (pyfunc , args , tile_context ) -> ir .Function :
77- ir_ctx = ir .IRContext (tile_context )
77+ def _get_final_ir (pyfunc ,
78+ args : Sequence [ir .KernelArgument ],
79+ config : TileContextConfig ) -> ir .Function :
7880 func_hir : hir .Function = get_function_hir (pyfunc , entry_point = True )
7981
80- ir_args = _bind_kernel_arguments (tuple (func_hir .signature .parameters ),
81- args , get_constant_annotations (pyfunc ))
82- func_body = hir2ir (func_hir , ir_args , ir_ctx )
82+ ir_ctx = ir .IRContext (config )
83+ func_body = hir2ir (func_hir , args , ir_ctx )
8384 eliminate_assign_ops (func_body )
8485 dead_code_elimination_pass (func_body )
8586
@@ -100,7 +101,7 @@ def _get_final_ir(pyfunc, args, tile_context) -> ir.Function:
100101
101102def _bind_kernel_arguments (param_names : tuple [str , ...],
102103 args : tuple [Any , ...],
103- constant_args : Set [str ]) -> tuple [Argument , ...]:
104+ constant_args : Set [str ]) -> tuple [ir . KernelArgument , ...]:
104105 # TODO: unify this logic with dispatcher from c extension
105106 # Refactor "extract_cuda_args" to return type descriptor
106107 # that can be wrapped as IR Type for type inference.
@@ -120,9 +121,7 @@ def _bind_kernel_arguments(param_names: tuple[str, ...],
120121 raise TileTypeError (
121122 f"Argument `{ param_name } ` is a constexpr, "
122123 f"but the value is not a supported constant." )
123- ir_args .append (Argument (type = ty ,
124- is_const = is_const ,
125- const_value = const_val ))
124+ ir_args .append (ir .KernelArgument (type = ty , is_const = is_const , const_value = const_val ))
126125 return tuple (ir_args )
127126
128127
@@ -181,7 +180,9 @@ def compile_tile(pyfunc,
181180 args ,
182181 compiler_options : CompilerOptions ,
183182 context : TileContext = default_tile_context ) -> TileLibrary :
184- func_ir = _get_final_ir (pyfunc , args , context )
183+ param_names = tuple (inspect .signature (pyfunc ).parameters .keys ())
184+ ir_args = _bind_kernel_arguments (param_names , args , get_constant_annotations (pyfunc ))
185+ func_ir = _get_final_ir (pyfunc , ir_args , context .config )
185186
186187 if 'CUTILEIR' in context .config .log_keys :
187188 code = (f"==== CuTile IR for { func_ir .name } ==== \n \n "
0 commit comments