-
Notifications
You must be signed in to change notification settings - Fork 4.1k
feat: enable MPS acceleration for TableFormer and add VLM auto-selection on Apple Silicon #3203
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
adfb145
85cb5c3
ce910cf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -353,6 +353,62 @@ | |
| ) | ||
|
|
||
|
|
||
| def _has_apple_silicon_mlx() -> bool: | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The new stage inference model should take care of automatic MLX choice when available. Where would these extra methods be needed? |
||
| """Return True if MPS is available and mlx-vlm is installed.""" | ||
| try: | ||
| import torch | ||
|
|
||
| has_mps = torch.backends.mps.is_built() and torch.backends.mps.is_available() | ||
| except ImportError: | ||
| has_mps = False | ||
|
|
||
| if not has_mps: | ||
| return False | ||
|
|
||
| try: | ||
| import mlx_vlm # type: ignore | ||
|
|
||
| return True | ||
| except ImportError: | ||
| return False | ||
|
|
||
|
|
||
| def _get_granitedocling_model(): | ||
| """Get the best GraniteDocling variant for the current hardware. | ||
|
|
||
| Automatically selects MLX variant on Apple Silicon if mlx-vlm is installed, | ||
| otherwise falls back to Transformers variant. | ||
| """ | ||
| if _has_apple_silicon_mlx(): | ||
| _log.debug("Auto-selected GraniteDocling MLX variant (Apple Silicon)") | ||
| return GRANITEDOCLING_MLX | ||
| else: | ||
| _log.debug("Auto-selected GraniteDocling Transformers variant") | ||
| return GRANITEDOCLING_TRANSFORMERS | ||
|
|
||
|
|
||
| # Auto-selecting: picks MLX on Apple Silicon, Transformers otherwise | ||
| GRANITEDOCLING = _get_granitedocling_model() | ||
|
|
||
|
|
||
| def _get_smoldocling_model(): | ||
| """Get the best SmolDocling variant for the current hardware. | ||
|
|
||
| Automatically selects MLX variant on Apple Silicon if mlx-vlm is installed, | ||
| otherwise falls back to Transformers variant. | ||
| """ | ||
| if _has_apple_silicon_mlx(): | ||
| _log.debug("Auto-selected SmolDocling MLX variant (Apple Silicon)") | ||
| return SMOLDOCLING_MLX | ||
| else: | ||
| _log.debug("Auto-selected SmolDocling Transformers variant") | ||
| return SMOLDOCLING_TRANSFORMERS | ||
|
|
||
|
|
||
| # Auto-selecting: picks MLX on Apple Silicon, Transformers otherwise | ||
| SMOLDOCLING = _get_smoldocling_model() | ||
|
|
||
|
|
||
| class VlmModelType(str, Enum): | ||
| SMOLDOCLING = "smoldocling" | ||
| SMOLDOCLING_VLLM = "smoldocling_vllm" | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -77,11 +77,16 @@ def __init__( | |
| TFPredictor, | ||
| ) | ||
|
|
||
| device = decide_device(accelerator_options.device) | ||
|
|
||
| # Disable MPS here, until we know why it makes things slower. | ||
| if device == AcceleratorDevice.MPS.value: | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This was a clear choice because of performance issues. Is this now resolved? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I can tell that I also achieved a 30% reduction in execution time for converting both a 10-page paper and a 100-page report to markdown by just re-enabling MPS in |
||
| device = AcceleratorDevice.CPU.value | ||
| device = decide_device( | ||
| accelerator_options.device, | ||
| supported_devices=[ | ||
| AcceleratorDevice.CPU, | ||
| AcceleratorDevice.CUDA, | ||
| AcceleratorDevice.MPS, | ||
| AcceleratorDevice.XPU, | ||
| ], | ||
| ) | ||
| _log.debug(f"TableStructureModel using device: {device}") | ||
|
|
||
| self.tm_config = c.read_config(f"{artifacts_path}/tm_config.json") | ||
| self.tm_config["model"]["save_dir"] = artifacts_path | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
these imports are deprecated, we should not add them.