@@ -5,7 +5,10 @@ defmodule Axon.SerializationGuideTest do
55 """
66 use Axon.Case , async: false
77
8- @ tmp_path Path . join ( System . tmp_dir! ( ) , "axon_serialization_guide_test_#{ :erlang . unique_integer ( [ :positive ] ) } " )
8+ @ tmp_path Path . join (
9+ System . tmp_dir! ( ) ,
10+ "axon_serialization_guide_test_#{ :erlang . unique_integer ( [ :positive ] ) } "
11+ )
912
1013 setup do
1114 File . mkdir_p! ( @ tmp_path )
@@ -28,12 +31,18 @@ defmodule Axon.SerializationGuideTest do
2831
2932 train_data =
3033 Stream . repeatedly ( fn ->
31- { xs , _ } = Nx.Random . normal ( Nx.Random . key ( :erlang . phash2 ( { self ( ) , System . unique_integer ( [ :monotonic ] ) } ) ) , shape: { 8 , 1 } )
34+ { xs , _ } =
35+ Nx.Random . normal (
36+ Nx.Random . key ( :erlang . phash2 ( { self ( ) , System . unique_integer ( [ :monotonic ] ) } ) ) ,
37+ shape: { 8 , 1 }
38+ )
39+
3240 { xs , Nx . sin ( xs ) }
3341 end )
3442
3543 # Train
36- trained_model_state = Axon.Loop . run ( loop , train_data , Axon.ModelState . empty ( ) , epochs: 2 , iterations: 50 )
44+ trained_model_state =
45+ Axon.Loop . run ( loop , train_data , Axon.ModelState . empty ( ) , epochs: 2 , iterations: 50 )
3746
3847 # Extract and save params (as in guide)
3948 params =
@@ -120,7 +129,10 @@ defmodule Axon.SerializationGuideTest do
120129 [ ckpt_file ] = File . ls! ( checkpoint_path )
121130 state = File . read! ( Path . join ( checkpoint_path , ckpt_file ) ) |> Axon.Loop . deserialize_state ( )
122131
123- resumed_loop = model |> Axon.Loop . trainer ( :mean_squared_error , :sgd , log: 0 ) |> Axon.Loop . from_state ( state )
132+ resumed_loop =
133+ model
134+ |> Axon.Loop . trainer ( :mean_squared_error , :sgd , log: 0 )
135+ |> Axon.Loop . from_state ( state )
124136
125137 # Resume - should complete without error
126138 result = Axon.Loop . run ( resumed_loop , train_data , Axon.ModelState . empty ( ) , epochs: 2 )
0 commit comments