Skip to content

Commit c1278ba

Browse files
committed
Delay the eliminate_load_store pass until the hir2ir stage
This is in preparation for implementing closures/lambdas. But this also fixes an edge case with flattening of if-else statements with a constant condition (see test_constfold.py). Signed-off-by: Greg Bonik <gbonik@nvidia.com>
1 parent cde97fb commit c1278ba

9 files changed

Lines changed: 428 additions & 348 deletions

File tree

src/cuda/tile/_compile.py

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import tempfile
1717
import threading
1818
import traceback
19-
from typing import Callable, Optional
19+
from typing import Callable, Optional, Any, Set
2020
import zipfile
2121

2222
from cuda.tile._cext import get_compute_capability, TileContext, default_tile_context
@@ -25,10 +25,11 @@
2525
from cuda.tile._exception import (
2626
TileCompilerError,
2727
TileCompilerExecutionError,
28-
TileCompilerTimeoutError,
28+
TileCompilerTimeoutError, TileValueError, TileTypeError
2929
)
3030
from cuda.tile._ir import ir, hir
31-
from cuda.tile._ir.ir import bind_kernel_arguments
31+
from cuda.tile._ir.ir import Argument
32+
from cuda.tile._ir.typing_support import typeof_pyval, get_constant_value
3233
from cuda.tile._passes.ast2hir import get_function_hir
3334
from cuda.tile._passes.code_motion import hoist_loop_invariants
3435
from cuda.tile._passes.eliminate_assign_ops import eliminate_assign_ops
@@ -74,9 +75,9 @@ def wrapper(*args, **kwargs):
7475

7576
def _get_final_ir(pyfunc, args, tile_context) -> ir.Block:
7677
ir_ctx = ir.IRContext(tile_context)
77-
func_hir: hir.Block = get_function_hir(pyfunc, ir_ctx, call_site=None)
78+
func_hir: hir.Function = get_function_hir(pyfunc, ir_ctx, call_site=None)
7879

79-
ir_args = bind_kernel_arguments(func_hir.params, args, get_constant_annotations(pyfunc))
80+
ir_args = _bind_kernel_arguments(func_hir.param_names, args, get_constant_annotations(pyfunc))
8081
func_body = hir2ir(func_hir, ir_args, ir_ctx)
8182
eliminate_assign_ops(func_body)
8283
dead_code_elimination_pass(func_body)
@@ -96,6 +97,34 @@ def _get_final_ir(pyfunc, args, tile_context) -> ir.Block:
9697
return func_body
9798

9899

100+
def _bind_kernel_arguments(param_names: tuple[str, ...],
101+
args: tuple[Any, ...],
102+
constant_args: Set[str]) -> tuple[Argument, ...]:
103+
# TODO: unify this logic with dispatcher from c extension
104+
# Refactor "extract_cuda_args" to return type descriptor
105+
# that can be wrapped as IR Type for type inference.
106+
if len(args) != len(param_names):
107+
msg = f"Expected {len(param_names)} arguments, got {len(args)}"
108+
raise TileValueError(msg)
109+
110+
ir_args = []
111+
for param_name, arg_value in zip(param_names, args, strict=True):
112+
const_val = None
113+
is_const = param_name in constant_args
114+
ty = typeof_pyval(arg_value, kernel_arg=not is_const)
115+
if is_const:
116+
try:
117+
const_val = get_constant_value(arg_value)
118+
except TileTypeError:
119+
raise TileTypeError(
120+
f"Argument `{param_name}` is a constexpr, "
121+
f"but the value is not a supported constant.")
122+
ir_args.append(Argument(type=ty,
123+
is_const=is_const,
124+
const_value=const_val))
125+
return tuple(ir_args)
126+
127+
99128
def _log_mlir(bytecode_buf):
100129
try:
101130
from cuda.tile_internal import _internal_cext

src/cuda/tile/_ir/hir.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import enum
1818
from dataclasses import dataclass
1919
from textwrap import indent
20-
from typing import Any, Set
20+
from typing import Any, Set, Mapping
2121

2222
from cuda.tile._exception import Loc
2323
from cuda.tile._ir.ir import Var
@@ -95,6 +95,14 @@ def jump_str(self):
9595
return f"{self.jump._value_}{results_str} # Line {self.jump_loc.line}"
9696

9797

98+
@dataclass
99+
class Function:
100+
body: Block
101+
param_names: tuple[str, ...]
102+
param_locs: tuple[Loc, ...]
103+
frozen_globals: Mapping[str, Any]
104+
105+
98106
@dataclass
99107
class _OperandFormatter:
100108
blocks: list["Block"]
@@ -118,6 +126,8 @@ def __call__(self, x: Operand) -> str:
118126
# ==================================
119127

120128
def if_else(cond, then_block, else_block, /): ...
121-
def loop(body, iterable, /, *initial_values): ... # infinite if `iterable` is None
129+
def loop(body, iterable, /): ... # infinite if `iterable` is None
122130
def build_tuple(*items): ... # Makes a tuple (i.e. returns `items`)
123131
def identity(x): ... # Identity function (i.e. returns `x`)
132+
def store_var(name, value, /): ... # Store into a named variable
133+
def load_var(name, /): ... # Load from a named variable

src/cuda/tile/_ir/ir.py

Lines changed: 126 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,8 @@
1818
List, Optional, Dict, Tuple, Set, Any, TYPE_CHECKING, Sequence, Iterator
1919
)
2020
from .type import Type, InvalidType
21-
from .typing_support import typeof_pyval, get_constant_value, loose_type_of_pyval
2221
from cuda.tile._exception import (
23-
TileTypeError,
24-
TileValueError,
25-
Loc, TileInternalError
22+
TileTypeError, Loc, TileInternalError, TileSyntaxError
2623
)
2724
from .._cext import TileContext
2825

@@ -308,28 +305,122 @@ def finalize_loopvar_type(self, body_var: Var):
308305
class LoopInfo:
309306
var_states: tuple[LoopVarState, ...]
310307
is_for_loop: bool
308+
stored_names: tuple[str, ...]
309+
flatten: bool = False
311310

312311

313-
def get_innermost_loop() -> LoopInfo | None:
314-
return Builder.get_current().loop_info
312+
@dataclass
313+
class IfElseInfo:
314+
result_phis: tuple[PhiState, ...]
315+
stored_names: tuple[str, ...]
316+
flatten: bool = False
317+
flattened_results: tuple[Var, ...] = ()
318+
have_end_branch: bool = False
315319

316320

317321
@contextmanager
318322
def nested_block(name: str, loc: Loc, params: Sequence[Var] = (),
319-
loop_info: LoopInfo | None = None):
323+
loop_info: LoopInfo | None = None,
324+
if_else_info: IfElseInfo | None = None):
320325
prev_builder = Builder.get_current()
321326
block = Block(prev_builder.ir_ctx, params=params, name=name, loc=loc)
322327
new_loop_info = loop_info or prev_builder.loop_info
323-
with Builder(prev_builder.ir_ctx, loc, new_loop_info) as builder:
328+
new_if_else_info = if_else_info or prev_builder.if_else_info
329+
scope = prev_builder.scope
330+
with Builder(prev_builder.ir_ctx, loc, scope, new_loop_info, new_if_else_info) as builder, \
331+
scope.local.enter_branch():
324332
yield block
325333
block.extend(builder.ops)
326334

327335

336+
class LocalScope:
337+
def __init__(self,
338+
all_locals: Set[str],
339+
ir_ctx: IRContext,
340+
parent: Optional["LocalScope"] = None):
341+
self._all_locals = all_locals
342+
self._ir_ctx = ir_ctx
343+
self._map = dict()
344+
self._parent = parent
345+
346+
def is_local_name(self, name: str):
347+
current = self
348+
while current is not None:
349+
if name in current._all_locals:
350+
return True
351+
current = current._parent
352+
return False
353+
354+
def redefine(self, name: str, loc: Loc) -> Var:
355+
var = self._ir_ctx.make_var(name, loc)
356+
self._map[name] = var
357+
return var
358+
359+
def __getitem__(self, name: str):
360+
var = self._lookup(name)
361+
if var is None:
362+
raise TileSyntaxError(f"Undefined variable {name} used")
363+
return var
364+
365+
def get(self, name: str, loc: Loc):
366+
var = self._lookup(name)
367+
if var is None:
368+
return self._ir_ctx.make_var(name, loc, undefined=True)
369+
else:
370+
return var
371+
372+
def _lookup(self, name: str) -> Optional[Var]:
373+
seen = set()
374+
current = self
375+
while current is not None:
376+
var = current._map.get(name)
377+
if var is not None:
378+
return var
379+
# Sanity check, should not reach here.
380+
if id(current) in seen:
381+
raise RuntimeError("Cycle detected in Scope chain")
382+
seen.add(id(current))
383+
current = current._parent
384+
return None
385+
386+
@contextmanager
387+
def enter_branch(self):
388+
old = self._map
389+
self._map = _OverlayDict(old)
390+
try:
391+
yield
392+
finally:
393+
self._map = old
394+
395+
396+
class _OverlayDict:
397+
def __init__(self, orig_dict: dict):
398+
self._orig = orig_dict
399+
self._overlay = dict()
400+
401+
def get(self, key):
402+
value = self._overlay.get(key)
403+
return self._orig.get(key) if value is None else value
404+
405+
def __setitem__(self, key, value):
406+
self._overlay[key] = value
407+
408+
409+
@dataclass
410+
class Scope:
411+
local: LocalScope
412+
frozen_globals: Mapping[str, Any]
413+
414+
328415
class Builder:
329-
def __init__(self, ctx: IRContext, loc: Loc, loop_info: LoopInfo | None = None):
416+
def __init__(self, ctx: IRContext, loc: Loc, scope: Scope,
417+
loop_info: LoopInfo | None = None,
418+
if_else_info: IfElseInfo | None = None):
330419
self.ir_ctx = ctx
420+
self.scope = scope
331421
self.is_terminated = False
332422
self.loop_info = loop_info
423+
self.if_else_info = if_else_info
333424
self._loc = loc
334425
self._ops = []
335426
self._entered = False
@@ -363,6 +454,10 @@ def add_operation(self, op_class,
363454
def ops(self) -> list[Operation]:
364455
return self._ops
365456

457+
@property
458+
def loc(self) -> Loc:
459+
return self._loc
460+
366461
def append_verbatim(self, op: Operation):
367462
self._ops.append(op)
368463

@@ -384,6 +479,24 @@ def change_loc(self, loc: Loc):
384479
finally:
385480
self._loc = old_loc
386481

482+
@contextmanager
483+
def change_if_else_info(self, new_info: IfElseInfo):
484+
old = self.if_else_info
485+
self.if_else_info = new_info
486+
try:
487+
yield
488+
finally:
489+
self.if_else_info = old
490+
491+
@contextmanager
492+
def change_loop_info(self, new_info: LoopInfo):
493+
old = self.loop_info
494+
self.loop_info = new_info
495+
try:
496+
yield
497+
finally:
498+
self.loop_info = old
499+
387500
def __enter__(self):
388501
assert not self._entered
389502
self._prev_builder = _current_builder.builder
@@ -685,69 +798,8 @@ def __str__(self) -> str:
685798
return self.to_string()
686799

687800

688-
def bind_kernel_arguments(params: Tuple[Var, ...],
689-
args: Tuple[Any, ...],
690-
constant_args: Set[str]) -> Tuple[Argument, ...]:
691-
# TODO: unify this logic with dispatcher from c extension
692-
# Refactor "extract_cuda_args" to return type descriptor
693-
# that can be wrapped as IR Type for type inference.
694-
if len(args) != len(params):
695-
msg = f"Expected {len(params)} arguments, got {len(args)}"
696-
raise TileValueError(msg)
697-
698-
ir_args = []
699-
for param, arg_value in zip(params, args):
700-
const_val = None
701-
is_const = param.name in constant_args
702-
ty = typeof_pyval(arg_value, kernel_arg=not is_const)
703-
loose_type = ty
704-
if is_const:
705-
try:
706-
const_val = get_constant_value(arg_value)
707-
except TileTypeError:
708-
raise TileTypeError(
709-
f"Argument {param.name} is a constexpr, "
710-
f"but the value is not a supported constant.")
711-
loose_type = loose_type_of_pyval(arg_value)
712-
ir_args.append(Argument(type=ty,
713-
loose_type=loose_type,
714-
is_const=is_const,
715-
const_value=const_val))
716-
return tuple(ir_args)
717-
718-
801+
@dataclass
719802
class Argument:
720-
def __init__(self,
721-
type: Type,
722-
loose_type: Type,
723-
is_const: bool = False,
724-
const_value: Any = None):
725-
self._type = type
726-
self._loose_type = loose_type
727-
self._is_const = is_const
728-
self._const_value = const_value
729-
730-
@property
731-
def is_const(self) -> bool:
732-
return self._is_const
733-
734-
@property
735-
def const_value(self) -> Any:
736-
return self._const_value
737-
738-
@property
739-
def type(self) -> Type:
740-
return self._type
741-
742-
@property
743-
def loose_type(self) -> Type:
744-
return self._loose_type
745-
746-
def __eq__(self, value: object) -> bool:
747-
if not isinstance(value, Argument):
748-
return False
749-
return (
750-
self.type == value.type and
751-
self.is_const == value.is_const and
752-
self.const_value == value.const_value
753-
)
803+
type: Type
804+
is_const: bool = False
805+
const_value: Any = None

0 commit comments

Comments
 (0)