Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ to see how to use the API in more detail. Check out the [batch generation
example](https://github.com/ml-explore/mlx-lm/tree/main/mlx_lm/examples/batch_generate_response.py)
to see how to efficiently generate continuations for a batch of prompts.

Check out the [RAG example](https://github.com/ml-explore/mlx-lm/tree/main/mlx_lm/examples/rag.py)
to see how to use retrieval-augmented generation to ground LLM responses in external documents.

The `mlx-lm` package also comes with functionality to quantize and optionally
upload models to the Hugging Face Hub.

Expand Down
93 changes: 93 additions & 0 deletions mlx_lm/examples/rag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Copyright © 2025 Apple Inc.

import numpy as np
from mlx_lm import load, generate
import mlx.core as mx


def retrieve(question, documents, model, tokenizer):
# Embed the question
question_embedding = get_embedding(question, model, tokenizer)

# Embed all documents
doc_embeddings = [get_embedding(doc, model, tokenizer) for doc in documents]

# Compute cosine similarity between question and each document
similarities = [
cosine_similarity(question_embedding, doc_emb) for doc_emb in doc_embeddings
]

# Return the document with the highest similarity score
best_idx = int(np.argmax(similarities))
return documents[best_idx]


def get_embedding(text, model, tokenizer):
# Tokenize the text
tokens = tokenizer.encode(text)
token_array = mx.array([tokens])

# Run a full forward pass and get the last hidden state
# by extracting from the model's transformer layers
hidden = model.model.embed_tokens(token_array)
for layer in model.model.layers:
hidden = layer(hidden, mask=None, cache=None)

# Force computation
mx.eval(hidden)

# Mean pool across token dimension to get sentence embedding
embedding = np.array(hidden[0].tolist()).mean(axis=0)
return embedding


def cosine_similarity(a, b):
# Compute cosine similarity between two vectors
return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))


if __name__ == "__main__":
# Specify the checkpoint
checkpoint = "mlx-community/Llama-3.2-3B-Instruct-4bit"

# Load the model and tokenizer
model, tokenizer = load(path_or_hf_repo=checkpoint)

# A list of documents the model does not have direct access to
documents = [
"MLX is an array framework for machine learning on Apple silicon, developed by Apple.",
"The Eiffel Tower is located in Paris, France, and was completed in 1889.",
"Photosynthesis is the process by which plants use sunlight, water, and CO2 to produce energy.",
"The Python programming language was created by Guido van Rossum and first released in 1991.",
"Apple silicon chips use a unified memory architecture where CPU and GPU share the same memory pool.",
]

# The user question
question = "What is MLX and who made it?"

# Retrieve the most relevant document using embedding similarity
retrieved_doc = retrieve(question, documents, model, tokenizer)

# Build the prompt with the retrieved context injected
prompt = f"""Use the following context to answer the question.

Context: {retrieved_doc}

Question: {question}
Answer:"""

# Format using the chat template
conversation = [{"role": "user", "content": prompt}]
formatted_prompt = tokenizer.apply_chat_template(
conversation=conversation,
add_generation_prompt=True,
)

# Generate the answer
response = generate(
model=model,
tokenizer=tokenizer,
prompt=formatted_prompt,
max_tokens=300,
verbose=True,
)
79 changes: 79 additions & 0 deletions tests/test_rag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright © 2025 Apple Inc.

import unittest
from unittest.mock import MagicMock, patch

import numpy as np

from mlx_lm.examples.rag import cosine_similarity, retrieve


class TestCosineSimilarity(unittest.TestCase):

def test_identical_vectors(self):
# Identical vectors should have similarity of 1.0
a = np.array([1.0, 2.0, 3.0])
self.assertAlmostEqual(cosine_similarity(a, a), 1.0, places=5)

def test_orthogonal_vectors(self):
# Orthogonal vectors should have similarity of 0.0
a = np.array([1.0, 0.0])
b = np.array([0.0, 1.0])
self.assertAlmostEqual(cosine_similarity(a, b), 0.0, places=5)

def test_opposite_vectors(self):
# Opposite vectors should have similarity of -1.0
a = np.array([1.0, 0.0])
b = np.array([-1.0, 0.0])
self.assertAlmostEqual(cosine_similarity(a, b), -1.0, places=5)


class TestRetrieve(unittest.TestCase):

def setUp(self):
self.documents = [
"MLX is an array framework for machine learning on Apple silicon, developed by Apple.",
"The Eiffel Tower is located in Paris, France, and was completed in 1889.",
"Photosynthesis is the process by which plants use sunlight, water, and CO2 to produce energy.",
]

def test_retrieves_most_similar_document(self):
# Mock get_embedding to return controlled vectors
# Question embedding is closest to document 0
embeddings = {
"what is mlx?": np.array([1.0, 0.0, 0.0]),
self.documents[0]: np.array([0.9, 0.1, 0.0]), # most similar
self.documents[1]: np.array([0.0, 1.0, 0.0]),
self.documents[2]: np.array([0.0, 0.0, 1.0]),
}

mock_model = MagicMock()
mock_tokenizer = MagicMock()

with patch(
"mlx_lm.examples.rag.get_embedding",
side_effect=lambda text, m, t: embeddings[text],
):
result = retrieve(
"what is mlx?", self.documents, mock_model, mock_tokenizer
)

self.assertEqual(result, self.documents[0])

def test_single_document_always_returned(self):
# With a single document, it should always be returned
single_doc = ["Only document."]
mock_model = MagicMock()
mock_tokenizer = MagicMock()

with patch(
"mlx_lm.examples.rag.get_embedding",
return_value=np.array([1.0, 0.0]),
):
result = retrieve("any question", single_doc, mock_model, mock_tokenizer)

self.assertEqual(result, single_doc[0])


if __name__ == "__main__":
unittest.main()