Skip to content

Commit 748da9d

Browse files
committed
Add docs for serialization with NX.serialize
1 parent 5bb8fbc commit 748da9d

5 files changed

Lines changed: 296 additions & 2 deletions

File tree

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@
1515

1616
## v0.7.0 (2024-10-08)
1717

18+
### Breaking Changes
19+
20+
* **Removed `Axon.serialize/2` and `Axon.deserialize/2`** — Use `Nx.serialize/2` and `Nx.deserialize/2` for parameters instead. Axon recommends serializing only the trained parameters (weights) and keeping the model definition in code. See the [Saving and Loading](guides/serialization/saving_and_loading.livemd) guide.
21+
1822
### Bug Fixes
1923

2024
* Do not cast integers in in Axon.MixedPrecision.cast/2

guides/guides.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,5 +28,5 @@ Axon is a library for creating and training neural networks in Elixir. The Axon
2828

2929
## Serialization
3030

31-
* [Converting ONNX models to Axon](serialization/onnx_to_axon.livemd)
31+
* [Saving and loading models](serialization/saving_and_loading.livemd)
3232

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
# Saving and Loading Models
2+
3+
## Section
4+
5+
```elixir
6+
Mix.install([
7+
{:axon, "~> 0.8"}
8+
])
9+
```
10+
11+
## Overview
12+
13+
Axon recommends a **parameters-only** approach to saving models: serialize only the trained parameters (weights) using `Nx.serialize/2` and `Nx.deserialize/2`, and keep the model definition in your code. This approach:
14+
15+
* Avoids serialization issues with anonymous functions and complex model structures
16+
* Makes the model structure explicit and version-controlled in code
17+
* Works reliably across processes and deployments
18+
19+
The model itself is just code, you define it once and reuse it. Only the learned parameters need to be persisted.
20+
21+
## Saving a Model After Training
22+
23+
When you run a training loop, it returns the trained model state by default. Extract the parameters and serialize them:
24+
25+
```elixir
26+
model =
27+
Axon.input("data")
28+
|> Axon.dense(8)
29+
|> Axon.relu()
30+
|> Axon.dense(4)
31+
|> Axon.relu()
32+
|> Axon.dense(1)
33+
34+
loop = Axon.Loop.trainer(model, :mean_squared_error, :sgd)
35+
36+
train_data =
37+
Stream.repeatedly(fn ->
38+
{xs, _} = Nx.Random.key(System.unique_integer([:positive])) |> Nx.Random.normal(shape: {8, 1})
39+
{xs, Nx.sin(xs)}
40+
end)
41+
42+
trained_model_state = Axon.Loop.run(loop, train_data, Axon.ModelState.empty(), epochs: 2, iterations: 100)
43+
```
44+
45+
The training loop returns `model_state` by default (from `Axon.Loop.trainer/3`). For inference, we need the parameters—extract the `data` field from `ModelState`:
46+
47+
```elixir
48+
# Extract parameters - trained_model_state.data contains the nested map of weights
49+
params = trained_model_state.data
50+
51+
# Serialize and save
52+
params_bytes = Nx.serialize(params)
53+
File.write!("model_params.axon", params_bytes)
54+
```
55+
56+
## Loading a Model for Inference
57+
58+
To load and run inference, you need:
59+
60+
1. The model definition (in code—the same structure you trained)
61+
2. The saved parameters
62+
63+
```elixir
64+
# 1. Define the same model structure (must match training)
65+
model =
66+
Axon.input("data")
67+
|> Axon.dense(8)
68+
|> Axon.relu()
69+
|> Axon.dense(4)
70+
|> Axon.relu()
71+
|> Axon.dense(1)
72+
73+
# 2. Load parameters
74+
params = File.read!("model_params.axon") |> Nx.deserialize()
75+
76+
# 3. Run inference
77+
input = Nx.tensor([[1.0]]) # shape {1, 1}: 1 sample with 1 feature (matches model input)
78+
Axon.predict(model, params, %{"data" => input})
79+
```
80+
81+
## Checkpointing During Training
82+
83+
To save checkpoints during training (e.g., every epoch or when validation improves), use `Axon.Loop.checkpoint/2`. This serializes the full loop state—including model parameters and optimizer state—so you can resume training later.
84+
85+
```elixir
86+
model =
87+
Axon.input("data")
88+
|> Axon.dense(8)
89+
|> Axon.relu()
90+
|> Axon.dense(1)
91+
92+
loop =
93+
model
94+
|> Axon.Loop.trainer(:mean_squared_error, :sgd)
95+
|> Axon.Loop.checkpoint(path: "checkpoints", event: :epoch_completed)
96+
97+
train_data =
98+
Stream.repeatedly(fn ->
99+
{xs, _} = Nx.Random.key(System.unique_integer([:positive])) |> Nx.Random.normal(shape: {8, 1})
100+
{xs, Nx.sin(xs)}
101+
end)
102+
103+
Axon.Loop.run(loop, train_data, Axon.ModelState.empty(), epochs: 3, iterations: 50)
104+
```
105+
106+
Checkpoints are saved to the `checkpoints/` directory, as configured above. Each file contains the serialized loop state from `Axon.Loop.serialize_state/2`.
107+
108+
## Resuming from a Checkpoint
109+
110+
To resume training from a saved checkpoint:
111+
112+
1. Load the checkpoint with `Axon.Loop.deserialize_state/2`
113+
2. Attach it to your loop with `Axon.Loop.from_state/2`
114+
3. Run the loop as usual
115+
116+
```elixir
117+
# Load the checkpoint (use the path from your checkpoint files)
118+
checkpoint_path = "checkpoints/checkpoint_2_50.ckpt"
119+
serialized = File.read!(checkpoint_path)
120+
state = Axon.Loop.deserialize_state(serialized)
121+
122+
# Resume training
123+
model =
124+
Axon.input("data")
125+
|> Axon.dense(8)
126+
|> Axon.relu()
127+
|> Axon.dense(1)
128+
129+
loop =
130+
model
131+
|> Axon.Loop.trainer(:mean_squared_error, :sgd)
132+
|> Axon.Loop.from_state(state)
133+
134+
Axon.Loop.run(loop, train_data, Axon.ModelState.empty(), epochs: 5, iterations: 50)
135+
```
136+
137+
## Saving Only Parameters from a Checkpoint
138+
139+
If you have a checkpoint file and want to extract parameters for inference (without optimizer state):
140+
141+
```elixir
142+
checkpoint_path = "checkpoints/checkpoint_2_50.ckpt"
143+
state = File.read!(checkpoint_path) |> Axon.Loop.deserialize_state()
144+
145+
# Extract model parameters from step_state
146+
%{model_state: model_state} = state.step_state
147+
params = model_state.data
148+
149+
# Save for inference
150+
File.write!("model_params.axon", Nx.serialize(params))
151+
```
152+
153+
## Summary
154+
155+
| Use Case | Save | Load |
156+
| ------------------------------ | --------------------------------------------------------- | ---------------------------------------------------------- |
157+
| Inference only | `Nx.serialize(params)` → file | `Nx.deserialize(file)` + model in code |
158+
| Checkpoint (resume training) | `Axon.Loop.checkpoint/2` or `Axon.Loop.serialize_state/2` | `Axon.Loop.deserialize_state/2` + `Axon.Loop.from_state/2` |
159+
| Extract params from checkpoint | `state.step_state.model_state.data``Nx.serialize` | Use with model in code |

lib/axon/loop.ex

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1511,7 +1511,7 @@ defmodule Axon.Loop do
15111511
15121512
It is the opposite of `Axon.Loop.serialize_state/2`.
15131513
1514-
By default, the step state is deserialized using `Nx.deserialize.2`;
1514+
By default, the step state is deserialized using `Nx.deserialize/2`;
15151515
however, this behavior can be changed if step state is an application
15161516
specific container. For example, if you introduce your own data
15171517
structure into step_state and you customized the serialization logic,
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
defmodule Axon.SerializationGuideTest do
2+
@moduledoc """
3+
Tests that validate the examples in guides/serialization/saving_and_loading.livemd.
4+
Run with: mix test test/axon/serialization_guide_test.exs
5+
"""
6+
use Axon.Case, async: false
7+
8+
@tmp_path Path.join(System.tmp_dir!(), "axon_serialization_guide_test_#{:erlang.unique_integer([:positive])}")
9+
10+
setup do
11+
File.mkdir_p!(@tmp_path)
12+
on_exit(fn -> File.rm_rf!(@tmp_path) end)
13+
[tmp_path: @tmp_path]
14+
end
15+
16+
describe "saving and loading guide examples" do
17+
test "full flow: train → save params → load → predict", %{tmp_path: tmp_path} do
18+
# Same model as the guide
19+
model =
20+
Axon.input("data")
21+
|> Axon.dense(8)
22+
|> Axon.relu()
23+
|> Axon.dense(4)
24+
|> Axon.relu()
25+
|> Axon.dense(1)
26+
27+
loop = Axon.Loop.trainer(model, :mean_squared_error, :sgd, log: 0)
28+
29+
train_data =
30+
Stream.repeatedly(fn ->
31+
{xs, _} = Nx.Random.normal(Nx.Random.key(:erlang.phash2({self(), System.unique_integer([:monotonic])})), shape: {8, 1})
32+
{xs, Nx.sin(xs)}
33+
end)
34+
35+
# Train
36+
trained_model_state = Axon.Loop.run(loop, train_data, Axon.ModelState.empty(), epochs: 2, iterations: 50)
37+
38+
# Extract and save params (as in guide)
39+
params =
40+
case trained_model_state do
41+
%Axon.ModelState{data: data} -> data
42+
params when is_map(params) -> params
43+
end
44+
45+
params_path = Path.join(tmp_path, "model_params.axon")
46+
params = Nx.backend_transfer(params)
47+
params_bytes = Nx.serialize(params)
48+
File.write!(params_path, params_bytes)
49+
50+
# Load and predict (input shape must match training: {batch, 1} for 1 feature)
51+
loaded_params = File.read!(params_path) |> Nx.deserialize()
52+
input = Nx.tensor([[1.0]])
53+
54+
prediction = Axon.predict(model, loaded_params, %{"data" => input})
55+
56+
assert Nx.rank(prediction) == 2
57+
assert Nx.shape(prediction) == {1, 1}
58+
end
59+
60+
test "checkpoint and resume flow", %{tmp_path: tmp_path} do
61+
model =
62+
Axon.input("data")
63+
|> Axon.dense(4)
64+
|> Axon.relu()
65+
|> Axon.dense(1)
66+
67+
checkpoint_path = Path.join(tmp_path, "checkpoints")
68+
69+
loop =
70+
model
71+
|> Axon.Loop.trainer(:mean_squared_error, :sgd, log: 0)
72+
|> Axon.Loop.checkpoint(path: checkpoint_path, event: :epoch_completed)
73+
74+
train_data = [
75+
{Nx.tensor([[1.0, 2.0, 3.0, 4.0]]), Nx.tensor([[1.0]])},
76+
{Nx.tensor([[2.0, 3.0, 4.0, 5.0]]), Nx.tensor([[2.0]])}
77+
]
78+
79+
Axon.Loop.run(loop, train_data, Axon.ModelState.empty(), epochs: 2)
80+
81+
# Verify checkpoint was saved
82+
ckpt_files = File.ls!(checkpoint_path) |> Enum.sort()
83+
assert length(ckpt_files) == 2
84+
assert Enum.any?(ckpt_files, &String.contains?(&1, "checkpoint_"))
85+
86+
# Load checkpoint and extract params for inference
87+
ckpt_file = Path.join(checkpoint_path, List.first(ckpt_files))
88+
state = File.read!(ckpt_file) |> Axon.Loop.deserialize_state()
89+
90+
%{model_state: model_state} = state.step_state
91+
params = model_state.data
92+
93+
# Run inference with extracted params
94+
input = Nx.tensor([[1.0, 2.0, 3.0, 4.0]])
95+
prediction = Axon.predict(model, params, %{"data" => input})
96+
97+
assert Nx.rank(prediction) == 2
98+
assert Nx.shape(prediction) == {1, 1}
99+
end
100+
101+
test "resume from checkpoint with from_state", %{tmp_path: tmp_path} do
102+
model =
103+
Axon.input("data")
104+
|> Axon.dense(2)
105+
|> Axon.dense(1)
106+
107+
checkpoint_path = Path.join(tmp_path, "checkpoints_resume")
108+
109+
loop =
110+
model
111+
|> Axon.Loop.trainer(:mean_squared_error, :sgd, log: 0)
112+
|> Axon.Loop.checkpoint(path: checkpoint_path, event: :epoch_completed)
113+
114+
train_data = [{Nx.tensor([[1.0, 2.0]]), Nx.tensor([[1.0]])}]
115+
116+
# Run for 1 epoch
117+
Axon.Loop.run(loop, train_data, Axon.ModelState.empty(), epochs: 1)
118+
119+
# Load checkpoint and resume
120+
[ckpt_file] = File.ls!(checkpoint_path)
121+
state = File.read!(Path.join(checkpoint_path, ckpt_file)) |> Axon.Loop.deserialize_state()
122+
123+
resumed_loop = model |> Axon.Loop.trainer(:mean_squared_error, :sgd, log: 0) |> Axon.Loop.from_state(state)
124+
125+
# Resume - should complete without error
126+
result = Axon.Loop.run(resumed_loop, train_data, Axon.ModelState.empty(), epochs: 2)
127+
128+
assert %Axon.ModelState{} = result
129+
end
130+
end
131+
end

0 commit comments

Comments
 (0)