@@ -28,7 +28,9 @@ public class NeuralNetOperationsConverter : StableHloOperationConverter {
2828 // Normalization operations
2929 " batchNorm" , " batchNormalization" , " BatchNormalization" ,
3030 " layerNorm" , " layerNormalization" , " LayerNormalization" ,
31- " rmsNorm" , " rms_norm" , " RMSNorm" , " RmsNorm"
31+ " rmsNorm" , " rms_norm" , " RMSNorm" , " RmsNorm" ,
32+ // Attention
33+ " scaledDotProductAttention"
3234 )
3335
3436 override fun convert (
@@ -44,6 +46,7 @@ public class NeuralNetOperationsConverter : StableHloOperationConverter {
4446 " batchnorm" , " batchnormalization" -> convertBatchNorm(node, operands, context)
4547 " layernorm" , " layernormalization" -> convertLayerNorm(node, operands, context)
4648 " rmsnorm" , " rms_norm" -> convertRmsNorm(node, operands, context)
49+ " scaleddotproductattention" -> convertSdpa(node, operands, context)
4750 else -> ConversionResult .Unsupported (
4851 node.operation.name,
4952 " Operation not supported by NeuralNetOperationsConverter"
@@ -770,5 +773,149 @@ public class NeuralNetOperationsConverter : StableHloOperationConverter {
770773 " epsilon = $epsilon , feature_index = $featureIndex : $outputType "
771774 }
772775 }
773-
776+
777+ /* *
778+ * Convert scaledDotProductAttention to StableHLO.
779+ * Decomposes into: Q @ K.T (batched) → scale → optional mask → softmax → @ V (batched)
780+ *
781+ * Input shapes: Q[B,H,S,D], K[B,H,T,D], V[B,H,T,D], optional mask[B,H,S,T] or broadcastable
782+ * Output: [B,H,S,D]
783+ */
784+ private fun convertSdpa (
785+ node : GraphNode ,
786+ operands : List <String >,
787+ context : ConversionContext
788+ ): ConversionResult {
789+ if (operands.size < 3 ) {
790+ return ConversionResult .Failure (" SDPA requires at least 3 operands (Q, K, V), got ${operands.size} " )
791+ }
792+
793+ val query = operands[0 ] // [B, H, S, D]
794+ val key = operands[1 ] // [B, H, T, D]
795+ val value = operands[2 ] // [B, H, T, D]
796+ val mask = if (operands.size >= 4 ) operands[3 ] else null
797+
798+ val querySpec = node.inputs.getOrNull(0 )
799+ val keySpec = node.inputs.getOrNull(1 )
800+ val valueSpec = node.inputs.getOrNull(2 )
801+
802+ val outputSpec = node.outputs.firstOrNull()
803+ val outputType = outputSpec?.let { context.getTypeMapper().mapTensorType(it) }
804+ ? : " tensor<?xf32>"
805+
806+ // Infer shapes for intermediate types
807+ val qShape = querySpec?.shape ? : return ConversionResult .Failure (" Unknown Q shape" )
808+ val kShape = keySpec?.shape ? : return ConversionResult .Failure (" Unknown K shape" )
809+ val vShape = valueSpec?.shape ? : return ConversionResult .Failure (" Unknown V shape" )
810+
811+ val rank = qShape.size
812+ if (rank != 4 ) {
813+ return ConversionResult .Failure (" SDPA expects 4D tensors [B,H,S,D], got rank $rank " )
814+ }
815+
816+ val batch = qShape[0 ]
817+ val heads = qShape[1 ]
818+ val seqQ = qShape[2 ]
819+ val headDim = qShape[3 ]
820+ val seqK = kShape[2 ]
821+
822+ val queryType = context.getValueType(query) ? : " tensor<${qShape.joinToString(" x" )} xf32>"
823+ val keyType = context.getValueType(key) ? : " tensor<${kShape.joinToString(" x" )} xf32>"
824+ val valueType = context.getValueType(value) ? : " tensor<${vShape.joinToString(" x" )} xf32>"
825+
826+ // scores = Q @ K.T: [B,H,S,D] @ [B,H,T,D] → [B,H,S,T]
827+ // dot_general with batching_dims=[0,1], contracting_dims=[3]x[3]
828+ val scoresType = " tensor<${batch} x${heads} x${seqQ} x${seqK} xf32>"
829+ val scoresVal = context.nextTempValue()
830+ context.emitOperation(
831+ " $scoresVal = stablehlo.dot_general $query , $key , " +
832+ " batching_dims = [0, 1] x [0, 1], contracting_dims = [3] x [3] " +
833+ " : ($queryType , $keyType ) -> $scoresType "
834+ )
835+ context.setValueType(scoresVal, scoresType)
836+
837+ // Scale
838+ val scale = node.operation.parameters[" scale" ] as ? Float
839+ ? : (1.0f / kotlin.math.sqrt(headDim.toFloat()))
840+ val scaledVal = context.nextTempValue()
841+ val scaleConst = context.nextTempValue()
842+ context.emitOperation(" $scaleConst = stablehlo.constant dense<$scale > : tensor<f32>" )
843+ context.emitOperation(
844+ " $scaledVal = stablehlo.broadcast_in_dim $scaleConst , dims = [] " +
845+ " : (tensor<f32>) -> $scoresType "
846+ )
847+ val scaledScores = context.nextTempValue()
848+ context.emitOperation(
849+ " $scaledScores = stablehlo.multiply $scoresVal , $scaledVal : $scoresType "
850+ )
851+ context.setValueType(scaledScores, scoresType)
852+
853+ // Optional mask
854+ var presoft = scaledScores
855+ if (mask != null ) {
856+ val maskedVal = context.nextTempValue()
857+ val maskType = context.getValueType(mask) ? : scoresType
858+ context.emitOperation(
859+ " $maskedVal = stablehlo.add $presoft , $mask : $scoresType "
860+ )
861+ context.setValueType(maskedVal, scoresType)
862+ presoft = maskedVal
863+ }
864+
865+ // Softmax over last dim (seqK)
866+ // Decompose: exp(x - max(x)) / sum(exp(x - max(x)))
867+ val maxVal = context.nextTempValue()
868+ val maxInitVal = context.nextTempValue()
869+ context.emitOperation(" $maxInitVal = stablehlo.constant dense<0xFF800000> : tensor<f32>" ) // -inf
870+ context.emitOperation(
871+ " $maxVal = stablehlo.reduce($presoft init: $maxInitVal ) applies stablehlo.maximum " +
872+ " across dimensions = [${rank - 1 } ] : ($scoresType , tensor<f32>) -> " +
873+ " tensor<${batch} x${heads} x${seqQ} xf32>"
874+ )
875+
876+ val maxBcast = context.nextTempValue()
877+ val reducedType = " tensor<${batch} x${heads} x${seqQ} xf32>"
878+ context.emitOperation(
879+ " $maxBcast = stablehlo.broadcast_in_dim $maxVal , dims = [0, 1, 2] " +
880+ " : ($reducedType ) -> $scoresType "
881+ )
882+
883+ val shifted = context.nextTempValue()
884+ context.emitOperation(" $shifted = stablehlo.subtract $presoft , $maxBcast : $scoresType " )
885+
886+ val expVal = context.nextTempValue()
887+ context.emitOperation(" $expVal = stablehlo.exponential $shifted : $scoresType " )
888+
889+ val sumInit = context.nextTempValue()
890+ context.emitOperation(" $sumInit = stablehlo.constant dense<0.0> : tensor<f32>" )
891+ val sumVal = context.nextTempValue()
892+ context.emitOperation(
893+ " $sumVal = stablehlo.reduce($expVal init: $sumInit ) applies stablehlo.add " +
894+ " across dimensions = [${rank - 1 } ] : ($scoresType , tensor<f32>) -> $reducedType "
895+ )
896+
897+ val sumBcast = context.nextTempValue()
898+ context.emitOperation(
899+ " $sumBcast = stablehlo.broadcast_in_dim $sumVal , dims = [0, 1, 2] " +
900+ " : ($reducedType ) -> $scoresType "
901+ )
902+
903+ val weightsVal = context.nextTempValue()
904+ context.emitOperation(" $weightsVal = stablehlo.divide $expVal , $sumBcast : $scoresType " )
905+ context.setValueType(weightsVal, scoresType)
906+
907+ // output = weights @ V: [B,H,S,T] @ [B,H,T,D] → [B,H,S,D]
908+ val resultValue = context.nextTempValue()
909+ context.emitOperation(
910+ " $resultValue = stablehlo.dot_general $weightsVal , $value , " +
911+ " batching_dims = [0, 1] x [0, 1], contracting_dims = [3] x [2] " +
912+ " : ($scoresType , $valueType ) -> $outputType "
913+ )
914+
915+ return ConversionResult .Success (
916+ outputValueName = resultValue,
917+ emittedOperations = emptyList()
918+ )
919+ }
920+
774921}
0 commit comments