Skip to content

Commit 31d4f6a

Browse files
authored
Merge pull request #1294 from TransformerLensOrg/dev
Release v3.2.0
2 parents 6f56518 + 52e4f6d commit 31d4f6a

48 files changed

Lines changed: 68653 additions & 31613 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.github/workflows/checks.yml

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,76 @@ jobs:
103103
- name: Build check
104104
run: uv build
105105

106+
mps-checks:
107+
name: MPS Checks
108+
runs-on: macos-latest
109+
# Only run on PRs merging to main or pushes directly to main
110+
if: >
111+
(github.event_name == 'pull_request' && github.base_ref == 'main') ||
112+
(github.event_name == 'push' && github.ref == 'refs/heads/main')
113+
steps:
114+
- uses: actions/checkout@v4
115+
- name: Install uv
116+
uses: astral-sh/setup-uv@v6
117+
with:
118+
python-version: "3.11"
119+
activate-environment: true
120+
enable-cache: true
121+
- name: MPS Cache Models
122+
uses: actions/cache@v3
123+
with:
124+
path: |
125+
~/.cache/huggingface/hub/models--roneneldan--TinyStories-1M*
126+
key: ${{ runner.os }}-huggingface-models-mps-v1
127+
- name: Install dependencies
128+
run: |
129+
uv lock --check
130+
uv sync
131+
- name: MPS Availability Check
132+
run: |
133+
uv run python -c "
134+
import torch
135+
print(f'PyTorch: {torch.__version__}')
136+
print(f'MPS available: {torch.backends.mps.is_available()}')
137+
print(f'MPS built: {torch.backends.mps.is_built()}')
138+
assert torch.backends.mps.is_available(), 'MPS not available on this runner!'
139+
"
140+
env:
141+
HF_TOKEN: ${{ secrets.HF_TOKEN }}
142+
- name: MPS Unit Tests
143+
run: >
144+
uv run pytest tests/unit -v
145+
--ignore=tests/unit/model_bridge/test_optimizer_compatibility.py
146+
--ignore=tests/unit/model_bridge/test_gpt_oss_moe.py
147+
--ignore=tests/unit/model_bridge/test_component_inspection.py
148+
--ignore=tests/unit/model_bridge/test_key_analysis.py
149+
--ignore=tests/unit/model_bridge/test_benchmark_gated_hooks_fire.py
150+
--ignore=tests/unit/model_bridge/test_weight_processing_adapter_paths.py
151+
--ignore=tests/unit/model_bridge/test_bridge_generate_kv_cache.py
152+
--ignore=tests/unit/model_bridge/test_bridge_vs_hooked_transformer_patching.py
153+
--ignore=tests/unit/model_bridge/compatibility/
154+
env:
155+
HF_TOKEN: ${{ secrets.HF_TOKEN }}
156+
- name: MPS Integration Tests
157+
run: >
158+
uv run pytest tests/integration -v
159+
--ignore=tests/integration/model_bridge/test_optimizer_compatibility.py
160+
--ignore=tests/integration/model_bridge/test_bridge_generation.py
161+
--ignore=tests/integration/model_bridge/test_bridge_integration.py
162+
--ignore=tests/integration/model_bridge/compatibility/
163+
--ignore=tests/integration/test_prepend_bos.py
164+
--ignore=tests/integration/test_generation_compatibility.py
165+
--ignore=tests/integration/test_match_huggingface.py
166+
--ignore=tests/integration/test_fold_layer_integration.py
167+
--ignore=tests/integration/test_centralized_weight_processing.py
168+
env:
169+
HF_TOKEN: ${{ secrets.HF_TOKEN }}
170+
- name: MPS Smoke Tests
171+
run: uv run pytest tests/mps -v
172+
env:
173+
TRANSFORMERLENS_ALLOW_MPS: "1"
174+
HF_TOKEN: ${{ secrets.HF_TOKEN }}
175+
106176
format-check:
107177
name: Format Check
108178
runs-on: ubuntu-latest
@@ -231,6 +301,7 @@ jobs:
231301
- "Activation_Patching_in_TL_Demo"
232302
- "ARENA_Content"
233303
- "BERT"
304+
- "Bridge_Evals_Demo"
234305
- "Exploratory_Analysis_Demo"
235306
# - "Grokking_Demo"
236307
- "Head_Detector_Demo"
@@ -272,7 +343,7 @@ jobs:
272343
- name: Install dependencies
273344
run: |
274345
uv lock --check
275-
uv sync --group quantization
346+
uv sync --group quantization --extra evals
276347
- name: Install pandoc
277348
uses: awalsh128/cache-apt-pkgs-action@latest
278349
with:

0 commit comments

Comments
 (0)