Skip to content

[Bug] Two bugs in Qwen3-Next training: UnboundLocalError + Missing CustomKernel gradients #482

@ArjunDivecha

Description

@ArjunDivecha

Updated Bug Report: Two Critical Issues in Qwen3-Next Training

Summary

MLX-LM v0.28.0 introduces Qwen3-Next support but has two critical bugs that prevent fine-tuning:

  1. FIXED: UnboundLocalError in variable initialization
  2. NEW: Missing gradient implementation for CustomKernel

Environment

  • MLX-LM Version: 0.28.0
  • Python Version: 3.11.13
  • Platform: macOS-15.6.1-arm64-arm-64bit
  • Architecture: arm64
  • MLX Version: 0.29.1

Bug 1: UnboundLocalError (FIXED)

Issue

UnboundLocalError: cannot access local variable 'state' where it is not associated with a value

Root Cause

In mlx_lm/models/qwen3_next.py lines 254-261, state variable used before initialization:

# Lines 254-261 (PROBLEMATIC)
if cache is not None:
    state = cache[1]

out, state = gated_delta_update(q, k, v, a, b, self.A_log, self.dt_bias, state)
# state is undefined when cache is None!

Fix Applied

# Lines 254-261 (FIXED)
if cache is not None:
    state = cache[1]
else:
    state = None

out, state = gated_delta_update(q, k, v, a, b, self.A_log, self.dt_bias, state)

Status: ✅ FIXED - Training now starts successfully

Bug 2: CustomKernel Gradient Implementation (NEW CRITICAL ISSUE)

Issue

ValueError: [Primitive::vjp] Not implemented for CustomKernel.

Root Cause

The custom kernels used in Qwen3-Next's linear attention mechanism lack gradient implementations (VJP - Vector-Jacobian Product) required for backpropagation.

When It Occurs

  • Forward pass works perfectly ✅
  • Tokenization works perfectly ✅
  • Model loading works perfectly ✅
  • Training fails during gradient computation

Stack Trace

File "mlx/nn/utils.py", line 35, in wrapped_value_grad_fn
    value, grad = value_grad_fn(model.trainable_parameters(), *args, **kwargs)
ValueError: [Primitive::vjp] Not implemented for CustomKernel.

Impact

  • Severity: CRITICAL - Completely blocks all Qwen3-Next training
  • Scope: All Qwen3-Next models requiring gradient computation
  • Workaround: None available

Testing Results

✅ Working Functionality

  • Model loading and initialization
  • Tokenization and decoding
  • Forward pass in evaluation mode
  • Basic inference

❌ Broken Functionality

  • Any gradient computation
  • LoRA fine-tuning
  • Full fine-tuning
  • Any training workflow

Reproduction

Minimal Test Case

import mlx.core as mx
import mlx.nn as nn
from mlx_lm import load

# Load model (works)
model, tokenizer = load("mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit")

# Forward pass (works)
model.eval()
tokens = tokenizer.encode("Hello")
output = model(mx.array([tokens]))  # ✅ Success

# Gradient computation (fails)
model.train()
loss_fn = lambda m, x: mx.sum(m(x))
grad_fn = nn.value_and_grad(model, loss_fn)
loss, grads = grad_fn(model, mx.array([tokens]))  # ❌ CustomKernel error

Recommendation

The Qwen3-Next implementation needs gradient (VJP) implementations for all custom kernels used in the linear attention mechanism. This likely requires:

  1. Implementing VJP functions for gated_delta_update and related custom ops
  2. Registering these VJP implementations in the MLX framework
  3. Testing gradient computation across all Qwen3-Next components

Files Involved

  • mlx_lm/models/qwen3_next.py (lines 254-261 for Bug 1)
  • Custom kernel implementations (location unknown, likely in MLX core)
  • Gradient registration system

Priority: CRITICAL - This blocks all training workflows for Qwen3-Next models

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions