Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
86cfce6
add placeholdeer code for spatz platform
Bumblebee00 Mar 2, 2026
0774556
code generation with generic c code
Bumblebee00 Mar 2, 2026
a8fd323
modified spatz c code to use proper memory allocation and copying fun…
Bumblebee00 Mar 12, 2026
bbf2751
tmp commit to send to badie103
Bumblebee00 Mar 16, 2026
b1a5868
modified Makefile and cmakefiles to build and use spatz runtime. ugly…
Mar 23, 2026
0acf6e7
vsim simulator runnable by deeployRunner
Mar 23, 2026
8e4e8c3
Removed reverence to conda environment and added commands to create v…
Mar 23, 2026
9cab786
typo
Mar 27, 2026
e3c46c4
double gvsoc temporaney configuration
Mar 27, 2026
651d4cf
reunited gvsoc build for spatz and other platforms
Mar 27, 2026
908261d
forgot comment
Mar 27, 2026
edc461f
forgot comment
Mar 27, 2026
cd13ce4
added topk generic binding (hardcoded k=10)
Mar 29, 2026
f72e147
added matmul softmax and topk generic bindings to spatz
Mar 29, 2026
93db81b
switched default simulator to gvsoc bc is faster
Mar 29, 2026
eee910c
added topk test network
Mar 29, 2026
1523825
added sparse attention test network
Mar 29, 2026
8b99e3a
added topk binding to generic platform
Mar 29, 2026
b541e17
improved generic gather node to support more than one index
Mar 29, 2026
e35433d
added gather binding for spatz
Mar 29, 2026
7eb2e35
now for any k inot the graph
Apr 1, 2026
8f99e67
added big attention
Apr 13, 2026
bfcdda6
first draft of tiling (not working)
Apr 13, 2026
bdcdd70
minimalloc fix
Apr 13, 2026
64e3874
fixed makefile adding missing things
Apr 13, 2026
23017e8
added yaml for my conda environment
Apr 14, 2026
8b46fd5
modified bindings to use snrt_dma_wait_all function
Apr 14, 2026
3187d64
added different dimensions of FP32/MatMul
Apr 16, 2026
1624642
fixed simulation staling issue
Apr 16, 2026
08ddf29
fixed memory levels
Apr 16, 2026
27f9558
added cycles indication
Apr 21, 2026
e77936d
added fp32 matmul kernel that uses vector instructions
Apr 21, 2026
0798c50
went back to not using snrt_l3alloc because its not working
Apr 21, 2026
8e552d7
[format] alignment fix
Apr 23, 2026
6a412c0
[tiling] enable tiling extension on Spatz
Apr 23, 2026
81dbb89
[template] Add proper allocation template and fix DMA template
Apr 23, 2026
eee94d5
[sw] use memcpy instead of DMA for DRAM buffer init
Apr 23, 2026
6e3a0f0
changed commit hash to include new version of spatz that has snrt_l3a…
Apr 23, 2026
8f073d1
removed redundant memcpy
Apr 23, 2026
ae0ef79
added gather only test
Apr 23, 2026
020df36
improved gather template
Apr 23, 2026
ea9e073
added nice dimensions of Matmul
Apr 27, 2026
693d1cb
improved main
Apr 27, 2026
3112c0f
spatz matmulfunction now splits work between cores
Apr 27, 2026
f209ae4
fixed name of input in graph
Apr 27, 2026
a8f547b
removed unnecessary print
Apr 27, 2026
b083883
gather tiling
Apr 27, 2026
213e832
topk tiling
Apr 27, 2026
3aca310
updated tiling to work with constant buffers
Apr 27, 2026
fae9d65
updated tiling to work with nodes with >1 output
Apr 27, 2026
9c4f9e7
fixed topk template
Apr 28, 2026
3e225aa
added softmax tiled (not working)
Apr 30, 2026
0ad2cfc
added another softmax test
Apr 30, 2026
5085400
added softmax function with custom exp and inv functions
May 1, 2026
aad67f8
detect quiet nan float output
May 8, 2026
97b9762
use non vector for when one dim is one
May 8, 2026
d73f9d8
fixed matmul kernel
May 13, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 32 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ if(TOOLCHAIN STREQUAL GCC)
set(CMAKE_INTERPROCEDURAL_OPTIMIZATION TRUE)
endif()

set(platform MemPool CACHE STRING "Platform (MemPool, SoftHier, QEMU, Siracusa, Siracusa_w_neureka, PULP-Open, GAP9, Generic, Snitch)")
set_property(CACHE platform PROPERTY STRINGS MemPool SoftHier QEMU Siracusa Siracusa_w_neureka PULP-Open GAP9 Generic Snitch)
set(platform MemPool CACHE STRING "Platform (MemPool, SoftHier, QEMU, Siracusa, Siracusa_w_neureka, PULP-Open, GAP9, Generic, Snitch, Spatz)")
set_property(CACHE platform PROPERTY STRINGS MemPool SoftHier QEMU Siracusa Siracusa_w_neureka PULP-Open GAP9 Generic Snitch Spatz)

if(platform STREQUAL MemPool)
message(STATUS "Building for platform 'MemPool'")
Expand All @@ -46,6 +46,8 @@ elseif(platform STREQUAL SoftHier)
message(STATUS "Building for platform 'SoftHier'")
elseif(platform STREQUAL Chimera)
message(STATUS "Building for platform 'Chimera'")
elseif(platform STREQUAL Spatz)
message(STATUS "Building for platform 'Spatz'")
else()
message(FATAL_ERROR "Invalid platform '${platform}' specified!")
endif()
Expand Down Expand Up @@ -299,5 +301,33 @@ if(platform STREQUAL Chimera)

endif()

if(platform STREQUAL Spatz)

if(NOT DEFINED ENV{SPATZ_HOME})
message(FATAL_ERROR "Environment variable SPATZ_HOME not set.")
endif()

set(SPATZ_HOME $ENV{SPATZ_HOME})

set(CMAKE_TOOLCHAIN_FILE ${CMAKE_CURRENT_LIST_DIR}/cmake/spatz/toolchain_llvm.cmake)

include(${CMAKE_CURRENT_LIST_DIR}/cmake/spatz/spatz.cmake)

project(deeploy LANGUAGES C ASM)

message(STATUS "============================= ${platform} Configuration ============================")
message(STATUS "[cMake ] ISA = " ${ISA})
message(STATUS "================================================================================")
message(STATUS "")

add_subdirectory(TargetLibraries/Generic)
add_subdirectory(TargetLibraries/Spatz)
target_include_directories(deeployspatz PUBLIC TargetLibraries/Generic/inc)

add_subdirectory(DeeployTest)
target_link_libraries(deeploylib INTERFACE deeploybasic deeployspatz)

endif()


print_simulation_config()
15 changes: 13 additions & 2 deletions Deeploy/Targets/Generic/Bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@
GatherTemplate, GemmTemplate, IntegerDivTemplate, ITAMaxTemplate, ITAPartialMaxTemplate, MatMulTemplate, \
MaxPoolTemplate, MulTemplate, PadTemplate, QuantTemplate, ReduceMeanTemplate, ReduceSumTemplate, \
RequantShiftTemplate, ReshapeTemplate, RQIntegerDivTemplate, RQSiGELUTemplate, SliceTemplate, TransposeTemplate, \
iGELUTemplate, iLayernormTemplate, iRMSNormTemplate, iSoftmaxTemplate
iGELUTemplate, iLayernormTemplate, iRMSNormTemplate, iSoftmaxTemplate, TopKTemplate
from Deeploy.Targets.Generic.TypeCheckers import AddChecker, BatchNormChecker, ConcatChecker, ConvChecker, \
DebugPrintChecker, DequantChecker, DivChecker, DummyChecker, GatherChecker, GELUChecker, GEMMChecker, \
LayerNormChecker, MatMulChecker, MaxPoolChecker, MulChecker, PadChecker, QuantChecker, ReduceMeanChecker, \
ReduceSumChecker, ReluChecker, RequantShiftChecker, ReshapeChecker, RQIntegerDivChecker, SliceChecker, \
SoftmaxChecker, TransposeChecker
SoftmaxChecker, TransposeChecker, TopKChecker

BasicTransformer = CodeTransformation([ArgumentStructGeneration(), MemoryManagementGeneration(), FutureGeneration()])

Expand Down Expand Up @@ -327,3 +327,14 @@
ConvTransposeTemplate.referenceTemplate,
BasicTransformer) for type in FloatDataTypes
]

BasicTopKBindings = [
NodeBinding(
TopKChecker(
[PointerClass(float32_t), PointerClass(int8_t)], # inputs
[PointerClass(float32_t), PointerClass(int8_t)] # outputs
),
TopKTemplate.referenceTemplate,
BasicTransformer,
)
]
12 changes: 12 additions & 0 deletions Deeploy/Targets/Generic/Layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,3 +709,15 @@ def computeOps(self):
numPx = opRep['dim_im_out_x']

return numPx * opsPerPx


class TopKLayer(ONNXLayer):

def __init__(self, maps: List[NodeMapper]):
super().__init__(maps)

# def computeOps(self):
# ???
#
# def computeShapes(self):
# ???
42 changes: 37 additions & 5 deletions Deeploy/Targets/Generic/Parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -982,7 +982,7 @@ def parseNode(self, node: gs.Node) -> (bool):
return False

indices_shape = node.inputs[1].shape
assert np.prod(indices_shape) == 1, f"Only indices of size 1 supported. Got indices of shape {indices_shape}"
self.operatorRepresentation['num_indices'] = int(np.prod(indices_shape))

self.operatorRepresentation['axis'] = node.attrs['axis'] if 'axis' in node.attrs else 0
return True
Expand All @@ -1002,10 +1002,17 @@ def parseNodeCtxt(self,

axis = self.operatorRepresentation['axis']
shape = ctxt.lookup(node.inputs[0].name).shape
self.operatorRepresentation['batch'] = np.prod(shape[:axis])
self.operatorRepresentation['batch_length'] = np.prod(shape[axis:])
self.operatorRepresentation['axis_length'] = np.prod(shape[axis + 1:])
self.operatorRepresentation['index'] = int(node.inputs[1].values.item())
self.operatorRepresentation['batch'] = int(np.prod(shape[:axis])) if axis > 0 else 1
self.operatorRepresentation['batch_length'] = int(np.prod(shape[axis:]))
self.operatorRepresentation['axis_length'] = int(np.prod(shape[axis + 1:])) if axis + 1 < len(shape) else 1

if self.operatorRepresentation['num_indices'] == 1:
try:
self.operatorRepresentation['index'] = int(node.inputs[1].values.item())
except Exception:
self.operatorRepresentation['index'] = f"{self.operatorRepresentation['indices']}[0]"
else:
self.operatorRepresentation['index'] = 0 # in this case is not used but is needed for mako template

return ctxt, True

Expand Down Expand Up @@ -2886,3 +2893,28 @@ def parseNodeCtxt(self,
self.operatorRepresentation['size'] = int(np.prod(data_in.shape))

return ctxt, True

# TopKParser: selects the largest k elements from a vector
class TopKParser(NodeParser):
def __init__(self):
super().__init__()

def parseNode(self, node: gs.Node) -> bool:
return len(node.inputs)==2 and len(node.outputs)==2 and node.op=='TopK'

def parseNodeCtxt(self,
ctxt: NetworkContext,
node: gs.Node,
channels_first: bool = True) -> Tuple[NetworkContext, bool]:
data_in = ctxt.lookup(node.inputs[0].name)
k_in = ctxt.lookup(node.inputs[1].name)
values_out = ctxt.lookup(node.outputs[0].name)
indices_out = ctxt.lookup(node.outputs[1].name)

self.operatorRepresentation['data_in'] = data_in.name
self.operatorRepresentation['data_in_size'] = int(np.prod(data_in.shape))
self.operatorRepresentation['k_value'] = int(k_in.values[0])
self.operatorRepresentation['values_out'] = values_out.name
self.operatorRepresentation['indices_out'] = indices_out.name

return ctxt, True
8 changes: 5 additions & 3 deletions Deeploy/Targets/Generic/Platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,19 @@
BasicPad1DBindings, BasicPad2DBindings, BasicPowBindings, BasicQuantBindings, BasicReduceMeanBindings, \
BasicReduceSumBindings, BasicReluBinding, BasicReshapeBindings, BasicRQIntegerDivBinding, BasicRQSBindings, \
BasicRQSGELUBinding, BasicSliceBindings, BasicSoftmaxBindings, BasicSqrtBindings, BasicTransposeBindings, \
DummyBinding
DummyBinding, BasicTopKBindings
from Deeploy.Targets.Generic.Layers import AddLayer, BatchNormalizationLayer, ConcatLayer, ConvLayer, \
ConvTransposeLayer, DebugPrintLayer, DequantLayer, DivLayer, GatherLayer, GELULayer, GEMMLayer, ITAMaxLayer, \
LayerNormLayer, MatMulLayer, MaxPoolLayer, MulLayer, PadLayer, PowLayer, QuantLayer, ReduceMeanLayer, \
ReduceSumLayer, ReluLayer, RequantShiftLayer, ReshapeLayer, RQIntegerDivLayer, RQSiGELULayer, SliceLayer, \
SoftmaxLayer, SqrtLayer, TransposeLayer
SoftmaxLayer, SqrtLayer, TransposeLayer, TopKLayer
from Deeploy.Targets.Generic.Parsers import AddParser, BatchNormParser, ConcatParser, ConvTranspose1DParser, \
DebugParser, DequantParser, DivParser, DummyParser, FlattenParser, GatherParser, GELUParser, GenericConv1DParser, \
GenericConv2DParser, GenericDWConv1DParser, GenericDWConv2DParser, GenericGEMMParser, GenericMaxPool2DParser, \
IntegerDivParser, ITAMaxParser, ITAPartialMaxParser, LayerNormParser, MatMulParser, MaxPool1DParser, MulParser, \
Pad1DParser, Pad2DParser, PowParser, QuantParser, ReduceMeanParser, ReduceSumParser, ReluParser, \
RequantShiftParser, ReshapeParser, RQIntegerDivParser, RQSiGELUParser, SliceParser, SoftmaxParser, SqrtParser, \
TransposeParser, UnsqueezeParser, iLayerNormParser, iSoftmaxParser
TransposeParser, TopKParser, UnsqueezeParser, iLayerNormParser, iSoftmaxParser
from Deeploy.Targets.Generic.Templates import AllocateTemplate, FreeTemplate
from Deeploy.Targets.Generic.TopologyOptimizationPasses.Passes import DequantPatternPass, ExtractPaddingFromConvPass, \
ExtractPaddingFromPoolPass, MatMulAddMergePass, MergeConstAddAndRequantPass, QuantPatternPass, \
Expand Down Expand Up @@ -67,6 +67,7 @@
SoftmaxMapper = NodeMapper(SoftmaxParser(), BasicSoftmaxBindings)
iSoftmaxMapper = NodeMapper(iSoftmaxParser(), BasicSoftmaxBindings)
TransposeMapper = NodeMapper(TransposeParser(), BasicTransposeBindings)
TopKMapper = NodeMapper(TopKParser(), BasicTopKBindings)
UnsqueezeMapper = NodeMapper(UnsqueezeParser(), BasicReshapeBindings)
QuantMapper = NodeMapper(QuantParser(), BasicQuantBindings)
DequantMapper = NodeMapper(DequantParser(), BasicDequantBindings)
Expand Down Expand Up @@ -113,6 +114,7 @@
'RQIntegerDiv': RQIntegerDivLayer([RQIntegerDivMapper]),
'Squeeze': ReshapeLayer([UnsqueezeMapper]),
'Transpose': TransposeLayer([TransposeMapper]),
'TopK': TopKLayer([TopKMapper]),
'Unsqueeze': ReshapeLayer([UnsqueezeMapper]),
'Slice': SliceLayer([SliceMapper]),
'Quant': QuantLayer([QuantMapper]),
Expand Down
10 changes: 10 additions & 0 deletions Deeploy/Targets/Generic/Templates/GatherTemplate.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,18 @@
width = int(data_in_type.referencedType.typeWidth/8)
%>
BEGIN_SINGLE_CORE
% if num_indices == 1:
for (uint32_t i=0; i<${batch}; ++i) {
memcpy(${data_out} + i * ${axis_length}, ${data_in} + i * ${batch_length} + ${index} * ${axis_length}, ${axis_length} * ${width});
}
% else:
for (uint32_t i=0; i<${batch}; ++i) {
for (uint32_t j=0; j<${num_indices}; ++j) {
memcpy(${data_out} + i * (${num_indices} * ${axis_length}) + j * ${axis_length},
${data_in} + i * ${batch_length} + ${indices}[j] * ${axis_length},
${axis_length} * ${width});
}
}
% endif
END_SINGLE_CORE
""")
40 changes: 40 additions & 0 deletions Deeploy/Targets/Generic/Templates/TopKTemplate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from typing import Dict, List, Tuple

from Deeploy.DeeployTypes import NetworkContext, NodeTemplate, OperatorRepresentation


referenceTemplate = NodeTemplate("""
// TopK (Name: ${nodeName}, Op: ${nodeOp})
BEGIN_SINGLE_CORE
// Find the top ${k_value} values and their indices
// Assumes 1D input for simplicity
typedef struct {
${data_in_type.referencedType.typeName} value;
uint32_t index;
} topk_pair_t;

topk_pair_t pairs[${data_in_size}];
for (uint32_t i = 0; i < ${data_in_size}; ++i) {
pairs[i].value = ((${data_in_type.referencedType.typeName}*)${data_in})[i];
pairs[i].index = i;
}
// Simple selection sort for top-k
for (uint32_t i = 0; i < ${k_value}; ++i) {
uint32_t max_idx = i;
for (uint32_t j = i + 1; j < ${data_in_size}; ++j) {
if (pairs[j].value > pairs[max_idx].value) {
max_idx = j;
}
}
// Swap
if (max_idx != i) {
topk_pair_t tmp = pairs[i];
pairs[i] = pairs[max_idx];
pairs[max_idx] = tmp;
}
// Write output
((${values_out_type.referencedType.typeName}*)${values_out})[i] = pairs[i].value;
((${indices_out_type.referencedType.typeName}*)${indices_out})[i] = pairs[i].index;
}
END_SINGLE_CORE
""")
14 changes: 14 additions & 0 deletions Deeploy/Targets/Generic/TypeCheckers.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,3 +610,17 @@ def _inferNumLevels(self, inputs: List[VariableBuffer],
def _inferSignedness(self, inputs: List[VariableBuffer],
operatorRepresentation: OperatorRepresentation) -> List[bool]:
return [True]

# TopKChecker: infers types for both values and indices outputs of TopK operation
class TopKChecker(SignPropTypeChecker):
def __init__(self, input_types: Sequence[Type[Pointer]], output_types: Sequence[Type[Pointer]]):
super().__init__(input_types, output_types)

def _inferNumLevels(self, inputs: List[VariableBuffer], operatorRepresentation: OperatorRepresentation) -> List[int]:
# Output 0: values (same as input), Output 1: indices (integer, usually not quantized)
# We assume indices output is not quantized (set to 0 or 1)
return [inputs[0].nLevels, 1]

def _inferSignedness(self, inputs: List[VariableBuffer], operatorRepresentation: OperatorRepresentation) -> List[bool]:
# Output 0: values (same signedness as input), Output 1: indices (unsigned)
return [inputs[0]._signed, False]
117 changes: 117 additions & 0 deletions Deeploy/Targets/Spatz/Bindings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
from functools import partial

from Deeploy.DeeployTypes import CodeTransformation, NodeBinding
from Deeploy.CommonExtensions.CodeTransformationPasses.MemoryAllocation import ArgumentStructGeneration, \
MemoryManagementGeneration
from Deeploy.FutureExtension.CodeTransformationPasses.FutureCodeTransformation import FutureGeneration
from Deeploy.AbstractDataTypes import PointerClass
from Deeploy.CommonExtensions.DataTypes import IntegerDataTypes, SignedIntegerDataTypes, float32_t, int8_t, int32_t
from Deeploy.Targets.Generic.TypeCheckers import GatherChecker, MatMulChecker, TopKChecker, SoftmaxChecker

from Deeploy.CommonExtensions.CodeTransformationPasses.Closure import ClosureGeneration, MemoryAwareClosureGeneration
from Deeploy.Targets.Snitch.CodeTransformationPasses.SnitchClusterTiling import SnitchClusterTiling
from Deeploy.Targets.Snitch.CodeTransformationPasses.SnitchCoreFilter import SnitchCoreFilterPass
from Deeploy.Targets.Snitch.CodeTransformationPasses.SnitchClusterSynch import SnitchSynchCoresPass
from Deeploy.Targets.Spatz.DMA.SpatzDma import SpatzDma
from Deeploy.Targets.Spatz.Templates import GatherTemplate, MatMulTemplate as SpatzMatMulTemplate, TopKTemplate, SoftmaxTemplate
from Deeploy.Targets.Generic.Templates import MatMulTemplate, FloatMatMulTemplate
from Deeploy.TilingExtension.CodeTransformationPasses.TilingVariableReplacement import TilingVariableReplacement, \
TilingVariableReplacementUpdate

TilingCallClosure = partial(ClosureGeneration, closureSuffix = "_tiling_closure")
MemoryAwareFunctionCallClosure = partial(MemoryAwareClosureGeneration,
closureSuffix = "_closure",
startRegion = "L3",
endRegion = "L1")

BasicTransformer = CodeTransformation(
[ArgumentStructGeneration(),
MemoryManagementGeneration(),
FutureGeneration()])

TiledTransformer = CodeTransformation([
SnitchCoreFilterPass("compute"),
TilingVariableReplacement("L1"),
TilingCallClosure(writeback = False),
SnitchSynchCoresPass(), # snrt_cluster_hw_barrier()
TilingVariableReplacementUpdate("L1"),
SnitchClusterTiling("L3", "L1", SpatzDma()),
ArgumentStructGeneration(),
MemoryManagementGeneration("L1"),
MemoryAwareFunctionCallClosure(writeback = False, generateStruct = True),
MemoryManagementGeneration()
])

SpatzGatherBindings = [
NodeBinding(
GatherChecker(
[PointerClass(float32_t), PointerClass(type)],
[PointerClass(float32_t)]
),
GatherTemplate.tilingReferenceTemplate,
TiledTransformer
) for type in IntegerDataTypes
]
# [
# NodeBinding(
# GatherChecker(
# [PointerClass(type), PointerClass(int32_t)],
# [PointerClass(type)]
# ),
# GatherTemplate.referenceTemplate,
# BasicTransformer
# ) for type in SignedIntegerDataTypes] +

# with tiled transformer
SpatzMatMulBindings = [
NodeBinding(MatMulChecker([PointerClass(int8_t), PointerClass(int8_t)], [PointerClass(int32_t)]),
SpatzMatMulTemplate.spatzSIMatMulTemplate, TiledTransformer),
NodeBinding(
MatMulChecker([PointerClass(float32_t), PointerClass(float32_t)], [PointerClass(float32_t)]),
SpatzMatMulTemplate.spatzFloatMatMulTemplate, TiledTransformer)
]
'''
# without tiled transformer
SpatzMatMulBindings = [
NodeBinding(MatMulChecker([PointerClass(int8_t), PointerClass(int8_t)], [PointerClass(int32_t)]),
SpatzMatMulTemplate.spatzSIMatMulTemplate, BasicTransformer),
NodeBinding(
MatMulChecker([PointerClass(float32_t), PointerClass(float32_t)], [PointerClass(float32_t)]),
SpatzMatMulTemplate.spatzFloatMatMulTemplate, BasicTransformer)
]
# with BEGIN_SINGLE_CORE
# SpatzMatMulBindings = [
# NodeBinding(MatMulChecker([PointerClass(int8_t), PointerClass(int8_t)], [PointerClass(int32_t)]),
# MatMulTemplate.referenceTemplate, TiledTransformer)
# ] + [
# NodeBinding(MatMulChecker([PointerClass(float32_t), PointerClass(float32_t)], [PointerClass(float32_t)]),
# FloatMatMulTemplate.referenceTemplate, TiledTransformer)
# ]
'''

SpatzTopKBindings = [
NodeBinding(
TopKChecker(
[PointerClass(float32_t), PointerClass(int32_t)], # inputs
[PointerClass(float32_t), PointerClass(int32_t)] # outputs
),
TopKTemplate.SpatzTilingTemplate,
TiledTransformer,
)
]


SpatzSoftmaxBindings = [
NodeBinding(
SoftmaxChecker([PointerClass(float32_t)], [PointerClass(float32_t)]),
SoftmaxTemplate.floatTilingTemplate,
TiledTransformer
)
]
# [
# NodeBinding(
# SoftmaxChecker([PointerClass(int8_t)], [PointerClass(int8_t)]),
# SoftmaxTemplate.integerTilingTemplate,
# TiledTransformer
# )
# ]
Loading