diff --git a/mlx_lm/models/minimax.py b/mlx_lm/models/minimax.py index 9bf78d9a4..7f4d19514 100644 --- a/mlx_lm/models/minimax.py +++ b/mlx_lm/models/minimax.py @@ -9,6 +9,7 @@ from mlx.nn.layers.distributed import shard_inplace, shard_linear, sum_gradients from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention +from .pipeline import PipelineMixin from .switch_layers import SwitchGLU @@ -221,9 +222,10 @@ def __call__( return r -class MiniMaxModel(nn.Module): +class MiniMaxModel(PipelineMixin, nn.Module): def __init__(self, args: ModelArgs): super().__init__() + self.args = args self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) self.layers = [ @@ -241,15 +243,32 @@ def __call__( h = self.embed_tokens(inputs) if cache is None: - cache = [None] * len(self.layers) + cache = [None] * len(self.pipeline_layers) mask = create_attention_mask(h, cache[0]) - for layer, c in zip(self.layers, cache): + if self.pipeline_rank < self.pipeline_size - 1: + h = mx.distributed.recv_like(h, self.pipeline_rank + 1) + + for layer, c in zip(self.pipeline_layers, cache): h = layer(h, mask, c) + if self.pipeline_rank != 0: + h = mx.distributed.send(h, (self.pipeline_rank - 1) % self.pipeline_size) + + if self.pipeline_size > 1: + h = mx.distributed.all_gather(h)[: h.shape[0]] + return self.norm(h) + def make_cache(self): + from .cache import KVCache + + return [ + KVCache() + for _ in self.pipeline_layers + ] + class Model(nn.Module): def __init__(self, args: ModelArgs): @@ -373,6 +392,12 @@ def shard(self, group: Optional[mx.distributed.Group] = None): ) layer.block_sparse_moe.sharding_group = group + def make_cache(self): + if hasattr(self.model, "make_cache"): + return self.model.make_cache() + from .cache import KVCache + return [KVCache() for _ in self.model.layers] + @property def layers(self): return self.model.layers diff --git a/mlx_lm/models/qwen3_moe.py b/mlx_lm/models/qwen3_moe.py index 52dc50f9b..38ba374a1 100644 --- a/mlx_lm/models/qwen3_moe.py +++ b/mlx_lm/models/qwen3_moe.py @@ -8,6 +8,7 @@ from .activations import swiglu from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention +from .pipeline import PipelineMixin from .switch_layers import SwitchGLU @@ -171,7 +172,7 @@ def __call__( return out -class Qwen3MoeModel(nn.Module): +class Qwen3MoeModel(PipelineMixin, nn.Module): def __init__(self, args: ModelArgs): super().__init__() self.args = args @@ -197,15 +198,37 @@ def __call__( h = self.embed_tokens(inputs) if cache is None: - cache = [None] * len(self.layers) + cache = [None] * len(self.pipeline_layers) - mask = create_attention_mask(h, cache[0]) + mask = create_attention_mask(h, cache[0], return_array=True) - for layer, c in zip(self.layers, cache): - h = layer(h, mask, c) + # In pipeline parallel, receive hidden state from the next rank + # (which processed earlier layers) + if self.pipeline_rank < self.pipeline_size - 1: + h = mx.distributed.recv_like(h, self.pipeline_rank + 1) + + for layer, c in zip(self.pipeline_layers, cache): + h = layer(h, mask, cache=c) + + # Send hidden state to the previous rank (which has later layers) + if self.pipeline_rank != 0: + h = mx.distributed.send(h, (self.pipeline_rank - 1) % self.pipeline_size) + + # All ranks need the final output for lm_head + if self.pipeline_size > 1: + h = mx.distributed.all_gather(h)[: h.shape[0]] return self.norm(h) + def make_cache(self): + """Return KV cache entries only for this rank's layers.""" + from .cache import KVCache + + return [ + KVCache() + for _ in self.pipeline_layers + ] + class Model(nn.Module): def __init__(self, args: ModelArgs): @@ -254,6 +277,12 @@ def predicate(path, _): return predicate + def make_cache(self): + if hasattr(self.model, "make_cache"): + return self.model.make_cache() + from .cache import KVCache + return [KVCache() for _ in self.model.layers] + @property def layers(self): return self.model.layers