Skip to content

2. Add Apple Silicon (MPS) training support#8

Open
musicalplatypus wants to merge 8 commits into
TexasInstruments:mainfrom
musicalplatypus:pr/mps-support
Open

2. Add Apple Silicon (MPS) training support#8
musicalplatypus wants to merge 8 commits into
TexasInstruments:mainfrom
musicalplatypus:pr/mps-support

Conversation

@musicalplatypus

Copy link
Copy Markdown

Summary

Adds full Apple Silicon (MPS/Metal) support for training across all task types. This enables GPU-accelerated training on Mac devices without CUDA.

Changes

MPS Device Compatibility

  • dtype casting order — Cast tensors to .float() before .to(device) to avoid MPS dtype errors (float64 not supported on MPS)
  • torcheval metrics — Move metric tensors to CPU before computation (torcheval doesn't support MPS backend)
  • non_blocking=True reverted — MPS doesn't support async transfers; reverted to synchronous .to(device) calls
  • float64 → float32 — Updated all test_onnx.py and training scripts to use float() instead of double()

MPS Performance Optimization

  • Deferred .item() in training loopsSmoothedValue.update() now stores detached tensors instead of calling .item() every batch, avoiding MPS command-buffer flush synchronization stalls
  • AMP support — Added --native-amp flag threading through timeseries_base.py (auto-enable on MPS was reverted due to gradient underflow — left as opt-in)
  • torch.compile integration — Added model compilation support for accelerated training

NAS on MPS

  • Fixed NAS pipeline for MPS/macOS compatibility
  • Added NAS support test suite (test_nas_support.py)

New Models

  • Added open-source implementations of application-specific models in tinyml-modelzoo

Files Changed (16 files)

Across tinyml-modelmaker, tinyml-tinyverse, tinyml-modelzoo, and tinyml-modeloptimization

Testing

  • Verified classification, regression, forecasting, and anomaly detection training on MPS (M1/M2/M3)
  • NAS search verified on MPS with test_nas_support.py
  • No regression on CUDA or CPU training paths

t5fkg8d44d-beep and others added 8 commits April 7, 2026 07:19
MPS doesn't support float64 tensors. The previous .to(device).float()
pattern moved float64 data to MPS first, which fails. Swapping to
.float().to(device) casts to float32 on CPU then moves to device.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
torcheval's multiclass_confusion_matrix, multiclass_f1_score, and
multiclass_auroc use sparse COO tensors internally, which are not
supported on MPS. Move inputs to CPU for these metric computations.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- Add non_blocking=True to all training loop .to(device) calls to
  overlap CPU-GPU transfers with computation
- Auto-enable native AMP (bfloat16) on MPS devices when not explicitly
  set; add --no-native-amp opt-out flag
- Reduce per-batch GPU syncs: accumulate loss as detached tensor,
  defer .item() to SmoothedValue.update() at print time
- Fix _get_device() to respect training_device config param instead
  of always auto-detecting
- Add MPS memory reporting to MetricLogger (current_allocated_memory)
- Tune default num_workers to 4 on macOS (spawn overhead vs 8 on Linux)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
MPS autocast defaults to float16 but GradScaler is unsupported,
so small TinyML models suffer gradient underflow — all predictions
collapse to a single class.  Revert to opt-in only (--native-amp).

Also adds macOS developer setup instructions to README.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
On MPS unified memory, non_blocking=True has negligible benefit
(CPU and GPU share physical memory) and may cause subtle issues
with certain DataLoader/model configurations. Revert to synchronous
transfers which were proven working.

The deferred .item() optimization in SmoothedValue is retained —
that provides the main MPS performance win (7.8x faster metric
logging path).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
run_tinyml_modelmaker.py:
- Skip model catalog validation when nas_enabled=True
- Guard params.update() against None model_description
- Generate fallback model_description with generic_model=True,
  training_backend, and model_training_id for NAS

train.py (timeseries_classification):
- Fix nas_enabled check: was comparing bool True to string 'True'
  (str2bool converts arg to bool, so == 'True' always failed)

train_cnn_search.py:
- Fix MPS float64 crash: cast to float32 before .to(device)
  (MPS doesn't support float64, so .to(device).float() fails)

tests/test_nas_support.py:
- 9 tests: model validation bypass, fallback description, str2bool,
  argparse integration
MPS does not support float64 dtype. When DataLoader returns float64
tensors, .to(device) transfers them as-is to MPS, and the subsequent
.long()/.float() conversion fails with TypeError.

Fix: cast dtype BEFORE device transfer (.long().to(device) instead of
.to(device).long()). Also add explicit dtype=torch.float32 to empty
tensor creation (torch.tensor([]) defaults to float64).

Fixed files (7):
- timeseries_classification/test_onnx.py
- image_classification/test_onnx.py
- timeseries_anomalydetection/test_onnx.py (2 functions)
- timeseries_anomalydetection/test_onnx_cls.py
- timeseries_anomalydetection/train.py
- timeseries_regression/test_onnx.py
- timeseries_forecasting/test_onnx.py
Replace proprietary tinyml-mlbackend model references with built-in
open-source CNN implementations for arc fault, motor fault, and fan
imbalance classification. Previously these 10 models failed with
ValueError because the proprietary model files were not available.

New model classes: CNN_AF_3L_{200,300,700,1400}, CNN_MF_{1L,2L,3L}

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@musicalplatypus musicalplatypus changed the title Add Apple Silicon (MPS) training support 2. Add Apple Silicon (MPS) training support Apr 7, 2026
Adithya-Thonse pushed a commit that referenced this pull request Jun 12, 2026
Merge in TINYML-ALGO/tinyml-agent-skills from 2026/pranav_a to main

* commit '8c3260bcdd549353b2cdbf3562bb3f32753fdddf':
  improving readme
Adithya-Thonse added a commit that referenced this pull request Jun 12, 2026
de8af16d Pull request #45: https://jira.itg.ti.com/browse/TINYML_ALGO-698
REVERT: e48ef1a Pull request #14: TINYML_ALGO-711: fixing readme
REVERT: 16fc6a6 TINYML_ALGO-711: fixing readme
REVERT: e3639d2 Pull request #13: removing pycache
REVERT: f8bb3b7 removing pycache
REVERT: dd38428 Pull request #12: restructuring agent skill
REVERT: ff02a0e restructuring agent skill
REVERT: d26c6a5 Pull request #11: fixing tiny ml name
REVERT: 640ffd3 fixing tiny ml name
REVERT: 4ee3a19 Pull request #10: 2026/pranav a
REVERT: be83fc6 minor fixes
REVERT: e3a5700 removed assets, included autoMP quant
REVERT: 1af575a Pull request #9: correcting npu devices list
REVERT: 31e9eb1 correcting npu devices list
REVERT: 59b209b Pull request #8: improving readme
REVERT: 8c3260b improving readme
REVERT: 668916f Pull request #7: improving readme
REVERT: 68686b3 improving readme
REVERT: 814316e Pull request #6: fixes to readme and marketplace json
REVERT: e4bc0b4 fixes to readme and marketplace json
REVERT: 6a64208 Pull request #5: fixes to readme
REVERT: 0f9c868 fixes to readme
REVERT: 52f95ff Pull request #4: 2026/pranav a
REVERT: 443295d fixes to readme
REVERT: 1881112 fixes to readme and marketplace json
REVERT: 229ab57 Pull request #3: 2026/pranav a
REVERT: 6519104 minor readme fix
REVERT: 38e9f9f minor readme fix
REVERT: db81f81 Pull request #2: minor readme fix
REVERT: 1c0737a minor readme fix
REVERT: 0a0c02d Pull request #1: minor readme fix
REVERT: b682335 minor readme fix
REVERT: 062eb39 Initial Commit

git-subtree-dir: tinyml-agent-skills
git-subtree-split: de8af16d9e23de3e9bda3d811a0ebdece1178260
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants