Updated Bug Report: Two Critical Issues in Qwen3-Next Training
Summary
MLX-LM v0.28.0 introduces Qwen3-Next support but has two critical bugs that prevent fine-tuning:
- FIXED: UnboundLocalError in variable initialization
- NEW: Missing gradient implementation for CustomKernel
Environment
- MLX-LM Version: 0.28.0
- Python Version: 3.11.13
- Platform: macOS-15.6.1-arm64-arm-64bit
- Architecture: arm64
- MLX Version: 0.29.1
Bug 1: UnboundLocalError (FIXED)
Issue
UnboundLocalError: cannot access local variable 'state' where it is not associated with a value
Root Cause
In mlx_lm/models/qwen3_next.py lines 254-261, state variable used before initialization:
# Lines 254-261 (PROBLEMATIC)
if cache is not None:
state = cache[1]
out, state = gated_delta_update(q, k, v, a, b, self.A_log, self.dt_bias, state)
# state is undefined when cache is None!
Fix Applied
# Lines 254-261 (FIXED)
if cache is not None:
state = cache[1]
else:
state = None
out, state = gated_delta_update(q, k, v, a, b, self.A_log, self.dt_bias, state)
Status: ✅ FIXED - Training now starts successfully
Bug 2: CustomKernel Gradient Implementation (NEW CRITICAL ISSUE)
Issue
ValueError: [Primitive::vjp] Not implemented for CustomKernel.
Root Cause
The custom kernels used in Qwen3-Next's linear attention mechanism lack gradient implementations (VJP - Vector-Jacobian Product) required for backpropagation.
When It Occurs
- Forward pass works perfectly ✅
- Tokenization works perfectly ✅
- Model loading works perfectly ✅
- Training fails during gradient computation ❌
Stack Trace
File "mlx/nn/utils.py", line 35, in wrapped_value_grad_fn
value, grad = value_grad_fn(model.trainable_parameters(), *args, **kwargs)
ValueError: [Primitive::vjp] Not implemented for CustomKernel.
Impact
- Severity: CRITICAL - Completely blocks all Qwen3-Next training
- Scope: All Qwen3-Next models requiring gradient computation
- Workaround: None available
Testing Results
✅ Working Functionality
- Model loading and initialization
- Tokenization and decoding
- Forward pass in evaluation mode
- Basic inference
❌ Broken Functionality
- Any gradient computation
- LoRA fine-tuning
- Full fine-tuning
- Any training workflow
Reproduction
Minimal Test Case
import mlx.core as mx
import mlx.nn as nn
from mlx_lm import load
# Load model (works)
model, tokenizer = load("mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit")
# Forward pass (works)
model.eval()
tokens = tokenizer.encode("Hello")
output = model(mx.array([tokens])) # ✅ Success
# Gradient computation (fails)
model.train()
loss_fn = lambda m, x: mx.sum(m(x))
grad_fn = nn.value_and_grad(model, loss_fn)
loss, grads = grad_fn(model, mx.array([tokens])) # ❌ CustomKernel error
Recommendation
The Qwen3-Next implementation needs gradient (VJP) implementations for all custom kernels used in the linear attention mechanism. This likely requires:
- Implementing VJP functions for
gated_delta_update and related custom ops
- Registering these VJP implementations in the MLX framework
- Testing gradient computation across all Qwen3-Next components
Files Involved
mlx_lm/models/qwen3_next.py (lines 254-261 for Bug 1)
- Custom kernel implementations (location unknown, likely in MLX core)
- Gradient registration system
Priority: CRITICAL - This blocks all training workflows for Qwen3-Next models
Updated Bug Report: Two Critical Issues in Qwen3-Next Training
Summary
MLX-LM v0.28.0 introduces Qwen3-Next support but has two critical bugs that prevent fine-tuning:
Environment
Bug 1: UnboundLocalError (FIXED)
Issue
UnboundLocalError: cannot access local variable 'state' where it is not associated with a valueRoot Cause
In
mlx_lm/models/qwen3_next.pylines 254-261,statevariable used before initialization:Fix Applied
Status: ✅ FIXED - Training now starts successfully
Bug 2: CustomKernel Gradient Implementation (NEW CRITICAL ISSUE)
Issue
ValueError: [Primitive::vjp] Not implemented for CustomKernel.Root Cause
The custom kernels used in Qwen3-Next's linear attention mechanism lack gradient implementations (VJP - Vector-Jacobian Product) required for backpropagation.
When It Occurs
Stack Trace
Impact
Testing Results
✅ Working Functionality
❌ Broken Functionality
Reproduction
Minimal Test Case
Recommendation
The Qwen3-Next implementation needs gradient (VJP) implementations for all custom kernels used in the linear attention mechanism. This likely requires:
gated_delta_updateand related custom opsFiles Involved
mlx_lm/models/qwen3_next.py(lines 254-261 for Bug 1)Priority: CRITICAL - This blocks all training workflows for Qwen3-Next models