You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/NeuralNetOperationsConverter.kt
+71-28Lines changed: 71 additions & 28 deletions
Original file line number
Diff line number
Diff line change
@@ -184,11 +184,13 @@ public class NeuralNetOperationsConverter : StableHloOperationConverter {
184
184
val padding = extractPadding(params)
185
185
186
186
val resultValue = context.nextTempValue()
187
-
187
+
val inputType = node.inputs.firstOrNull()?.let { context.getTypeMapper().mapTensorType(it) } ?: outputType
188
+
188
189
// Build StableHLO reduce_window operation for max pooling
189
190
val operations = buildMaxPoolOperations(
190
191
resultValue = resultValue,
191
192
input = operands[0],
193
+
inputType = inputType,
192
194
outputType = outputType,
193
195
kernelSize = kernelSize,
194
196
stride = stride,
@@ -225,11 +227,13 @@ public class NeuralNetOperationsConverter : StableHloOperationConverter {
225
227
val padding = extractPadding(params)
226
228
227
229
val resultValue = context.nextTempValue()
228
-
230
+
val inputType = node.inputs.firstOrNull()?.let { context.getTypeMapper().mapTensorType(it) } ?: outputType
231
+
229
232
// Build StableHLO reduce_window operation for average pooling
230
233
val operations = buildAvgPoolOperations(
231
234
resultValue = resultValue,
232
235
input = operands[0],
236
+
inputType = inputType,
233
237
outputType = outputType,
234
238
kernelSize = kernelSize,
235
239
stride = stride,
@@ -691,64 +695,103 @@ public class NeuralNetOperationsConverter : StableHloOperationConverter {
691
695
}
692
696
}
693
697
698
+
/** The MLIR element type ("f32"/"f16"/…) parsed from a `tensor<…xT>` string. */
Copy file name to clipboardExpand all lines: skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/NeuralNetOperationsConverterTest.kt
+3-3Lines changed: 3 additions & 3 deletions
Original file line number
Diff line number
Diff line change
@@ -73,9 +73,9 @@ class NeuralNetOperationsConverterTest {
0 commit comments