Skip to content

Commit 4d0ea3e

Browse files
bug/kl-divergence (#9)
* fix: improve logprob extraction and error handling in mlx backend, correct token ID parsing in runner * test: add KL divergence test for rank index alignment in llama.cpp distributions * test: add tests for MLXBackend generate fallback and double failure handling
1 parent fbc417c commit 4d0ea3e

4 files changed

Lines changed: 122 additions & 13 deletions

File tree

src/infer_check/backends/mlx_lm.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,9 @@ async def generate(self, prompt: Prompt) -> InferenceResult:
4747
try:
4848
return self._generate_with_logprobs(prompt)
4949
except Exception as exc:
50-
import logging
50+
from rich.console import Console
5151

52-
logging.debug("generate_step failed (%s), falling back to simple generate", exc)
52+
Console().print(f"[yellow]⚠ generate_step failed, falling back to simple generate: {exc}[/yellow]")
5353
try:
5454
return self._generate_simple(prompt)
5555
except Exception as inner:
@@ -184,7 +184,7 @@ def _generate_with_logprobs(self, prompt: Prompt) -> InferenceResult:
184184
temp = prompt.metadata.get("temperature", 0.0) if prompt.metadata else 0.0
185185
sampler = make_sampler(temp=temp)
186186
formatted = self._format_prompt(prompt.text)
187-
input_ids = self._tokenizer.encode(formatted, return_tensors="mlx")
187+
input_ids = mx.array(self._tokenizer.encode(formatted))
188188

189189
# Configurable top-K to avoid memory explosion. Default to 10.
190190
top_k = prompt.metadata.get("top_k_logprobs", 10) if prompt.metadata else 10
@@ -225,13 +225,7 @@ def _generate_with_logprobs(self, prompt: Prompt) -> InferenceResult:
225225
if effective_top_k > vocab_size:
226226
effective_top_k = vocab_size
227227

228-
# Get top-K indices and values
229-
if hasattr(mx, "topk"):
230-
# mx.topk is the most efficient way to get top-K if available.
231-
top_k_values, top_k_indices = mx.topk(logprob_dist, effective_top_k)
232-
dist_list = cast(list[float], top_k_values.tolist())
233-
dist_indices = cast(list[int], top_k_indices.tolist())
234-
elif hasattr(mx, "argpartition"):
228+
if hasattr(mx, "argpartition"):
235229
# Fallback to argpartition which is often available in newer MLX.
236230
top_k_indices = mx.argpartition(-logprob_dist, effective_top_k - 1)[:effective_top_k]
237231
top_k_values = logprob_dist[top_k_indices]
@@ -264,7 +258,7 @@ def _generate_with_logprobs(self, prompt: Prompt) -> InferenceResult:
264258
meta[f"id_{i}"] = int(idx)
265259
distribution_metadata.append(meta)
266260

267-
token_id = int(token.item())
261+
token_id = int(token)
268262
token_str = self._tokenizer.decode([token_id])
269263
tokens.append(token_str)
270264

src/infer_check/runner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,8 +124,8 @@ def to_probs(dist: Any) -> Any:
124124
# llama-server: id_0, id_1, ... (top-K)
125125
elif b_meta and t_meta and "id_0" in b_meta and "id_0" in t_meta:
126126
# Align on union of token IDs (can be int IDs or token strings)
127-
b_ids = {v: i for k, v in b_meta.items() if k.startswith("id_")}
128-
t_ids = {v: i for k, v in t_meta.items() if k.startswith("id_")}
127+
b_ids = {v: int(k.split("_")[1]) for k, v in b_meta.items() if k.startswith("id_")}
128+
t_ids = {v: int(k.split("_")[1]) for k, v in t_meta.items() if k.startswith("id_")}
129129

130130
# Skip if ID types are different (e.g. int vs str)
131131
b_id_types = {type(v) for v in b_ids}

tests/unit/test_kl_alignment.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,55 @@ def test_kl_alignment_llama_cpp() -> None:
7272
assert result.kl_divergence > 0
7373

7474

75+
def test_kl_alignment_llama_cpp_rank_index() -> None:
76+
"""Test KL computation for top-K aligned distributions using rank index from key name (id_N).
77+
78+
This test asserts that it's the rank index 'N' in 'id_N' that matters,
79+
not the iteration order or the index in metadata dictionary.
80+
"""
81+
import numpy as np
82+
83+
runner = TestRunner()
84+
85+
# Case: id_0 and id_1 are swapped in the metadata dictionary,
86+
# but the distributions are [prob_of_id_0, prob_of_id_1].
87+
# If the logic incorrectly uses the order in which keys are processed,
88+
# it might swap the probabilities.
89+
90+
baseline = InferenceResult(
91+
prompt_id="p1",
92+
backend_name="b1",
93+
model_id="m1",
94+
text="hi",
95+
tokens=["h"],
96+
distributions=[[0.8, 0.2]],
97+
# Swapped order in dict
98+
distribution_metadata=[{"id_1": 11, "id_0": 10}],
99+
latency_ms=10.0,
100+
)
101+
102+
test = InferenceResult(
103+
prompt_id="p1",
104+
backend_name="b2",
105+
model_id="m1",
106+
text="hi",
107+
tokens=["h"],
108+
distributions=[[0.7, 0.3]],
109+
# Normal order
110+
distribution_metadata=[{"id_0": 10, "id_1": 11}],
111+
latency_ms=10.0,
112+
)
113+
114+
result = runner._compare(baseline, test)
115+
assert result.kl_divergence is not None
116+
117+
expected_p = np.array([0.8, 0.2])
118+
expected_q = np.array([0.7, 0.3])
119+
expected_kl = np.sum(expected_p * np.log(expected_p / expected_q))
120+
121+
assert np.isclose(result.kl_divergence, expected_kl, atol=1e-5)
122+
123+
75124
def test_kl_skips_unaligned() -> None:
76125
"""Ensure KL is None if distributions cannot be aligned."""
77126
runner = TestRunner()

tests/unit/test_mlx_backend.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,3 +67,69 @@ def test_mlx_cleanup(mock_mlx: tuple[MagicMock, MagicMock, MagicMock]) -> None:
6767
asyncio.run(backend.cleanup())
6868
assert backend._model is None
6969
assert backend._tokenizer is None
70+
71+
72+
@pytest.mark.asyncio
73+
async def test_mlx_generate_fallback(mock_mlx: tuple[MagicMock, MagicMock, MagicMock]) -> None:
74+
from unittest.mock import patch
75+
76+
from infer_check.types import InferenceResult, Prompt
77+
78+
backend = MLXBackend(model_id="dummy-model")
79+
backend._model = mock_mlx[1]
80+
backend._tokenizer = mock_mlx[2]
81+
82+
prompt = Prompt(text="test prompt")
83+
simple_result = InferenceResult(
84+
prompt_id=prompt.id,
85+
backend_name="mlx-lm",
86+
model_id="dummy-model",
87+
tokens=["hello"],
88+
text="hello",
89+
latency_ms=10.0,
90+
)
91+
92+
with (
93+
patch.object(MLXBackend, "_generate_with_logprobs") as mock_logprobs,
94+
patch.object(MLXBackend, "_generate_simple") as mock_simple,
95+
patch("rich.console.Console.print") as mock_print,
96+
):
97+
mock_logprobs.side_effect = Exception("Logprobs failed")
98+
mock_simple.return_value = simple_result
99+
100+
result = await backend.generate(prompt)
101+
102+
assert result == simple_result
103+
mock_logprobs.assert_called_once_with(prompt)
104+
mock_simple.assert_called_once_with(prompt)
105+
mock_print.assert_called_once()
106+
args, _ = mock_print.call_args
107+
assert "generate_step failed, falling back to simple generate" in args[0]
108+
assert "Logprobs failed" in args[0]
109+
110+
111+
@pytest.mark.asyncio
112+
async def test_mlx_generate_double_failure(mock_mlx: tuple[MagicMock, MagicMock, MagicMock]) -> None:
113+
from unittest.mock import patch
114+
115+
from infer_check.types import Prompt
116+
117+
backend = MLXBackend(model_id="dummy-model")
118+
backend._model = mock_mlx[1]
119+
backend._tokenizer = mock_mlx[2]
120+
121+
prompt = Prompt(text="test prompt")
122+
123+
with (
124+
patch.object(MLXBackend, "_generate_with_logprobs") as mock_logprobs,
125+
patch.object(MLXBackend, "_generate_simple") as mock_simple,
126+
patch("rich.console.Console.print"),
127+
):
128+
mock_logprobs.side_effect = Exception("Logprobs failed")
129+
mock_simple.side_effect = Exception("Simple failed")
130+
131+
with pytest.raises(RuntimeError) as exc_info:
132+
await backend.generate(prompt)
133+
134+
assert "MLX generation failed" in str(exc_info.value)
135+
assert "Simple failed" in str(exc_info.value)

0 commit comments

Comments
 (0)