The distribution module provides tools for distributing lattice Boltzmann operators across multiple devices (e.g., GPUs or TPUs) using JAX sharding.
This enables simulations to run in parallel while ensuring correct halo communication between device partitions.
In lattice Boltzmann methods (LBM), each lattice site’s distribution function depends on its neighbors.
When running on multiple devices, the domain is split (sharded) across them, requiring data exchange at the boundaries after each step.
The distribution module handles:
- Sharding operators across devices.
- Exchanging boundary (halo) data between devices.
- Supporting stepper operators (like
IncompressibleNavierStokesStepper) with or without boundary conditions.
distribute_operator(operator, grid, velocity_set, num_results=1, ops="permute")Wraps an operator to run in distributed fashion.
-
operator (
Operator)
The LBM operator (e.g., collision, streaming). -
grid
Grid definition with device mesh info (grid.global_mesh,grid.shape,grid.nDevices). -
velocity_set
Velocity set defining the LBM stencil (e.g., D2Q9, D3Q19). -
num_results (
int, default=1)
Number of results returned by the operator. -
ops (
str, default="permute")
Communication scheme. Currently supports"permute"for halo exchange.
- Uses
shard_mapto parallelize across devices. - Applies halo communication via
jax.lax.ppermute:- Sends right-edge values to the left neighbor.
- Sends left-edge values to the right neighbor.
- Returns a JIT-compiled distributed operator.
distribute(operator, grid, velocity_set, num_results=1, ops="permute")Decides how to distribute an operator or stepper.
Same as distribute_operator.
- Checks if boundary conditions require post-streaming updates:
- If yes → only the
.streamoperator is distributed. - If no → the entire stepper is distributed.
- If yes → only the
from xlb.operator.stepper import IncompressibleNavierStokesStepper
from xlb.distribution import distribute
# Create stepper
stepper = IncompressibleNavierStokesStepper(...)
# Distribute across devices
distributed_stepper = distribute(stepper, grid, velocity_set)
# Run simulation
state = distributed_stepper(state)