Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 39 additions & 1 deletion src/finchlite/autoschedule/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,31 @@
with_default_scheduler,
)
from .executor import LogicExecutor
from .formatter import DefaultLogicFormatter, LogicFormatter
from .formatter import BufferizedNDArrayFormatter, DefaultLogicFormatter, LogicFormatter
from .loop_ordering import DefaultLoopOrderer
from .normalize import LogicNormalizer, normalize_names
from .optimize import DefaultLogicOptimizer
from .rep_operations import (
aggregate_rep,
data_rep,
dropdims_rep,
eltype,
expanddims_rep,
fill_value,
map_rep,
permutedims_rep,
)
from .representation import (
DenseData,
ElementData,
ExtrudeData,
HollowData,
RepeatData,
SparseData,
)
from .stages import LogicEinsumLowerer, LogicNotationLowerer
from .standardize import LogicStandardizer
from .suitable_rep import SmartLogicFormatter, SuitableRep, toposort

__all__ = [
"COMPILE_NUMBA",
Expand All @@ -44,10 +64,15 @@
"OPTIMIZE_LOGIC",
"Aggregate",
"Alias",
"BufferizedNDArrayFormatter",
"DefaultLogicFormatter",
"DefaultLogicOptimizer",
"DefaultLoopOrderer",
"DenseData",
"ElementData",
"ExtrudeData",
"Field",
"HollowData",
"Literal",
"LogicCapture",
"LogicCompiler",
Expand All @@ -67,10 +92,23 @@
"Query",
"Relabel",
"Reorder",
"RepeatData",
"SmartLogicFormatter",
"SparseData",
"SuitableRep",
"Table",
"Value",
"aggregate_rep",
"data_rep",
"dropdims_rep",
"eltype",
"expanddims_rep",
"fill_value",
"get_default_scheduler",
"map_rep",
"normalize_names",
"permutedims_rep",
"set_default_scheduler",
"toposort",
"with_default_scheduler",
]
4 changes: 2 additions & 2 deletions src/finchlite/autoschedule/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from finchlite.finch_logic.nodes import TableValue
from finchlite.symbolic import Namespace, PostWalk, Rewrite, UnvalidatedForm

from .formatter import DefaultLogicFormatter
from .formatter import BufferizedNDArrayFormatter


def extract_tensors(
Expand Down Expand Up @@ -52,7 +52,7 @@ def __init__(
cache: bool = False,
):
if ctx is None:
ctx = DefaultLogicFormatter()
ctx = BufferizedNDArrayFormatter()
if stats_factory is None:
stats_factory = DenseStatsFactory()
self.ctx: LogicLoader = ctx
Expand Down
4 changes: 4 additions & 0 deletions src/finchlite/autoschedule/formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,7 @@ def get_output_tns_ftype(self, fill_value: Any, shape_type: tuple[FType, ...]):
ndim=len(shape_type),
dimension_type=TupleFType.from_tuple(shape_type),
)


class BufferizedNDArrayFormatter(DefaultLogicFormatter):
pass
Loading
Loading