@@ -193,23 +193,23 @@ ExportedProgram (subgraph)
193193
194194## How to Add a New Op
195195
196- This section walks through adding a new op end-to-end, using ** ` aten.linear ` **
196+ This section walks through adding a new op end-to-end, using ** ` aten.addmm ` **
197197as an example.
198198
199199### Step 1: Add the Node to ` schema.fbs `
200200
201201Add a new table in the "Op nodes" section and add it to the ` OpNode ` union:
202202
203203``` fbs
204- table LinearNode {
205- x : Tid (required);
206- weight : Tid (required);
204+ table AddmmNode {
205+ mat1 : Tid (required);
206+ mat2 : Tid (required);
207207 out: Tid (required);
208208 bias: Tid; // optional
209209}
210210```
211211
212- Then add ` LinearNode ` to the ` union OpNode { ... } ` list.
212+ Then add ` AddmmNode ` to the ` union OpNode { ... } ` list.
213213
214214### Step 2: Run the Code Generator
215215
@@ -219,34 +219,40 @@ python backends/mlx/serialization/generate.py
219219
220220This regenerates:
221221
222- - ` mlx_graph_schema.py ` — adds ` LinearNode ` Python dataclass
223- - ` _generated_serializers.py ` — adds ` _build_LinearNode ` serializer
224- - ` runtime/MLXLoader.h ` — adds ` LinearNode ` C++ struct, ` OpCode::LINEAR ` , loader
225- - ` runtime/MLXLoader.cpp ` — adds FlatBuffer → ` LinearNode ` deserialization
222+ - ` mlx_graph_schema.py ` — adds ` AddmmNode ` Python dataclass
223+ - ` _generated_serializers.py ` — adds ` _build_AddmmNode ` serializer
224+ - ` runtime/MLXLoader.h ` — adds ` AddmmNode ` C++ struct, ` OpCode::ADDMM ` , loader
225+ - ` runtime/MLXLoader.cpp ` — adds FlatBuffer → ` AddmmNode ` deserialization
226226- ` runtime/schema_generated.h ` — FlatBuffer C++ bindings
227227
228228### Step 3: Add the Python Op Handler (` ops.py ` )
229229
230230Register a handler that converts the ATen op to your new node. Make sure to
231- import ` LinearNode ` from ` mlx_graph_schema ` :
231+ import ` AddmmNode ` from ` mlx_graph_schema ` :
232232
233233``` python
234- from executorch.backends.mlx.serialization.mlx_graph_schema import LinearNode
234+ from executorch.backends.mlx.serialization.mlx_graph_schema import AddmmNode
235235
236- @REGISTRY.register (target = [torch.ops.aten.linear .default])
237- def _linear_handler (P : MLXProgramBuilder, n : Node) -> Slot:
236+ @REGISTRY.register (target = [torch.ops.aten.addmm .default])
237+ def _addmm_handler (P : MLXProgramBuilder, n : Node) -> Slot:
238238 args = P.args(n)
239- require_args(args, 2 , 3 , " aten.linear" )
240- require_kwargs(P.kwargs(n), set (), " aten.linear" )
241- x, w = args[0 ], args[1 ]
242- b = args[2 ] if len (args) > 2 else None
239+ kwargs = P.kwargs(n)
240+ require_args(args, 3 , 3 , " aten.addmm" )
241+ require_kwargs(kwargs, {" beta" , " alpha" }, " aten.addmm" )
242+ bias, mat1, mat2 = args[0 ], args[1 ], args[2 ]
243+
244+ beta = kwargs.get(" beta" , 1 )
245+ alpha = kwargs.get(" alpha" , 1 )
246+
243247 out = P.make_or_get_slot(n)
244248 P.emit(
245- LinearNode (
246- x = P.slot_to_tid(x ),
247- weight = P.slot_to_tid(w ),
249+ AddmmNode (
250+ mat1 = P.slot_to_tid(mat1 ),
251+ mat2 = P.slot_to_tid(mat2 ),
248252 out = P.slot_to_tid(out),
249- bias = P.slot_to_tid(b) if b else None ,
253+ bias = P.slot_to_tid(bias),
254+ alpha = float (alpha),
255+ beta = float (beta),
250256 )
251257 )
252258 return out
@@ -263,21 +269,28 @@ Key APIs:
263269Add an ` exec_* ` function in the ` ops ` namespace:
264270
265271``` cpp
266- inline void exec_linear (const LinearNode& n, ExecutionState& st, StreamOrDevice s) {
267- const auto& X = st.const_tensor_ref(n.x);
268- auto W = transpose(st.const_tensor_ref(n.weight), {1, 0}, s);
269- array Y = n.bias
270- ? addmm(st.const_tensor_ref(* n.bias), X, W, 1.0f, 1.0f, s)
271- : matmul(X, W, s);
272+ inline void exec_addmm (const AddmmNode& n, ExecutionState& st, StreamOrDevice s) {
273+ const auto& mat1 = st.const_tensor_ref(n.mat1);
274+ const auto& mat2 = st.const_tensor_ref(n.mat2);
275+
276+ array Y = n.bias ? addmm(
277+ st.const_tensor_ref(*n.bias),
278+ mat1,
279+ mat2,
280+ /*alpha=*/n.alpha,
281+ /*beta=*/n.beta,
282+ s)
283+ : matmul(mat1, mat2, s);
284+
272285 st.set_tensor(n.out, std::move(Y));
273286}
274287```
275288
276- Then add the dispatch case in `Interpreter::execute_instruction ()`:
289+ Then add the dispatch case in `Interpreter::dispatch ()`:
277290
278291```cpp
279- case OpCode::LINEAR :
280- ops::exec_linear (std::get<LinearNode >(instr.node), st, s);
292+ case OpCode::ADDMM :
293+ ops::exec_addmm (std::get<AddmmNode >(instr.node), st, s);
281294 break;
282295```
283296
@@ -290,34 +303,60 @@ Each test follows a standard pattern:
2903033 . ** Decorate with ` @register_test ` ** to register it with the test runner.
291304
292305``` python
293- class LinearModel (nn .Module ):
294- def __init__ (self , in_features = 64 , out_features = 128 , bias = True ):
306+ class AddmmModel (nn .Module ):
307+ """ Model that performs addmm: bias + (mat1 @ mat2)."""
308+
309+ def __init__ (self , in_features , out_features , bias = True , alpha = 1.0 , beta = 1.0 ):
295310 super ().__init__ ()
296- self .linear = nn.Linear(in_features, out_features, bias = bias)
311+ self .weight = nn.Parameter(torch.randn(out_features, in_features))
312+ if bias:
313+ self .bias = nn.Parameter(torch.randn(out_features))
314+ else :
315+ self .bias = None
316+ self .alpha = alpha
317+ self .beta = beta
297318
298319 def forward (self , x : torch.Tensor) -> torch.Tensor:
299- return self .linear(x)
320+ if self .bias is not None :
321+ return torch.addmm(
322+ self .bias, x, self .weight.t(), beta = self .beta, alpha = self .alpha
323+ )
324+ else :
325+ return torch.mm(x, self .weight.t())
300326
301327@register_test
302- class LinearTest (OpTestCase ):
303- name = " linear "
328+ class AddmmTest (OpTestCase ):
329+ name = " addmm "
304330 rtol = 1e-4
305331 atol = 1e-4
306332
307- def __init__ (self , in_features = 64 , out_features = 128 , bias = True ):
333+ def __init__ (self , batch_size = 2 , in_features = 64 , out_features = 32 ,
334+ bias = True , alpha = 1.0 , beta = 1.0 ):
335+ self .batch_size = batch_size
308336 self .in_features = in_features
309337 self .out_features = out_features
310338 self .bias = bias
339+ self .alpha = alpha
340+ self .beta = beta
341+ self .name = f " addmm_ { in_features} x { out_features} "
311342
312343 @ classmethod
313344 def get_test_configs (cls ):
314- return [cls (), cls (bias = False )]
345+ return [
346+ cls (batch_size = 2 , in_features = 64 , out_features = 32 ),
347+ cls (batch_size = 2 , in_features = 64 , out_features = 32 , bias = False ),
348+ cls (batch_size = 4 , in_features = 128 , out_features = 64 ),
349+ cls (batch_size = 2 , in_features = 64 , out_features = 32 , alpha = 2.0 , beta = 0.5 ),
350+ ]
315351
316352 def create_model (self ):
317- return LinearModel(self .in_features, self .out_features, bias = self .bias)
353+ return AddmmModel(
354+ self .in_features, self .out_features,
355+ bias = self .bias, alpha = self .alpha, beta = self .beta,
356+ )
318357
319358 def create_inputs (self ):
320- return (torch.randn(2 , 16 , self .in_features),)
359+ return (torch.randn(self .batch_size , self .in_features),)
321360```
322361
323362### Step 6: Run Tests
@@ -327,7 +366,7 @@ outputs against PyTorch reference. Since adding a new op always involves C++
327366changes, use ` --rebuild ` to recompile the runtime:
328367
329368``` bash
330- python -m executorch.backends.mlx.test.run_all_tests --rebuild linear
369+ python -m executorch.backends.mlx.test.run_all_tests --rebuild addmm
331370```
332371
333372Run all tests in parallel:
@@ -356,7 +395,7 @@ architecture, prerequisites, and the `OpTestCase` API.
356395- [ ] Run ` python backends/mlx/serialization/generate.py `
357396- [ ] Add ` @REGISTRY.register ` handler in ` ops.py ` (and import the new node class)
358397- [ ] Add ` exec_* ` function in ` runtime/MLXInterpreter.h `
359- - [ ] Add ` case OpCode::* ` in ` Interpreter::execute_instruction () `
398+ - [ ] Add ` case OpCode::* ` in ` Interpreter::dispatch () `
360399- [ ] Add test model + ` OpTestCase ` in ` test/test_ops.py `
361400- [ ] Run ` python -m executorch.backends.mlx.test.run_all_tests --rebuild <test_name> `
362401
0 commit comments