diff --git a/src/finchlite/autoschedule/compiler.py b/src/finchlite/autoschedule/compiler.py index 07632a0a..3ed14614 100644 --- a/src/finchlite/autoschedule/compiler.py +++ b/src/finchlite/autoschedule/compiler.py @@ -11,7 +11,6 @@ from .. import finch_logic as lgc from .. import finch_notation as ntn from ..algebra import FinchOperator, FType, ffuncs, ftypes -from ..compile.lower import make_extent from ..finch_assembly import AssemblyLibrary from ..finch_logic import LogicLoader, StatsFactory, TensorStats, compute_shape_vars from ..finch_notation import NotationInterpreter @@ -101,6 +100,8 @@ def _lower_query_of_reorder( arg: lgc.Table, reorder_idxs: tuple[lgc.Field, ...], ): + from ..compile.lower import make_extent + arg_dims = arg.dimmap(merge_shapes, self.shapes) shapes_map = dict(zip(arg.idxs, arg_dims, strict=True)) shapes = { @@ -189,6 +190,8 @@ def _lower_query_of_aggregate( agg_arg: lgc.Reorder, agg_idxs: tuple[lgc.Field, ...], ): + from ..compile.lower import make_extent + # Build a dict mapping fields to their shapes arg_dims = agg_arg.dimmap(merge_shapes, self.shapes) shapes_map = dict(zip(agg_arg.idxs, arg_dims, strict=True)) diff --git a/src/finchlite/autoschedule/formatter.py b/src/finchlite/autoschedule/formatter.py index f0cccf19..c19c4565 100644 --- a/src/finchlite/autoschedule/formatter.py +++ b/src/finchlite/autoschedule/formatter.py @@ -7,7 +7,6 @@ from .. import finch_logic as lgc from ..algebra import FType, TensorFType, TupleFType, ftype from ..codegen import NumpyBufferFType -from ..compile import BufferizedNDArrayFType from ..finch_assembly import AssemblyLibrary from ..finch_logic import LogicLoader, MockLogicLoader, StatsFactory, TensorStats from ..util.logging import LOG_LOGIC_POST_OPT @@ -95,6 +94,8 @@ def __init__( super().__init__(loader) def get_output_tns_ftype(self, fill_value: Any, shape_type: tuple[FType, ...]): + from ..compile.bufferized_ndarray import BufferizedNDArrayFType + return BufferizedNDArrayFType( buffer_type=NumpyBufferFType(ftype(fill_value)), ndim=len(shape_type), diff --git a/src/finchlite/autoschedule/tensor_stats/dc_stats.py b/src/finchlite/autoschedule/tensor_stats/dc_stats.py index 48c30147..b95f3d74 100644 --- a/src/finchlite/autoschedule/tensor_stats/dc_stats.py +++ b/src/finchlite/autoschedule/tensor_stats/dc_stats.py @@ -13,7 +13,6 @@ from ... import finch_notation as ntn from ...algebra import Tensor, ffuncs, ftype, int64 from ...algebra.algebra import FinchOperator -from ...compile import BufferizedNDArray, make_extent from .numeric_stats import NumericStats from .tensor_def import TensorDef from .tensor_stats import BaseTensorStatsFactory @@ -37,9 +36,6 @@ class DC: value: float -_INT64_VECTOR_FTYPE = BufferizedNDArray.from_numpy(np.zeros(1, dtype=np.int64)).ftype - - def _int_tuple_ftype(size: int): return ftype(tuple(np.int64(0) for _ in range(size))) @@ -152,17 +148,22 @@ def _structure_to_dcs(self, arr: Tensor, fields: Iterable[Field]) -> set[DC]: # For each field i, we compute DC({}, {i}) and DC({i}, {*fields}). # Additionally, we compute the nnz for the full tensor DC({}, {*fields}). def _array_to_dcs(self, arr: Any, fields: Iterable[Field]) -> set[DC]: + from ...compile import BufferizedNDArray, make_extent + + int64_vector_ftype = BufferizedNDArray.from_numpy( + np.zeros(1, dtype=np.int64) + ).ftype fields = list(fields) ndims = len(fields) dim_loop_variables = [ntn.Variable(f"{fields[i]}", int64) for i in range(ndims)] dim_array_variables = [ - ntn.Variable(f"x_{fields[i]}", _INT64_VECTOR_FTYPE) for i in range(ndims) + ntn.Variable(f"x_{fields[i]}", int64_vector_ftype) for i in range(ndims) ] dim_size_variables = [ ntn.Variable(f"n_{fields[i]}", int64) for i in range(ndims) ] dim_array_slots = [ - ntn.Slot(f"x_{fields[i]}_", _INT64_VECTOR_FTYPE) for i in range(ndims) + ntn.Slot(f"x_{fields[i]}_", int64_vector_ftype) for i in range(ndims) ] dim_proj_variables = [ ntn.Variable(f"proj_{fields[i]}", int64) for i in range(ndims) diff --git a/src/finchlite/compile/__init__.py b/src/finchlite/compile/__init__.py index b6c453bc..c306a565 100644 --- a/src/finchlite/compile/__init__.py +++ b/src/finchlite/compile/__init__.py @@ -1,4 +1,3 @@ -from .bufferized_ndarray import BufferizedNDArray, BufferizedNDArrayFType from .lower import ( AssemblyContext, Extent, @@ -9,6 +8,9 @@ make_extent, ) +# isort: split +from .bufferized_ndarray import BufferizedNDArray, BufferizedNDArrayFType + __all__ = [ "AssemblyContext", "BufferizedNDArray", diff --git a/src/finchlite/compile/bufferized_ndarray.py b/src/finchlite/compile/bufferized_ndarray.py index 9a8c0081..ba0f454e 100644 --- a/src/finchlite/compile/bufferized_ndarray.py +++ b/src/finchlite/compile/bufferized_ndarray.py @@ -8,6 +8,7 @@ from ..algebra import FType, ImmutableStructFType, Tensor, TupleFType, ffuncs, ftype from ..codegen import NumpyBuffer, NumpyBufferFType from ..codegen.numba_codegen import to_numpy_type +from ..interface.eager import EagerTensor from . import looplets as lplt from .lower import AssemblyContext, FinchTensorFType @@ -16,7 +17,7 @@ def _get_default_strides(size: tuple[int, ...]) -> tuple[int, ...]: return tuple(np.cumprod((1,) + size[::-1]).astype(int))[-2::-1] -class BufferizedNDArray(Tensor): +class BufferizedNDArray(EagerTensor): def __init__( self, val: NumpyBuffer, diff --git a/src/finchlite/interface/eager.py b/src/finchlite/interface/eager.py index 70b6cec7..cca5b107 100644 --- a/src/finchlite/interface/eager.py +++ b/src/finchlite/interface/eager.py @@ -6,8 +6,6 @@ from typing import Any from ..algebra import FinchOperator -from . import lazy -from .fuse import compute from .overrides import OverrideTensor @@ -231,6 +229,10 @@ def __ne__(self, other): return not_equal(self, other) +from . import lazy # noqa: E402 +from .fuse import compute # noqa: E402 + + def full( shape: int | tuple[int, ...], fill_value: bool | complex, @@ -418,6 +420,7 @@ def matmul(x1, x2, /): """ if isinstance(x1, lazy.LazyTensor) or isinstance(x2, lazy.LazyTensor): return lazy.matmul(x1, x2) + c = lazy.matmul(x1, x2) return compute(c) diff --git a/src/finchlite/interface/lazy.py b/src/finchlite/interface/lazy.py index 918eebee..1c620d89 100644 --- a/src/finchlite/interface/lazy.py +++ b/src/finchlite/interface/lazy.py @@ -28,7 +28,6 @@ ) from ..algebra.ftypes import FDType from ..autoschedule.tensor_stats import StatsInterpreter -from ..compile import BufferizedNDArray from ..finch_logic import ( Aggregate, Alias, @@ -463,6 +462,8 @@ def asarray(arg: Any, format: TensorFType | None = None) -> Any: from finchlite.interface.scalar import Scalar if isinstance(arg, np.ndarray): + from ..compile import BufferizedNDArray + return BufferizedNDArray.from_numpy(arg) if np.isscalar(arg) or arg is None: return Scalar(arg) diff --git a/tests/test_interface.py b/tests/test_interface.py index aa7122ae..e361a6c8 100644 --- a/tests/test_interface.py +++ b/tests/test_interface.py @@ -650,6 +650,31 @@ def test_matmul(a, b, a_wrap, b_wrap): finch_assert_allclose(result_with_np, expected) +@pytest.mark.parametrize( + "a, b", + [ + ( + np.array([[1.0, 2.0], [3.0, 4.0]]), + np.array([[5.0, 6.0], [7.0, 8.0]]), + ), + ( + np.arange(12, dtype=np.float64).reshape(3, 4), + np.arange(8, dtype=np.float64).reshape(4, 2), + ), + ], +) +def test_matmul_bufferized_ndarray(a, b): + ba = finchlite.asarray(a) + bb = finchlite.asarray(b) + expected = a @ b + + result = finchlite.matmul(ba, bb) + result_with_op = ba @ bb + + finch_assert_allclose(result, expected) + finch_assert_allclose(result_with_op, expected) + + @pytest.mark.usefixtures("interpreter_scheduler") # TODO: remove @pytest.mark.parametrize( "a", diff --git a/tests/test_notation_interpreter.py b/tests/test_notation_interpreter.py index 59b09bfb..7deb40f6 100644 --- a/tests/test_notation_interpreter.py +++ b/tests/test_notation_interpreter.py @@ -220,4 +220,4 @@ def test_count_nonfill_vector(a): mod = ntn.NotationInterpreter()(prgm) cnt = mod.count_nonfill_vector(a) - assert cnt == np.count_nonzero(a) + assert cnt == np.count_nonzero(a.to_numpy())