diff --git a/docling/datamodel/pipeline_options.py b/docling/datamodel/pipeline_options.py index dd15b6b7d5..9ec6d92a28 100644 --- a/docling/datamodel/pipeline_options.py +++ b/docling/datamodel/pipeline_options.py @@ -54,7 +54,9 @@ from docling.datamodel.vlm_model_specs import ( GRANITE_VISION_OLLAMA as granite_vision_vlm_ollama_conversion_options, GRANITE_VISION_TRANSFORMERS as granite_vision_vlm_conversion_options, + GRANITEDOCLING as granite_docling_vlm_conversion_options, NU_EXTRACT_2B_TRANSFORMERS, + SMOLDOCLING as smoldocling_vlm_auto_conversion_options, SMOLDOCLING_MLX as smoldocling_vlm_mlx_conversion_options, SMOLDOCLING_TRANSFORMERS as smoldocling_vlm_conversion_options, VlmModelType, diff --git a/docling/datamodel/vlm_model_specs.py b/docling/datamodel/vlm_model_specs.py index eec75390b7..fdd7bdca12 100644 --- a/docling/datamodel/vlm_model_specs.py +++ b/docling/datamodel/vlm_model_specs.py @@ -353,6 +353,62 @@ ) +def _has_apple_silicon_mlx() -> bool: + """Return True if MPS is available and mlx-vlm is installed.""" + try: + import torch + + has_mps = torch.backends.mps.is_built() and torch.backends.mps.is_available() + except ImportError: + has_mps = False + + if not has_mps: + return False + + try: + import mlx_vlm # type: ignore + + return True + except ImportError: + return False + + +def _get_granitedocling_model(): + """Get the best GraniteDocling variant for the current hardware. + + Automatically selects MLX variant on Apple Silicon if mlx-vlm is installed, + otherwise falls back to Transformers variant. + """ + if _has_apple_silicon_mlx(): + _log.debug("Auto-selected GraniteDocling MLX variant (Apple Silicon)") + return GRANITEDOCLING_MLX + else: + _log.debug("Auto-selected GraniteDocling Transformers variant") + return GRANITEDOCLING_TRANSFORMERS + + +# Auto-selecting: picks MLX on Apple Silicon, Transformers otherwise +GRANITEDOCLING = _get_granitedocling_model() + + +def _get_smoldocling_model(): + """Get the best SmolDocling variant for the current hardware. + + Automatically selects MLX variant on Apple Silicon if mlx-vlm is installed, + otherwise falls back to Transformers variant. + """ + if _has_apple_silicon_mlx(): + _log.debug("Auto-selected SmolDocling MLX variant (Apple Silicon)") + return SMOLDOCLING_MLX + else: + _log.debug("Auto-selected SmolDocling Transformers variant") + return SMOLDOCLING_TRANSFORMERS + + +# Auto-selecting: picks MLX on Apple Silicon, Transformers otherwise +SMOLDOCLING = _get_smoldocling_model() + + class VlmModelType(str, Enum): SMOLDOCLING = "smoldocling" SMOLDOCLING_VLLM = "smoldocling_vllm" diff --git a/docling/models/stages/table_structure/table_structure_model.py b/docling/models/stages/table_structure/table_structure_model.py index c7c9219b1e..aba3d80ec1 100644 --- a/docling/models/stages/table_structure/table_structure_model.py +++ b/docling/models/stages/table_structure/table_structure_model.py @@ -77,11 +77,16 @@ def __init__( TFPredictor, ) - device = decide_device(accelerator_options.device) - - # Disable MPS here, until we know why it makes things slower. - if device == AcceleratorDevice.MPS.value: - device = AcceleratorDevice.CPU.value + device = decide_device( + accelerator_options.device, + supported_devices=[ + AcceleratorDevice.CPU, + AcceleratorDevice.CUDA, + AcceleratorDevice.MPS, + AcceleratorDevice.XPU, + ], + ) + _log.debug(f"TableStructureModel using device: {device}") self.tm_config = c.read_config(f"{artifacts_path}/tm_config.json") self.tm_config["model"]["save_dir"] = artifacts_path diff --git a/docling/models/stages/table_structure/table_structure_model_v2.py b/docling/models/stages/table_structure/table_structure_model_v2.py index 47b3091db1..dd75fd3704 100644 --- a/docling/models/stages/table_structure/table_structure_model_v2.py +++ b/docling/models/stages/table_structure/table_structure_model_v2.py @@ -58,9 +58,16 @@ def __init__( model_path = artifacts_path # Determine device - device = decide_device(accelerator_options.device) - if device == AcceleratorDevice.MPS.value: - device = AcceleratorDevice.CPU.value + device = decide_device( + accelerator_options.device, + supported_devices=[ + AcceleratorDevice.CPU, + AcceleratorDevice.CUDA, + AcceleratorDevice.MPS, + AcceleratorDevice.XPU, + ], + ) + _log.debug(f"TableStructureModelV2 using device: {device}") self.device = device # Set number of threads for CPU inference diff --git a/docs/usage/model_catalog.md b/docs/usage/model_catalog.md index 95c0f879e5..41f69a6845 100644 --- a/docs/usage/model_catalog.md +++ b/docs/usage/model_catalog.md @@ -191,10 +191,8 @@ The following table shows all processing stages in Docling, their model families | Model | Inference Engine | Supported Devices | |-------|------------------|-------------------| -| TableFormer (fast) | docling-ibm-models | CPU, CUDA, XPU | -| TableFormer (accurate) | docling-ibm-models | CPU, CUDA, XPU | - -**Note:** MPS is currently disabled for TableFormer due to performance issues. +| TableFormer (fast) | docling-ibm-models | CPU, CUDA, MPS, XPU | +| TableFormer (accurate) | docling-ibm-models | CPU, CUDA, MPS, XPU | ### Image Classifier (Picture Classifier) diff --git a/docs/usage/vision_models.md b/docs/usage/vision_models.md index 3a0ba141eb..eeeab799c4 100644 --- a/docs/usage/vision_models.md +++ b/docs/usage/vision_models.md @@ -52,8 +52,10 @@ The following table reports the models currently available out-of-the-box. | Model instance | Model | Framework | Device | Num pages | Inference time (sec) | | ---------------|------ | --------- | ------ | --------- | ---------------------| +| `vlm_model_specs.GRANITEDOCLING` | Auto-selects MLX or Transformers | `Auto` | MPS | 1 | - | | `vlm_model_specs.GRANITEDOCLING_TRANSFORMERS` | [ibm-granite/granite-docling-258M](https://huggingface.co/ibm-granite/granite-docling-258M) | `Transformers/AutoModelForVision2Seq` | MPS | 1 | - | | `vlm_model_specs.GRANITEDOCLING_MLX` | [ibm-granite/granite-docling-258M-mlx-bf16](https://huggingface.co/ibm-granite/granite-docling-258M-mlx-bf16) | `MLX`| MPS | 1 | - | +| `vlm_model_specs.SMOLDOCLING` | Auto-selects MLX or Transformers | `Auto` | MPS | 1 | - | | `vlm_model_specs.SMOLDOCLING_TRANSFORMERS` | [ds4sd/SmolDocling-256M-preview](https://huggingface.co/ds4sd/SmolDocling-256M-preview) | `Transformers/AutoModelForVision2Seq` | MPS | 1 | 102.212 | | `vlm_model_specs.SMOLDOCLING_MLX` | [ds4sd/SmolDocling-256M-preview-mlx-bf16](https://huggingface.co/ds4sd/SmolDocling-256M-preview-mlx-bf16) | `MLX`| MPS | 1 | 6.15453 | | `vlm_model_specs.QWEN25_VL_3B_MLX` | [mlx-community/Qwen2.5-VL-3B-Instruct-bf16](https://huggingface.co/mlx-community/Qwen2.5-VL-3B-Instruct-bf16) | `MLX`| MPS | 1 | 23.4951 | diff --git a/tests/test_apple_silicon_optimization.py b/tests/test_apple_silicon_optimization.py new file mode 100644 index 0000000000..89d78a9c22 --- /dev/null +++ b/tests/test_apple_silicon_optimization.py @@ -0,0 +1,264 @@ +"""Tests for Apple Silicon optimization: TableFormer MPS and VLM auto-selection. + +Validates: +1. TableFormer models no longer override MPS to CPU +2. VLM auto-selecting constants pick MLX on Apple Silicon, Transformers otherwise +3. _has_apple_silicon_mlx() helper detects hardware correctly +""" + +import sys + +import pytest + +from docling.datamodel.accelerator_options import AcceleratorDevice +from docling.datamodel.pipeline_options_vlm_model import InferenceFramework +from docling.datamodel.vlm_model_specs import ( + GRANITEDOCLING, + GRANITEDOCLING_MLX, + GRANITEDOCLING_TRANSFORMERS, + SMOLDOCLING, + SMOLDOCLING_MLX, + SMOLDOCLING_TRANSFORMERS, +) +from docling.utils.accelerator_utils import decide_device + + +class TestTableFormerMpsSupport: + """Verify TableFormer models support MPS device selection.""" + + def test_decide_device_allows_mps(self): + """decide_device with MPS in supported_devices returns 'mps' when available.""" + import torch + + if not (torch.backends.mps.is_built() and torch.backends.mps.is_available()): + pytest.skip("MPS not available on this machine") + + device = decide_device( + AcceleratorDevice.AUTO, + supported_devices=[ + AcceleratorDevice.CPU, + AcceleratorDevice.CUDA, + AcceleratorDevice.MPS, + AcceleratorDevice.XPU, + ], + ) + # On Apple Silicon without CUDA, AUTO should resolve to mps + assert device == "mps" + + def test_decide_device_mps_explicit(self): + """Explicitly requesting MPS with MPS in supported_devices returns 'mps'.""" + import torch + + if not (torch.backends.mps.is_built() and torch.backends.mps.is_available()): + pytest.skip("MPS not available on this machine") + + device = decide_device( + AcceleratorDevice.MPS, + supported_devices=[ + AcceleratorDevice.CPU, + AcceleratorDevice.CUDA, + AcceleratorDevice.MPS, + AcceleratorDevice.XPU, + ], + ) + assert device == "mps" + + def test_decide_device_cpu_fallback(self): + """CPU is always a valid fallback.""" + device = decide_device( + AcceleratorDevice.CPU, + supported_devices=[ + AcceleratorDevice.CPU, + AcceleratorDevice.CUDA, + AcceleratorDevice.MPS, + AcceleratorDevice.XPU, + ], + ) + assert device == "cpu" + + +class TestVlmAutoSelection: + """Verify VLM auto-selecting constants and factory functions.""" + + def test_auto_selecting_constants_exist(self): + """GRANITEDOCLING and SMOLDOCLING auto-selecting constants are importable.""" + assert GRANITEDOCLING is not None + assert SMOLDOCLING is not None + assert hasattr(GRANITEDOCLING, "inference_framework") + assert hasattr(SMOLDOCLING, "inference_framework") + + def test_explicit_constants_unchanged(self): + """Explicit MLX and Transformers constants remain intact.""" + assert GRANITEDOCLING_TRANSFORMERS.inference_framework == InferenceFramework.TRANSFORMERS + assert GRANITEDOCLING_MLX.inference_framework == InferenceFramework.MLX + assert SMOLDOCLING_TRANSFORMERS.inference_framework == InferenceFramework.TRANSFORMERS + assert SMOLDOCLING_MLX.inference_framework == InferenceFramework.MLX + + def test_explicit_constants_repo_ids(self): + """Explicit constants have correct repo IDs.""" + assert GRANITEDOCLING_TRANSFORMERS.repo_id == "ibm-granite/granite-docling-258M" + assert GRANITEDOCLING_MLX.repo_id == "ibm-granite/granite-docling-258M-mlx" + assert "SmolDocling" in SMOLDOCLING_TRANSFORMERS.repo_id + assert "SmolDocling" in SMOLDOCLING_MLX.repo_id + + def test_selectors_mlx_path(self, monkeypatch): + """Factory functions return MLX variants when MPS and mlx-vlm are available.""" + from docling.datamodel import vlm_model_specs as specs + + class _Mps: + def is_built(self): + return True + + def is_available(self): + return True + + class _Torch: + class backends: + mps = _Mps() + + monkeypatch.setitem(sys.modules, "torch", _Torch()) + monkeypatch.setitem(sys.modules, "mlx_vlm", object()) + + granite = specs._get_granitedocling_model() + smol = specs._get_smoldocling_model() + + assert granite.inference_framework == InferenceFramework.MLX + assert granite.repo_id == "ibm-granite/granite-docling-258M-mlx" + assert smol.inference_framework == InferenceFramework.MLX + assert "mlx" in smol.repo_id + + def test_selectors_transformers_fallback(self, monkeypatch): + """Factory functions return Transformers variants when MPS is unavailable.""" + from docling.datamodel import vlm_model_specs as specs + + class _MpsOff: + def is_built(self): + return False + + def is_available(self): + return False + + class _TorchOff: + class backends: + mps = _MpsOff() + + monkeypatch.setitem(sys.modules, "torch", _TorchOff()) + if "mlx_vlm" in sys.modules: + del sys.modules["mlx_vlm"] + + granite = specs._get_granitedocling_model() + smol = specs._get_smoldocling_model() + + assert granite.inference_framework == InferenceFramework.TRANSFORMERS + assert granite.repo_id == "ibm-granite/granite-docling-258M" + assert smol.inference_framework == InferenceFramework.TRANSFORMERS + assert "preview" in smol.repo_id + assert "mlx" not in smol.repo_id + + def test_selectors_no_mlx_vlm_installed(self, monkeypatch): + """Factory functions fall back to Transformers when mlx-vlm is not installed.""" + from docling.datamodel import vlm_model_specs as specs + + # MPS available but mlx-vlm not installed + class _Mps: + def is_built(self): + return True + + def is_available(self): + return True + + class _Torch: + class backends: + mps = _Mps() + + monkeypatch.setitem(sys.modules, "torch", _Torch()) + # Setting to None causes ImportError on `import mlx_vlm` + monkeypatch.setitem(sys.modules, "mlx_vlm", None) + + granite = specs._get_granitedocling_model() + assert granite.inference_framework == InferenceFramework.TRANSFORMERS + + def test_selectors_torch_import_error(self, monkeypatch): + """If torch cannot be imported, selectors return Transformers variant.""" + from docling.datamodel import vlm_model_specs as specs + + # Remove torch to simulate ImportError + monkeypatch.setitem(sys.modules, "torch", None) + if "mlx_vlm" in sys.modules: + del sys.modules["mlx_vlm"] + + granite = specs._get_granitedocling_model() + assert granite.inference_framework == InferenceFramework.TRANSFORMERS + + +class TestHasAppleSiliconMlx: + """Test the _has_apple_silicon_mlx() shared helper.""" + + def test_returns_true_when_both_available(self, monkeypatch): + """Returns True when MPS is available and mlx-vlm is installed.""" + from docling.datamodel import vlm_model_specs as specs + + class _Mps: + def is_built(self): + return True + + def is_available(self): + return True + + class _Torch: + class backends: + mps = _Mps() + + monkeypatch.setitem(sys.modules, "torch", _Torch()) + monkeypatch.setitem(sys.modules, "mlx_vlm", object()) + + assert specs._has_apple_silicon_mlx() is True + + def test_returns_false_no_mps(self, monkeypatch): + """Returns False when MPS is not available.""" + from docling.datamodel import vlm_model_specs as specs + + class _MpsOff: + def is_built(self): + return False + + def is_available(self): + return False + + class _TorchOff: + class backends: + mps = _MpsOff() + + monkeypatch.setitem(sys.modules, "torch", _TorchOff()) + monkeypatch.setitem(sys.modules, "mlx_vlm", object()) + + assert specs._has_apple_silicon_mlx() is False + + def test_returns_false_no_mlx_vlm(self, monkeypatch): + """Returns False when mlx-vlm is not installed.""" + from docling.datamodel import vlm_model_specs as specs + + class _Mps: + def is_built(self): + return True + + def is_available(self): + return True + + class _Torch: + class backends: + mps = _Mps() + + monkeypatch.setitem(sys.modules, "torch", _Torch()) + # Setting to None causes ImportError on `import mlx_vlm` + monkeypatch.setitem(sys.modules, "mlx_vlm", None) + + assert specs._has_apple_silicon_mlx() is False + + def test_returns_false_no_torch(self, monkeypatch): + """Returns False when torch is not installed.""" + from docling.datamodel import vlm_model_specs as specs + + monkeypatch.setitem(sys.modules, "torch", None) + + assert specs._has_apple_silicon_mlx() is False