Skip to content

Commit 0adbe8c

Browse files
committed
up
1 parent 93afd3e commit 0adbe8c

2 files changed

Lines changed: 82 additions & 43 deletions

File tree

backends/mlx/README.md

Lines changed: 81 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -193,23 +193,23 @@ ExportedProgram (subgraph)
193193

194194
## How to Add a New Op
195195

196-
This section walks through adding a new op end-to-end, using **`aten.linear`**
196+
This section walks through adding a new op end-to-end, using **`aten.addmm`**
197197
as an example.
198198

199199
### Step 1: Add the Node to `schema.fbs`
200200

201201
Add a new table in the "Op nodes" section and add it to the `OpNode` union:
202202

203203
```fbs
204-
table LinearNode {
205-
x: Tid (required);
206-
weight: Tid (required);
204+
table AddmmNode {
205+
mat1: Tid (required);
206+
mat2: Tid (required);
207207
out: Tid (required);
208208
bias: Tid; // optional
209209
}
210210
```
211211

212-
Then add `LinearNode` to the `union OpNode { ... }` list.
212+
Then add `AddmmNode` to the `union OpNode { ... }` list.
213213

214214
### Step 2: Run the Code Generator
215215

@@ -219,34 +219,40 @@ python backends/mlx/serialization/generate.py
219219

220220
This regenerates:
221221

222-
- `mlx_graph_schema.py` — adds `LinearNode` Python dataclass
223-
- `_generated_serializers.py` — adds `_build_LinearNode` serializer
224-
- `runtime/MLXLoader.h` — adds `LinearNode` C++ struct, `OpCode::LINEAR`, loader
225-
- `runtime/MLXLoader.cpp` — adds FlatBuffer → `LinearNode` deserialization
222+
- `mlx_graph_schema.py` — adds `AddmmNode` Python dataclass
223+
- `_generated_serializers.py` — adds `_build_AddmmNode` serializer
224+
- `runtime/MLXLoader.h` — adds `AddmmNode` C++ struct, `OpCode::ADDMM`, loader
225+
- `runtime/MLXLoader.cpp` — adds FlatBuffer → `AddmmNode` deserialization
226226
- `runtime/schema_generated.h` — FlatBuffer C++ bindings
227227

228228
### Step 3: Add the Python Op Handler (`ops.py`)
229229

230230
Register a handler that converts the ATen op to your new node. Make sure to
231-
import `LinearNode` from `mlx_graph_schema`:
231+
import `AddmmNode` from `mlx_graph_schema`:
232232

233233
```python
234-
from executorch.backends.mlx.serialization.mlx_graph_schema import LinearNode
234+
from executorch.backends.mlx.serialization.mlx_graph_schema import AddmmNode
235235

236-
@REGISTRY.register(target=[torch.ops.aten.linear.default])
237-
def _linear_handler(P: MLXProgramBuilder, n: Node) -> Slot:
236+
@REGISTRY.register(target=[torch.ops.aten.addmm.default])
237+
def _addmm_handler(P: MLXProgramBuilder, n: Node) -> Slot:
238238
args = P.args(n)
239-
require_args(args, 2, 3, "aten.linear")
240-
require_kwargs(P.kwargs(n), set(), "aten.linear")
241-
x, w = args[0], args[1]
242-
b = args[2] if len(args) > 2 else None
239+
kwargs = P.kwargs(n)
240+
require_args(args, 3, 3, "aten.addmm")
241+
require_kwargs(kwargs, {"beta", "alpha"}, "aten.addmm")
242+
bias, mat1, mat2 = args[0], args[1], args[2]
243+
244+
beta = kwargs.get("beta", 1)
245+
alpha = kwargs.get("alpha", 1)
246+
243247
out = P.make_or_get_slot(n)
244248
P.emit(
245-
LinearNode(
246-
x=P.slot_to_tid(x),
247-
weight=P.slot_to_tid(w),
249+
AddmmNode(
250+
mat1=P.slot_to_tid(mat1),
251+
mat2=P.slot_to_tid(mat2),
248252
out=P.slot_to_tid(out),
249-
bias=P.slot_to_tid(b) if b else None,
253+
bias=P.slot_to_tid(bias),
254+
alpha=float(alpha),
255+
beta=float(beta),
250256
)
251257
)
252258
return out
@@ -263,21 +269,28 @@ Key APIs:
263269
Add an `exec_*` function in the `ops` namespace:
264270

265271
```cpp
266-
inline void exec_linear(const LinearNode& n, ExecutionState& st, StreamOrDevice s) {
267-
const auto& X = st.const_tensor_ref(n.x);
268-
auto W = transpose(st.const_tensor_ref(n.weight), {1, 0}, s);
269-
array Y = n.bias
270-
? addmm(st.const_tensor_ref(*n.bias), X, W, 1.0f, 1.0f, s)
271-
: matmul(X, W, s);
272+
inline void exec_addmm(const AddmmNode& n, ExecutionState& st, StreamOrDevice s) {
273+
const auto& mat1 = st.const_tensor_ref(n.mat1);
274+
const auto& mat2 = st.const_tensor_ref(n.mat2);
275+
276+
array Y = n.bias ? addmm(
277+
st.const_tensor_ref(*n.bias),
278+
mat1,
279+
mat2,
280+
/*alpha=*/n.alpha,
281+
/*beta=*/n.beta,
282+
s)
283+
: matmul(mat1, mat2, s);
284+
272285
st.set_tensor(n.out, std::move(Y));
273286
}
274287
```
275288
276-
Then add the dispatch case in `Interpreter::execute_instruction()`:
289+
Then add the dispatch case in `Interpreter::dispatch()`:
277290
278291
```cpp
279-
case OpCode::LINEAR:
280-
ops::exec_linear(std::get<LinearNode>(instr.node), st, s);
292+
case OpCode::ADDMM:
293+
ops::exec_addmm(std::get<AddmmNode>(instr.node), st, s);
281294
break;
282295
```
283296

@@ -290,34 +303,60 @@ Each test follows a standard pattern:
290303
3. **Decorate with `@register_test`** to register it with the test runner.
291304

292305
```python
293-
class LinearModel(nn.Module):
294-
def __init__(self, in_features=64, out_features=128, bias=True):
306+
class AddmmModel(nn.Module):
307+
"""Model that performs addmm: bias + (mat1 @ mat2)."""
308+
309+
def __init__(self, in_features, out_features, bias=True, alpha=1.0, beta=1.0):
295310
super().__init__()
296-
self.linear = nn.Linear(in_features, out_features, bias=bias)
311+
self.weight = nn.Parameter(torch.randn(out_features, in_features))
312+
if bias:
313+
self.bias = nn.Parameter(torch.randn(out_features))
314+
else:
315+
self.bias = None
316+
self.alpha = alpha
317+
self.beta = beta
297318

298319
def forward(self, x: torch.Tensor) -> torch.Tensor:
299-
return self.linear(x)
320+
if self.bias is not None:
321+
return torch.addmm(
322+
self.bias, x, self.weight.t(), beta=self.beta, alpha=self.alpha
323+
)
324+
else:
325+
return torch.mm(x, self.weight.t())
300326

301327
@register_test
302-
class LinearTest(OpTestCase):
303-
name = "linear"
328+
class AddmmTest(OpTestCase):
329+
name = "addmm"
304330
rtol = 1e-4
305331
atol = 1e-4
306332

307-
def __init__(self, in_features=64, out_features=128, bias=True):
333+
def __init__(self, batch_size=2, in_features=64, out_features=32,
334+
bias=True, alpha=1.0, beta=1.0):
335+
self.batch_size = batch_size
308336
self.in_features = in_features
309337
self.out_features = out_features
310338
self.bias = bias
339+
self.alpha = alpha
340+
self.beta = beta
341+
self.name = f"addmm_{in_features}x{out_features}"
311342

312343
@classmethod
313344
def get_test_configs(cls):
314-
return [cls(), cls(bias=False)]
345+
return [
346+
cls(batch_size=2, in_features=64, out_features=32),
347+
cls(batch_size=2, in_features=64, out_features=32, bias=False),
348+
cls(batch_size=4, in_features=128, out_features=64),
349+
cls(batch_size=2, in_features=64, out_features=32, alpha=2.0, beta=0.5),
350+
]
315351

316352
def create_model(self):
317-
return LinearModel(self.in_features, self.out_features, bias=self.bias)
353+
return AddmmModel(
354+
self.in_features, self.out_features,
355+
bias=self.bias, alpha=self.alpha, beta=self.beta,
356+
)
318357

319358
def create_inputs(self):
320-
return (torch.randn(2, 16, self.in_features),)
359+
return (torch.randn(self.batch_size, self.in_features),)
321360
```
322361

323362
### Step 6: Run Tests
@@ -327,7 +366,7 @@ outputs against PyTorch reference. Since adding a new op always involves C++
327366
changes, use `--rebuild` to recompile the runtime:
328367

329368
```bash
330-
python -m executorch.backends.mlx.test.run_all_tests --rebuild linear
369+
python -m executorch.backends.mlx.test.run_all_tests --rebuild addmm
331370
```
332371

333372
Run all tests in parallel:
@@ -356,7 +395,7 @@ architecture, prerequisites, and the `OpTestCase` API.
356395
- [ ] Run `python backends/mlx/serialization/generate.py`
357396
- [ ] Add `@REGISTRY.register` handler in `ops.py` (and import the new node class)
358397
- [ ] Add `exec_*` function in `runtime/MLXInterpreter.h`
359-
- [ ] Add `case OpCode::*` in `Interpreter::execute_instruction()`
398+
- [ ] Add `case OpCode::*` in `Interpreter::dispatch()`
360399
- [ ] Add test model + `OpTestCase` in `test/test_ops.py`
361400
- [ ] Run `python -m executorch.backends.mlx.test.run_all_tests --rebuild <test_name>`
362401

backends/mlx/serialization/generate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1006,7 +1006,7 @@ def _fbs_type_to_cpp(
10061006

10071007

10081008
def _table_name_to_opcode(name: str) -> str:
1009-
"""Convert table name like 'LinearNode' to opcode like 'LINEAR'.
1009+
"""Convert table name like 'AddNode' to opcode like 'ADD'.
10101010
10111011
Uses regex-based camelCase → UPPER_SNAKE_CASE conversion with a small
10121012
override dict for names whose conventional opcode doesn't follow the

0 commit comments

Comments
 (0)