Skip to content

Commit e8ebdeb

Browse files
authored
Add barrier to JACCL (ml-explore#3459)
1 parent d7d0992 commit e8ebdeb

8 files changed

Lines changed: 63 additions & 0 deletions

File tree

mlx/distributed/jaccl/lib/README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ in macOS 26.2.
2929
- **Point-to-Point Operations**:
3030
- `send`: Send data to a specific node
3131
- `recv`: Receive data from a specific node
32+
- **Synchronization**:
33+
- `barrier`: Block until all nodes in the group reach this point
3234
- **Type Support**: Bool, Int8-64, UInt8-64, Float16, BFloat16, Float32,
3335
Float64, Complex64
3436

@@ -286,6 +288,9 @@ class Group {
286288
// Simple send/recv primitives.
287289
virtual void send(const void* input, size_t n_bytes, int dst) = 0;
288290
virtual void recv(void* output, size_t n_bytes, int src) = 0;
291+
292+
// Block until every rank reaches this point.
293+
virtual void barrier() = 0;
289294
};
290295
```
291296

mlx/distributed/jaccl/lib/examples/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ endfunction()
3535
# Examples
3636
build_example(minimal_env.cpp)
3737
build_example(minimal_cfg.cpp)
38+
build_example(minimal_barrier.cpp)
3839

3940
# Benchmarks
4041
build_example(allreduce_bench.cpp)
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
// Copyright © 2026 Apple Inc.
2+
//
3+
// Exercises Group::barrier(). Ranks arrive at the barrier at staggered times;
4+
// after the barrier returns we do a small all_sum to confirm the group is
5+
// healthy and that barrier() carried the correct fence semantics.
6+
7+
#include <chrono>
8+
#include <iostream>
9+
#include <thread>
10+
11+
#include <jaccl/jaccl.h>
12+
13+
int main() {
14+
auto group = jaccl::init();
15+
if (!group) {
16+
std::cerr << "Failed to initialize JACCL" << std::endl;
17+
return 1;
18+
}
19+
20+
int rank = group->rank();
21+
int size = group->size();
22+
23+
std::this_thread::sleep_for(std::chrono::milliseconds(100 * rank));
24+
std::cout << "rank " << rank << " entering barrier" << std::endl;
25+
26+
group->barrier();
27+
28+
std::cout << "rank " << rank << " exited barrier" << std::endl;
29+
30+
int in = rank + 1;
31+
int out = 0;
32+
group->all_sum(&in, &out, sizeof(in), jaccl::Int32);
33+
int expected = size * (size + 1) / 2;
34+
if (out != expected) {
35+
std::cerr << "rank " << rank << ": post-barrier all_sum mismatch (got "
36+
<< out << ", expected " << expected << ")" << std::endl;
37+
return 1;
38+
}
39+
std::cout << "rank " << rank << ": post-barrier all_sum OK (" << out << ")"
40+
<< std::endl;
41+
return 0;
42+
}

mlx/distributed/jaccl/lib/jaccl/group.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ class Group {
3030

3131
virtual void send(const void* input, size_t n_bytes, int dst) = 0;
3232
virtual void recv(void* output, size_t n_bytes, int src) = 0;
33+
virtual void barrier() = 0;
3334
};
3435

3536
/**

mlx/distributed/jaccl/lib/jaccl/mesh.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,11 @@ void MeshGroup::recv(void* output, size_t n_bytes, int src) {
184184
mesh_.recv(static_cast<char*>(output), n_bytes, src);
185185
}
186186

187+
void MeshGroup::barrier() {
188+
uint8_t b = 0;
189+
all_sum(&b, &b, sizeof(b), Dtype::UInt8);
190+
}
191+
187192
template <typename T, typename ReduceOp>
188193
void MeshGroup::all_reduce(
189194
const void* input,

mlx/distributed/jaccl/lib/jaccl/mesh.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ class MeshGroup : public Group {
4747
void send(const void* input, size_t n_bytes, int dst) override;
4848
void recv(void* output, size_t n_bytes, int src) override;
4949

50+
void barrier() override;
51+
5052
private:
5153
template <typename T, typename ReduceOp>
5254
void all_reduce(

mlx/distributed/jaccl/lib/jaccl/ring.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,11 @@ void RingGroup::recv(void* output, size_t n_bytes, int src) {
190190
ring_.recv(static_cast<char*>(output), n_bytes, src, n_conns_);
191191
}
192192

193+
void RingGroup::barrier() {
194+
uint8_t b = 0;
195+
all_sum(&b, &b, sizeof(b), Dtype::UInt8);
196+
}
197+
193198
template <typename T, typename ReduceOp>
194199
void RingGroup::all_reduce(
195200
const void* input,

mlx/distributed/jaccl/lib/jaccl/ring.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ class RingGroup : public Group {
4848
void send(const void* input, size_t n_bytes, int dst) override;
4949
void recv(void* output, size_t n_bytes, int src) override;
5050

51+
void barrier() override;
52+
5153
private:
5254
template <typename T, typename ReduceOp>
5355
void all_reduce(

0 commit comments

Comments
 (0)