Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/finchlite/autoschedule/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from .. import finch_logic as lgc
from .. import finch_notation as ntn
from ..algebra import FinchOperator, FType, ffuncs, ftypes
from ..compile.lower import make_extent
from ..finch_assembly import AssemblyLibrary
from ..finch_logic import LogicLoader, StatsFactory, TensorStats, compute_shape_vars
from ..finch_notation import NotationInterpreter
Expand Down Expand Up @@ -101,6 +100,8 @@ def _lower_query_of_reorder(
arg: lgc.Table,
reorder_idxs: tuple[lgc.Field, ...],
):
from ..compile.lower import make_extent

arg_dims = arg.dimmap(merge_shapes, self.shapes)
shapes_map = dict(zip(arg.idxs, arg_dims, strict=True))
shapes = {
Expand Down Expand Up @@ -189,6 +190,8 @@ def _lower_query_of_aggregate(
agg_arg: lgc.Reorder,
agg_idxs: tuple[lgc.Field, ...],
):
from ..compile.lower import make_extent

# Build a dict mapping fields to their shapes
arg_dims = agg_arg.dimmap(merge_shapes, self.shapes)
shapes_map = dict(zip(agg_arg.idxs, arg_dims, strict=True))
Expand Down
3 changes: 2 additions & 1 deletion src/finchlite/autoschedule/formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from .. import finch_logic as lgc
from ..algebra import FType, TensorFType, TupleFType, ftype
from ..codegen import NumpyBufferFType
from ..compile import BufferizedNDArrayFType
from ..finch_assembly import AssemblyLibrary
from ..finch_logic import LogicLoader, MockLogicLoader, StatsFactory, TensorStats
from ..util.logging import LOG_LOGIC_POST_OPT
Expand Down Expand Up @@ -95,6 +94,8 @@ def __init__(
super().__init__(loader)

def get_output_tns_ftype(self, fill_value: Any, shape_type: tuple[FType, ...]):
from ..compile.bufferized_ndarray import BufferizedNDArrayFType

return BufferizedNDArrayFType(
buffer_type=NumpyBufferFType(ftype(fill_value)),
ndim=len(shape_type),
Expand Down
13 changes: 7 additions & 6 deletions src/finchlite/autoschedule/tensor_stats/dc_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from ... import finch_notation as ntn
from ...algebra import Tensor, ffuncs, ftype, int64
from ...algebra.algebra import FinchOperator
from ...compile import BufferizedNDArray, make_extent
from .numeric_stats import NumericStats
from .tensor_def import TensorDef
from .tensor_stats import BaseTensorStatsFactory
Expand All @@ -37,9 +36,6 @@ class DC:
value: float


_INT64_VECTOR_FTYPE = BufferizedNDArray.from_numpy(np.zeros(1, dtype=np.int64)).ftype


def _int_tuple_ftype(size: int):
return ftype(tuple(np.int64(0) for _ in range(size)))

Expand Down Expand Up @@ -152,17 +148,22 @@ def _structure_to_dcs(self, arr: Tensor, fields: Iterable[Field]) -> set[DC]:
# For each field i, we compute DC({}, {i}) and DC({i}, {*fields}).
# Additionally, we compute the nnz for the full tensor DC({}, {*fields}).
def _array_to_dcs(self, arr: Any, fields: Iterable[Field]) -> set[DC]:
from ...compile import BufferizedNDArray, make_extent

int64_vector_ftype = BufferizedNDArray.from_numpy(
np.zeros(1, dtype=np.int64)
).ftype
fields = list(fields)
ndims = len(fields)
dim_loop_variables = [ntn.Variable(f"{fields[i]}", int64) for i in range(ndims)]
dim_array_variables = [
ntn.Variable(f"x_{fields[i]}", _INT64_VECTOR_FTYPE) for i in range(ndims)
ntn.Variable(f"x_{fields[i]}", int64_vector_ftype) for i in range(ndims)
]
dim_size_variables = [
ntn.Variable(f"n_{fields[i]}", int64) for i in range(ndims)
]
dim_array_slots = [
ntn.Slot(f"x_{fields[i]}_", _INT64_VECTOR_FTYPE) for i in range(ndims)
ntn.Slot(f"x_{fields[i]}_", int64_vector_ftype) for i in range(ndims)
]
dim_proj_variables = [
ntn.Variable(f"proj_{fields[i]}", int64) for i in range(ndims)
Expand Down
4 changes: 3 additions & 1 deletion src/finchlite/compile/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from .bufferized_ndarray import BufferizedNDArray, BufferizedNDArrayFType
from .lower import (
AssemblyContext,
Extent,
Expand All @@ -9,6 +8,9 @@
make_extent,
)

# isort: split
from .bufferized_ndarray import BufferizedNDArray, BufferizedNDArrayFType

__all__ = [
"AssemblyContext",
"BufferizedNDArray",
Expand Down
3 changes: 2 additions & 1 deletion src/finchlite/compile/bufferized_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from ..algebra import FType, ImmutableStructFType, Tensor, TupleFType, ffuncs, ftype
from ..codegen import NumpyBuffer, NumpyBufferFType
from ..codegen.numba_codegen import to_numpy_type
from ..interface.eager import EagerTensor
from . import looplets as lplt
from .lower import AssemblyContext, FinchTensorFType

Expand All @@ -16,7 +17,7 @@ def _get_default_strides(size: tuple[int, ...]) -> tuple[int, ...]:
return tuple(np.cumprod((1,) + size[::-1]).astype(int))[-2::-1]


class BufferizedNDArray(Tensor):
class BufferizedNDArray(EagerTensor):
def __init__(
self,
val: NumpyBuffer,
Expand Down
7 changes: 5 additions & 2 deletions src/finchlite/interface/eager.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
from typing import Any

from ..algebra import FinchOperator
from . import lazy
from .fuse import compute
from .overrides import OverrideTensor


Expand Down Expand Up @@ -231,6 +229,10 @@ def __ne__(self, other):
return not_equal(self, other)


from . import lazy # noqa: E402
from .fuse import compute # noqa: E402


def full(
shape: int | tuple[int, ...],
fill_value: bool | complex,
Expand Down Expand Up @@ -418,6 +420,7 @@ def matmul(x1, x2, /):
"""
if isinstance(x1, lazy.LazyTensor) or isinstance(x2, lazy.LazyTensor):
return lazy.matmul(x1, x2)

c = lazy.matmul(x1, x2)
return compute(c)

Expand Down
3 changes: 2 additions & 1 deletion src/finchlite/interface/lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
)
from ..algebra.ftypes import FDType
from ..autoschedule.tensor_stats import StatsInterpreter
from ..compile import BufferizedNDArray
from ..finch_logic import (
Aggregate,
Alias,
Expand Down Expand Up @@ -463,6 +462,8 @@ def asarray(arg: Any, format: TensorFType | None = None) -> Any:
from finchlite.interface.scalar import Scalar

if isinstance(arg, np.ndarray):
from ..compile import BufferizedNDArray

return BufferizedNDArray.from_numpy(arg)
if np.isscalar(arg) or arg is None:
return Scalar(arg)
Expand Down
25 changes: 25 additions & 0 deletions tests/test_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,6 +650,31 @@ def test_matmul(a, b, a_wrap, b_wrap):
finch_assert_allclose(result_with_np, expected)


@pytest.mark.parametrize(
"a, b",
[
(
np.array([[1.0, 2.0], [3.0, 4.0]]),
np.array([[5.0, 6.0], [7.0, 8.0]]),
),
(
np.arange(12, dtype=np.float64).reshape(3, 4),
np.arange(8, dtype=np.float64).reshape(4, 2),
),
],
)
def test_matmul_bufferized_ndarray(a, b):
ba = finchlite.asarray(a)
bb = finchlite.asarray(b)
expected = a @ b

result = finchlite.matmul(ba, bb)
result_with_op = ba @ bb

finch_assert_allclose(result, expected)
finch_assert_allclose(result_with_op, expected)


@pytest.mark.usefixtures("interpreter_scheduler") # TODO: remove
@pytest.mark.parametrize(
"a",
Expand Down
2 changes: 1 addition & 1 deletion tests/test_notation_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,4 +220,4 @@ def test_count_nonfill_vector(a):

mod = ntn.NotationInterpreter()(prgm)
cnt = mod.count_nonfill_vector(a)
assert cnt == np.count_nonzero(a)
assert cnt == np.count_nonzero(a.to_numpy())
Loading