diff --git a/README.md b/README.md index b5dc7f4e0d..ebe5c74942 100644 --- a/README.md +++ b/README.md @@ -39,37 +39,65 @@ pip install mlx-embeddings Qwen3-VL uses a model-specific processor and a high-level `model.process(...)` API for multimodal embedding and reranking. -#### Multimodal Embedding +#### Multimodal Retrieval + +Text-to-image retrieval over a small gallery — embed images once, then score any number of text queries against them. Full notebook (with heatmap + top-K plot): [`examples/qwen3_vl_retrieval.ipynb`](examples/qwen3_vl_retrieval.ipynb). ```python +from io import BytesIO + +import matplotlib.pyplot as plt import mlx.core as mx -from mlx_embeddings import load +import numpy as np +import requests +from PIL import Image -model, processor = load("Qwen/Qwen3-VL-Embedding-2B") +from mlx_embeddings import load -inputs = [ - { - "text": "A woman playing with her dog on a beach at sunset.", - "instruction": "Retrieve images or text relevant to the user's query.", - }, - { - "text": "A woman shares a joyful moment with her golden retriever on a sun-drenched beach at sunset." - }, - { - "image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg" - }, - { - "text": "A woman shares a joyful moment with her golden retriever on a sun-drenched beach at sunset.", - "image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg", - }, +GALLERY = [ + ("woman with dog on beach", + "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg"), + ("two cats on a couch", + "http://images.cocodataset.org/val2017/000000039769.jpg"), + ("tennis player on a court", + "http://images.cocodataset.org/val2017/000000000872.jpg"), + ("bear in the wild", + "http://images.cocodataset.org/val2017/000000000285.jpg"), + ("dark train tunnel", + "http://images.cocodataset.org/val2017/000000001268.jpg"), + ("group of people standing together", + "http://images.cocodataset.org/val2017/000000001000.jpg"), ] +QUERIES = [ + "a person spending time with their pet outdoors at sunset", + "sleepy cats relaxing indoors", + "someone playing a racquet sport", + "wildlife in a natural habitat", + "the inside of a transit tunnel", + "a crowd of people gathered outside", +] +INSTRUCTION = "Retrieve images that match the user's query." + +def fetch(src): + if src.startswith(("http://", "https://")): + return Image.open(BytesIO(requests.get(src, timeout=30).content)).convert("RGB") + return Image.open(src).convert("RGB") -embeddings = model.process(inputs, processor=processor) -similarity = embeddings @ embeddings.T +labels, urls = zip(*GALLERY) +images = [fetch(u) for u in urls] + +model, processor = load("Qwen/Qwen3-VL-Embedding-2B") +img_embeds = model.process([{"image": i} for i in images], processor=processor) +txt_embeds = model.process( + [{"text": q, "instruction": INSTRUCTION} for q in QUERIES], processor=processor, +) +sim = np.array((txt_embeds @ img_embeds.T).astype(mx.float32)) -mx.eval(embeddings, similarity) -print(embeddings.shape) # (4, 2048) -print(similarity) +for qi, q in enumerate(QUERIES): + top = np.argsort(-sim[qi])[:3] + print(f"q{qi}: {q}") + for k, idx in enumerate(top): + print(f" #{k + 1} {sim[qi, idx]:.3f} {labels[idx]}") ``` #### Multimodal Reranking diff --git a/examples/qwen3_vl_retrieval.ipynb b/examples/qwen3_vl_retrieval.ipynb new file mode 100644 index 0000000000..a6a48363f2 --- /dev/null +++ b/examples/qwen3_vl_retrieval.ipynb @@ -0,0 +1,900 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "f74d81bd", + "metadata": {}, + "source": [ + "# Qwen3-VL text-to-image retrieval\n", + "\n", + "Embed a small image gallery once, score a batch of text queries against it, and plot the similarity heatmap + top-K images per query." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e100671b", + "metadata": { + "execution": { + "iopub.execute_input": "2026-04-24T00:17:28.099053Z", + "iopub.status.busy": "2026-04-24T00:17:28.098933Z", + "iopub.status.idle": "2026-04-24T00:17:38.121706Z", + "shell.execute_reply": "2026-04-24T00:17:38.121159Z" + } + }, + "outputs": [], + "source": [ + "from io import BytesIO\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import mlx.core as mx\n", + "import numpy as np\n", + "import requests\n", + "from PIL import Image\n", + "\n", + "from mlx_embeddings import load\n", + "\n", + "GALLERY = [\n", + " (\"woman with dog on beach\",\n", + " \"https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg\"),\n", + " (\"two cats on a couch\",\n", + " \"http://images.cocodataset.org/val2017/000000039769.jpg\"),\n", + " (\"tennis player on a court\",\n", + " \"http://images.cocodataset.org/val2017/000000000872.jpg\"),\n", + " (\"bear in the wild\",\n", + " \"http://images.cocodataset.org/val2017/000000000285.jpg\"),\n", + " (\"dark train tunnel\",\n", + " \"http://images.cocodataset.org/val2017/000000001268.jpg\"),\n", + " (\"group of people standing together\",\n", + " \"http://images.cocodataset.org/val2017/000000001000.jpg\"),\n", + "]\n", + "QUERIES = [\n", + " \"a person spending time with their pet outdoors at sunset\",\n", + " \"sleepy cats relaxing indoors\",\n", + " \"someone playing a racquet sport\",\n", + " \"wildlife in a natural habitat\",\n", + " \"the inside of a transit tunnel\",\n", + " \"a crowd of people gathered outside\",\n", + "]\n", + "INSTRUCTION = \"Retrieve images that match the user's query.\"\n", + "TOP_K = 3\n", + "\n", + "def fetch(src):\n", + " if src.startswith((\"http://\", \"https://\")):\n", + " return Image.open(BytesIO(requests.get(src, timeout=30).content)).convert(\"RGB\")\n", + " return Image.open(src).convert(\"RGB\")\n", + "\n", + "labels, urls = zip(*GALLERY)\n", + "images = [fetch(u) for u in urls]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "abac81a9", + "metadata": { + "execution": { + "iopub.execute_input": "2026-04-24T00:17:38.123717Z", + "iopub.status.busy": "2026-04-24T00:17:38.123505Z", + "iopub.status.idle": "2026-04-24T00:17:43.709091Z", + "shell.execute_reply": "2026-04-24T00:17:43.708508Z" + } + }, + "outputs": [], + "source": [ + "model, processor = load(\"Qwen/Qwen3-VL-Embedding-2B\")\n", + "\n", + "img_embeds = model.process([{\"image\": i} for i in images], processor=processor)\n", + "txt_embeds = model.process(\n", + " [{\"text\": q, \"instruction\": INSTRUCTION} for q in QUERIES], processor=processor,\n", + ")\n", + "sim = np.array((txt_embeds @ img_embeds.T).astype(mx.float32))\n", + "\n", + "for qi, q in enumerate(QUERIES):\n", + " top = np.argsort(-sim[qi])[:TOP_K]\n", + " print(f\"q{qi}: {q}\")\n", + " for k, idx in enumerate(top):\n", + " print(f\" #{k + 1} {sim[qi, idx]:.3f} {labels[idx]}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "994bbe40", + "metadata": { + "execution": { + "iopub.execute_input": "2026-04-24T00:17:43.710772Z", + "iopub.status.busy": "2026-04-24T00:17:43.710634Z", + "iopub.status.idle": "2026-04-24T00:17:44.233571Z", + "shell.execute_reply": "2026-04-24T00:17:44.233107Z" + } + }, + "outputs": [], + "source": [ + "fig, (ax_hm, ax_tk) = plt.subplots(\n", + " 2, 1, figsize=(2 + TOP_K * 3, len(QUERIES) * 2.2),\n", + " gridspec_kw={\"height_ratios\": [1, 2]},\n", + ")\n", + "\n", + "# Heatmap\n", + "im = ax_hm.imshow(sim, cmap=\"viridis\", aspect=\"auto\")\n", + "ax_hm.set_xticks(range(len(GALLERY)), labels, rotation=30, ha=\"right\", fontsize=8)\n", + "ax_hm.set_yticks(range(len(QUERIES)), [f\"q{i}\" for i in range(len(QUERIES))], fontsize=8)\n", + "ax_hm.set_title(\"text \\u2194 image similarity\")\n", + "thresh = sim.mean()\n", + "for i in range(sim.shape[0]):\n", + " for j in range(sim.shape[1]):\n", + " ax_hm.text(j, i, f\"{sim[i, j]:.2f}\", ha=\"center\", va=\"center\", fontsize=7,\n", + " color=\"white\" if sim[i, j] < thresh else \"black\")\n", + "fig.colorbar(im, ax=ax_hm, shrink=0.8)\n", + "\n", + "# Top-K images per query\n", + "ax_tk.axis(\"off\")\n", + "gs = ax_tk.get_subplotspec().subgridspec(len(QUERIES), TOP_K + 1, wspace=0.1, hspace=0.4)\n", + "for qi, q in enumerate(QUERIES):\n", + " lbl_ax = fig.add_subplot(gs[qi, 0])\n", + " lbl_ax.axis(\"off\")\n", + " lbl_ax.text(1.0, 0.5, f\"q{qi}: {q}\", ha=\"right\", va=\"center\", fontsize=9, wrap=True)\n", + " for k, idx in enumerate(np.argsort(-sim[qi])[:TOP_K]):\n", + " ax = fig.add_subplot(gs[qi, k + 1])\n", + " ax.imshow(images[idx])\n", + " ax.set_xticks([]); ax.set_yticks([])\n", + " ax.set_title(f\"#{k + 1} {sim[qi, idx]:.2f}\\n{labels[idx]}\", fontsize=8,\n", + " color=\"tab:green\" if k == 0 else \"black\")\n", + "\n", + "fig.suptitle(\"Qwen3-VL retrieval\", fontsize=13)\n", + "fig.tight_layout(rect=[0, 0, 1, 0.98])\n", + "plt.show()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.9" + }, + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "state": { + "028d083ebb344a91a42c6d8c71a95c7a": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "2.0.0", + "_view_name": "HTMLView", + "description": "", + "description_allow_html": false, + "layout": "IPY_MODEL_ee09293dd1dc4974953d53b665ecb74f", + "placeholder": "​", + "style": "IPY_MODEL_b2c931ae21ce45c9bac82472ff76df66", + "tabbable": null, + "tooltip": null, + "value": "Download complete: " + } + }, + "0b6fecad1e3c4036b506cf9404ef8832": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "2.0.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_aaabbb1278604cb1b443bd13bde18427", + "IPY_MODEL_9ad0ebfb361d4cf28a6e655ba77f421f", + "IPY_MODEL_a12c20acaafd4e8ebc75fd7f4e27710c" + ], + "layout": "IPY_MODEL_d3d85583a6e54c6da625fc5bd5b29b50", + "tabbable": null, + "tooltip": null + } + }, + "0e960da3862c44bc933cfb3c31f51802": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "2.0.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "2.0.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border_bottom": null, + "border_left": null, + "border_right": null, + "border_top": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "1e5b89b2d4da47a9a8e3997b6a456691": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "1ec205bf5fbe4dd9b245e16587ead219": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "2.0.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "2.0.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border_bottom": null, + "border_left": null, + "border_right": null, + "border_top": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": "20px" + } + }, + "2877188fa7454857952fbed99a16dce4": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "2.0.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "2.0.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border_bottom": null, + "border_left": null, + "border_right": null, + "border_top": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "39de3cb3ecad43efa46b7fbb0320121f": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "HTMLStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "HTMLStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "StyleView", + "background": null, + "description_width": "", + "font_size": null, + "text_color": null + } + }, + "3d6ccb46d06745b9bd3065789ace3d39": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "HTMLStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "HTMLStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "StyleView", + "background": null, + "description_width": "", + "font_size": null, + "text_color": null + } + }, + "80e0835dda36478cac41fb574cbec5e1": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "HTMLStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "HTMLStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "StyleView", + "background": null, + "description_width": "", + "font_size": null, + "text_color": null + } + }, + "8a16c1bc65a54fd98f991a34c9ce295e": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "2.0.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "2.0.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border_bottom": null, + "border_left": null, + "border_right": null, + "border_top": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "9ad0ebfb361d4cf28a6e655ba77f421f": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "2.0.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_allow_html": false, + "layout": "IPY_MODEL_0e960da3862c44bc933cfb3c31f51802", + "max": 15, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_1e5b89b2d4da47a9a8e3997b6a456691", + "tabbable": null, + "tooltip": null, + "value": 15 + } + }, + "a12c20acaafd4e8ebc75fd7f4e27710c": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "2.0.0", + "_view_name": "HTMLView", + "description": "", + "description_allow_html": false, + "layout": "IPY_MODEL_a63f979356f64a3583286cef7c3302d3", + "placeholder": "​", + "style": "IPY_MODEL_39de3cb3ecad43efa46b7fbb0320121f", + "tabbable": null, + "tooltip": null, + "value": " 15/15 [00:00<00:00, 5117.92it/s]" + } + }, + "a63f979356f64a3583286cef7c3302d3": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "2.0.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "2.0.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border_bottom": null, + "border_left": null, + "border_right": null, + "border_top": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "aaabbb1278604cb1b443bd13bde18427": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "2.0.0", + "_view_name": "HTMLView", + "description": "", + "description_allow_html": false, + "layout": "IPY_MODEL_d62f330b6e2a48c69c1e3cf820eefe2d", + "placeholder": "​", + "style": "IPY_MODEL_3d6ccb46d06745b9bd3065789ace3d39", + "tabbable": null, + "tooltip": null, + "value": "Fetching 15 files: 100%" + } + }, + "b0d1d50b1f6d4803bce132f79e73bb70": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "2.0.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_028d083ebb344a91a42c6d8c71a95c7a", + "IPY_MODEL_d90d4073b3d74b75b166a467be2617e2", + "IPY_MODEL_e4a7a60b5eec4bfc9a5be15a303a65e7" + ], + "layout": "IPY_MODEL_2877188fa7454857952fbed99a16dce4", + "tabbable": null, + "tooltip": null + } + }, + "b2c931ae21ce45c9bac82472ff76df66": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "HTMLStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "HTMLStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "StyleView", + "background": null, + "description_width": "", + "font_size": null, + "text_color": null + } + }, + "c02139b7c1904f03acd60373d717c4d4": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "d3d85583a6e54c6da625fc5bd5b29b50": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "2.0.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "2.0.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border_bottom": null, + "border_left": null, + "border_right": null, + "border_top": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "d62f330b6e2a48c69c1e3cf820eefe2d": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "2.0.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "2.0.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border_bottom": null, + "border_left": null, + "border_right": null, + "border_top": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "d90d4073b3d74b75b166a467be2617e2": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "2.0.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_allow_html": false, + "layout": "IPY_MODEL_1ec205bf5fbe4dd9b245e16587ead219", + "max": 1, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_c02139b7c1904f03acd60373d717c4d4", + "tabbable": null, + "tooltip": null, + "value": 0 + } + }, + "e4a7a60b5eec4bfc9a5be15a303a65e7": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "2.0.0", + "_view_name": "HTMLView", + "description": "", + "description_allow_html": false, + "layout": "IPY_MODEL_8a16c1bc65a54fd98f991a34c9ce295e", + "placeholder": "​", + "style": "IPY_MODEL_80e0835dda36478cac41fb574cbec5e1", + "tabbable": null, + "tooltip": null, + "value": " 0.00/0.00 [00:00<?, ?B/s]" + } + }, + "ee09293dd1dc4974953d53b665ecb74f": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "2.0.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "2.0.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border_bottom": null, + "border_left": null, + "border_right": null, + "border_top": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + } + }, + "version_major": 2, + "version_minor": 0 + } + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/mlx_embeddings/models/qwen3_vl/processor.py b/mlx_embeddings/models/qwen3_vl/processor.py index b595f13fe5..090f8b905c 100644 --- a/mlx_embeddings/models/qwen3_vl/processor.py +++ b/mlx_embeddings/models/qwen3_vl/processor.py @@ -1,10 +1,16 @@ +import math from contextlib import contextmanager -from pathlib import Path -from typing import Any, Dict, Iterable, List, Optional +from typing import Any, Dict, List, Optional, Tuple, Union import mlx.core as mx -from transformers import AutoImageProcessor, AutoTokenizer -from transformers.models.qwen3_vl.processing_qwen3_vl import Qwen3VLProcessor +import numpy as np +from mlx_vlm.models.base import load_chat_template, to_mlx +from transformers.feature_extraction_utils import BatchFeature +from transformers.image_processing_utils import ImageProcessingMixin +from transformers.image_utils import ImageInput +from transformers.processing_utils import ProcessorMixin +from transformers.tokenization_utils_base import PreTokenizedInput, TextInput +from transformers.video_processing_utils import BaseVideoProcessor DEFAULT_EMBEDDING_INSTRUCTION = "Represent the user's input." DEFAULT_RERANKING_INSTRUCTION = ( @@ -21,15 +27,647 @@ MAX_PIXELS = 1800 * IMAGE_FACTOR * IMAGE_FACTOR -class _UnsupportedVideoProcessor: - def __init__(self, merge_size: int = 2): +def _smart_resize_video( + num_frames: int, + height: int, + width: int, + temporal_factor: int = 2, + factor: int = 32, + min_pixels: int = 128 * 128, + max_pixels: int = 16 * 16 * 2 * 2 * 2 * 6144, +) -> Tuple[int, int]: + if height < factor or width < factor: + raise ValueError( + f"height:{height} or width:{width} must be larger than factor:{factor}" + ) + if max(height, width) / min(height, width) > 200: + raise ValueError( + f"absolute aspect ratio must be smaller than 200, got " + f"{max(height, width) / min(height, width)}" + ) + h_bar = round(height / factor) * factor + w_bar = round(width / factor) * factor + t_bar = math.ceil(num_frames / temporal_factor) * temporal_factor + + if t_bar * h_bar * w_bar > max_pixels: + beta = math.sqrt((num_frames * height * width) / max_pixels) + h_bar = max(factor, math.floor(height / beta / factor) * factor) + w_bar = max(factor, math.floor(width / beta / factor) * factor) + elif t_bar * h_bar * w_bar < min_pixels: + beta = math.sqrt(min_pixels / (num_frames * height * width)) + h_bar = math.ceil(height * beta / factor) * factor + w_bar = math.ceil(width * beta / factor) * factor + return h_bar, w_bar + + +def _resize_video_frames(video: np.ndarray, target_h: int, target_w: int) -> np.ndarray: + from PIL import Image + + T, C, H, W = video.shape + if target_h == H and target_w == W: + return video + out = np.empty((T, C, target_h, target_w), dtype=video.dtype) + for i, frame in enumerate(video): + arr = np.transpose(frame, (1, 2, 0)) + if arr.dtype in (np.float32, np.float64): + arr = (arr * 255).clip(0, 255).astype(np.uint8) + pil = Image.fromarray(arr) + pil = pil.resize((target_w, target_h), resample=Image.BICUBIC) + out[i] = np.transpose(np.array(pil), (2, 0, 1)) + return out + + +def _smart_resize_image( + height: int, + width: int, + factor: int = 32, + min_pixels: int = 56 * 56, + max_pixels: int = 14 * 14 * 4 * 1280, +) -> Tuple[int, int]: + if max(height, width) / min(height, width) > 200: + raise ValueError( + f"absolute aspect ratio must be smaller than 200, got " + f"{max(height, width) / min(height, width)}" + ) + h_bar = round(height / factor) * factor + w_bar = round(width / factor) * factor + if h_bar * w_bar > max_pixels: + beta = math.sqrt((height * width) / max_pixels) + h_bar = max(factor, math.floor(height / beta / factor) * factor) + w_bar = max(factor, math.floor(width / beta / factor) * factor) + elif h_bar * w_bar < min_pixels: + beta = math.sqrt(min_pixels / (height * width)) + h_bar = math.ceil(height * beta / factor) * factor + w_bar = math.ceil(width * beta / factor) * factor + return h_bar, w_bar + + +def _to_numpy_image(img) -> np.ndarray: + from io import BytesIO + + from PIL import Image + + if isinstance(img, str): + if img.startswith(("http://", "https://")): + import requests + + img = Image.open(BytesIO(requests.get(img, timeout=30).content)) + else: + img = Image.open(img) + if hasattr(img, "convert"): + img = img.convert("RGB") + arr = np.array(img) + elif isinstance(img, np.ndarray): + arr = img + else: + arr = np.asarray(img) + if arr.ndim == 2: + arr = np.stack([arr] * 3, axis=-1) + if arr.shape[-1] in (1, 3, 4) and arr.ndim == 3: + arr = np.transpose(arr, (2, 0, 1)) + if arr.shape[0] == 4: + arr = arr[:3] + return arr + + +class Qwen3VLImageProcessor(ImageProcessingMixin): + model_input_names = ["pixel_values", "image_grid_thw"] + + def __init__( + self, + patch_size: int = 16, + temporal_patch_size: int = 2, + merge_size: int = 2, + min_pixels: int = 56 * 56, + max_pixels: int = 14 * 14 * 4 * 1280, + do_rescale: bool = True, + rescale_factor: float = 1 / 255.0, + do_normalize: bool = True, + image_mean: Optional[List[float]] = None, + image_std: Optional[List[float]] = None, + do_convert_rgb: bool = True, + **kwargs, + ): + self.patch_size = patch_size + self.temporal_patch_size = temporal_patch_size self.merge_size = merge_size - self.temporal_patch_size = 2 + self.min_pixels = min_pixels + self.max_pixels = max_pixels + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.image_mean = image_mean or [0.5, 0.5, 0.5] + self.image_std = image_std or [0.5, 0.5, 0.5] + self.do_convert_rgb = do_convert_rgb + + def fetch_images(self, images): + if not isinstance(images, list): + images = [images] + return [_to_numpy_image(img) for img in images] + + def _process_one(self, image: np.ndarray) -> Tuple[np.ndarray, List[int]]: + C, H, W = image.shape + resized_h, resized_w = _smart_resize_image( + H, + W, + factor=self.patch_size * self.merge_size, + min_pixels=self.min_pixels, + max_pixels=self.max_pixels, + ) + frame = _resize_video_frames(image[None, ...], resized_h, resized_w)[0] + + img = frame.astype(np.float32) + if self.do_rescale and image.dtype == np.uint8: + img = img * self.rescale_factor + if self.do_normalize: + mean = np.array(self.image_mean, dtype=np.float32)[:, None, None] + std = np.array(self.image_std, dtype=np.float32)[:, None, None] + img = (img - mean) / std + + patches = np.repeat(img[None, None, ...], self.temporal_patch_size, axis=1) + + ps = self.patch_size + tps = self.temporal_patch_size + ms = self.merge_size + grid_t = 1 + grid_h = resized_h // ps + grid_w = resized_w // ps + + patches = patches.reshape( + 1, + grid_t, + tps, + C, + grid_h // ms, + ms, + ps, + grid_w // ms, + ms, + ps, + ) + patches = patches.transpose(0, 1, 4, 7, 5, 8, 3, 2, 6, 9) + flatten = patches.reshape(1, grid_t * grid_h * grid_w, C * tps * ps * ps) + return flatten[0], [grid_t, grid_h, grid_w] + + def __call__(self, images, **kwargs): + # HF's apply_chat_template passes images as a list-of-list (one inner + # list per batch item). Flatten into a single list of images. + if not isinstance(images, list): + images = [images] + flat = [] + for item in images: + if isinstance(item, list): + flat.extend(item) + else: + flat.append(item) + imgs = [ + ( + img + if (isinstance(img, np.ndarray) and img.ndim == 3) + else _to_numpy_image(img) + ) + for img in flat + ] + all_patches = [] + all_thw = [] + for v in imgs: + patches, thw = self._process_one(v) + all_patches.append(patches) + all_thw.append(thw) + return { + "pixel_values": np.concatenate(all_patches, axis=0), + "image_grid_thw": np.array(all_thw, dtype=np.int64), + } - def __call__(self, *args, **kwargs): - del args, kwargs - raise ValueError( - "Qwen3-VL video inputs are not supported by the custom MLX processor." + def preprocess(self, images, **kwargs): + return self(images, **kwargs) + + +class Qwen3VLVideoProcessor(BaseVideoProcessor): + model_input_names = ["pixel_values_videos", "video_grid_thw"] + + def __init__( + self, + patch_size: int = 16, + temporal_patch_size: int = 2, + merge_size: int = 2, + min_pixels: int = 128 * 32 * 32, + max_pixels: int = 32 * 32 * 768, + do_rescale: bool = True, + rescale_factor: float = 1 / 255.0, + do_normalize: bool = True, + image_mean: Optional[List[float]] = None, + image_std: Optional[List[float]] = None, + do_convert_rgb: bool = True, + fps: float = 2.0, + min_frames: int = 4, + max_frames: int = 768, + **kwargs, + ): + self.patch_size = patch_size + self.temporal_patch_size = temporal_patch_size + self.merge_size = merge_size + self.min_pixels = min_pixels + self.max_pixels = max_pixels + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.image_mean = image_mean or [0.5, 0.5, 0.5] + self.image_std = image_std or [0.5, 0.5, 0.5] + self.do_convert_rgb = do_convert_rgb + self.fps = fps + self.min_frames = min_frames + self.max_frames = max_frames + + def _process_one(self, video: np.ndarray) -> Tuple[np.ndarray, List[int]]: + if video.ndim != 4: + raise ValueError( + f"Expected video as (T, C, H, W), got shape {video.shape}." + ) + T, C, H, W = video.shape + if C == 1 and self.do_convert_rgb: + video = np.repeat(video, 3, axis=1) + C = 3 + + resized_h, resized_w = _smart_resize_video( + num_frames=T, + height=H, + width=W, + temporal_factor=self.temporal_patch_size, + factor=self.patch_size * self.merge_size, + min_pixels=self.min_pixels, + max_pixels=self.max_pixels, + ) + video = _resize_video_frames(video, resized_h, resized_w) + + video_f = video.astype(np.float32) + if self.do_rescale and video.dtype == np.uint8: + video_f = video_f * self.rescale_factor + if self.do_normalize: + mean = np.array(self.image_mean, dtype=np.float32)[None, :, None, None] + std = np.array(self.image_std, dtype=np.float32)[None, :, None, None] + video_f = (video_f - mean) / std + + pad = (-video_f.shape[0]) % self.temporal_patch_size + if pad: + video_f = np.concatenate( + [video_f, np.repeat(video_f[-1:], pad, axis=0)], axis=0 + ) + + T_padded = video_f.shape[0] + grid_t = T_padded // self.temporal_patch_size + grid_h = resized_h // self.patch_size + grid_w = resized_w // self.patch_size + ps = self.patch_size + tps = self.temporal_patch_size + ms = self.merge_size + + patches = video_f[None, ...] + patches = patches.reshape( + 1, + grid_t, + tps, + C, + grid_h // ms, + ms, + ps, + grid_w // ms, + ms, + ps, + ) + patches = patches.transpose(0, 1, 4, 7, 5, 8, 3, 2, 6, 9) + flatten = patches.reshape(1, grid_t * grid_h * grid_w, C * tps * ps * ps) + return flatten[0], [grid_t, grid_h, grid_w] + + def __call__(self, videos, **kwargs): + # Same list-of-list batching convention as the image processor. + if not isinstance(videos, list): + videos = [videos] + flat = [] + for item in videos: + if isinstance(item, list): + flat.extend(item) + else: + flat.append(item) + all_patches = [] + all_thw = [] + for v in flat: + if not isinstance(v, np.ndarray): + v = np.asarray(v) + patches, thw = self._process_one(v) + all_patches.append(patches) + all_thw.append(thw) + return { + "pixel_values_videos": np.concatenate(all_patches, axis=0), + "video_grid_thw": np.array(all_thw, dtype=np.int64), + } + + +def _load_file(pretrained_model_name_or_path, relative_name: str): + """Read a file from the checkpoint (local dir or HF Hub). + + Returns the parsed dict when *relative_name* ends in ``.json``, otherwise + the raw text. Returns ``None`` if the file isn't available. + """ + import json + from pathlib import Path + + local = Path(pretrained_model_name_or_path) / relative_name + if local.exists(): + text = local.read_text(encoding="utf-8") + else: + try: + from huggingface_hub import hf_hub_download + + fetched = Path( + hf_hub_download(pretrained_model_name_or_path, relative_name) + ) + text = fetched.read_text(encoding="utf-8") + except Exception: + return None + return json.loads(text) if relative_name.endswith(".json") else text + + +def _image_kwargs(pretrained_model_name_or_path, default_patch_size: int = 16): + proc_cfg = ( + _load_file(pretrained_model_name_or_path, "processor_config.json") or {} + ) + raw = ( + _load_file(pretrained_model_name_or_path, "preprocessor_config.json") + or {} + ) + raw.update(proc_cfg.get("image_processor", {}) or {}) + out = {"patch_size": default_patch_size} + for k in ( + "patch_size", + "temporal_patch_size", + "merge_size", + "image_mean", + "image_std", + "rescale_factor", + "do_rescale", + "do_normalize", + "do_convert_rgb", + ): + if raw.get(k) is not None: + out[k] = raw[k] + size = raw.get("size") or {} + if size.get("shortest_edge") is not None: + out["min_pixels"] = size["shortest_edge"] + if size.get("longest_edge") is not None: + out["max_pixels"] = size["longest_edge"] + if raw.get("min_pixels") is not None: + out["min_pixels"] = raw["min_pixels"] + if raw.get("max_pixels") is not None: + out["max_pixels"] = raw["max_pixels"] + return out + + +def _video_kwargs(pretrained_model_name_or_path, default_patch_size: int = 16): + raw = _load_file( + pretrained_model_name_or_path, "video_preprocessor_config.json" + ) + if raw is None: + raw = ( + _load_file( + pretrained_model_name_or_path, "preprocessor_config.json" + ) + or {} + ) + out = {"patch_size": default_patch_size} + for k in ( + "patch_size", + "temporal_patch_size", + "merge_size", + "fps", + "min_frames", + "max_frames", + "image_mean", + "image_std", + "rescale_factor", + "do_rescale", + "do_normalize", + "do_convert_rgb", + ): + if raw.get(k) is not None: + out[k] = raw[k] + size = raw.get("size") or {} + if size.get("shortest_edge") is not None: + out["min_pixels"] = size["shortest_edge"] + if size.get("longest_edge") is not None: + out["max_pixels"] = size["longest_edge"] + if raw.get("min_pixels") is not None: + out["min_pixels"] = raw["min_pixels"] + if raw.get("max_pixels") is not None: + out["max_pixels"] = raw["max_pixels"] + return out + + +class Qwen3VLProcessor(ProcessorMixin): + attributes = ["image_processor", "tokenizer", "video_processor"] + valid_kwargs = ["chat_template"] + image_processor_class = "AutoImageProcessor" + tokenizer_class = "AutoTokenizer" + video_processor_class = "AutoVideoProcessor" + + # HF's ProcessorMixin resolves expected base classes at runtime; in torch- + # free environments it picks up dummy classes from + # ``transformers.utils.dummy_torchvision_objects``, so our (real) numpy + # subclasses fail ``isinstance``. Skip that validation — our processors + # are duck-typed to the interfaces the call sites use. + def check_argument_for_proper_class(self, argument_name, argument): + return type(argument) + + def __init__( + self, + image_processor=None, + tokenizer=None, + video_processor=None, + chat_template=None, + **kwargs, + ): + self.image_token = ( + "<|image_pad|>" + if not hasattr(tokenizer, "image_token") + else tokenizer.image_token + ) + self.video_token = ( + "<|video_pad|>" + if not hasattr(tokenizer, "video_token") + else tokenizer.video_token + ) + self.image_token_id = ( + tokenizer.image_token_id + if getattr(tokenizer, "image_token_id", None) + else tokenizer.convert_tokens_to_ids(self.image_token) + ) + self.video_token_id = ( + tokenizer.video_token_id + if getattr(tokenizer, "video_token_id", None) + else tokenizer.convert_tokens_to_ids(self.video_token) + ) + super().__init__( + image_processor, tokenizer, video_processor, chat_template=chat_template + ) + + self.vision_start_token = ( + "<|vision_start|>" + if not hasattr(tokenizer, "vision_start_token") + else tokenizer.vision_start_token + ) + self.vision_end_token = ( + "<|vision_end|>" + if not hasattr(tokenizer, "vision_end_token") + else tokenizer.vision_end_token + ) + self.vision_start_token_id = ( + tokenizer.vision_start_token_id + if getattr(tokenizer, "vision_start_token_id", None) + else tokenizer.convert_tokens_to_ids(self.vision_start_token) + ) + self.vision_end_token_id = ( + tokenizer.vision_end_token_id + if getattr(tokenizer, "vision_end_token_id", None) + else tokenizer.convert_tokens_to_ids(self.vision_end_token) + ) + + def __call__( + self, + images: Optional[ImageInput] = None, + text: Optional[ + Union[ + TextInput, + PreTokenizedInput, + List[TextInput], + List[PreTokenizedInput], + ] + ] = None, + videos=None, + **kwargs, + ) -> BatchFeature: + image_inputs = {} + videos_inputs = {} + + if images is not None: + image_inputs = self.image_processor(images=images) + image_grid_thw = image_inputs["image_grid_thw"] + else: + image_grid_thw = None + + if videos is not None: + _video_proc = self.video_processor or self.image_processor + videos_inputs = _video_proc(videos=videos) + video_grid_thw = videos_inputs["video_grid_thw"] + else: + video_grid_thw = None + + if not isinstance(text, list): + text = [text] + + text = text.copy() + if image_grid_thw is not None: + merge_length = self.image_processor.merge_size**2 + index = 0 + for i in range(len(text)): + while self.image_token in text[i]: + num_image_tokens = image_grid_thw[index].prod() // merge_length + text[i] = text[i].replace( + self.image_token, + "<|placeholder|>" * num_image_tokens, + 1, + ) + index += 1 + text[i] = text[i].replace("<|placeholder|>", self.image_token) + + if video_grid_thw is not None: + _video_proc = self.video_processor or self.image_processor + merge_length = _video_proc.merge_size**2 + index = 0 + for i in range(len(text)): + while self.video_token in text[i]: + num_video_tokens = video_grid_thw[index].prod() // merge_length + text[i] = text[i].replace( + self.video_token, + "<|placeholder|>" * num_video_tokens, + 1, + ) + index += 1 + text[i] = text[i].replace("<|placeholder|>", self.video_token) + + kwargs.pop("return_tensors", None) + return_mm_token_type_ids = kwargs.pop("return_mm_token_type_ids", None) + text_inputs = self.tokenizer(text, **kwargs) + + if return_mm_token_type_ids: + array_ids = np.array(text_inputs["input_ids"]) + mm_token_type_ids = np.zeros_like(text_inputs["input_ids"]) + mm_token_type_ids[array_ids == self.image_token_id] = 1 + mm_token_type_ids[array_ids == self.video_token_id] = 2 + text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist() + + return BatchFeature( + data=to_mlx({**text_inputs, **image_inputs, **videos_inputs}) + ) + + def batch_decode(self, *args, **kwargs): + return self.tokenizer.batch_decode(*args, **kwargs) + + def decode(self, *args, **kwargs): + return self.tokenizer.decode(*args, **kwargs) + + @property + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + image_processor_input_names = self.image_processor.model_input_names + return list( + dict.fromkeys( + tokenizer_input_names + + image_processor_input_names + + ["mm_token_type_ids"] + ) + ) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): + from transformers import AutoTokenizer + + kwargs.pop("use_fast", None) + tokenizer = AutoTokenizer.from_pretrained( + pretrained_model_name_or_path, **kwargs + ) + load_chat_template(tokenizer, pretrained_model_name_or_path) + + image_processor = Qwen3VLImageProcessor( + **_image_kwargs( + pretrained_model_name_or_path, default_patch_size=16 + ) + ) + video_processor = Qwen3VLVideoProcessor( + **_video_kwargs( + pretrained_model_name_or_path, default_patch_size=16 + ) + ) + + proc_cfg = ( + _load_file(pretrained_model_name_or_path, "processor_config.json") + or {} + ) + chat_template = proc_cfg.get( + "chat_template", getattr(tokenizer, "chat_template", None) + ) + + if chat_template is None: + chat_template = _load_file( + pretrained_model_name_or_path, "chat_template.jinja" + ) + if chat_template is not None: + tokenizer.chat_template = chat_template + + return cls( + image_processor=image_processor, + tokenizer=tokenizer, + video_processor=video_processor, + chat_template=chat_template, ) @@ -69,32 +707,10 @@ def from_pretrained(cls, model_path, **kwargs): reranking_system_prompt = kwargs.pop( "reranking_system_prompt", DEFAULT_RERANKING_SYSTEM_PROMPT ) - trust_remote_code = kwargs.pop("trust_remote_code", True) + kwargs.setdefault("trust_remote_code", True) kwargs.pop("use_fast", None) - model_dir = Path(model_path) - is_local = model_dir.exists() and model_dir.is_dir() - load_path = str(model_dir) if is_local else model_path - hub_kwargs = { - key: kwargs[key] - for key in ["cache_dir", "force_download", "revision", "token"] - if key in kwargs - } - - tokenizer = AutoTokenizer.from_pretrained( - load_path, - trust_remote_code=trust_remote_code, - local_files_only=is_local, - **kwargs, - ) - image_processor = AutoImageProcessor.from_pretrained( - load_path, - trust_remote_code=trust_remote_code, - local_files_only=is_local, - use_fast=False, - **hub_kwargs, - ) - processor = cls._build_processor(tokenizer, image_processor) + processor = Qwen3VLProcessor.from_pretrained(model_path, **kwargs) return cls( processor=processor, embedding_max_length=embedding_max_length, @@ -106,70 +722,6 @@ def from_pretrained(cls, model_path, **kwargs): reranking_system_prompt=reranking_system_prompt, ) - @staticmethod - def _build_processor(tokenizer, image_processor): - processor = object.__new__(Qwen3VLProcessor) - processor.tokenizer = tokenizer - processor.image_processor = image_processor - processor.video_processor = _UnsupportedVideoProcessor( - merge_size=getattr(image_processor, "merge_size", 2) - ) - processor.chat_template = getattr(tokenizer, "chat_template", None) - - if processor.chat_template is None: - try: - processor.chat_template = Processor._load_chat_template( - tokenizer.name_or_path - ) - except Exception: - processor.chat_template = None - - processor.image_token = ( - getattr(tokenizer, "image_token", None) or "<|image_pad|>" - ) - processor.video_token = ( - getattr(tokenizer, "video_token", None) or "<|video_pad|>" - ) - processor.image_token_id = ( - getattr(tokenizer, "image_token_id", None) - if getattr(tokenizer, "image_token_id", None) is not None - else tokenizer.convert_tokens_to_ids(processor.image_token) - ) - processor.video_token_id = ( - getattr(tokenizer, "video_token_id", None) - if getattr(tokenizer, "video_token_id", None) is not None - else tokenizer.convert_tokens_to_ids(processor.video_token) - ) - processor.vision_start_token = ( - getattr(tokenizer, "vision_start_token", None) or "<|vision_start|>" - ) - processor.vision_end_token = ( - getattr(tokenizer, "vision_end_token", None) or "<|vision_end|>" - ) - processor.vision_start_token_id = ( - getattr(tokenizer, "vision_start_token_id", None) - if getattr(tokenizer, "vision_start_token_id", None) is not None - else tokenizer.convert_tokens_to_ids(processor.vision_start_token) - ) - processor.vision_end_token_id = ( - getattr(tokenizer, "vision_end_token_id", None) - if getattr(tokenizer, "vision_end_token_id", None) is not None - else tokenizer.convert_tokens_to_ids(processor.vision_end_token) - ) - - return processor - - @staticmethod - def _load_chat_template(model_path: str) -> str: - path = Path(model_path) - if path.exists() and path.is_dir(): - template_path = path / "chat_template.jinja" - else: - from huggingface_hub import hf_hub_download - - template_path = Path(hf_hub_download(model_path, "chat_template.jinja")) - return template_path.read_text(encoding="utf-8") - def __call__(self, *args, **kwargs): return self.processor(*args, **kwargs) @@ -390,3 +942,11 @@ def prepare_model_inputs(self, inputs, return_tensors: str = "mlx", **kwargs): return self.prepare_embedding_inputs( inputs, return_tensors=return_tensors, **kwargs ) + + +__all__ = [ + "Processor", + "Qwen3VLImageProcessor", + "Qwen3VLProcessor", + "Qwen3VLVideoProcessor", +] diff --git a/mlx_embeddings/tests/test_models.py b/mlx_embeddings/tests/test_models.py index efe2cf9604..c3b922a585 100644 --- a/mlx_embeddings/tests/test_models.py +++ b/mlx_embeddings/tests/test_models.py @@ -637,43 +637,27 @@ def test_qwen3_vl_processor_formats_embedding_and_reranker_inputs(self): def test_qwen3_vl_processor_from_pretrained_uses_custom_loader(self): from mlx_embeddings.models import qwen3_vl - class DummyTokenizer: - def __init__(self): - self.chat_template = "dummy-template" - self.padding_side = "right" - self.name_or_path = "dummy-model" - - def convert_tokens_to_ids(self, token): - mapping = { - "<|image_pad|>": 1, - "<|video_pad|>": 2, - "<|vision_start|>": 3, - "<|vision_end|>": 4, - } - return mapping[token] - - dummy_tokenizer = DummyTokenizer() - dummy_image_processor = MagicMock() - dummy_image_processor.merge_size = 2 - - with ( - patch.object( - qwen3_vl.processor.AutoTokenizer, - "from_pretrained", - return_value=dummy_tokenizer, - ) as mock_tokenizer, - patch.object( - qwen3_vl.processor.AutoImageProcessor, - "from_pretrained", - return_value=dummy_image_processor, - ) as mock_image_processor, - ): + dummy_inner = MagicMock() + dummy_inner.tokenizer = MagicMock() + dummy_inner.image_processor = MagicMock() + dummy_inner.video_processor = MagicMock() + dummy_inner.video_processor.merge_size = 2 + dummy_inner.chat_template = "dummy-template" + + with patch.object( + qwen3_vl.processor.Qwen3VLProcessor, + "from_pretrained", + return_value=dummy_inner, + ) as mock_from_pretrained: processor = qwen3_vl.Processor.from_pretrained("dummy-model") - mock_tokenizer.assert_called_once() - mock_image_processor.assert_called_once() - self.assertIs(processor.tokenizer, dummy_tokenizer) - self.assertIs(processor.image_processor, dummy_image_processor) + mock_from_pretrained.assert_called_once() + call_args, call_kwargs = mock_from_pretrained.call_args + self.assertEqual(call_args[0], "dummy-model") + self.assertTrue(call_kwargs.get("trust_remote_code")) + self.assertIs(processor.processor, dummy_inner) + self.assertIs(processor.tokenizer, dummy_inner.tokenizer) + self.assertIs(processor.image_processor, dummy_inner.image_processor) self.assertEqual(processor.processor.chat_template, "dummy-template") self.assertEqual(processor.processor.video_processor.merge_size, 2)