Skip to content

🐛 Fix issue with MPS device#1073

Merged
shaneahmed merged 1 commit into
developfrom
fix-MPS-issue
Jun 4, 2026
Merged

🐛 Fix issue with MPS device#1073
shaneahmed merged 1 commit into
developfrom
fix-MPS-issue

Conversation

@Jiaqi-Lv
Copy link
Copy Markdown
Collaborator

@Jiaqi-Lv Jiaqi-Lv commented Jun 3, 2026

This PR fixes #1070.
It standardizes the way tensors are moved to a device and cast to float32 across multiple model architectures in the codebase. The updates ensure that both the device and data type are explicitly specified in a single call to .to(), improving code clarity and consistency. Additionally, the image preprocessing in kongnet.py is updated to enforce float32 precision throughout.

  • Updated all instances of .to(device).type(torch.float32) to .to(device=device, dtype=torch.float32) in various infer_batch functions, ensuring that tensors are moved to the correct device and cast to float32 in a single step. This change affects the following files: efficientunet_tissue_mask_model.py [1] grandqc.py [2] hovernet.py [3] hovernetplus.py [4] kongnet.py [5] mapde.py [6] micronet.py [7] nuclick.py [8] sccnn.py [9] unet.py [10] and vanilla.py [11].

@Jiaqi-Lv Jiaqi-Lv requested a review from Copilot June 3, 2026 15:08
@Jiaqi-Lv Jiaqi-Lv self-assigned this Jun 3, 2026
@Jiaqi-Lv Jiaqi-Lv added the bug Something isn't working label Jun 3, 2026
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR addresses #1070 by preventing float64 tensors from being moved onto Apple Silicon’s MPS backend (which doesn’t support float64) before being cast to float32. It standardizes inference-time tensor transfers across multiple model architectures by using a single Tensor.to(device=..., dtype=torch.float32) call, and additionally ensures KongNet’s NumPy preprocessing stays in float32 to avoid producing float64 tensors upstream.

Changes:

  • Update multiple infer_batch implementations to move-and-cast inputs in one call (.to(device=..., dtype=torch.float32)) to avoid MPS float64 transfer errors.
  • Refactor vanilla._infer_batch input preparation to use to(..., dtype=...) instead of .to(...).type(...).
  • Make KongNet.preproc explicitly float32 throughout (mean/std and scaling) to avoid returning float64 arrays.

Reviewed changes

Copilot reviewed 11 out of 11 changed files in this pull request and generated no comments.

Show a summary per file
File Description
tiatoolbox/models/architecture/vanilla.py Uses to(device=..., dtype=torch.float32) for batch transfer/cast before permuting.
tiatoolbox/models/architecture/unet.py Updates input transfer/cast in infer_batch to avoid MPS float64 issues.
tiatoolbox/models/architecture/sccnn.py Updates input transfer/cast in infer_batch to avoid MPS float64 issues.
tiatoolbox/models/architecture/nuclick.py Updates input transfer/cast in infer_batch to avoid MPS float64 issues.
tiatoolbox/models/architecture/micronet.py Updates input transfer/cast in infer_batch to avoid MPS float64 issues.
tiatoolbox/models/architecture/mapde.py Updates input transfer/cast in infer_batch to avoid MPS float64 issues (directly tied to #1070).
tiatoolbox/models/architecture/kongnet.py Updates input transfer/cast in infer_batch and makes preproc float32 end-to-end.
tiatoolbox/models/architecture/hovernetplus.py Updates input transfer/cast in infer_batch to avoid MPS float64 issues.
tiatoolbox/models/architecture/hovernet.py Updates input transfer/cast in infer_batch to avoid MPS float64 issues.
tiatoolbox/models/architecture/grandqc.py Updates input transfer/cast in infer_batch to avoid MPS float64 issues.
tiatoolbox/models/architecture/efficientunet_tissue_mask_model.py Updates input transfer/cast in infer_batch to avoid MPS float64 issues.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@codecov
Copy link
Copy Markdown

codecov Bot commented Jun 3, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 99.88%. Comparing base (e337bea) to head (72668a9).

Additional details and impacted files
@@           Coverage Diff            @@
##           develop    #1073   +/-   ##
========================================
  Coverage    99.88%   99.88%           
========================================
  Files           85       85           
  Lines        11625    11626    +1     
  Branches      1524     1524           
========================================
+ Hits         11612    11613    +1     
  Misses           7        7           
  Partials         6        6           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@Jiaqi-Lv Jiaqi-Lv requested review from eshasadia and shaneahmed June 3, 2026 16:04
@shaneahmed
Copy link
Copy Markdown
Member

@gozdeg Please can you review this PR?

@gozdeg gozdeg self-requested a review June 4, 2026 11:14
Copy link
Copy Markdown
Collaborator

@gozdeg gozdeg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All looks good to me

@shaneahmed shaneahmed merged commit c9c72c9 into develop Jun 4, 2026
21 of 25 checks passed
@shaneahmed shaneahmed deleted the fix-MPS-issue branch June 4, 2026 13:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working

Projects

None yet

Development

Successfully merging this pull request may close these issues.

KongNet and MapDe infer_batch crash on Apple Silicon (MPS): Cannot convert a MPS Tensor to float64

4 participants