Skip to content

Commit 441d0f4

Browse files
committed
Integrate MLX-c 0.31.1
1 parent 05717c4 commit 441d0f4

7 files changed

Lines changed: 65 additions & 29 deletions

File tree

DISTRIBUTED-LM-INTEGRATION.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -894,7 +894,7 @@ Implement in this order. Each step produces a testable, shippable increment:
894894
| Limitation | Impact | Workaround |
895895
|-----------|--------|------------|
896896
| All distributed ops are CPU-only | Must use `Device.withDefaultDevice(.cpu)` | Wrap model loading and generation in CPU scope |
897-
| MLX-C has no backend selection parameter | Cannot programmatically choose ring vs JACCL | MLX-C tries JACCL first, then ring — usually correct |
897+
| No backend introspection API | Cannot query which backend was initialized for an existing group | Use `isAvailable(backend:)` to check before init |
898898
| `mlx_distributed_group_free()` not in public C API | Group deallocation relies on C++ shared_ptr | No action needed — works via ref counting |
899899
| `group.split()` unsupported by ring/JACCL | Cannot create subgroups | Not needed for tensor parallelism |
900900
| `sumScatter` not implemented in ring backend | Cannot use reduce-scatter collective | Use allSum instead (slightly more bandwidth) |

Source/Examples/DistributedWorker.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@ struct DistributedWorker {
4242
}
4343

4444
static func runWorker(rank: Int, testOp: String) {
45-
// Initialize distributed with strict=true (ring backend must be available)
46-
guard let group = MLXDistributed.`init`(strict: true) else {
45+
// Initialize distributed with strict=true using the ring backend
46+
guard let group = MLXDistributed.`init`(strict: true, backend: .ring) else {
4747
fputs("ERROR: Failed to initialize distributed group (strict=true)\n", stderr)
4848
exit(1)
4949
}

Source/MLX/Distributed.swift

Lines changed: 36 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@ import Foundation
66
/// Wrapper around the MLX C distributed group handle.
77
///
88
/// A `DistributedGroup` represents a group of independent MLX processes
9-
/// that can communicate using collective operations. Use ``MLXDistributed/init(strict:)``
9+
/// that can communicate using collective operations. Use ``MLXDistributed/init(strict:backend:)``
1010
/// to create the initial group, then ``split(color:key:)`` to create sub-groups.
1111
///
1212
/// ### See Also
1313
/// - ``MLXDistributed``
14-
/// - ``MLXDistributed/init(strict:)``
14+
/// - ``MLXDistributed/init(strict:backend:)``
1515
public final class DistributedGroup: @unchecked Sendable {
1616

1717
let ctx: mlx_distributed_group
@@ -69,6 +69,23 @@ public final class DistributedGroup: @unchecked Sendable {
6969
}
7070
}
7171

72+
/// The distributed communication backend to use.
73+
///
74+
/// When ``DistributedBackend/any`` is specified, MLX chooses the best available
75+
/// backend automatically. Use a specific case to force a particular backend.
76+
public enum DistributedBackend: String, CaseIterable, Sendable {
77+
/// Let MLX choose the best available backend automatically.
78+
case any
79+
/// TCP socket-based ring backend.
80+
case ring
81+
/// Joint Accelerator Communication Library (Thunderbolt 5 RDMA).
82+
case jaccl
83+
/// Message Passing Interface backend.
84+
case mpi
85+
/// NVIDIA Collective Communications Library backend.
86+
case nccl
87+
}
88+
7289
/// Collection of distributed communication operations.
7390
///
7491
/// Use ``MLXDistributed`` to check for distributed backend availability,
@@ -77,7 +94,7 @@ public final class DistributedGroup: @unchecked Sendable {
7794
///
7895
/// ```swift
7996
/// // Initialize distributed communication
80-
/// let group = MLXDistributed.init()
97+
/// let group = MLXDistributed.`init`()
8198
/// print("Rank \(group.rank) of \(group.size)")
8299
///
83100
/// // Perform an all-sum reduction
@@ -91,10 +108,10 @@ public enum MLXDistributed {
91108

92109
/// Check if a distributed communication backend is available.
93110
///
94-
/// Returns `true` when the ring backend (or another backend) is compiled and
95-
/// available for use.
96-
public static func isAvailable() -> Bool {
97-
mlx_distributed_is_available()
111+
/// - Parameter backend: the backend to check (default: `.any`, checks all)
112+
/// - Returns: `true` when the specified backend is available
113+
public static func isAvailable(backend: DistributedBackend = .any) -> Bool {
114+
backend.rawValue.withCString { mlx_distributed_is_available($0) }
98115
}
99116

100117
/// Initialize the distributed backend and return the group containing
@@ -105,16 +122,21 @@ public enum MLXDistributed {
105122
/// When `strict` is `true`, returns `nil` if initialization fails
106123
/// (e.g., no hostfile configured).
107124
///
108-
/// > Note: MLX-C does not currently expose a backend selection parameter.
109-
/// > The C layer tries backends in priority order (JACCL first, then ring).
110-
/// > Track upstream mlx-c for a future `backend` parameter.
125+
/// ```swift
126+
/// // Use a specific backend
127+
/// let group = MLXDistributed.`init`(strict: true, backend: .ring)
128+
/// ```
111129
///
112-
/// - Parameter strict: if `true`, return `nil` on initialization failure
113-
/// instead of falling back to a singleton group
130+
/// - Parameters:
131+
/// - strict: if `true`, return `nil` on initialization failure
132+
/// instead of falling back to a singleton group
133+
/// - backend: the backend to use (default: `.any`, let MLX choose)
114134
/// - Returns: the ``DistributedGroup`` for this process, or `nil` if
115135
/// `strict` is `true` and initialization failed
116-
public static func `init`(strict: Bool = false) -> DistributedGroup? {
117-
let group = mlx_distributed_init(strict)
136+
public static func `init`(strict: Bool = false, backend: DistributedBackend = .any)
137+
-> DistributedGroup?
138+
{
139+
let group = backend.rawValue.withCString { mlx_distributed_init(strict, $0) }
118140
if group.ctx == nil {
119141
return nil
120142
}

Tests/MLXTests/DistributedTests.swift

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,11 @@ class DistributedTests: XCTestCase {
6767
func testIsAvailable() {
6868
// Ring backend is compiled in, so isAvailable should return true
6969
XCTAssertTrue(MLXDistributed.isAvailable())
70+
71+
// Verify backend-specific availability check
72+
XCTAssertTrue(
73+
MLXDistributed.isAvailable(backend: .ring),
74+
"Ring backend should always be available")
7075
}
7176

7277
// MARK: - (2b) JACCL availability check
@@ -86,10 +91,9 @@ class DistributedTests: XCTestCase {
8691
// 2. The ring backend is available (true)
8792
// 3. On this hardware, the overall availability is true (ring)
8893
//
89-
// NOTE: We cannot directly query which backend (ring vs JACCL) was
90-
// selected because MLX-C does not expose a backend-name API. The
91-
// isAvailable() call returns true if ANY backend is available. On
92-
// machines without RDMA/TB5, this is the ring backend.
94+
// NOTE: Backend selection is supported (e.g., .ring, .jaccl), but
95+
// MLX-C does not expose a backend introspection API — there is no way
96+
// to query which backend was actually initialized for an existing group.
9397

9498
// (1) Verify isAvailable() returns a Bool
9599
let available = MLXDistributed.isAvailable()

skills/mlx-distributed/SKILL.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -403,7 +403,7 @@ let avgGrads3 = averageGradients(
403403

404404
| Limitation | Impact |
405405
|------------|--------|
406-
| MLX-C doesn't expose backend selection parameter | Cannot choose between JACCL and ring; tries JACCL first, falls back to ring |
406+
| No backend introspection API | Cannot query which backend was initialized for an existing group; use `isAvailable(backend:)` to check before init |
407407
| `mlx_distributed_group_free()` not exposed in public C API | Groups leak small amounts of memory on deallocation (minimal practical impact) |
408408
| `group.split()` unsupported by ring and JACCL backends | Only MPI (not available on macOS) supports sub-group creation |
409409
| `sumScatter`/`reduceScatter` not implemented in ring backend | Use allSum + manual slicing as a workaround |

skills/mlx-distributed/references/multi-process.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ JACCL (Joint Accelerator Communication Library) uses RDMA over Thunderbolt 5 for
2525
- RDMA explicitly enabled in Recovery Mode (`csrutil`)
2626
- Physical Thunderbolt 5 cable between nodes
2727

28-
> **Note:** MLX-C does not expose a backend selection parameter. You cannot force one backend over the other. If JACCL hardware is present, it will be preferred.
28+
> **Note:** You can select a specific backend using the `backend` parameter (e.g., `MLXDistributed.\`init\`(backend: .jaccl)`). Use `.any` (the default) to let MLX choose automatically.
2929
3030
---
3131

skills/mlx-distributed/references/primitives.md

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -81,36 +81,46 @@ public enum MLXDistributed
8181

8282
### Static Methods
8383

84-
#### isAvailable()
84+
#### isAvailable(backend:)
8585

8686
Check if a distributed communication backend is available.
8787

8888
```swift
89-
public static func isAvailable() -> Bool
89+
public static func isAvailable(backend: DistributedBackend = .any) -> Bool
9090
```
9191

92-
**Returns:** `true` when the ring backend (or another backend) is compiled and available.
92+
**Parameters:**
93+
- `backend`: The backend to check. Default is `.any`, which checks if any backend is available.
94+
95+
**Returns:** `true` when the specified backend is available.
9396

9497
```swift
98+
// Check if any backend is available
9599
if MLXDistributed.isAvailable() {
96100
print("Distributed backend ready")
97101
}
102+
103+
// Check a specific backend
104+
if MLXDistributed.isAvailable(backend: .ring) {
105+
print("Ring backend ready")
106+
}
98107
```
99108

100-
#### init(strict:)
109+
#### init(strict:backend:)
101110

102111
Initialize the distributed backend and return the group containing all discoverable processes.
103112

104113
```swift
105-
public static func `init`(strict: Bool = false) -> DistributedGroup?
114+
public static func `init`(strict: Bool = false, backend: DistributedBackend = .any) -> DistributedGroup?
106115
```
107116

108117
**Parameters:**
109118
- `strict`: If `true`, returns `nil` on initialization failure instead of falling back to a singleton group. Default is `false`.
119+
- `backend`: The backend to use. Default is `.any`, which lets MLX choose automatically.
110120

111121
**Returns:** The `DistributedGroup` for this process, or `nil` if `strict` is `true` and initialization failed.
112122

113-
When `strict` is `false` (default), returns a singleton group (rank 0, size 1) if no distributed backend can be initialized. MLX-C does not expose a backend selection parameter — it tries JACCL first, then ring.
123+
When `strict` is `false` (default), returns a singleton group (rank 0, size 1) if no distributed backend can be initialized.
114124

115125
```swift
116126
// Non-strict: always returns a group (size-1 fallback)

0 commit comments

Comments
 (0)