|
37 | 37 | startRegion = "L2", |
38 | 38 | endRegion = "L1") |
39 | 39 |
|
40 | | -BasicTransformer = CodeTransformation( |
41 | | - [SnitchSynchCoresPass(), |
42 | | - ArgumentStructGeneration(), |
43 | | - MemoryManagementGeneration(), |
44 | | - FutureGeneration()]) |
45 | | - |
46 | 40 | SkipTransformer = CodeTransformation( |
47 | 41 | [SnitchSynchCoresPass(), |
48 | 42 | ArgumentStructGeneration(), |
|
92 | 86 | FloatAddTemplate, TiledTransformer) |
93 | 87 | ] |
94 | 88 |
|
95 | | -# Basic (non-tiled) FP32 Add Bindings |
96 | | -BasicAddBindings = [ |
97 | | - NodeBinding(AddChecker([PointerClass(float32_t), PointerClass(float32_t)], [PointerClass(float32_t)]), |
98 | | - FloatAddTemplate, BasicTransformer) |
99 | | -] |
100 | | - |
101 | 89 | SnitchGemmBindings = [ |
102 | 90 | NodeBinding( |
103 | 91 | GEMMChecker([PointerClass(int8_t), PointerClass(int8_t), |
|
119 | 107 | ], [PointerClass(int8_t)]), SnitchRqGemm_Template, TiledTransformer) |
120 | 108 | ] |
121 | 109 |
|
122 | | -# RMSNorm Bindings (Tiled) |
123 | 110 | SnitchRMSNormBindings = [ |
124 | 111 | NodeBinding(RMSNormChecker([PointerClass(float32_t), PointerClass(float32_t)], [PointerClass(float32_t)]), |
125 | 112 | FloatRMSNormTemplate, TiledTransformer) |
126 | 113 | ] |
127 | 114 |
|
128 | | -# RMSNorm Bindings (Non-tiled) |
129 | | -BasicRMSNormBindings = [ |
130 | | - NodeBinding(RMSNormChecker([PointerClass(float32_t), PointerClass(float32_t)], [PointerClass(float32_t)]), |
131 | | - FloatRMSNormTemplate, BasicTransformer) |
132 | | -] |
133 | | - |
134 | | -# HardSwish Bindings (Tiled) |
135 | 115 | SnitchHardSwishBindings = [ |
136 | 116 | NodeBinding(HardswishChecker([PointerClass(float32_t)], [PointerClass(float32_t)]), FloatHardSwishTemplate, |
137 | 117 | TiledTransformer) |
138 | 118 | ] |
139 | 119 |
|
140 | | -# HardSwish Bindings (Non-tiled) |
141 | | -BasicHardSwishBindings = [ |
142 | | - NodeBinding(HardswishChecker([PointerClass(float32_t)], [PointerClass(float32_t)]), FloatHardSwishTemplate, |
143 | | - BasicTransformer) |
144 | | -] |
145 | | - |
146 | | -# Div Bindings (Tiled) |
147 | 120 | SnitchDivBindings = [ |
148 | 121 | NodeBinding(DivChecker([PointerClass(float32_t), PointerClass(float32_t)], [PointerClass(float32_t)]), |
149 | 122 | FloatDivTemplate, TiledTransformer) |
150 | 123 | ] |
151 | 124 |
|
152 | | -# Div Bindings (Non-tiled) |
153 | | -BasicDivBindings = [ |
154 | | - NodeBinding(DivChecker([PointerClass(float32_t), PointerClass(float32_t)], [PointerClass(float32_t)]), |
155 | | - FloatDivTemplate, BasicTransformer) |
156 | | -] |
157 | | - |
158 | | -# Mul Bindings (Tiled) |
159 | 125 | SnitchMulBindings = [ |
160 | 126 | NodeBinding(MulChecker([PointerClass(float32_t), PointerClass(float32_t)], [PointerClass(float32_t)]), |
161 | 127 | FloatMulTemplate, TiledTransformer) |
162 | 128 | ] |
163 | 129 |
|
164 | | -# Mul Bindings (Non-tiled) |
165 | | -BasicMulBindings = [ |
166 | | - NodeBinding(MulChecker([PointerClass(float32_t), PointerClass(float32_t)], [PointerClass(float32_t)]), |
167 | | - FloatMulTemplate, BasicTransformer) |
168 | | -] |
169 | | - |
170 | 130 | # MatMul Bindings (Tiled) |
171 | 131 | SnitchMatMulBindings = [ |
172 | 132 | NodeBinding(MatMulChecker([PointerClass(float32_t), PointerClass(float32_t)], [PointerClass(float32_t)]), |
|
179 | 139 | ConcatTemplate.referenceTemplate, TiledTransformer) |
180 | 140 | ] |
181 | 141 |
|
182 | | -# Transpose Bindings (Tiled) |
183 | 142 | SnitchTransposeBindings = [ |
184 | 143 | NodeBinding(TransposeChecker([PointerClass(float32_t)], [PointerClass(float32_t)]), |
185 | 144 | TransposeTemplate.referenceTemplate, TiledTransformer) |
186 | 145 | ] |
187 | 146 |
|
188 | | -# Transpose Bindings (Non-tiled, multi-core) |
189 | | -BasicSnitchTransposeBindings = [ |
190 | | - NodeBinding(TransposeChecker([PointerClass(float32_t)], [PointerClass(float32_t)]), |
191 | | - TransposeTemplate.referenceTemplate, BasicTransformer) |
192 | | -] |
193 | | - |
194 | 147 | # Reshape Bindings (pointer passthrough, no DMA needed) |
195 | 148 | SnitchReshapeBindings = [ |
196 | 149 | NodeBinding(ReshapeChecker([PointerClass(float32_t)], [PointerClass(float32_t)]), ReshapeTemplate.referenceTemplate, |
|
0 commit comments