diff --git a/README.md b/README.md index ce71596b3..3136555ee 100644 --- a/README.md +++ b/README.md @@ -89,6 +89,9 @@ to see how to use the API in more detail. Check out the [batch generation example](https://github.com/ml-explore/mlx-lm/tree/main/mlx_lm/examples/batch_generate_response.py) to see how to efficiently generate continuations for a batch of prompts. +Check out the [RAG example](https://github.com/ml-explore/mlx-lm/tree/main/mlx_lm/examples/rag.py) +to see how to use retrieval-augmented generation to ground LLM responses in external documents. + The `mlx-lm` package also comes with functionality to quantize and optionally upload models to the Hugging Face Hub. diff --git a/mlx_lm/examples/rag.py b/mlx_lm/examples/rag.py new file mode 100644 index 000000000..28553214e --- /dev/null +++ b/mlx_lm/examples/rag.py @@ -0,0 +1,93 @@ +# Copyright © 2025 Apple Inc. + +import numpy as np +from mlx_lm import load, generate +import mlx.core as mx + + +def retrieve(question, documents, model, tokenizer): + # Embed the question + question_embedding = get_embedding(question, model, tokenizer) + + # Embed all documents + doc_embeddings = [get_embedding(doc, model, tokenizer) for doc in documents] + + # Compute cosine similarity between question and each document + similarities = [ + cosine_similarity(question_embedding, doc_emb) for doc_emb in doc_embeddings + ] + + # Return the document with the highest similarity score + best_idx = int(np.argmax(similarities)) + return documents[best_idx] + + +def get_embedding(text, model, tokenizer): + # Tokenize the text + tokens = tokenizer.encode(text) + token_array = mx.array([tokens]) + + # Run a full forward pass and get the last hidden state + # by extracting from the model's transformer layers + hidden = model.model.embed_tokens(token_array) + for layer in model.model.layers: + hidden = layer(hidden, mask=None, cache=None) + + # Force computation + mx.eval(hidden) + + # Mean pool across token dimension to get sentence embedding + embedding = np.array(hidden[0].tolist()).mean(axis=0) + return embedding + + +def cosine_similarity(a, b): + # Compute cosine similarity between two vectors + return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)) + + +if __name__ == "__main__": + # Specify the checkpoint + checkpoint = "mlx-community/Llama-3.2-3B-Instruct-4bit" + + # Load the model and tokenizer + model, tokenizer = load(path_or_hf_repo=checkpoint) + + # A list of documents the model does not have direct access to + documents = [ + "MLX is an array framework for machine learning on Apple silicon, developed by Apple.", + "The Eiffel Tower is located in Paris, France, and was completed in 1889.", + "Photosynthesis is the process by which plants use sunlight, water, and CO2 to produce energy.", + "The Python programming language was created by Guido van Rossum and first released in 1991.", + "Apple silicon chips use a unified memory architecture where CPU and GPU share the same memory pool.", + ] + + # The user question + question = "What is MLX and who made it?" + + # Retrieve the most relevant document using embedding similarity + retrieved_doc = retrieve(question, documents, model, tokenizer) + + # Build the prompt with the retrieved context injected + prompt = f"""Use the following context to answer the question. + +Context: {retrieved_doc} + +Question: {question} +Answer:""" + + # Format using the chat template + conversation = [{"role": "user", "content": prompt}] + formatted_prompt = tokenizer.apply_chat_template( + conversation=conversation, + add_generation_prompt=True, + ) + + # Generate the answer + response = generate( + model=model, + tokenizer=tokenizer, + prompt=formatted_prompt, + max_tokens=300, + verbose=True, + ) diff --git a/tests/test_rag.py b/tests/test_rag.py new file mode 100644 index 000000000..aa1781bbc --- /dev/null +++ b/tests/test_rag.py @@ -0,0 +1,79 @@ +# Copyright © 2025 Apple Inc. + +import unittest +from unittest.mock import MagicMock, patch + +import numpy as np + +from mlx_lm.examples.rag import cosine_similarity, retrieve + + +class TestCosineSimilarity(unittest.TestCase): + + def test_identical_vectors(self): + # Identical vectors should have similarity of 1.0 + a = np.array([1.0, 2.0, 3.0]) + self.assertAlmostEqual(cosine_similarity(a, a), 1.0, places=5) + + def test_orthogonal_vectors(self): + # Orthogonal vectors should have similarity of 0.0 + a = np.array([1.0, 0.0]) + b = np.array([0.0, 1.0]) + self.assertAlmostEqual(cosine_similarity(a, b), 0.0, places=5) + + def test_opposite_vectors(self): + # Opposite vectors should have similarity of -1.0 + a = np.array([1.0, 0.0]) + b = np.array([-1.0, 0.0]) + self.assertAlmostEqual(cosine_similarity(a, b), -1.0, places=5) + + +class TestRetrieve(unittest.TestCase): + + def setUp(self): + self.documents = [ + "MLX is an array framework for machine learning on Apple silicon, developed by Apple.", + "The Eiffel Tower is located in Paris, France, and was completed in 1889.", + "Photosynthesis is the process by which plants use sunlight, water, and CO2 to produce energy.", + ] + + def test_retrieves_most_similar_document(self): + # Mock get_embedding to return controlled vectors + # Question embedding is closest to document 0 + embeddings = { + "what is mlx?": np.array([1.0, 0.0, 0.0]), + self.documents[0]: np.array([0.9, 0.1, 0.0]), # most similar + self.documents[1]: np.array([0.0, 1.0, 0.0]), + self.documents[2]: np.array([0.0, 0.0, 1.0]), + } + + mock_model = MagicMock() + mock_tokenizer = MagicMock() + + with patch( + "mlx_lm.examples.rag.get_embedding", + side_effect=lambda text, m, t: embeddings[text], + ): + result = retrieve( + "what is mlx?", self.documents, mock_model, mock_tokenizer + ) + + self.assertEqual(result, self.documents[0]) + + def test_single_document_always_returned(self): + # With a single document, it should always be returned + single_doc = ["Only document."] + mock_model = MagicMock() + mock_tokenizer = MagicMock() + + with patch( + "mlx_lm.examples.rag.get_embedding", + return_value=np.array([1.0, 0.0]), + ): + result = retrieve("any question", single_doc, mock_model, mock_tokenizer) + + self.assertEqual(result, single_doc[0]) + + +if __name__ == "__main__": + unittest.main()