Skip to content

Commit f0be470

Browse files
committed
Use dataclasses for Operations
Signed-off-by: Greg Bonik <gbonik@nvidia.com>
1 parent e1854cc commit f0be470

6 files changed

Lines changed: 443 additions & 679 deletions

File tree

src/cuda/tile/_ir/ir.py

Lines changed: 114 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,13 @@
44

55
from __future__ import annotations
66

7+
import dataclasses
78
import enum
89
import itertools
910
import threading
10-
from collections import OrderedDict
1111
from collections import defaultdict
1212
from collections.abc import Mapping
1313
from contextlib import contextmanager
14-
from copy import copy
1514
from dataclasses import dataclass
1615
from types import MappingProxyType
1716
from typing import (
@@ -316,11 +315,6 @@ def as_tuple(self) -> tuple["Var", ...]:
316315
)
317316

318317

319-
def terminator(cls):
320-
cls._is_terminator = True
321-
return cls
322-
323-
324318
class MemoryEffect(enum.IntEnum):
325319
# Int value assigned here is meaningful.
326320
# It implies the relative strength of memory effects.
@@ -330,18 +324,6 @@ class MemoryEffect(enum.IntEnum):
330324
STORE = 2
331325

332326

333-
def memory_effect(eff: MemoryEffect):
334-
def decorate(cls):
335-
cls.memory_effect = eff
336-
return cls
337-
return decorate
338-
339-
340-
def has_multiple_results(cls):
341-
cls._multiple_results = True
342-
return cls
343-
344-
345327
class Mapper:
346328
def __init__(self, ctx: IRContext, preserve_vars: bool = False):
347329
self._ctx = ctx
@@ -450,8 +432,8 @@ def add_operation(self, op_class,
450432
for var, ty in zip(result, result_ty, strict=True):
451433
if ty is not None:
452434
var.set_type(ty, force=force_type)
453-
if len(result) > 0 or op_class._multiple_results:
454-
attrs_and_operands["result_vars"] = result
435+
436+
result_vars = result
455437
else:
456438
if result is None:
457439
result = self.ir_ctx.make_temp(self._loc)
@@ -460,9 +442,10 @@ def add_operation(self, op_class,
460442
force_type = True
461443
if result_ty is not None:
462444
result.set_type(result_ty, force=force_type)
463-
attrs_and_operands["result_var"] = result
464445

465-
new_op = op_class(**attrs_and_operands, loc=self._loc)
446+
result_vars = (result,)
447+
448+
new_op = op_class(**attrs_and_operands, loc=self._loc, result_vars=result_vars)
466449
self._ops.append(new_op)
467450
if new_op.is_terminator:
468451
self.is_terminated = True
@@ -529,7 +512,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
529512

530513

531514
@contextmanager
532-
def nested_block(loc: Loc, reduction_body: bool = False, scan_body: bool = False):
515+
def enter_nested_block(loc: Loc, reduction_body: bool = False, scan_body: bool = False):
533516
prev_builder = Builder.get_current()
534517
block = Block(prev_builder.ir_ctx, loc=loc)
535518
with Builder(prev_builder.ir_ctx, loc,
@@ -546,66 +529,132 @@ class _CurrentBuilder(threading.local):
546529
_current_builder = _CurrentBuilder()
547530

548531

532+
class _FieldKind(enum.IntEnum):
533+
OPERAND = 0
534+
ATTRIBUTE = 1
535+
NESTED_BLOCK = 2
536+
537+
538+
_FIELD_METADATA_KEY = "operation_field_kind"
539+
540+
541+
def attribute(*, default=dataclasses.MISSING) -> dataclasses.Field:
542+
return dataclasses.field(default=default, metadata={_FIELD_METADATA_KEY: _FieldKind.ATTRIBUTE},
543+
kw_only=True)
544+
545+
546+
def operand(*, default=dataclasses.MISSING) -> dataclasses.Field:
547+
return dataclasses.field(default=default, metadata={_FIELD_METADATA_KEY: _FieldKind.OPERAND},
548+
kw_only=True)
549+
550+
551+
def nested_block() -> dataclasses.Field:
552+
return dataclasses.field(metadata={_FIELD_METADATA_KEY: _FieldKind.NESTED_BLOCK},
553+
kw_only=True)
554+
555+
556+
def _get_result_vars_tuple_for_single_result_op(self):
557+
return (self.result_var,)
558+
559+
560+
@dataclass(eq=False)
549561
class Operation:
550-
memory_effect = MemoryEffect.NONE
551-
_multiple_results = False
552-
553-
def __init__(
554-
self,
555-
op: str,
556-
operands: dict[str, Optional[Var | Tuple[Var, ...]]],
557-
result_vars: Sequence[Var],
558-
attributes: Optional[Dict[str, Any]] = None,
559-
nested_blocks: Optional[Sequence[Block]] = None,
560-
loc: Loc = Loc.unknown(),
561-
):
562-
self.op = op
563-
self.result_vars = result_vars or []
564-
self.attributes = attributes or {}
565-
self.nested_blocks = nested_blocks or []
566-
self.loc = loc
562+
result_vars: tuple[Var, ...]
563+
loc: Loc
564+
565+
def __init_subclass__(cls,
566+
opcode: str,
567+
terminator: bool = False,
568+
memory_effect: MemoryEffect = MemoryEffect.NONE):
569+
cls._opcode = opcode
570+
cls._is_terminator = terminator
571+
cls.memory_effect = memory_effect
572+
573+
operand_names = []
574+
attribute_names = []
575+
nested_block_names = []
576+
for field_name in cls.__annotations__.keys():
577+
f = getattr(cls, field_name, None)
578+
kind = f.metadata.get(_FIELD_METADATA_KEY) if isinstance(f, dataclasses.Field) else None
579+
if kind == _FieldKind.OPERAND:
580+
operand_names.append(field_name)
581+
elif kind == _FieldKind.ATTRIBUTE:
582+
attribute_names.append(field_name)
583+
elif kind == _FieldKind.NESTED_BLOCK:
584+
nested_block_names.append(field_name)
585+
else:
586+
raise TypeError(f"Field {field_name} of {cls} must be annotated with either"
587+
f" operand(), attribute() or nested_block()")
567588

568-
self._operands = OrderedDict()
569-
for k, v in operands.items():
570-
self._add_operand(k, v)
571-
self._is_terminator = getattr(self.__class__, "_is_terminator", False)
572-
self._parent_block = None
589+
cls._operand_names = tuple(operand_names)
590+
cls._attribute_names = tuple(attribute_names)
591+
cls._nested_block_names = tuple(nested_block_names)
592+
593+
def __post_init__(self):
594+
for var in self.all_inputs():
595+
assert isinstance(var, Var | tuple) or var is None
596+
if isinstance(var, tuple):
597+
assert all(isinstance(x, Var) for x in var)
598+
599+
if isinstance(var, Var) and var.is_aggregate() and self.op != "assign":
600+
# Don't allow aggregate values as operands, except for arrays and lists.
601+
# All other aggregates should only exist in the HIR level.
602+
# Also make an exception for the Assign op, until we find a better way to handle it.
603+
agg_val = var.get_aggregate()
604+
assert isinstance(agg_val, ArrayValue | ListValue)
605+
606+
for nb in self.nested_blocks:
607+
assert isinstance(nb, Block)
573608

574609
def clone(self, mapper: Mapper) -> Operation:
575610
result_vars = mapper.clone_vars(self.result_vars)
576611
return self._clone_impl(mapper, result_vars)
577612

578613
def _clone_impl(self, mapper: Mapper, result_vars: Sequence[Var]) -> Operation:
579-
new_nested_blocks = []
580-
for old_block in self.nested_blocks:
614+
new_fields = {}
615+
616+
for name in self._attribute_names:
617+
new_fields[name] = getattr(self, name)
618+
619+
for name in self._operand_names:
620+
var = getattr(self, name)
621+
if isinstance(var, Var):
622+
new_var = mapper.get_var(var)
623+
elif var is None:
624+
new_var = None
625+
else:
626+
new_var = tuple(mapper.get_var(v) for v in var)
627+
new_fields[name] = new_var
628+
629+
for name in self._nested_block_names:
630+
old_block = getattr(self, name)
581631
new_block = Block(old_block.ctx, old_block.loc)
582632
new_block.params = mapper.clone_vars(old_block.params)
583633
for old_op in old_block:
584634
new_block.append(old_op.clone(mapper))
585-
new_nested_blocks.append(new_block)
635+
new_fields[name] = new_block
586636

587-
ret = copy(self)
588-
ret._operands = OrderedDict()
589-
for name, var in self._operands.items():
590-
if isinstance(var, Var):
591-
ret._operands[name] = mapper.get_var(var)
592-
elif var is None:
593-
ret._operands[name] = None
594-
else:
595-
ret._operands[name] = tuple(mapper.get_var(v) for v in var)
637+
return type(self)(result_vars=tuple(result_vars), loc=self.loc, **new_fields)
596638

597-
ret.attributes = dict(ret.attributes)
598-
ret.result_vars = result_vars
599-
ret.parent_block = None
600-
ret.nested_blocks = new_nested_blocks
601-
return ret
639+
@property
640+
def op(self) -> str:
641+
return self._opcode
602642

603643
@property
604644
def operands(self) -> Mapping[str, Var | Tuple[Var, ...]]:
605-
return MappingProxyType(self._operands)
645+
return MappingProxyType({name: getattr(self, name) for name in self._operand_names})
646+
647+
@property
648+
def attributes(self):
649+
return MappingProxyType({name: getattr(self, name) for name in self._attribute_names})
650+
651+
@property
652+
def nested_blocks(self):
653+
return tuple(getattr(self, name) for name in self._nested_block_names)
606654

607655
def all_inputs(self) -> Iterator[Var]:
608-
for x in self._operands.values():
656+
for name in self._operand_names:
657+
x = getattr(self, name)
609658
if isinstance(x, tuple):
610659
yield from iter(x)
611660
elif x is not None:
@@ -615,28 +664,6 @@ def all_inputs(self) -> Iterator[Var]:
615664
def is_terminator(self) -> bool:
616665
return self._is_terminator
617666

618-
def _add_operand(self, name: str, var: Var | Tuple[Var, ...]):
619-
if isinstance(var, Var) and var.is_aggregate() and self.op != "assign":
620-
# Don't allow aggregate values as operands, except for arrays and lists.
621-
# All other aggregates should only exist in the HIR level.
622-
# Also make an exception for the Assign op, until we find a better way to handle it.
623-
agg_val = var.get_aggregate()
624-
assert isinstance(agg_val, ArrayValue | ListValue)
625-
self._operands[name] = var
626-
627-
def update_operand(self, name: str, var: Var | Tuple[Var, ...]):
628-
self._add_operand(name, var)
629-
630-
def __getattr__(self, name: str) -> Any:
631-
if name == "__setstate__":
632-
raise AttributeError(name)
633-
634-
if name in self.operands:
635-
return self.operands[name]
636-
if name in self.attributes:
637-
return self.attributes[name]
638-
raise AttributeError(f"{self.__class__.__name__} has no operand or attribute {name}")
639-
640667
@property
641668
def result_var(self) -> Var:
642669
if len(self.result_vars) != 1:
@@ -740,11 +767,6 @@ def format_var(var: Var) -> str:
740767
return ret
741768

742769

743-
# TODO: no longer needed, remove by inheriting from Operation instead
744-
class TypedOperation(Operation):
745-
pass
746-
747-
748770
class Block:
749771
def __init__(self, ctx: IRContext, loc: Loc):
750772
self.ctx = ctx

0 commit comments

Comments
 (0)