@@ -915,8 +915,30 @@ struct DistributedWorker {
915915 }
916916 if abs ( avg3BiasValues [ 0 ] - expectedBias[ 0 ] ) > 0.1 { commTypeMatch = false }
917917
918+ // 4. Mixed-dtype gradients — triggers fallback to non-batched mode
919+ let mixedFlat : [ String : MLXArray ] = [
920+ " weight_f32 " : MLXArray (
921+ rank == 0 ? [ 2.0 , 4.0 ] as [ Float ] : [ 4.0 , 8.0 ] as [ Float ] ) ,
922+ " weight_f16 " : MLXArray (
923+ rank == 0 ? [ 10.0 , 20.0 ] as [ Float ] : [ 30.0 , 40.0 ] as [ Float ] ) . asType ( . float16) ,
924+ ]
925+ let mixedGrads = ModuleParameters . unflattened ( mixedFlat)
926+ let mixedResult = averageGradients ( gradients: mixedGrads, group: group)
927+ eval ( mixedResult)
928+
929+ let mixedResultFlat = mixedResult. flattened ( )
930+ let f32Result = Dictionary ( uniqueKeysWithValues: mixedResultFlat) [ " weight_f32 " ] !
931+ let f16Result = Dictionary ( uniqueKeysWithValues: mixedResultFlat) [ " weight_f16 " ] !
932+
933+ let f32Values = f32Result. asArray ( Float . self)
934+ let f16Values = f16Result. asType ( . float32) . asArray ( Float . self)
935+ let mixedDtypeMatch =
936+ abs ( f32Values [ 0 ] - 3.0 ) < 0.1 && abs ( f32Values [ 1 ] - 6.0 ) < 0.1
937+ && abs ( f16Values [ 0 ] - 20.0 ) < 1.0 && abs ( f16Values [ 1 ] - 30.0 ) < 1.0
938+ let mixedDtypePreserved = f16Result. dtype == . float16
939+
918940 print (
919- " { \" defaultMatch \" : \( defaultMatch) , \" unbatchedMatch \" : \( unbatchedMatch) , \" commTypeMatch \" : \( commTypeMatch) , \" commTypeDtype \" : \" \( commTypeDtype) \" } "
941+ " { \" defaultMatch \" : \( defaultMatch) , \" unbatchedMatch \" : \( unbatchedMatch) , \" commTypeMatch \" : \( commTypeMatch) , \" commTypeDtype \" : \" \( commTypeDtype) \" , \" mixedDtypeMatch \" : \( mixedDtypeMatch ) , \" mixedDtypePreserved \" : \( mixedDtypePreserved ) } "
920942 )
921943 }
922944
@@ -1002,8 +1024,11 @@ struct DistributedWorker {
10021024 let int32Match =
10031025 i32Result. shape == [ 2 ] && i32Values [ 0 ] == 10 && i32Values [ 1 ] == 20
10041026
1027+ let float16Dtype = String ( describing: f16Result. dtype)
1028+ let int32Dtype = String ( describing: i32Result. dtype)
1029+
10051030 print (
1006- " { \" float16Match \" : \( float16Match) , \" int32Match \" : \( int32Match) , \" float16Shape \" : [ \( f16Result. shape. map { String ( $0) } . joined ( separator: " , " ) ) ], \" int32Shape \" : [ \( i32Result. shape. map { String ( $0) } . joined ( separator: " , " ) ) ]} "
1031+ " { \" float16Match \" : \( float16Match) , \" int32Match \" : \( int32Match) , \" float16Shape \" : [ \( f16Result. shape. map { String ( $0) } . joined ( separator: " , " ) ) ], \" int32Shape \" : [ \( i32Result. shape. map { String ( $0) } . joined ( separator: " , " ) ) ], \" float16Dtype \" : \" \( float16Dtype ) \" , \" int32Dtype \" : \" \( int32Dtype ) \" } "
10071032 )
10081033 }
10091034
0 commit comments