From 9bb9385a005d9092ba561a0ae0a97a0b107d2432 Mon Sep 17 00:00:00 2001 From: Feng Shijie Date: Mon, 11 May 2026 10:02:19 +0000 Subject: [PATCH 1/4] [Feat] support composite type based on DslType --- python/flydsl/compiler/jit_argument.py | 13 +- python/flydsl/compiler/jit_function.py | 14 +- python/flydsl/compiler/kernel_function.py | 10 +- python/flydsl/compiler/protocol.py | 36 ++ python/flydsl/expr/__init__.py | 1 + python/flydsl/expr/numeric.py | 43 ++ python/flydsl/expr/struct.py | 638 ++++++++++++++++++++++ python/flydsl/expr/typing.py | 198 ++++++- tests/unit/test_struct.py | 533 ++++++++++++++++++ 9 files changed, 1456 insertions(+), 30 deletions(-) create mode 100644 python/flydsl/expr/struct.py create mode 100644 tests/unit/test_struct.py diff --git a/python/flydsl/compiler/jit_argument.py b/python/flydsl/compiler/jit_argument.py index 160d5ca44..5ec1f9b63 100644 --- a/python/flydsl/compiler/jit_argument.py +++ b/python/flydsl/compiler/jit_argument.py @@ -70,14 +70,7 @@ def get_dsl_type(cls, jit_arg_type: type) -> Type[DslType]: return cls.jit_arg2dsl_type[jit_arg_type] -def _is_constexpr_annotation(annotation) -> bool: - """Check if annotation is Constexpr or Constexpr[T].""" - if annotation is Constexpr: - return True - return get_origin(annotation) is Constexpr - - -def _is_type_param_annotation(annotation) -> bool: +def is_type_param_annotation(annotation) -> bool: """Check if annotation is Type, Type[T].""" origin = get_origin(annotation) return annotation is Type or origin is Type or origin is type @@ -95,11 +88,11 @@ def convert_to_jit_arguments( param = sig.parameters[param_name] annotation = param.annotation - if annotation is not inspect.Parameter.empty and _is_constexpr_annotation(annotation): + if annotation is not inspect.Parameter.empty and Constexpr.is_constexpr_annotation(annotation): constexpr_values[param_name] = value continue - if annotation is not inspect.Parameter.empty and _is_type_param_annotation(annotation): + if annotation is not inspect.Parameter.empty and is_type_param_annotation(annotation): constexpr_values[param_name] = value continue diff --git a/python/flydsl/compiler/jit_function.py b/python/flydsl/compiler/jit_function.py index 1b4ec48db..4276c85bf 100644 --- a/python/flydsl/compiler/jit_function.py +++ b/python/flydsl/compiler/jit_function.py @@ -22,11 +22,11 @@ from .._mlir import ir from .._mlir.dialects import func from .._mlir.passmanager import PassManager -from ..expr.typing import Stream +from ..expr.typing import Constexpr, Stream from ..utils import env, log from .ast_rewriter import ASTRewriter from .backends import compile_backend_name, get_backend -from .jit_argument import convert_to_jit_arguments +from .jit_argument import convert_to_jit_arguments, is_type_param_annotation from .jit_executor import CompiledArtifact from .kernel_function import ( CompilationContext, @@ -945,7 +945,6 @@ def _build_call_state(sig, args_tuple, func_exe): Returns a CallState, or None if any parameter can't be fast-pathed. """ - from .jit_argument import _is_constexpr_annotation, _is_type_param_annotation slot_specs = [] has_user_stream = False @@ -953,10 +952,10 @@ def _build_call_state(sig, args_tuple, func_exe): for i, (param_name, param) in enumerate(sig.parameters.items()): annotation = param.annotation - if annotation is not inspect.Parameter.empty and _is_constexpr_annotation(annotation): + if annotation is not inspect.Parameter.empty and Constexpr.is_constexpr_annotation(annotation): continue - if annotation is not inspect.Parameter.empty and _is_type_param_annotation(annotation): + if annotation is not inspect.Parameter.empty and is_type_param_annotation(annotation): continue if getattr(annotation, "_is_stream_param", False): @@ -1148,7 +1147,6 @@ def _make_cache_key(self, bound_args): key — only the Python type matters. This prevents unnecessary recompilation when only runtime values change. """ - from .jit_argument import _is_constexpr_annotation, _is_type_param_annotation sig = self._sig key_parts = [("_target_", self._target)] @@ -1158,11 +1156,11 @@ def _make_cache_key(self, bound_args): param = sig.parameters.get(name) ann = param.annotation if param else inspect.Parameter.empty - if ann is not inspect.Parameter.empty and _is_constexpr_annotation(ann): + if ann is not inspect.Parameter.empty and Constexpr.is_constexpr_annotation(ann): key_parts.append((name, type(arg), arg)) continue - if ann is not inspect.Parameter.empty and _is_type_param_annotation(ann): + if ann is not inspect.Parameter.empty and is_type_param_annotation(ann): key_parts.append((name, "Type", arg)) continue diff --git a/python/flydsl/compiler/kernel_function.py b/python/flydsl/compiler/kernel_function.py index f40abe422..1d5a4c2e7 100644 --- a/python/flydsl/compiler/kernel_function.py +++ b/python/flydsl/compiler/kernel_function.py @@ -5,7 +5,7 @@ import threading from contextlib import contextmanager from functools import partial -from typing import Any, Callable, Dict, List, Optional, Tuple, Union, get_origin +from typing import Any, Callable, Dict, List, Optional, Tuple, Union from .._mlir import ir from .._mlir.dialects import arith, gpu @@ -459,12 +459,6 @@ def __get__(self, obj, objtype=None): return self return partial(self.__call__, obj) - @staticmethod - def _is_constexpr_annotation(annotation) -> bool: - if annotation is Constexpr: - return True - return get_origin(annotation) is Constexpr - def _emit_kernel(self, ctx: CompilationContext, args: Tuple, kwargs: Dict, bound_self: Any = None): """Emit gpu.func for this kernel into the GPU module.""" sig = self._sig @@ -478,7 +472,7 @@ def _emit_kernel(self, ctx: CompilationContext, args: Tuple, kwargs: Dict, bound for param_name, value in bound.arguments.items(): param = sig.parameters[param_name] annotation = param.annotation - if annotation is not inspect.Parameter.empty and self._is_constexpr_annotation(annotation): + if annotation is not inspect.Parameter.empty and Constexpr.is_constexpr_annotation(annotation): constexpr_values[param_name] = value else: param_names.append(param_name) diff --git a/python/flydsl/compiler/protocol.py b/python/flydsl/compiler/protocol.py index 1a630113f..5c135fb8e 100644 --- a/python/flydsl/compiler/protocol.py +++ b/python/flydsl/compiler/protocol.py @@ -22,6 +22,18 @@ def __get_ir_types__(self) -> List[ir.Type]: ... def __get_c_pointers__(self) -> List[ctypes.c_void_p]: ... +@runtime_checkable +class Storable(Protocol): + @classmethod + def __dsl_size_of__(cls) -> int: ... + @classmethod + def __dsl_align_of__(cls) -> int: ... + @classmethod + def __peek_from_ptr__(cls, ptr: ir.Value): ... + @classmethod + def __poke_into_ptr__(cls, ptr: ir.Value, value): ... + + def get_ir_types(obj) -> List[ir.Type]: if isinstance(obj, ir.Value): return [obj.type] @@ -80,3 +92,27 @@ def construct_from_ir_values(dsl_type, args, values: List[ir.Value]) -> DslType: values = values[count:] return type(dsl_type)(elems) raise TypeError(f"Cannot construct DSL value for {dsl_type}") + + +def dsl_size_of(dsl_type) -> int: + if hasattr(dsl_type, "__dsl_size_of__"): + return dsl_type.__dsl_size_of__() + raise TypeError(f"type {dsl_type} does not implement the Storable protocol") + + +def dsl_align_of(dsl_type) -> int: + if hasattr(dsl_type, "__dsl_align_of__"): + return dsl_type.__dsl_align_of__() + raise TypeError(f"type {dsl_type} does not implement the Storable protocol") + + +def peek_from_ptr(dsl_type, ptr: ir.Value): + if hasattr(dsl_type, "__peek_from_ptr__"): + return dsl_type.__peek_from_ptr__(ptr) + raise TypeError(f"type {dsl_type} does not implement the Storable protocol") + + +def poke_into_ptr(dsl_type, ptr: ir.Value, value): + if hasattr(dsl_type, "__poke_into_ptr__"): + return dsl_type.__poke_into_ptr__(ptr, value) + raise TypeError(f"type {dsl_type} does not implement the Storable protocol") diff --git a/python/flydsl/expr/__init__.py b/python/flydsl/expr/__init__.py index 7c2bd331c..ddf93c270 100644 --- a/python/flydsl/expr/__init__.py +++ b/python/flydsl/expr/__init__.py @@ -6,6 +6,7 @@ from .primitive import * from .gpu import * from .derived import * +from .struct import * from . import utils diff --git a/python/flydsl/expr/numeric.py b/python/flydsl/expr/numeric.py index 8c993082a..575e85377 100644 --- a/python/flydsl/expr/numeric.py +++ b/python/flydsl/expr/numeric.py @@ -72,10 +72,44 @@ def _get_c_pointers(self): inferred_np = np_dtype if np_dtype is not None else _infer_np_dtype(width, signed, name) + is_storable = width >= 8 or (width == 1 and name == "Boolean") + + def _dsl_size_of(cls): + return 1 if cls.width < 8 else (cls.width + 7) // 8 + + def _dsl_align_of(cls): + return 1 if cls.width < 8 else (cls.width + 7) // 8 + + def _peek_from_ptr(cls, ptr): + from .primitive import ptr_load, recast_iter + + typed_ptr = recast_iter(cls, ptr) + return cls(ptr_load(typed_ptr, cls)) + + def _poke_into_ptr(cls, ptr, value): + from .primitive import ptr_store, recast_iter + + typed_ptr = recast_iter(cls, ptr) + coerced = value.to(cls) if isinstance(value, Numeric) else cls(value) + return ptr_store(coerced.ir_value(), typed_ptr) + + def _not_storable(cls): + raise TypeError(f"sub-byte type {cls.__name__} (width={cls.width}) is not Storable") + new_attrs = { "__extract_to_ir_values__": _extract_to_ir_values, "__construct_from_ir_values__": classmethod(_construct_from_ir_values), } + if is_storable: + new_attrs["__dsl_size_of__"] = classmethod(_dsl_size_of) + new_attrs["__dsl_align_of__"] = classmethod(_dsl_align_of) + new_attrs["__peek_from_ptr__"] = classmethod(_peek_from_ptr) + new_attrs["__poke_into_ptr__"] = classmethod(_poke_into_ptr) + elif any(hasattr(base, "__dsl_size_of__") for base in bases): + new_attrs["__dsl_size_of__"] = classmethod(_not_storable) + new_attrs["__dsl_align_of__"] = classmethod(_not_storable) + new_attrs["__peek_from_ptr__"] = classmethod(_not_storable) + new_attrs["__poke_into_ptr__"] = classmethod(lambda cls, ptr, value: _not_storable(cls)) if signed is not None: new_attrs["__get_c_pointers__"] = _get_c_pointers @@ -300,6 +334,15 @@ def select(self, true_value, false_value, *, loc=None): """Ternary select (for Boolean conditions from Int32 comparisons).""" return ArithValue(self).select(true_value, false_value, loc=loc) + @classmethod + def __coerce__(cls, value): + if isinstance(value, cls): + return value + try: + return cls(value) + except Exception: + raise TypeError(f"expects {cls.__name__}, got {type(value).__name__}") + @property def dtype(self) -> Type["Numeric"]: return type(self) diff --git a/python/flydsl/expr/struct.py b/python/flydsl/expr/struct.py new file mode 100644 index 000000000..7dc800588 --- /dev/null +++ b/python/flydsl/expr/struct.py @@ -0,0 +1,638 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 FlyDSL Project Contributors + +from __future__ import annotations + +from dataclasses import FrozenInstanceError, dataclass +from enum import Enum +from itertools import chain +from typing import Any, List + +from .._mlir import ir +from ..compiler.protocol import ( + dsl_align_of, + dsl_size_of, + extract_to_ir_values, + get_c_pointers, + get_ir_types, + peek_from_ptr, + poke_into_ptr, +) +from .primitive import add_offset +from .typing import Array, Constexpr, Pointer + +__all__ = [ + "struct", + "Struct", + "union", + "Union", + "Array", + "Align", + "Storage", +] + + +class CompositeKind(Enum): + Product = 0 + Sum = 1 + + +def is_composite_type(obj: Any) -> bool: + return isinstance(obj, type) and hasattr(obj, "__dsl_composite_kind__") + + +def is_struct_type(obj: Any) -> bool: + return is_composite_type(obj) and obj.__dsl_composite_kind__ == CompositeKind.Product + + +@dataclass(slots=True) +class FieldDef: + name: str + type_spec: Any + + +def _is_constexpr_type(type_spec: Any) -> bool: + return isinstance(type_spec, type) and issubclass(type_spec, Constexpr) + + +def _type_name(dtype: Any) -> str: + return getattr(dtype, "__name__", repr(dtype)) + + +def _display_name(schema: type) -> str: + return getattr(schema, "__dsl_display_name__", getattr(schema, "__name__", repr(schema))) + + +_RESERVED_FIELD_NAMES = frozenset({"peek", "poke", "replace"}) + + +def _validate_field_name(name: str, context: str): + if name.startswith("_"): + raise ValueError(f"{context}: field name '{name}' must not start with underscore") + if name in _RESERVED_FIELD_NAMES: + raise ValueError(f"{context}: field name '{name}' is reserved") + + +def _normalize_decorator_fields(klass: type) -> tuple[FieldDef, ...]: + annotations = getattr(klass, "__annotations__", {}) + context = klass.__name__ + for name in annotations: + _validate_field_name(name, context) + return tuple(FieldDef(name, annotation) for name, annotation in annotations.items()) + + +def _align_up(value: int, align: int) -> int: + if align <= 0: + raise ValueError(f"alignment must be positive, got {align}") + return (value + align - 1) // align * align + + +def _storage_layout(schema: type) -> tuple[int, int, dict[str, int]]: + cached = getattr(schema, "__dsl_storage_layout_cache__", None) + if cached is not None: + return cached + + fields = [field for field in schema.__dsl_field_defs__ if not _is_constexpr_type(field.type_spec)] + if not fields: + result = (0, 1, {}) + schema.__dsl_storage_layout_cache__ = result + return result + + def _field_layout(field: FieldDef) -> tuple[int, int]: + try: + return dsl_size_of(field.type_spec), dsl_align_of(field.type_spec) + except TypeError as exc: + raise TypeError( + f"Cannot compute layout for schema {_display_name(schema)}: field '{field.name}' has type " + f"{_type_name(field.type_spec)} which does not implement the Storable protocol." + ) from exc + + if schema.__dsl_composite_kind__ == CompositeKind.Sum: + sizes_aligns = [_field_layout(field) for field in fields] + align = max(a for _, a in sizes_aligns) + size = max(s for s, _ in sizes_aligns) + result = (_align_up(size, align), align, {field.name: 0 for field in fields}) + schema.__dsl_storage_layout_cache__ = result + return result + else: + offset = 0 + align = 1 + offsets: dict[str, int] = {} + for field in fields: + field_size, field_align = _field_layout(field) + offset = _align_up(offset, field_align) + offsets[field.name] = offset + offset += field_size + align = max(align, field_align) + result = (_align_up(offset, align), align, offsets) + schema.__dsl_storage_layout_cache__ = result + return result + + +def _coerce_value_type(schema: type, field: FieldDef, value: Any) -> Any: + type_spec = field.type_spec + coerce_fn = getattr(type_spec, "__coerce__", None) + if coerce_fn is not None: + try: + return coerce_fn(value) + except TypeError as exc: + raise TypeError(f"{_display_name(schema)}(...) field '{field.name}' {exc}") from exc + if isinstance(type_spec, type) and not isinstance(value, type_spec): + raise TypeError( + f"{_display_name(schema)}(...) field '{field.name}' expects " + f"{_type_name(type_spec)}, got {type(value).__name__}." + ) + return value + + +def _resolve_field(schema: type, key: int | str) -> FieldDef: + fields = schema.__dsl_field_defs__ + if isinstance(key, int): + if key < 0 or key >= len(fields): + raise IndexError(f"Index {key} out of range for schema {_display_name(schema)} with {len(fields)} fields.") + return fields[key] + if isinstance(key, str): + for f in fields: + if f.name == key: + return f + available = [f.name for f in fields] + raise KeyError(f"Field '{key}' not found in schema {_display_name(schema)}. Available fields: {available}.") + + +def _type_cache_key(type_spec: Any): + if is_composite_type(type_spec): + return ("composite", type_spec.__dsl_type_identity__) + sig = getattr(type_spec, "__cache_signature__", None) + if callable(sig): + try: + return ("sig", sig()) + except TypeError: + pass + try: + hash(type_spec) + except TypeError: + return ("repr", repr(type_spec)) + return type_spec + + +def _make_type_identity(policy: CompositeKind, fields: tuple[FieldDef, ...]): + return (policy, tuple((field.name, _type_cache_key(field.type_spec)) for field in fields)) + + +def _field_values_from_args(schema: type, args: tuple[Any, ...], kwargs: dict[str, Any]) -> dict[str, Any]: + fields = schema.__dsl_field_defs__ + if len(args) > len(fields): + raise TypeError( + f"{_display_name(schema)}(...) expected {len(fields)} field(s), got {len(args)} positional value(s)." + ) + + values: dict[str, Any] = {} + for field, value in zip(fields, args, strict=False): + values[field.name] = value + + field_names = {field.name for field in fields} + unexpected = set(kwargs) - field_names + for field in fields: + if field.name in kwargs: + if field.name in values: + raise TypeError(f"{_display_name(schema)}(...) got multiple values for field '{field.name}'.") + values[field.name] = kwargs[field.name] + + if unexpected: + raise TypeError(f"{_display_name(schema)}(...) got unexpected field(s): {sorted(unexpected)}.") + + missing = [field.name for field in fields if field.name not in values] + if missing: + raise TypeError(f"{_display_name(schema)}(...) missing required field(s): {missing}.") + return values + + +_specialization_cache: dict[tuple, type] = {} + + +def _specialize_type(base_cls: type, fields: tuple[FieldDef, ...], values: dict[str, Any]) -> type: + effective: list[tuple[str, Any]] = [] + needs_specialization = False + for field in fields: + value = values[field.name] + type_spec = field.type_spec + specializer = getattr(type_spec, "__specialize_for_value__", None) + if specializer is not None: + spec_type = specializer(value) + effective.append((field.name, spec_type)) + if spec_type is not type_spec: + needs_specialization = True + elif is_composite_type(type(value)): + sub_type = type(value) + effective.append((field.name, sub_type)) + if sub_type is not type_spec: + needs_specialization = True + else: + effective.append((field.name, type_spec)) + if not needs_specialization: + return base_cls + + cache_key = (base_cls, tuple(effective)) + cached = _specialization_cache.get(cache_key) + if cached is not None: + return cached + + suffix_parts = [] + for (name, eff_type), field in zip(effective, fields, strict=True): + if eff_type is field.type_spec: + continue + if isinstance(eff_type, type) and issubclass(eff_type, Constexpr) and eff_type.is_specialized: + suffix_parts.append(f"{name}={eff_type.value!r}") + suffix = f"[{', '.join(suffix_parts)}]" if suffix_parts else "" + namespace: dict[str, Any] = { + "__dsl_effective_field_defs__": tuple(effective), + "__dsl_base_type__": base_cls, + "__dsl_display_name__": _display_name(base_cls) + suffix, + } + specialized = type(base_cls.__name__ + suffix, (base_cls,), namespace) + _specialization_cache[cache_key] = specialized + return specialized + + +def _effective_field_defs(schema: type) -> tuple[tuple[str, Any], ...]: + effective = getattr(schema, "__dsl_effective_field_defs__", None) + if effective is not None: + return effective + base_cls = getattr(schema, "__dsl_base_type__", schema) + return tuple(getattr(base_cls, "__annotations__", {}).items()) + + +def _carrier_for_field(eff_type: Any, value: Any) -> Any: + if isinstance(eff_type, type) and issubclass(eff_type, Constexpr): + return eff_type + return value + + +def _construct_field_from_ir(type_spec: Any, values): + ctor = getattr(type_spec, "__construct_from_ir_values__", None) + if ctor is None: + raise TypeError(f"struct field type {_type_name(type_spec)} does not implement __construct_from_ir_values__") + return ctor(values) + + +def _ir_value_count_from_type(type_spec: Any) -> int: + if is_struct_type(type_spec): + return sum(_ir_value_count_from_type(eff) for _, eff in _effective_field_defs(type_spec)) + types_fn = getattr(type_spec, "__get_ir_types__", None) + if types_fn is not None and isinstance(type_spec, type): + try: + return len(types_fn()) + except TypeError: + pass + return 1 + + +def _normalize_inline_fields(params) -> tuple[FieldDef, ...]: + if not isinstance(params, tuple): + params = (params,) + if len(params) == 0: + raise ValueError("inline schema requires at least one field") + fields = [] + seen: set[str] = set() + for idx, item in enumerate(params): + if isinstance(item, slice): + if not isinstance(item.start, str) or item.stop is None: + raise TypeError("named inline fields must use Schema['name': Type] syntax") + name = item.start + _validate_field_name(name, "inline schema") + type_spec = item.stop + else: + name = f"_{idx}" + type_spec = item + if name in seen: + raise ValueError(f"duplicate inline field name '{name}'") + seen.add(name) + fields.append(FieldDef(name, type_spec)) + return tuple(fields) + + +def _inline_display_name(display: str, params, fields: tuple[FieldDef, ...]) -> str: + raw = params if isinstance(params, tuple) else (params,) + all_anonymous = all(not isinstance(item, slice) for item in raw) + if all_anonymous: + body = ", ".join(_type_name(field.type_spec) for field in fields) + else: + body = ", ".join(f"{field.name!r}: {_type_name(field.type_spec)}" for field in fields) + return f"{display}[{body}]" + + +def _make_composite_class( + *, + name: str, + module: str, + fields: tuple[FieldDef, ...], + policy: CompositeKind, + display_name: str, +): + identity = _make_type_identity(policy, fields) + + def __init__(self, *args, **kwargs): + if policy == CompositeKind.Sum: + raise TypeError( + f"Union schema {_display_name(type(self))} has no value form; use Storage[...] or allocator.allocate()." + ) + base_cls = getattr(type(self), "__dsl_base_type__", type(self)) + values = _field_values_from_args(base_cls, args, kwargs) + coerced = {field.name: _coerce_value_type(base_cls, field, values[field.name]) for field in fields} + specialized = _specialize_type(base_cls, fields, coerced) + if specialized is not type(self): + object.__setattr__(self, "__class__", specialized) + for field in fields: + object.__setattr__(self, field.name, coerced[field.name]) + object.__setattr__(self, "_schema_frozen", True) + + def __setattr__(self, key, value): + if getattr(self, "_schema_frozen", False): + raise FrozenInstanceError(f"cannot assign to field '{key}'") + object.__setattr__(self, key, value) + + def __delattr__(self, key): + if getattr(self, "_schema_frozen", False): + raise FrozenInstanceError(f"cannot delete field '{key}'") + object.__delattr__(self, key) + + def __repr__(self): + body = ", ".join(f"{field.name}={getattr(self, field.name)!r}" for field in fields) + return f"{_display_name(type(self))}({body})" + + def __eq__(self, other): + self_base = getattr(type(self), "__dsl_base_type__", None) + other_base = getattr(type(other), "__dsl_base_type__", None) + if self_base is None or self_base is not other_base: + return NotImplemented + return all(getattr(self, f.name) == getattr(other, f.name) for f in fields) + + def __hash__(self): + base = getattr(type(self), "__dsl_base_type__", type(self)) + return hash((base,) + tuple(getattr(self, f.name) for f in fields)) + + def replace(self, **kwargs): + values = {field.name: getattr(self, field.name) for field in fields} + for key, value in kwargs.items(): + field_def = _resolve_field(type(self), key) + values[field_def.name] = value + return type(self)(**values) + + def __extract_to_ir_values__(self) -> List[ir.Value]: + return list( + chain.from_iterable( + extract_to_ir_values(_carrier_for_field(eff_type, getattr(self, name))) + for name, eff_type in _effective_field_defs(type(self)) + ) + ) + + @classmethod + def __construct_from_ir_values__(cls, values): + rebuilt = {} + cursor = 0 + for name, eff_type in _effective_field_defs(cls): + nvalues = _ir_value_count_from_type(eff_type) + rebuilt[name] = _construct_field_from_ir(eff_type, values[cursor : cursor + nvalues]) + cursor += nvalues + if cursor != len(values): + raise ValueError(f"struct {_display_name(cls)} expected {cursor} ir.Values, got {len(values)}") + return cls(**rebuilt) + + def __get_ir_types__(self) -> List[ir.Type]: + return list( + chain.from_iterable( + get_ir_types(_carrier_for_field(eff_type, getattr(self, name))) + for name, eff_type in _effective_field_defs(type(self)) + ) + ) + + def __get_c_pointers__(self): + return list( + chain.from_iterable( + get_c_pointers(_carrier_for_field(eff_type, getattr(self, name))) + for name, eff_type in _effective_field_defs(type(self)) + ) + ) + + @classmethod + def __dsl_size_of__(cls) -> int: + return _storage_layout(cls)[0] + + @classmethod + def __dsl_align_of__(cls) -> int: + return _storage_layout(cls)[1] + + @classmethod + def __peek_from_ptr__(cls, ptr: Pointer): + raise NotImplementedError(f"{_display_name(cls)} does not support __peek_from_ptr__ yet") + + @classmethod + def __poke_into_ptr__(cls, ptr: Pointer, value): + raise NotImplementedError(f"{_display_name(cls)} does not support __poke_into_ptr__ yet") + + def __cache_signature__(self): + parts = [type(self)] + for field in fields: + value = getattr(self, field.name) + sig = getattr(value, "__cache_signature__", None) + if callable(sig): + parts.append((field.name, sig())) + return tuple(parts) + + namespace = { + "__module__": module, + "__annotations__": {field.name: field.type_spec for field in fields}, + "__dsl_composite_kind__": policy, + "__dsl_field_defs__": fields, + "__dsl_type_identity__": identity, + "__dsl_display_name__": display_name, + "__init__": __init__, + "__setattr__": __setattr__, + "__delattr__": __delattr__, + "__repr__": __repr__, + "__eq__": __eq__, + "__hash__": __hash__, + "__extract_to_ir_values__": __extract_to_ir_values__, + "__construct_from_ir_values__": __construct_from_ir_values__, + "__cache_signature__": __cache_signature__, + "__get_ir_types__": __get_ir_types__, + "__get_c_pointers__": __get_c_pointers__, + "__dsl_size_of__": __dsl_size_of__, + "__dsl_align_of__": __dsl_align_of__, + "__peek_from_ptr__": __peek_from_ptr__, + "__poke_into_ptr__": __poke_into_ptr__, + "replace": replace, + } + schema = type(name, (), namespace) + schema.__dsl_base_type__ = schema + return schema + + +class CompositeMeta(type): + def __new__(mcs, name, bases, namespace, *, policy=CompositeKind.Product, display=None, **kwargs): + cls = super().__new__(mcs, name, bases, namespace, **kwargs) + cls._policy = policy + cls._display = display or name + return cls + + def __call__(cls, klass=None, /, **kwargs): + policy = cls._policy + + def wrap(wrapped): + fields = _normalize_decorator_fields(wrapped) + return _make_composite_class( + name=wrapped.__name__, + module=wrapped.__module__, + fields=fields, + policy=policy, + display_name=wrapped.__name__, + ) + + if klass is None: + return wrap + return wrap(klass) + + def __getitem__(cls, params): + fields = _normalize_inline_fields(params) + display_name = _inline_display_name(cls._display, params, fields) + return _make_composite_class( + name=f"_Dsl{cls._display}_{abs(hash(_make_type_identity(cls._policy, fields)))}", + module=__name__, + fields=fields, + policy=cls._policy, + display_name=display_name, + ) + + def __repr__(cls): + return cls._display.lower() + + +class struct(metaclass=CompositeMeta, policy=CompositeKind.Product, display="Struct"): ... + + +class union(metaclass=CompositeMeta, policy=CompositeKind.Sum, display="Union"): ... + + +Struct = struct +Union = union + + +class Align: + __dsl_align_wrapper__: bool = False + dtype: Any = None + align: int | None = None + + def __class_getitem__(cls, params): + if cls is not Align: + raise TypeError(f"{cls.__name__} cannot be re-parametrized") + if not isinstance(params, tuple) or len(params) != 2: + raise TypeError("struct.Align expects struct.Align[Type, N]") + dtype, requested_align = params + if isinstance(requested_align, bool) or not isinstance(requested_align, int): + raise TypeError(f"struct.Align alignment must be an int, got {requested_align!r}") + if requested_align <= 0: + raise ValueError(f"struct.Align alignment must be positive, got {requested_align}") + if not (requested_align > 0 and (requested_align & (requested_align - 1)) == 0): + raise ValueError(f"struct.Align alignment must be a power of two, got {requested_align}") + natural = dsl_align_of(dtype) + if requested_align < natural: + raise ValueError( + f"struct.Align[{_type_name(dtype)}, {requested_align}]: requested alignment {requested_align} " + f"is smaller than natural alignment {natural} of {_type_name(dtype)}; use a value >= {natural}." + ) + + def _aligned_size_of(inner=dtype): + return dsl_size_of(inner) + + def _aligned_align_of(val=requested_align): + return val + + def _aligned_peek_from_ptr(ptr, inner=dtype): + return peek_from_ptr(inner, ptr) + + def _aligned_poke_into_ptr(ptr, value, inner=dtype): + return poke_into_ptr(inner, ptr, value) + + inner_key = _type_cache_key(dtype) + + def _cache_sig(key=inner_key, a=requested_align): + return ("align", key, a) + + return type( + f"struct.Align[{_type_name(dtype)}, {requested_align}]", + (cls,), + { + "dtype": dtype, + "align": requested_align, + "__dsl_align_wrapper__": True, + "__cache_signature__": classmethod(lambda cls, _f=_cache_sig: _f()), + "__dsl_size_of__": classmethod(lambda cls, _f=_aligned_size_of: _f()), + "__dsl_align_of__": classmethod(lambda cls, _f=_aligned_align_of: _f()), + "__peek_from_ptr__": classmethod(lambda cls, ptr, _f=_aligned_peek_from_ptr: _f(ptr)), + "__poke_into_ptr__": classmethod(lambda cls, ptr, value, _f=_aligned_poke_into_ptr: _f(ptr, value)), + }, + ) + + +class Storage: + """Typed memory view: ``Storage[T]`` wraps an ``i8`` pointer for Storable ``T``. + + - ``._ptr`` — underlying i8* pointer + - ``.peek()`` — calls ``T.__peek_from_ptr__(ptr)`` + - ``.poke(value)`` — calls ``T.__poke_into_ptr__(ptr, value)`` + - For composite T: attribute access (``storage.field_name``) returns ``Storage[FieldType]`` + """ + + _target_type = None + _cache: dict[Any, type] = {} + + def __class_getitem__(cls, target_type): + cached = Storage._cache.get(target_type) + if cached is not None: + return cached + + target_name = _type_name(target_type) + + class _StorageImpl(Storage): + _target_type = target_type + + def __init__(self, ptr): + object.__setattr__(self, "_ptr", ptr) + + def peek(self): + dsl_type = type(self)._target_type + return peek_from_ptr(dsl_type, object.__getattribute__(self, "_ptr")) + + def poke(self, value): + dsl_type = type(self)._target_type + return poke_into_ptr(dsl_type, object.__getattribute__(self, "_ptr"), value) + + def __getattr__(self, name): + dsl_type = type(self)._target_type + if is_composite_type(dsl_type): + try: + field_def = _resolve_field(dsl_type, name) + except KeyError: + raise AttributeError( + f"Storage[{target_name}] has no field '{name}'. " + f"Available: {[f.name for f in dsl_type.__dsl_field_defs__]}" + ) from None + _, _, offsets = _storage_layout(dsl_type) + if field_def.name not in offsets: + raise AttributeError( + f"Storage[{target_name}] field '{field_def.name}' is compile-time only and has no storage" + ) from None + offset = offsets[field_def.name] + return Storage[field_def.type_spec](add_offset(object.__getattribute__(self, "_ptr"), offset)) + raise AttributeError(f"Storage[{target_name}] has no attribute '{name}'") + + def __repr__(self): + return f"Storage[{target_name}]({object.__getattribute__(self, '_ptr')})" + + _StorageImpl.__name__ = f"Storage[{target_name}]" + _StorageImpl.__qualname__ = f"Storage[{target_name}]" + Storage._cache[target_type] = _StorageImpl + return _StorageImpl + + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) diff --git a/python/flydsl/expr/typing.py b/python/flydsl/expr/typing.py index d79dad7eb..ba6267ba7 100644 --- a/python/flydsl/expr/typing.py +++ b/python/flydsl/expr/typing.py @@ -5,7 +5,7 @@ import enum import operator from inspect import isclass -from typing import Generic, Type, TypeVar, overload +from typing import Any, List, Type, overload from flydsl.runtime.device import get_rocm_arch @@ -321,11 +321,107 @@ def is_target_address_space(address_space, expected) -> bool: return str(actual) == str(exp) -ValueT = TypeVar("ValueT") +class Constexpr: + _annotation_cache: dict = {} + _value_cache: dict = {} + value_type: type | None = None + value: Any = None + is_specialized: bool = False -class Constexpr(Generic[ValueT]): - pass + def __class_getitem__(cls, param): + if cls is not Constexpr: + raise TypeError(f"{cls.__name__} cannot be re-parametrized") + if not isinstance(param, type): + raise TypeError( + f"Constexpr[...] expects a type (e.g. Constexpr[int]), " + f"got value {param!r}; constexpr values are provided at call site" + ) + cached = Constexpr._annotation_cache.get(param) + if cached is not None: + return cached + result = type( + f"Constexpr[{getattr(param, '__name__', repr(param))}]", + (Constexpr,), + { + "__origin__": Constexpr, + "__args__": (param,), + "value_type": param, + "value": None, + "is_specialized": False, + }, + ) + Constexpr._annotation_cache[param] = result + return result + + @classmethod + def _specialize(cls, value): + cache_key = (type(value), value) + try: + cached = Constexpr._value_cache.get(cache_key) + if cached is not None: + return cached + except TypeError: + cached = None + result = type( + f"Constexpr[{value!r}]", + (Constexpr,), + { + "__origin__": Constexpr, + "__args__": (type(value),), + "value_type": type(value), + "value": value, + "is_specialized": True, + }, + ) + try: + Constexpr._value_cache[cache_key] = result + except TypeError: + # cache_key is not hashable + pass + return result + + @classmethod + def __construct_from_ir_values__(cls, values): + if values: + raise ValueError(f"{cls.__name__} expects 0 ir.Values, got {len(values)}") + if not cls.is_specialized: + raise TypeError( + f"{cls.__name__} must be value-specialized (e.g. Constexpr[42]) " + f"before reconstruction; the surrounding schema did not bind a value." + ) + return cls.value + + @classmethod + def __extract_to_ir_values__(cls): + return [] + + @classmethod + def __get_ir_types__(cls): + return [] + + @classmethod + def __get_c_pointers__(cls): + return [] + + @classmethod + def __coerce__(cls, value): + inner = cls.value_type + if inner is not None and not isinstance(value, inner): + raise TypeError(f"expects {getattr(inner, '__name__', repr(inner))}, got {type(value).__name__}") + return value + + @classmethod + def __specialize_for_value__(cls, value): + if cls.is_specialized: + return cls + return Constexpr._specialize(value) + + @staticmethod + def is_constexpr_annotation(annotation) -> bool: + if annotation is Constexpr: + return True + return isinstance(annotation, type) and issubclass(annotation, Constexpr) class BuiltinDslType(ir.Value): @@ -614,6 +710,10 @@ def memspace(self): def alignment(self): return self.type.alignment + @traced_op + def view(self, layout, loc=None, ip=None): + return make_view(self, layout, loc=loc, ip=ip) + @ir.register_value_caster(MemRefType.static_typeid, replace=True) @ir.register_value_caster(CoordTensorType.static_typeid, replace=True) @@ -1493,3 +1593,93 @@ def ones_like(a: Vector, dtype=None, *, loc=None, ip=None) -> Vector: def zeros_like(a: Vector, dtype=None, *, loc=None, ip=None) -> Vector: return Vector.zeros_like(a, dtype, loc=loc, ip=ip) + + +class Array: + _cache: dict[tuple, type] = {} + + class _Base: + dtype = None + size = None + align = None + + def __init__(self, ptr_value): + self._ptr_value = ptr_value + + def __repr__(self): + cls = type(self) + name = getattr(cls.dtype, "__name__", repr(cls.dtype)) + suffix = f", {cls.align}" if cls.align != max(1, cls.dtype.width // 8) else "" + return f"Array[{name}, {cls.size}{suffix}]({self._ptr_value})" + + @classmethod + def __construct_from_ir_values__(cls, values): + if len(values) != 1: + raise ValueError(f"{cls.__name__} expects 1 ir.Value, got {len(values)}") + return cls(values[0]) + + def __extract_to_ir_values__(self) -> List[ir.Value]: + return [self._ptr_value] + + @classmethod + def __cache_signature__(cls): + return ("array", cls.dtype, cls.size, cls.align) + + @classmethod + def __dsl_size_of__(cls) -> int: + total_bytes = max(1, cls.dtype.width * cls.size // 8) + return total_bytes + + @classmethod + def __dsl_align_of__(cls) -> int: + return cls.align + + @classmethod + def __peek_from_ptr__(cls, ptr): + typed_ptr = recast_iter(cls.dtype, ptr) + return cls(typed_ptr) + + @classmethod + def __poke_into_ptr__(cls, ptr, value): + raise NotImplementedError(f"{cls.__name__} does not support __poke_into_ptr__ yet") + + def view(self, layout, *, loc=None, ip=None): + return make_view(self._ptr_value, layout, loc=loc, ip=ip) + + def __class_getitem__(cls, params): + if not isinstance(params, tuple): + params = (params,) + if len(params) == 2: + dtype, size = params + align = None + elif len(params) == 3: + dtype, size, align = params + else: + raise TypeError("Array expects Array[dtype, size] or Array[dtype, size, align]") + + if not (isinstance(dtype, type) and issubclass(dtype, Numeric)): + raise TypeError(f"Array dtype must be a Numeric subclass, got {dtype!r}") + if not isinstance(size, int) or size <= 0: + raise TypeError(f"Array size must be a positive integer, got {size!r}") + + elem_byte_size = max(1, dtype.width // 8) + if align is None: + align = elem_byte_size + else: + if not isinstance(align, int) or align <= 0: + raise TypeError(f"Array align must be a positive integer, got {align!r}") + + cache_key = (dtype, size, align) + cached = cls._cache.get(cache_key) + if cached is not None: + return cached + + name = getattr(dtype, "__name__", repr(dtype)) + suffix = f", {align}" if align != elem_byte_size else "" + array_type = type( + f"Array[{name}, {size}{suffix}]", + (cls._Base,), + {"dtype": dtype, "size": size, "align": align}, + ) + cls._cache[cache_key] = array_type + return array_type diff --git a/tests/unit/test_struct.py b/tests/unit/test_struct.py new file mode 100644 index 000000000..49462a18a --- /dev/null +++ b/tests/unit/test_struct.py @@ -0,0 +1,533 @@ +#!/usr/bin/env python3 + +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 FlyDSL Project Contributors + +"""Unit tests for unified struct / union / Array / Storage types.""" + +import ctypes + +import pytest + +import flydsl.compiler as flyc +import flydsl.expr as fx +from flydsl._mlir import ir +from flydsl.compiler import jit_function +from flydsl.compiler.protocol import ( + construct_from_ir_values, + dsl_align_of, + dsl_size_of, + extract_to_ir_values, + get_c_pointers, + get_ir_types, +) +from flydsl.expr.numeric import Float32, Int32, Uint8 +from flydsl.expr.struct import Storage +from flydsl.expr.typing import Array + +pytestmark = pytest.mark.l0_backend_agnostic + + +@pytest.fixture +def frontend_only_jit(monkeypatch): + monkeypatch.setenv("FLYDSL_COMPILE_BACKEND", "rocm") + monkeypatch.setenv("FLYDSL_RUNTIME_KIND", "rocm") + monkeypatch.setenv("ARCH", "gfx942") + monkeypatch.setenv("COMPILE_ONLY", "1") + monkeypatch.setenv("FLYDSL_RUNTIME_ENABLE_CACHE", "0") + monkeypatch.setattr(jit_function, "_flydsl_key", lambda: "test-flydsl-key") + + def compile_noop(cls, module, **_kwargs): + return module + + monkeypatch.setattr(jit_function.MlirCompiler, "compile", classmethod(compile_noop)) + + +# --------------------------------------------------------------------------- +# struct basics +# --------------------------------------------------------------------------- + + +def test_struct_decorator_creates_frozen_value_schema(): + @fx.struct + class Pair: + a: Int32 + b: Float32 + + p = Pair(1, 2.0) + assert tuple(Pair.__annotations__) == ("a", "b") + assert isinstance(p.a, Int32) + assert isinstance(p.b, Float32) + assert p.a == 1 + assert p.b == 2.0 + + q = p.replace(a=3) + assert q.a == 3 + assert p.a == 1 + + with pytest.raises(Exception, match="cannot assign"): + p.a = Int32(4) + + +def test_struct_constructor_requires_exact_fields(): + @fx.struct + class Pair: + a: Int32 + b: Float32 + + with pytest.raises(TypeError, match="missing required field"): + Pair(a=1) + with pytest.raises(TypeError, match="unexpected field"): + Pair(a=1, b=2.0, c=3) + with pytest.raises(TypeError, match="expects Int32"): + Pair(a=object(), b=2.0) + + +def test_inline_struct_named_and_positional_forms(): + Named = fx.Struct["a":Int32, "b":Float32] + assert Named is not fx.Struct["a":Int32, "b":Float32] + assert Named.__dsl_type_identity__ == fx.Struct["a":Int32, "b":Float32].__dsl_type_identity__ + + n = Named(1, 2.0) + assert n.a == 1 + assert n.b == 2.0 + + Pos = fx.Struct["a":Int32, Float32] + assert tuple(Pos.__annotations__) == ("a", "_1") + + p = Pos(3, 4.0) + assert p.a == 3 + assert p._1 == 4.0 + + Anonymous = fx.Struct[Int32, Float32] + Positional = fx.Struct[Int32, Float32] + assert Anonymous.__dsl_type_identity__ == Positional.__dsl_type_identity__ + assert tuple(Anonymous.__annotations__) == ("_0", "_1") + anon = Anonymous(5, 6.0) + assert anon._0 == 5 + assert anon._1 == 6.0 + assert anon == Anonymous(5, 6.0) + with pytest.raises(AttributeError): + anon.nonexistent + + +def test_union_schema_has_no_value_form(): + @fx.union + class Variant: + i: Int32 + f: Float32 + + with pytest.raises(TypeError, match="no value form"): + Variant(i=Int32(1)) + + Inline = fx.Union["i":Int32, "f":Float32] + with pytest.raises(TypeError, match="no value form"): + Inline(Int32(1)) + + +def test_struct_flattens_non_constexpr_fields_only(frontend_only_jit): + @fx.struct + class Params: + a: Int32 + b: Float32 + n: fx.Constexpr[int] + + @flyc.jit + def build(p: Params): + values = p.__extract_to_ir_values__() + assert len(values) == 2 + assert isinstance(values[0].type, ir.IntegerType) + assert isinstance(values[1].type, ir.F32Type) + assert [str(t) for t in get_ir_types(p)] == [str(v.type) for v in values] + assert p.n == 32 + + rebuilt = construct_from_ir_values(type(p), p, values) + assert isinstance(rebuilt, Params) + assert rebuilt.n == 32 + assert [v.get_name() for v in extract_to_ir_values(rebuilt)] == [v.get_name() for v in values] + + build(Params(a=Int32(7), b=Float32(2.0), n=32)) + + +def test_constexpr_is_not_part_of_storage_layout(): + @fx.struct + class Params: + n: fx.Constexpr[int] + a: Int32 + + assert dsl_size_of(Params) == 4 + assert dsl_align_of(Params) == 4 + storage = Storage[Params](None) + with pytest.raises(AttributeError, match="compile-time only"): + storage.n + with pytest.raises(TypeError, match="Storable"): + dsl_size_of(fx.Constexpr[int]) + + +def test_nested_struct_round_trip_via_exemplar(frontend_only_jit): + @fx.struct + class Inner: + x: Int32 + y: Int32 + + @fx.struct + class Outer: + head: Int32 + inner: Inner + tail: Float32 + + @flyc.jit + def build(outer: Outer): + flat = outer.__extract_to_ir_values__() + rebuilt = construct_from_ir_values(type(outer), outer, flat) + assert isinstance(rebuilt.inner, Inner) + assert [v.get_name() for v in rebuilt.__extract_to_ir_values__()] == [v.get_name() for v in flat] + + build( + Outer( + head=Int32(1), + inner=Inner(x=Int32(2), y=Int32(3)), + tail=Float32(4.0), + ) + ) + + +def test_align_wrapper_overrides_natural_alignment(): + Aligned = fx.Align[Int32, 16] + assert Aligned.dtype is Int32 + assert Aligned.align == 16 + + @fx.struct + class WithAligned: + a: Int32 + b: fx.Align[Int32, 16] + + assert dsl_align_of(WithAligned) == 16 + + +@pytest.mark.parametrize( + "align,exc,match", + [ + (3, ValueError, "power of two"), + (6, ValueError, "power of two"), + (5, ValueError, "power of two"), + (0, ValueError, "positive"), + (-1, ValueError, "positive"), + (2, ValueError, "smaller than natural"), + (1.0, TypeError, "must be an int"), + (True, TypeError, "must be an int"), + ], +) +def test_align_validation_rejects_invalid_values(align, exc, match): + with pytest.raises(exc, match=match): + fx.Align[Int32, align] + + +def test_align_requires_two_parameters(): + with pytest.raises(TypeError, match="Align\\[Type, N\\]"): + fx.Align[Int32] + + +def test_struct_equality_and_hash(): + @fx.struct + class Pair: + a: Int32 + b: Float32 + + p1 = Pair(1, 2.0) + p2 = Pair(1, 2.0) + p3 = Pair(1, 3.0) + + assert p1 == p2 + assert p1 != p3 + + @fx.struct + class OtherPair: + a: Int32 + b: Float32 + + assert Pair(1, 2.0) != OtherPair(1, 2.0) + + +def test_inline_schema_rejects_duplicate_field_names(): + with pytest.raises(ValueError, match="duplicate"): + fx.Struct["a":Int32, "a":Float32] + + with pytest.raises(ValueError, match="must not start with underscore"): + fx.Struct["_1":Int32, Float32] + + +def test_field_name_validation_rejects_underscore_prefix(): + with pytest.raises(ValueError, match="must not start with underscore"): + + @fx.struct + class Bad: + _hidden: Int32 + + with pytest.raises(ValueError, match="must not start with underscore"): + fx.Struct["_x":Int32] + + +def test_field_name_validation_rejects_reserved_names(): + with pytest.raises(ValueError, match="reserved"): + + @fx.struct + class Bad: + peek: Int32 + + with pytest.raises(ValueError, match="reserved"): + + @fx.struct + class Bad2: + poke: Int32 + + with pytest.raises(ValueError, match="reserved"): + fx.Union["replace":Int32, "b":Float32] + + +def test_host_jit_argument_protocol_pointers(): + @fx.struct + class HostPair: + a: Int32 + b: Int32 + + with ir.Context(), ir.Location.unknown(): + p = HostPair(a=Int32(7), b=Int32(11)) + ptrs = p.__get_c_pointers__() + assert len(ptrs) == 2 + assert all(isinstance(ptr, ctypes.c_void_p) for ptr in ptrs) + assert len(get_c_pointers(p)) == 2 + + +# --------------------------------------------------------------------------- +# Array type +# --------------------------------------------------------------------------- + + +def test_array_type_creation(): + A = Array[Float32, 32] + assert issubclass(A, Array._Base) + assert A.dtype is Float32 + assert A.size == 32 + assert A.align == 4 + + A16 = Array[Float32, 32, 16] + assert A16.align == 16 + + +def test_array_type_caching(): + A1 = Array[Int32, 64] + A2 = Array[Int32, 64] + assert A1 is A2 + + A3 = Array[Int32, 64, 8] + assert A3 is not A1 + + +def test_array_storable_protocol(): + A = Array[Float32, 32] + assert dsl_size_of(A) == 4 * 32 + assert dsl_align_of(A) == 4 + + A_aligned = Array[Float32, 32, 16] + assert dsl_align_of(A_aligned) == 16 + + +def test_array_dsl_size_of_free_function(): + A = Array[Int32, 16] + assert dsl_size_of(A) == 4 * 16 + assert dsl_align_of(A) == 4 + + +def test_array_rejects_non_numeric_dtype(): + with pytest.raises(TypeError, match="Numeric subclass"): + Array[object, 32] + + +def test_array_rejects_invalid_size(): + with pytest.raises(TypeError, match="positive integer"): + Array[Float32, 0] + with pytest.raises(TypeError, match="positive integer"): + Array[Float32, -1] + + +def test_array_subbyte_dtype(): + A = Array[Uint8, 64] + assert dsl_size_of(A) == 64 + assert dsl_align_of(A) == 1 + + +# --------------------------------------------------------------------------- +# Storage type +# --------------------------------------------------------------------------- + + +def test_storage_type_creation(): + S = Storage[Int32] + assert S._target_type is Int32 + assert S.__name__ == "Storage[Int32]" + + +def test_storage_schema_field_access_struct(): + @fx.struct + class Pair: + a: Int32 + b: Float32 + + SPair = Storage[Pair] + assert SPair._target_type is Pair + + +def test_storage_schema_field_access_with_array(): + @fx.struct + class SharedStorage: + sharedA: Array[Float32, 32] + sharedB: Array[Float32, 32] + + S = Storage[SharedStorage] + assert S._target_type is SharedStorage + + +# --------------------------------------------------------------------------- +# Numeric Storable via NumericMeta +# --------------------------------------------------------------------------- + + +def test_numeric_storable_protocol(): + assert dsl_size_of(Int32) == 4 + assert dsl_align_of(Int32) == 4 + assert dsl_size_of(Float32) == 4 + assert dsl_align_of(Float32) == 4 + + from flydsl.expr.numeric import Float64, Int64 + + assert dsl_size_of(Int64) == 8 + assert dsl_align_of(Int64) == 8 + assert dsl_size_of(Float64) == 8 + assert dsl_align_of(Float64) == 8 + + +def test_numeric_storable_via_free_functions(): + assert dsl_size_of(Int32) == 4 + assert dsl_align_of(Int32) == 4 + assert dsl_size_of(Float32) == 4 + assert dsl_align_of(Float32) == 4 + + +def test_subbyte_numeric_not_storable(): + from flydsl.expr.numeric import Int4 + + with pytest.raises(TypeError, match="sub-byte|Storable"): + dsl_size_of(Int4) + + +def test_boolean_is_storable(): + from flydsl.expr.numeric import Boolean + + assert dsl_size_of(Boolean) == 1 + assert dsl_align_of(Boolean) == 1 + + +# --------------------------------------------------------------------------- +# Struct layout with new types +# --------------------------------------------------------------------------- + + +def test_struct_with_array_fields_layout(): + @fx.struct + class SharedStorage: + sharedA: Array[Float32, 32] + sharedB: Array[Float32, 32] + + assert dsl_size_of(SharedStorage) == 128 + 128 + assert dsl_align_of(SharedStorage) == 4 + + +def test_struct_with_aligned_array_fields(): + @fx.struct + class AlignedStorage: + sharedA: Array[Float32, 32, 16] + sharedB: Array[Float32, 32, 16] + + assert dsl_align_of(AlignedStorage) == 16 + + +def test_union_storable_layout(): + @fx.union + class Variant: + i: Int32 + f: Float32 + + assert dsl_size_of(Variant) == 4 + assert dsl_align_of(Variant) == 4 + + +def test_struct_containing_union_layout(): + @fx.union + class Variant: + i: Int32 + f: Float32 + + @fx.struct + class Tagged: + tag: Int32 + data: Variant + + assert dsl_size_of(Tagged) == 8 + assert dsl_align_of(Tagged) == 4 + + from flydsl.expr.struct import _storage_layout + + _, _, offsets = _storage_layout(Tagged) + assert offsets == {"tag": 0, "data": 4} + + +def test_union_containing_struct_layout(): + @fx.struct + class Pair: + a: Int32 + b: Int32 + + @fx.union + class UnionWithStruct: + pair: Pair + f: Float32 + + assert dsl_size_of(UnionWithStruct) == 8 + assert dsl_align_of(UnionWithStruct) == 4 + + from flydsl.expr.struct import _storage_layout + + _, _, offsets = _storage_layout(UnionWithStruct) + assert offsets == {"pair": 0, "f": 0} + + +def test_nested_struct_union_with_alignment(): + @fx.union + class DataVariant: + arr: Array[Float32, 16] + single: Int32 + + @fx.struct + class TaggedVariant: + tag: Int32 + data: DataVariant + + assert dsl_size_of(DataVariant) == 64 + assert dsl_align_of(DataVariant) == 4 + assert dsl_size_of(TaggedVariant) == 68 + assert dsl_align_of(TaggedVariant) == 4 + + +def test_union_with_aligned_field_pads_size(): + @fx.struct + class Inner: + x: Int32 + + @fx.union + class Outer: + inner: Inner + aligned: fx.Align[Int32, 8] + + assert dsl_align_of(Outer) == 8 + assert dsl_size_of(Outer) == 8 From f30ed1790f0cc5f0bf35b5bb90e415bad88cb2e6 Mon Sep 17 00:00:00 2001 From: Feng Shijie Date: Mon, 11 May 2026 10:08:39 +0000 Subject: [PATCH 2/4] reimport package --- python/flydsl/expr/__init__.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/python/flydsl/expr/__init__.py b/python/flydsl/expr/__init__.py index ddf93c270..e3e194bcb 100644 --- a/python/flydsl/expr/__init__.py +++ b/python/flydsl/expr/__init__.py @@ -8,7 +8,11 @@ from .derived import * from .struct import * -from . import utils - -from . import arith, vector, gpu, buffer_ops, rocdl, math -from .rocdl import tdm_ops +from . import utils as utils +from . import arith as arith +from . import buffer_ops as buffer_ops +from . import gpu as gpu +from . import math as math +from . import rocdl as rocdl +from . import vector as vector +from .rocdl import tdm_ops as tdm_ops From a9ea352f02a19272cdb32b6dfe5c3a5691ff7bb2 Mon Sep 17 00:00:00 2001 From: Feng Shijie Date: Mon, 11 May 2026 10:54:33 +0000 Subject: [PATCH 3/4] set boolean is not storable --- python/flydsl/expr/numeric.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/flydsl/expr/numeric.py b/python/flydsl/expr/numeric.py index 575e85377..e7accea1e 100644 --- a/python/flydsl/expr/numeric.py +++ b/python/flydsl/expr/numeric.py @@ -71,8 +71,7 @@ def _get_c_pointers(self): return [ptr] inferred_np = np_dtype if np_dtype is not None else _infer_np_dtype(width, signed, name) - - is_storable = width >= 8 or (width == 1 and name == "Boolean") + is_storable = width >= 8 def _dsl_size_of(cls): return 1 if cls.width < 8 else (cls.width + 7) // 8 From ca2f7f3cda698bfe4def433f5d6e6a77fad339d1 Mon Sep 17 00:00:00 2001 From: Feng Shijie Date: Mon, 11 May 2026 11:21:18 +0000 Subject: [PATCH 4/4] remove boolean test --- tests/unit/test_struct.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/tests/unit/test_struct.py b/tests/unit/test_struct.py index 49462a18a..ec1d039ea 100644 --- a/tests/unit/test_struct.py +++ b/tests/unit/test_struct.py @@ -421,13 +421,6 @@ def test_subbyte_numeric_not_storable(): dsl_size_of(Int4) -def test_boolean_is_storable(): - from flydsl.expr.numeric import Boolean - - assert dsl_size_of(Boolean) == 1 - assert dsl_align_of(Boolean) == 1 - - # --------------------------------------------------------------------------- # Struct layout with new types # ---------------------------------------------------------------------------