Skip to content

Commit 3bbe3aa

Browse files
yebaiclaude
andcommitted
Rename evaluator API to ADProblems and fold autograd testing into the core interface.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 216232b commit 3bbe3aa

23 files changed

Lines changed: 903 additions & 374 deletions

.claude/skills/inspect/SKILL.md

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
---
2+
name: inspect
3+
description: Inspect the AD pipeline IR for a Julia function at each Mooncake compilation stage.
4+
---
5+
6+
# Inspect
7+
8+
Inspect IR transformations in Mooncake's AD pipeline for a given function.
9+
10+
## Setup
11+
12+
```julia
13+
using Mooncake, Mooncake.SkillUtils
14+
```
15+
16+
## Gathering user intent
17+
18+
Ask the user:
19+
20+
1. **Function and arguments** — e.g. `sin, 1.0` or a custom function
21+
2. **Mode** — reverse (default) or forward
22+
3. **What to view** — all stages, a specific stage, a diff between two stages, or world age info
23+
24+
Do not assume — ask the user to pick.
25+
26+
## Pipeline stages
27+
28+
### Reverse mode (default)
29+
30+
| Stage | Symbol | Description |
31+
|:----------------- |:---------------- |:---------------------------------------------------- |
32+
| Raw IR | `:raw` | optimised, type-inferred SSAIR from Julia's compiler |
33+
| Normalised | `:normalized` | after Mooncake's normalisation passes |
34+
| BBCode | `:bbcode` | BBCode representation with stable IDs |
35+
| Forward IR | `:fwd_ir` | generated forward-pass IR |
36+
| Reverse IR | `:rvs_ir` | generated pullback IR |
37+
| Optimised Forward | `:optimized_fwd` | forward pass after optimisation |
38+
| Optimised Reverse | `:optimized_rvs` | pullback after optimisation |
39+
40+
### Forward mode
41+
42+
| Stage | Symbol | Description |
43+
|:---------- |:------------- |:------------------------------------------------------------- |
44+
| Raw IR | `:raw` | optimised, type-inferred SSAIR from Julia's compiler |
45+
| Normalised | `:normalized` | after Mooncake's normalisation passes |
46+
| BBCode | `:bbcode` | inspection-only — forward mode does not use BBCode internally |
47+
| Dual IR | `:dual_ir` | generated dual-number IR |
48+
| Optimised | `:optimized` | after optimisation passes |
49+
50+
## Commands
51+
52+
```julia
53+
# Full inspection
54+
ins = inspect_ir(f, args...; mode=:reverse) # or mode=:forward
55+
56+
# View stages
57+
show_ir(ins) # all stages
58+
show_stage(ins, :raw) # one stage
59+
60+
# Diffs between stages
61+
show_diff(ins; from=:raw, to=:normalized)
62+
show_all_diffs(ins)
63+
64+
# World age debugging
65+
show_world_info(ins)
66+
67+
# Write everything to files
68+
write_ir(ins, "/tmp/ir_output")
69+
70+
# Shorthand helpers
71+
ins = inspect_fwd(f, args...) # forward mode
72+
ins = inspect_rvs(f, args...) # reverse mode
73+
ins = quick_inspect(f, args...) # inspect + display immediately
74+
75+
# Options
76+
inspect_ir(f, args...; mode=:reverse, optimize=true, do_inline=true, debug_mode=false)
77+
```
78+
79+
## Presenting results
80+
81+
- Run commands via Bash and present IR in fenced code blocks.
82+
- When showing diffs, explain what changed and why the transformation matters.
83+
- If errors occur, check that Mooncake is loaded and the function signature is valid.
84+
85+
## Limitations
86+
87+
Inspects Mooncake's internal AD pipeline only. For allocation, world-age, or compiler-boundary debugging, see `docs/src/developer_documentation/advanced_debugging.md`.

.claude/skills/minimise/SKILL.md

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
---
2+
name: minimise
3+
description: Prune a bug fix or new tests down to the smallest correct diff through multiple elimination passes. Use before committing any fix or test addition.
4+
---
5+
6+
# Minimise
7+
8+
The goal is to remove every line that is not strictly required for correctness,
9+
then verify the result still passes the relevant tests.
10+
11+
## Process
12+
13+
Repeat the following until no further reductions are possible:
14+
15+
1. **Read the diff.** Run `git diff HEAD` (or `git diff --cached` if staged) and
16+
read every changed file in full.
17+
18+
2. **Challenge each change.** For every changed line ask:
19+
20+
+ Would removing this line cause a test to fail or a bug to reappear?
21+
+ Is this a cleanup, rename, refactor, or comment that is not load-bearing?
22+
+ For new tests: does an existing test already cover this behaviour?
23+
If so, drop the new test entirely.
24+
3. **Remove non-essential changes.** Delete anything that does not answer
25+
"yes" to the first question above. Prefer shrinking an existing case over
26+
adding a new one.
27+
4. **Run the minimal test group.** Use the smallest focused test group that
28+
exercises the changed code (see `test/runtests.jl` for group names).
29+
Confirm all tests pass before continuing.
30+
5. **Repeat** from step 1 until a full pass produces no further removals.
31+
32+
## Heuristics
33+
34+
- A one-line fix is better than a five-line fix.
35+
- A new test case added to an existing `@testset` is better than a new `@testset`.
36+
- A new value constructor in `src/test_resources.jl` should be the minimum needed
37+
to instantiate the type under test; no extra fields or variants.
38+
- Comments and blank lines added alongside a fix are not load-bearing; remove them
39+
unless they explain something non-obvious.
40+
- Helper functions introduced solely for the fix are a red flag; inline them.
41+
42+
## When to stop
43+
44+
Stop when every remaining line answers "yes" to: *if I remove this, the targeted
45+
bug reappears or the targeted test fails*. At that point report the final diff and
46+
suggest committing.

.claude/skills/scrutinise/SKILL.md

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
---
2+
name: scrutinise
3+
description: Scrutinise newly added or changed code on the current branch against main. Checks new types, methods, changed signatures, overloads, and helpers for necessity, correctness, clarity, consistency, robustness, and minimality. Reviews new tests for gap coverage, overlap with existing tests, minimality, and use of established testing patterns. Invoke with /scrutinise.
4+
tools: Bash, Glob, Grep, Read, Edit, Write
5+
---
6+
7+
# Scrutinise
8+
9+
Review all changes on the current branch relative to `main`. Cover source and tests separately, then simplify.
10+
11+
## Step 1: Gather the diff
12+
13+
```bash
14+
git diff main...HEAD --name-only
15+
git diff main...HEAD -- src/ ext/ test/
16+
```
17+
18+
Read changed files in full before commenting.
19+
20+
## Step 2: Source
21+
22+
For every new type, method, changed signature, or overload:
23+
24+
- **Necessary?** Does existing infrastructure (`@zero_derivative`, `@from_rrule`, broader signatures) already cover this? Could an overload be eliminated by broadening an existing one?
25+
- **Correct?** Tangent/cotangent types consistent with `tangent_type`? `@is_primitive` declared? For `rrule!!`: pullback restores mutations, aliasing handled. For `frule!!`: dual propagation correct, removable singularities handled.
26+
- **Clear and consistent?** Names and structure match the surrounding file and `src/rules/`. `NoTangent`/`ZeroTangent` used correctly.
27+
- **Robust?** Edge cases (empty arrays, zero-size structs, complex types) handled or explicitly excluded. Fails loudly on unsupported inputs.
28+
- **Minimal?** No dead branches, unused arguments, or speculative generalisations.
29+
30+
For every new helper: does it genuinely aid readability or reduce duplication, or can it be inlined? Does it belong in `src/utils.jl` or is it rule-local?
31+
32+
For every new or changed **comment**:
33+
34+
- **WHY not WHAT?** Delete comments that restate what the code already says (variable names, types, control flow). Keep only non-obvious constraints, invariants, and design rationale.
35+
- **Accurate?** Does the comment still match the code? Stale or contradictory comments are worse than none.
36+
- **Brief?** Trim verbose multi-line blocks to the minimum that preserves the WHY. Cross-references (`see X for WHY`) are fine but the local comment should still give enough context to understand the constraint without chasing the reference.
37+
38+
For every new or changed **docstring**:
39+
40+
- **Correct?** Does it accurately describe current behaviour, including any overloads (e.g. `Ptr` special cases)?
41+
- **No leaking internals?** Docstrings are public-facing; do not refer users to internal comments or implementation details they cannot rely on.
42+
- **Concise?** One sentence for simple functions; a short paragraph for complex ones. Avoid restating the signature.
43+
44+
## Step 3: Tests
45+
46+
For every new or changed test:
47+
48+
- **Real gap?** Would removing it leave a regression undetected, or is it duplicating interpreter-level coverage via `TestResources.generate_test_functions()`?
49+
- **No overlap?** Check the corresponding test file and `test/front_matter.jl` for existing tests on the same rule/type.
50+
- **Minimal?** Smallest example that exercises the gap; no redundant argument combinations.
51+
- **Right pattern?** Rules → `test_rule`. Tangents → `test_tangent` / `test_tangent_type_and_tglob_type_agree`. Duals → `test_dual` / `test_fdata` / `test_rdata`. Allocations → `count_allocs`. Malformed rules → `DebugMode`. Flag any test reimplementing logic already in the test utilities.
52+
53+
## Step 4: Output
54+
55+
Findings grouped by file, labelled:
56+
57+
- **Unnecessary** / **Incorrect** / **Unclear** / **Inconsistent** / **Non-minimal** / **Fragile**
58+
- **Comment: stale** / **Comment: explains WHAT** / **Comment: too verbose** / **Comment: missing WHY**
59+
- **Docstring: incorrect** / **Docstring: leaks internals** / **Docstring: too verbose**
60+
- **Test: redundant** / **Test: missing pattern** / **Test: weak gap**
61+
62+
No issues in a section → write "No issues." Do not suggest additions beyond what the diff introduces.
63+
64+
## Step 5: Simplify
65+
66+
Invoke the `simplify` skill to apply code-quality and reuse fixes to the changed files.

docs/src/pplapi.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,5 +26,6 @@ DerivativeOrder
2626
capabilities
2727
prepare
2828
value_and_gradient
29+
test_autograd
2930
dimension
3031
```

ext/AbstractPPLDifferentiationInterfaceExt.jl

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,12 @@ struct DIPrepared{E,B,C}
88
evaluator::E
99
backend::B
1010
prep::C
11-
dim::Int
1211
end
1312

1413
AbstractPPL.capabilities(::Type{<:DIPrepared}) = DerivativeOrder{1}()
15-
AbstractPPL.dimension(p::DIPrepared) = p.dim
14+
AbstractPPL.dimension(p::DIPrepared) = AbstractPPL.dimension(p.evaluator)
1615

17-
function (p::DIPrepared)(x::AbstractVector{<:AbstractFloat})
18-
length(x) == p.dim || throw(
19-
DimensionMismatch(
20-
"Expected a vector of length $(p.dim), but got length $(length(x))."
21-
),
22-
)
16+
function (p::DIPrepared)(x)
2317
return p.evaluator(x)
2418
end
2519

@@ -29,19 +23,16 @@ end
2923
function AbstractPPL.prepare(
3024
adtype::ADTypes.AbstractADType, problem, x::AbstractVector{<:AbstractFloat}
3125
)
32-
evaluator = AbstractPPL.prepare(problem, x)
26+
evaluator = AbstractPPL.ADProblems.VectorEvaluator(
27+
AbstractPPL.prepare(problem, x), length(x)
28+
)
3329
prep = DI.prepare_gradient(evaluator, adtype, x)
34-
return DIPrepared(evaluator, adtype, prep, length(x))
30+
return DIPrepared(evaluator, adtype, prep)
3531
end
3632

3733
@inline function AbstractPPL.value_and_gradient(
3834
p::DIPrepared, x::AbstractVector{<:AbstractFloat}
3935
)
40-
length(x) == p.dim || throw(
41-
DimensionMismatch(
42-
"Expected a vector of length $(p.dim), but got length $(length(x))."
43-
),
44-
)
4536
return DI.value_and_gradient(p.evaluator, p.prep, p.backend, x)
4637
end
4738

ext/AbstractPPLEnzymeExt.jl

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,34 +6,25 @@ using Enzyme: Enzyme
66

77
struct EnzymePrepared{E}
88
evaluator::E
9-
dim::Int
109
end
1110

1211
AbstractPPL.capabilities(::Type{<:EnzymePrepared}) = DerivativeOrder{1}()
13-
AbstractPPL.dimension(p::EnzymePrepared) = p.dim
12+
AbstractPPL.dimension(p::EnzymePrepared) = AbstractPPL.dimension(p.evaluator)
1413

15-
function (p::EnzymePrepared)(x::AbstractVector{<:AbstractFloat})
16-
length(x) == p.dim || throw(
17-
DimensionMismatch(
18-
"Expected a vector of length $(p.dim), but got length $(length(x))."
19-
),
20-
)
14+
function (p::EnzymePrepared)(x)
2115
return p.evaluator(x)
2216
end
2317

2418
function AbstractPPL.prepare(::AutoEnzyme, problem, x::AbstractVector{<:AbstractFloat})
25-
evaluator = AbstractPPL.prepare(problem, x)
26-
return EnzymePrepared(evaluator, length(x))
19+
evaluator = AbstractPPL.ADProblems.VectorEvaluator(
20+
AbstractPPL.prepare(problem, x), length(x)
21+
)
22+
return EnzymePrepared(evaluator)
2723
end
2824

2925
@inline function AbstractPPL.value_and_gradient(
3026
p::EnzymePrepared, x::AbstractVector{<:AbstractFloat}
3127
)
32-
length(x) == p.dim || throw(
33-
DimensionMismatch(
34-
"Expected a vector of length $(p.dim), but got length $(length(x))."
35-
),
36-
)
3728
dx = zero(x)
3829
result = Enzyme.autodiff(
3930
Enzyme.set_runtime_activity(Enzyme.ReverseWithPrimal),

0 commit comments

Comments
 (0)