@@ -103,6 +103,76 @@ jobs:
103103 - name : Build check
104104 run : uv build
105105
106+ mps-checks :
107+ name : MPS Checks
108+ runs-on : macos-latest
109+ # Only run on PRs merging to main or pushes directly to main
110+ if : >
111+ (github.event_name == 'pull_request' && github.base_ref == 'main') ||
112+ (github.event_name == 'push' && github.ref == 'refs/heads/main')
113+ steps :
114+ - uses : actions/checkout@v4
115+ - name : Install uv
116+ uses : astral-sh/setup-uv@v6
117+ with :
118+ python-version : " 3.11"
119+ activate-environment : true
120+ enable-cache : true
121+ - name : MPS Cache Models
122+ uses : actions/cache@v3
123+ with :
124+ path : |
125+ ~/.cache/huggingface/hub/models--roneneldan--TinyStories-1M*
126+ key : ${{ runner.os }}-huggingface-models-mps-v1
127+ - name : Install dependencies
128+ run : |
129+ uv lock --check
130+ uv sync
131+ - name : MPS Availability Check
132+ run : |
133+ uv run python -c "
134+ import torch
135+ print(f'PyTorch: {torch.__version__}')
136+ print(f'MPS available: {torch.backends.mps.is_available()}')
137+ print(f'MPS built: {torch.backends.mps.is_built()}')
138+ assert torch.backends.mps.is_available(), 'MPS not available on this runner!'
139+ "
140+ env :
141+ HF_TOKEN : ${{ secrets.HF_TOKEN }}
142+ - name : MPS Unit Tests
143+ run : >
144+ uv run pytest tests/unit -v
145+ --ignore=tests/unit/model_bridge/test_optimizer_compatibility.py
146+ --ignore=tests/unit/model_bridge/test_gpt_oss_moe.py
147+ --ignore=tests/unit/model_bridge/test_component_inspection.py
148+ --ignore=tests/unit/model_bridge/test_key_analysis.py
149+ --ignore=tests/unit/model_bridge/test_benchmark_gated_hooks_fire.py
150+ --ignore=tests/unit/model_bridge/test_weight_processing_adapter_paths.py
151+ --ignore=tests/unit/model_bridge/test_bridge_generate_kv_cache.py
152+ --ignore=tests/unit/model_bridge/test_bridge_vs_hooked_transformer_patching.py
153+ --ignore=tests/unit/model_bridge/compatibility/
154+ env :
155+ HF_TOKEN : ${{ secrets.HF_TOKEN }}
156+ - name : MPS Integration Tests
157+ run : >
158+ uv run pytest tests/integration -v
159+ --ignore=tests/integration/model_bridge/test_optimizer_compatibility.py
160+ --ignore=tests/integration/model_bridge/test_bridge_generation.py
161+ --ignore=tests/integration/model_bridge/test_bridge_integration.py
162+ --ignore=tests/integration/model_bridge/compatibility/
163+ --ignore=tests/integration/test_prepend_bos.py
164+ --ignore=tests/integration/test_generation_compatibility.py
165+ --ignore=tests/integration/test_match_huggingface.py
166+ --ignore=tests/integration/test_fold_layer_integration.py
167+ --ignore=tests/integration/test_centralized_weight_processing.py
168+ env :
169+ HF_TOKEN : ${{ secrets.HF_TOKEN }}
170+ - name : MPS Smoke Tests
171+ run : uv run pytest tests/mps -v
172+ env :
173+ TRANSFORMERLENS_ALLOW_MPS : " 1"
174+ HF_TOKEN : ${{ secrets.HF_TOKEN }}
175+
106176 format-check :
107177 name : Format Check
108178 runs-on : ubuntu-latest
@@ -231,6 +301,7 @@ jobs:
231301 - " Activation_Patching_in_TL_Demo"
232302 - " ARENA_Content"
233303 - " BERT"
304+ - " Bridge_Evals_Demo"
234305 - " Exploratory_Analysis_Demo"
235306 # - "Grokking_Demo"
236307 - " Head_Detector_Demo"
@@ -272,7 +343,7 @@ jobs:
272343 - name : Install dependencies
273344 run : |
274345 uv lock --check
275- uv sync --group quantization
346+ uv sync --group quantization --extra evals
276347 - name : Install pandoc
277348 uses : awalsh128/cache-apt-pkgs-action@latest
278349 with :
0 commit comments