🐛 Fix issue with MPS device#1073
Conversation
There was a problem hiding this comment.
Pull request overview
This PR addresses #1070 by preventing float64 tensors from being moved onto Apple Silicon’s MPS backend (which doesn’t support float64) before being cast to float32. It standardizes inference-time tensor transfers across multiple model architectures by using a single Tensor.to(device=..., dtype=torch.float32) call, and additionally ensures KongNet’s NumPy preprocessing stays in float32 to avoid producing float64 tensors upstream.
Changes:
- Update multiple
infer_batchimplementations to move-and-cast inputs in one call (.to(device=..., dtype=torch.float32)) to avoid MPS float64 transfer errors. - Refactor
vanilla._infer_batchinput preparation to useto(..., dtype=...)instead of.to(...).type(...). - Make
KongNet.preprocexplicitly float32 throughout (mean/std and scaling) to avoid returning float64 arrays.
Reviewed changes
Copilot reviewed 11 out of 11 changed files in this pull request and generated no comments.
Show a summary per file
| File | Description |
|---|---|
tiatoolbox/models/architecture/vanilla.py |
Uses to(device=..., dtype=torch.float32) for batch transfer/cast before permuting. |
tiatoolbox/models/architecture/unet.py |
Updates input transfer/cast in infer_batch to avoid MPS float64 issues. |
tiatoolbox/models/architecture/sccnn.py |
Updates input transfer/cast in infer_batch to avoid MPS float64 issues. |
tiatoolbox/models/architecture/nuclick.py |
Updates input transfer/cast in infer_batch to avoid MPS float64 issues. |
tiatoolbox/models/architecture/micronet.py |
Updates input transfer/cast in infer_batch to avoid MPS float64 issues. |
tiatoolbox/models/architecture/mapde.py |
Updates input transfer/cast in infer_batch to avoid MPS float64 issues (directly tied to #1070). |
tiatoolbox/models/architecture/kongnet.py |
Updates input transfer/cast in infer_batch and makes preproc float32 end-to-end. |
tiatoolbox/models/architecture/hovernetplus.py |
Updates input transfer/cast in infer_batch to avoid MPS float64 issues. |
tiatoolbox/models/architecture/hovernet.py |
Updates input transfer/cast in infer_batch to avoid MPS float64 issues. |
tiatoolbox/models/architecture/grandqc.py |
Updates input transfer/cast in infer_batch to avoid MPS float64 issues. |
tiatoolbox/models/architecture/efficientunet_tissue_mask_model.py |
Updates input transfer/cast in infer_batch to avoid MPS float64 issues. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## develop #1073 +/- ##
========================================
Coverage 99.88% 99.88%
========================================
Files 85 85
Lines 11625 11626 +1
Branches 1524 1524
========================================
+ Hits 11612 11613 +1
Misses 7 7
Partials 6 6 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
|
@gozdeg Please can you review this PR? |
This PR fixes #1070.
It standardizes the way tensors are moved to a device and cast to
float32across multiple model architectures in the codebase. The updates ensure that both the device and data type are explicitly specified in a single call to.to(), improving code clarity and consistency. Additionally, the image preprocessing inkongnet.pyis updated to enforce float32 precision throughout..to(device).type(torch.float32)to.to(device=device, dtype=torch.float32)in variousinfer_batchfunctions, ensuring that tensors are moved to the correct device and cast tofloat32in a single step. This change affects the following files:efficientunet_tissue_mask_model.py[1]grandqc.py[2]hovernet.py[3]hovernetplus.py[4]kongnet.py[5]mapde.py[6]micronet.py[7]nuclick.py[8]sccnn.py[9]unet.py[10] andvanilla.py[11].