Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 3 additions & 10 deletions python/flydsl/compiler/jit_argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,7 @@ def get_dsl_type(cls, jit_arg_type: type) -> Type[DslType]:
return cls.jit_arg2dsl_type[jit_arg_type]


def _is_constexpr_annotation(annotation) -> bool:
"""Check if annotation is Constexpr or Constexpr[T]."""
if annotation is Constexpr:
return True
return get_origin(annotation) is Constexpr


def _is_type_param_annotation(annotation) -> bool:
def is_type_param_annotation(annotation) -> bool:
"""Check if annotation is Type, Type[T]."""
origin = get_origin(annotation)
return annotation is Type or origin is Type or origin is type
Expand All @@ -95,11 +88,11 @@ def convert_to_jit_arguments(
param = sig.parameters[param_name]
annotation = param.annotation

if annotation is not inspect.Parameter.empty and _is_constexpr_annotation(annotation):
if annotation is not inspect.Parameter.empty and Constexpr.is_constexpr_annotation(annotation):
constexpr_values[param_name] = value
continue

if annotation is not inspect.Parameter.empty and _is_type_param_annotation(annotation):
if annotation is not inspect.Parameter.empty and is_type_param_annotation(annotation):
constexpr_values[param_name] = value
continue

Expand Down
14 changes: 6 additions & 8 deletions python/flydsl/compiler/jit_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@
from .._mlir import ir
from .._mlir.dialects import func
from .._mlir.passmanager import PassManager
from ..expr.typing import Stream
from ..expr.typing import Constexpr, Stream
from ..utils import env, log
from .ast_rewriter import ASTRewriter
from .backends import compile_backend_name, get_backend
from .jit_argument import convert_to_jit_arguments
from .jit_argument import convert_to_jit_arguments, is_type_param_annotation
from .jit_executor import CompiledArtifact
from .kernel_function import (
CompilationContext,
Expand Down Expand Up @@ -945,18 +945,17 @@ def _build_call_state(sig, args_tuple, func_exe):

Returns a CallState, or None if any parameter can't be fast-pathed.
"""
from .jit_argument import _is_constexpr_annotation, _is_type_param_annotation

slot_specs = []
has_user_stream = False

for i, (param_name, param) in enumerate(sig.parameters.items()):
annotation = param.annotation

if annotation is not inspect.Parameter.empty and _is_constexpr_annotation(annotation):
if annotation is not inspect.Parameter.empty and Constexpr.is_constexpr_annotation(annotation):
continue

if annotation is not inspect.Parameter.empty and _is_type_param_annotation(annotation):
if annotation is not inspect.Parameter.empty and is_type_param_annotation(annotation):
continue

if getattr(annotation, "_is_stream_param", False):
Expand Down Expand Up @@ -1148,7 +1147,6 @@ def _make_cache_key(self, bound_args):
key — only the Python type matters. This prevents unnecessary
recompilation when only runtime values change.
"""
from .jit_argument import _is_constexpr_annotation, _is_type_param_annotation

sig = self._sig
key_parts = [("_target_", self._target)]
Expand All @@ -1158,11 +1156,11 @@ def _make_cache_key(self, bound_args):
param = sig.parameters.get(name)
ann = param.annotation if param else inspect.Parameter.empty

if ann is not inspect.Parameter.empty and _is_constexpr_annotation(ann):
if ann is not inspect.Parameter.empty and Constexpr.is_constexpr_annotation(ann):
key_parts.append((name, type(arg), arg))
continue

if ann is not inspect.Parameter.empty and _is_type_param_annotation(ann):
if ann is not inspect.Parameter.empty and is_type_param_annotation(ann):
key_parts.append((name, "Type", arg))
continue

Expand Down
10 changes: 2 additions & 8 deletions python/flydsl/compiler/kernel_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import threading
from contextlib import contextmanager
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, get_origin
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

from .._mlir import ir
from .._mlir.dialects import arith, gpu
Expand Down Expand Up @@ -459,12 +459,6 @@ def __get__(self, obj, objtype=None):
return self
return partial(self.__call__, obj)

@staticmethod
def _is_constexpr_annotation(annotation) -> bool:
if annotation is Constexpr:
return True
return get_origin(annotation) is Constexpr

def _emit_kernel(self, ctx: CompilationContext, args: Tuple, kwargs: Dict, bound_self: Any = None):
"""Emit gpu.func for this kernel into the GPU module."""
sig = self._sig
Expand All @@ -478,7 +472,7 @@ def _emit_kernel(self, ctx: CompilationContext, args: Tuple, kwargs: Dict, bound
for param_name, value in bound.arguments.items():
param = sig.parameters[param_name]
annotation = param.annotation
if annotation is not inspect.Parameter.empty and self._is_constexpr_annotation(annotation):
if annotation is not inspect.Parameter.empty and Constexpr.is_constexpr_annotation(annotation):
constexpr_values[param_name] = value
else:
param_names.append(param_name)
Expand Down
36 changes: 36 additions & 0 deletions python/flydsl/compiler/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,18 @@ def __get_ir_types__(self) -> List[ir.Type]: ...
def __get_c_pointers__(self) -> List[ctypes.c_void_p]: ...


@runtime_checkable
class Storable(Protocol):
@classmethod
def __dsl_size_of__(cls) -> int: ...
@classmethod
def __dsl_align_of__(cls) -> int: ...
@classmethod
def __peek_from_ptr__(cls, ptr: ir.Value): ...
@classmethod
def __poke_into_ptr__(cls, ptr: ir.Value, value): ...


def get_ir_types(obj) -> List[ir.Type]:
if isinstance(obj, ir.Value):
return [obj.type]
Expand Down Expand Up @@ -80,3 +92,27 @@ def construct_from_ir_values(dsl_type, args, values: List[ir.Value]) -> DslType:
values = values[count:]
return type(dsl_type)(elems)
raise TypeError(f"Cannot construct DSL value for {dsl_type}")


def dsl_size_of(dsl_type) -> int:
if hasattr(dsl_type, "__dsl_size_of__"):
return dsl_type.__dsl_size_of__()
raise TypeError(f"type {dsl_type} does not implement the Storable protocol")


def dsl_align_of(dsl_type) -> int:
if hasattr(dsl_type, "__dsl_align_of__"):
return dsl_type.__dsl_align_of__()
raise TypeError(f"type {dsl_type} does not implement the Storable protocol")


def peek_from_ptr(dsl_type, ptr: ir.Value):
if hasattr(dsl_type, "__peek_from_ptr__"):
return dsl_type.__peek_from_ptr__(ptr)
raise TypeError(f"type {dsl_type} does not implement the Storable protocol")


def poke_into_ptr(dsl_type, ptr: ir.Value, value):
if hasattr(dsl_type, "__poke_into_ptr__"):
return dsl_type.__poke_into_ptr__(ptr, value)
raise TypeError(f"type {dsl_type} does not implement the Storable protocol")
13 changes: 9 additions & 4 deletions python/flydsl/expr/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,13 @@
from .primitive import *
from .gpu import *
from .derived import *
from .struct import *

from . import utils

from . import arith, vector, gpu, buffer_ops, rocdl, math
from .rocdl import tdm_ops
from . import utils as utils
from . import arith as arith
from . import buffer_ops as buffer_ops
from . import gpu as gpu
from . import math as math
from . import rocdl as rocdl
from . import vector as vector
from .rocdl import tdm_ops as tdm_ops
42 changes: 42 additions & 0 deletions python/flydsl/expr/numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,44 @@ def _get_c_pointers(self):
return [ptr]

inferred_np = np_dtype if np_dtype is not None else _infer_np_dtype(width, signed, name)
is_storable = width >= 8

def _dsl_size_of(cls):
return 1 if cls.width < 8 else (cls.width + 7) // 8

def _dsl_align_of(cls):
return 1 if cls.width < 8 else (cls.width + 7) // 8

def _peek_from_ptr(cls, ptr):
from .primitive import ptr_load, recast_iter

typed_ptr = recast_iter(cls, ptr)
return cls(ptr_load(typed_ptr, cls))

def _poke_into_ptr(cls, ptr, value):
from .primitive import ptr_store, recast_iter

typed_ptr = recast_iter(cls, ptr)
coerced = value.to(cls) if isinstance(value, Numeric) else cls(value)
return ptr_store(coerced.ir_value(), typed_ptr)

def _not_storable(cls):
raise TypeError(f"sub-byte type {cls.__name__} (width={cls.width}) is not Storable")

new_attrs = {
"__extract_to_ir_values__": _extract_to_ir_values,
"__construct_from_ir_values__": classmethod(_construct_from_ir_values),
}
if is_storable:
new_attrs["__dsl_size_of__"] = classmethod(_dsl_size_of)
new_attrs["__dsl_align_of__"] = classmethod(_dsl_align_of)
new_attrs["__peek_from_ptr__"] = classmethod(_peek_from_ptr)
new_attrs["__poke_into_ptr__"] = classmethod(_poke_into_ptr)
elif any(hasattr(base, "__dsl_size_of__") for base in bases):
new_attrs["__dsl_size_of__"] = classmethod(_not_storable)
new_attrs["__dsl_align_of__"] = classmethod(_not_storable)
new_attrs["__peek_from_ptr__"] = classmethod(_not_storable)
new_attrs["__poke_into_ptr__"] = classmethod(lambda cls, ptr, value: _not_storable(cls))
if signed is not None:
new_attrs["__get_c_pointers__"] = _get_c_pointers

Expand Down Expand Up @@ -300,6 +333,15 @@ def select(self, true_value, false_value, *, loc=None):
"""Ternary select (for Boolean conditions from Int32 comparisons)."""
return ArithValue(self).select(true_value, false_value, loc=loc)

@classmethod
def __coerce__(cls, value):
if isinstance(value, cls):
return value
try:
return cls(value)
except Exception:
raise TypeError(f"expects {cls.__name__}, got {type(value).__name__}")

@property
def dtype(self) -> Type["Numeric"]:
return type(self)
Expand Down
Loading
Loading