|
| 1 | +defmodule Axon.SerializationGuideTest do |
| 2 | + @moduledoc """ |
| 3 | + Tests that validate the examples in guides/serialization/saving_and_loading.livemd. |
| 4 | + Run with: mix test test/axon/serialization_guide_test.exs |
| 5 | + """ |
| 6 | + use Axon.Case, async: false |
| 7 | + |
| 8 | + @tmp_path Path.join( |
| 9 | + System.tmp_dir!(), |
| 10 | + "axon_serialization_guide_test_#{:erlang.unique_integer([:positive])}" |
| 11 | + ) |
| 12 | + |
| 13 | + setup do |
| 14 | + File.mkdir_p!(@tmp_path) |
| 15 | + on_exit(fn -> File.rm_rf!(@tmp_path) end) |
| 16 | + [tmp_path: @tmp_path] |
| 17 | + end |
| 18 | + |
| 19 | + describe "saving and loading guide examples" do |
| 20 | + test "full flow: train → save params → load → predict", %{tmp_path: tmp_path} do |
| 21 | + # Same model as the guide |
| 22 | + model = |
| 23 | + Axon.input("data") |
| 24 | + |> Axon.dense(8) |
| 25 | + |> Axon.relu() |
| 26 | + |> Axon.dense(4) |
| 27 | + |> Axon.relu() |
| 28 | + |> Axon.dense(1) |
| 29 | + |
| 30 | + loop = Axon.Loop.trainer(model, :mean_squared_error, :sgd, log: 0) |
| 31 | + |
| 32 | + train_data = |
| 33 | + Stream.repeatedly(fn -> |
| 34 | + {xs, _} = |
| 35 | + Nx.Random.normal( |
| 36 | + Nx.Random.key(:erlang.phash2({self(), System.unique_integer([:monotonic])})), |
| 37 | + shape: {8, 1} |
| 38 | + ) |
| 39 | + |
| 40 | + {xs, Nx.sin(xs)} |
| 41 | + end) |
| 42 | + |
| 43 | + # Train |
| 44 | + trained_model_state = |
| 45 | + Axon.Loop.run(loop, train_data, Axon.ModelState.empty(), epochs: 2, iterations: 50) |
| 46 | + |
| 47 | + # Extract and save params (as in guide) |
| 48 | + params = |
| 49 | + case trained_model_state do |
| 50 | + %Axon.ModelState{data: data} -> data |
| 51 | + params when is_map(params) -> params |
| 52 | + end |
| 53 | + |
| 54 | + params_path = Path.join(tmp_path, "model_params.axon") |
| 55 | + params = Nx.backend_transfer(params) |
| 56 | + params_bytes = Nx.serialize(params) |
| 57 | + File.write!(params_path, params_bytes) |
| 58 | + |
| 59 | + # Load and predict (input shape must match training: {batch, 1} for 1 feature) |
| 60 | + loaded_params = File.read!(params_path) |> Nx.deserialize() |
| 61 | + input = Nx.tensor([[1.0]]) |
| 62 | + |
| 63 | + prediction = Axon.predict(model, loaded_params, %{"data" => input}) |
| 64 | + |
| 65 | + assert Nx.rank(prediction) == 2 |
| 66 | + assert Nx.shape(prediction) == {1, 1} |
| 67 | + end |
| 68 | + |
| 69 | + test "checkpoint and resume flow", %{tmp_path: tmp_path} do |
| 70 | + model = |
| 71 | + Axon.input("data") |
| 72 | + |> Axon.dense(4) |
| 73 | + |> Axon.relu() |
| 74 | + |> Axon.dense(1) |
| 75 | + |
| 76 | + checkpoint_path = Path.join(tmp_path, "checkpoints") |
| 77 | + |
| 78 | + loop = |
| 79 | + model |
| 80 | + |> Axon.Loop.trainer(:mean_squared_error, :sgd, log: 0) |
| 81 | + |> Axon.Loop.checkpoint(path: checkpoint_path, event: :epoch_completed) |
| 82 | + |
| 83 | + train_data = [ |
| 84 | + {Nx.tensor([[1.0, 2.0, 3.0, 4.0]]), Nx.tensor([[1.0]])}, |
| 85 | + {Nx.tensor([[2.0, 3.0, 4.0, 5.0]]), Nx.tensor([[2.0]])} |
| 86 | + ] |
| 87 | + |
| 88 | + Axon.Loop.run(loop, train_data, Axon.ModelState.empty(), epochs: 2) |
| 89 | + |
| 90 | + # Verify checkpoint was saved |
| 91 | + ckpt_files = File.ls!(checkpoint_path) |> Enum.sort() |
| 92 | + assert length(ckpt_files) == 2 |
| 93 | + assert Enum.any?(ckpt_files, &String.contains?(&1, "checkpoint_")) |
| 94 | + |
| 95 | + # Load checkpoint and extract params for inference |
| 96 | + ckpt_file = Path.join(checkpoint_path, List.first(ckpt_files)) |
| 97 | + state = File.read!(ckpt_file) |> Axon.Loop.deserialize_state() |
| 98 | + |
| 99 | + %{model_state: model_state} = state.step_state |
| 100 | + params = model_state.data |
| 101 | + |
| 102 | + # Run inference with extracted params |
| 103 | + input = Nx.tensor([[1.0, 2.0, 3.0, 4.0]]) |
| 104 | + prediction = Axon.predict(model, params, %{"data" => input}) |
| 105 | + |
| 106 | + assert Nx.rank(prediction) == 2 |
| 107 | + assert Nx.shape(prediction) == {1, 1} |
| 108 | + end |
| 109 | + |
| 110 | + test "resume from checkpoint with from_state", %{tmp_path: tmp_path} do |
| 111 | + model = |
| 112 | + Axon.input("data") |
| 113 | + |> Axon.dense(2) |
| 114 | + |> Axon.dense(1) |
| 115 | + |
| 116 | + checkpoint_path = Path.join(tmp_path, "checkpoints_resume") |
| 117 | + |
| 118 | + loop = |
| 119 | + model |
| 120 | + |> Axon.Loop.trainer(:mean_squared_error, :sgd, log: 0) |
| 121 | + |> Axon.Loop.checkpoint(path: checkpoint_path, event: :epoch_completed) |
| 122 | + |
| 123 | + train_data = [{Nx.tensor([[1.0, 2.0]]), Nx.tensor([[1.0]])}] |
| 124 | + |
| 125 | + # Run for 1 epoch |
| 126 | + Axon.Loop.run(loop, train_data, Axon.ModelState.empty(), epochs: 1) |
| 127 | + |
| 128 | + # Load checkpoint and resume |
| 129 | + [ckpt_file] = File.ls!(checkpoint_path) |
| 130 | + state = File.read!(Path.join(checkpoint_path, ckpt_file)) |> Axon.Loop.deserialize_state() |
| 131 | + |
| 132 | + resumed_loop = |
| 133 | + model |
| 134 | + |> Axon.Loop.trainer(:mean_squared_error, :sgd, log: 0) |
| 135 | + |> Axon.Loop.from_state(state) |
| 136 | + |
| 137 | + # Resume - should complete without error |
| 138 | + result = Axon.Loop.run(resumed_loop, train_data, Axon.ModelState.empty(), epochs: 2) |
| 139 | + |
| 140 | + assert %Axon.ModelState{} = result |
| 141 | + end |
| 142 | + end |
| 143 | +end |
0 commit comments