Skip to content

Commit d8aa509

Browse files
committed
Fix mixed-dtype averageGradients multi-process
1 parent 6624907 commit d8aa509

3 files changed

Lines changed: 45 additions & 3 deletions

File tree

Source/Examples/DistributedWorker.swift

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -915,8 +915,30 @@ struct DistributedWorker {
915915
}
916916
if abs(avg3BiasValues[0] - expectedBias[0]) > 0.1 { commTypeMatch = false }
917917

918+
// 4. Mixed-dtype gradients — triggers fallback to non-batched mode
919+
let mixedFlat: [String: MLXArray] = [
920+
"weight_f32": MLXArray(
921+
rank == 0 ? [2.0, 4.0] as [Float] : [4.0, 8.0] as [Float]),
922+
"weight_f16": MLXArray(
923+
rank == 0 ? [10.0, 20.0] as [Float] : [30.0, 40.0] as [Float]).asType(.float16),
924+
]
925+
let mixedGrads = ModuleParameters.unflattened(mixedFlat)
926+
let mixedResult = averageGradients(gradients: mixedGrads, group: group)
927+
eval(mixedResult)
928+
929+
let mixedResultFlat = mixedResult.flattened()
930+
let f32Result = Dictionary(uniqueKeysWithValues: mixedResultFlat)["weight_f32"]!
931+
let f16Result = Dictionary(uniqueKeysWithValues: mixedResultFlat)["weight_f16"]!
932+
933+
let f32Values = f32Result.asArray(Float.self)
934+
let f16Values = f16Result.asType(.float32).asArray(Float.self)
935+
let mixedDtypeMatch =
936+
abs(f32Values[0] - 3.0) < 0.1 && abs(f32Values[1] - 6.0) < 0.1
937+
&& abs(f16Values[0] - 20.0) < 1.0 && abs(f16Values[1] - 30.0) < 1.0
938+
let mixedDtypePreserved = f16Result.dtype == .float16
939+
918940
print(
919-
"{\"defaultMatch\": \(defaultMatch), \"unbatchedMatch\": \(unbatchedMatch), \"commTypeMatch\": \(commTypeMatch), \"commTypeDtype\": \"\(commTypeDtype)\"}"
941+
"{\"defaultMatch\": \(defaultMatch), \"unbatchedMatch\": \(unbatchedMatch), \"commTypeMatch\": \(commTypeMatch), \"commTypeDtype\": \"\(commTypeDtype)\", \"mixedDtypeMatch\": \(mixedDtypeMatch), \"mixedDtypePreserved\": \(mixedDtypePreserved)}"
920942
)
921943
}
922944

@@ -1002,8 +1024,11 @@ struct DistributedWorker {
10021024
let int32Match =
10031025
i32Result.shape == [2] && i32Values[0] == 10 && i32Values[1] == 20
10041026

1027+
let float16Dtype = String(describing: f16Result.dtype)
1028+
let int32Dtype = String(describing: i32Result.dtype)
1029+
10051030
print(
1006-
"{\"float16Match\": \(float16Match), \"int32Match\": \(int32Match), \"float16Shape\": [\(f16Result.shape.map { String($0) }.joined(separator: ","))], \"int32Shape\": [\(i32Result.shape.map { String($0) }.joined(separator: ","))]}"
1031+
"{\"float16Match\": \(float16Match), \"int32Match\": \(int32Match), \"float16Shape\": [\(f16Result.shape.map { String($0) }.joined(separator: ","))], \"int32Shape\": [\(i32Result.shape.map { String($0) }.joined(separator: ","))], \"float16Dtype\": \"\(float16Dtype)\", \"int32Dtype\": \"\(int32Dtype)\"}"
10071032
)
10081033
}
10091034

Tests/MLXTests/DistributedNNTests.swift

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1525,7 +1525,9 @@ class DistributedNNTests: XCTestCase {
15251525
let defaultMatch = json["defaultMatch"] as? Bool,
15261526
let unbatchedMatch = json["unbatchedMatch"] as? Bool,
15271527
let commTypeMatch = json["commTypeMatch"] as? Bool,
1528-
let commTypeDtype = json["commTypeDtype"] as? String
1528+
let commTypeDtype = json["commTypeDtype"] as? String,
1529+
let mixedDtypeMatch = json["mixedDtypeMatch"] as? Bool,
1530+
let mixedDtypePreserved = json["mixedDtypePreserved"] as? Bool
15291531
else {
15301532
XCTFail("Rank \(rank) produced invalid JSON output: '\(stdout)'")
15311533
continue
@@ -1543,6 +1545,12 @@ class DistributedNNTests: XCTestCase {
15431545
XCTAssertEqual(
15441546
commTypeDtype, "float32",
15451547
"Rank \(rank): communicationType should preserve original float32 dtype")
1548+
XCTAssertTrue(
1549+
mixedDtypeMatch,
1550+
"Rank \(rank): mixed-dtype averageGradients values mismatch")
1551+
XCTAssertTrue(
1552+
mixedDtypePreserved,
1553+
"Rank \(rank): mixed-dtype averageGradients should preserve float16")
15461554
}
15471555
}
15481556
}

Tests/MLXTests/DistributedTests.swift

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1438,6 +1438,15 @@ class DistributedTests: XCTestCase {
14381438
XCTAssertTrue(int32Match, "Rank \(rank): int32 allGather mismatch")
14391439
XCTAssertEqual(float16Shape, [4], "Rank \(rank): float16 shape mismatch")
14401440
XCTAssertEqual(int32Shape, [2], "Rank \(rank): int32 shape mismatch")
1441+
1442+
let float16Dtype = json["float16Dtype"] as? String
1443+
let int32Dtype = json["int32Dtype"] as? String
1444+
XCTAssertEqual(
1445+
float16Dtype, "float16",
1446+
"Rank \(rank): allGather should preserve float16 dtype")
1447+
XCTAssertEqual(
1448+
int32Dtype, "int32",
1449+
"Rank \(rank): allGather should preserve int32 dtype")
14411450
}
14421451
}
14431452

0 commit comments

Comments
 (0)