@@ -5538,6 +5538,120 @@ defmodule CompilerTest do
55385538 end
55395539 end
55405540
5541+ describe "weight tying" do
5542+ test "tied parameter uses source parameter value" do
5543+ # Both dense layers have same input/output size so kernels are compatible
5544+ model =
5545+ Axon . input ( "input" , shape: { nil , 4 } )
5546+ |> Axon . dense ( 4 , name: "dense_0" , use_bias: false )
5547+ |> Axon . dense ( 4 , name: "dense_1" , use_bias: false )
5548+
5549+ { init_fn , predict_fn } = Axon . build ( model )
5550+ input = Nx . tensor ( [ [ 1.0 , 2.0 , 3.0 , 4.0 ] ] )
5551+
5552+ model_state = init_fn . ( input , ModelState . empty ( ) )
5553+
5554+ # Set dense_0 kernel to identity matrix so we can trace the computation
5555+ identity = Nx . eye ( 4 )
5556+ model_state = put_in ( model_state . data [ "dense_0" ] [ "kernel" ] , identity )
5557+ model_state = put_in ( model_state . data [ "dense_1" ] [ "kernel" ] , Nx . broadcast ( 0.0 , { 4 , 4 } ) )
5558+
5559+ # Without tying: input -> identity -> zeros = zeros
5560+ output_untied = predict_fn . ( model_state , input )
5561+ assert_equal ( output_untied , Nx . tensor ( [ [ 0.0 , 0.0 , 0.0 , 0.0 ] ] ) )
5562+
5563+ # With tying: input -> identity -> identity = input
5564+ tied_state =
5565+ ModelState . tie ( model_state , [ "dense_1" , "kernel" ] , [ "dense_0" , "kernel" ] )
5566+
5567+ output_tied = predict_fn . ( tied_state , input )
5568+ assert_equal ( output_tied , input )
5569+ end
5570+
5571+ test "tied parameter with transform applies transformation" do
5572+ model =
5573+ Axon . input ( "input" , shape: { nil , 2 } )
5574+ |> Axon . dense ( 4 , name: "dense_0" , use_bias: false )
5575+ |> Axon . dense ( 2 , name: "dense_1" , use_bias: false )
5576+
5577+ { init_fn , predict_fn } = Axon . build ( model )
5578+ input = Nx . tensor ( [ [ 1.0 , 2.0 ] ] )
5579+
5580+ model_state = init_fn . ( input , ModelState . empty ( ) )
5581+
5582+ # Set a known kernel value
5583+ kernel = Nx . tensor ( [ [ 1.0 , 0.0 , 0.0 , 0.0 ] , [ 0.0 , 1.0 , 0.0 , 0.0 ] ] )
5584+ model_state = put_in ( model_state . data [ "dense_0" ] [ "kernel" ] , kernel )
5585+
5586+ # Tie with transpose: dense_1 uses kernel^T which is {4, 2}
5587+ tied_state =
5588+ ModelState . tie (
5589+ model_state ,
5590+ [ "dense_1" , "kernel" ] ,
5591+ [ "dense_0" , "kernel" ] ,
5592+ transform: & Nx . transpose / 1
5593+ )
5594+
5595+ # input {1,2} @ kernel {2,4} = {1,4}, then @ kernel^T {4,2} = {1,2}
5596+ # [[1,2]] @ [[1,0,0,0],[0,1,0,0]] = [[1,2,0,0]]
5597+ # [[1,2,0,0]] @ [[1,0],[0,1],[0,0],[0,0]] = [[1,2]]
5598+ output = predict_fn . ( tied_state , input )
5599+ assert_equal ( output , input )
5600+ end
5601+
5602+ test "modifying source parameter affects tied layers" do
5603+ model =
5604+ Axon . input ( "input" , shape: { nil , 2 } )
5605+ |> Axon . dense ( 2 , name: "dense_0" , use_bias: false )
5606+ |> Axon . dense ( 2 , name: "dense_1" , use_bias: false )
5607+
5608+ { init_fn , predict_fn } = Axon . build ( model )
5609+ input = Nx . tensor ( [ [ 1.0 , 0.0 ] ] )
5610+
5611+ model_state = init_fn . ( input , ModelState . empty ( ) )
5612+
5613+ tied_state =
5614+ ModelState . tie ( model_state , [ "dense_1" , "kernel" ] , [ "dense_0" , "kernel" ] )
5615+
5616+ # Set source kernel to a specific value
5617+ kernel_v1 = Nx . tensor ( [ [ 1.0 , 0.0 ] , [ 0.0 , 1.0 ] ] )
5618+ tied_state = put_in ( tied_state . data [ "dense_0" ] [ "kernel" ] , kernel_v1 )
5619+ output_v1 = predict_fn . ( tied_state , input )
5620+
5621+ # Change source kernel - tied layer should see the change
5622+ kernel_v2 = Nx . tensor ( [ [ 2.0 , 0.0 ] , [ 0.0 , 2.0 ] ] )
5623+ tied_state = put_in ( tied_state . data [ "dense_0" ] [ "kernel" ] , kernel_v2 )
5624+ output_v2 = predict_fn . ( tied_state , input )
5625+
5626+ # Outputs should differ because the shared kernel changed
5627+ refute Nx . all ( Nx . equal ( output_v1 , output_v2 ) ) |> Nx . to_number ( ) == 1
5628+
5629+ # Verify expected values: input @ kernel @ kernel
5630+ # v1: [1,0] @ I @ I = [1,0]
5631+ # v2: [1,0] @ 2I @ 2I = [4,0]
5632+ assert_equal ( output_v1 , Nx . tensor ( [ [ 1.0 , 0.0 ] ] ) )
5633+ assert_equal ( output_v2 , Nx . tensor ( [ [ 4.0 , 0.0 ] ] ) )
5634+ end
5635+
5636+ test "raises on non-existent shared parameter source" do
5637+ model =
5638+ Axon . input ( "input" , shape: { nil , 2 } )
5639+ |> Axon . dense ( 4 , name: "dense_0" )
5640+
5641+ { init_fn , predict_fn } = Axon . build ( model )
5642+ input = Nx . tensor ( [ [ 1.0 , 2.0 ] ] )
5643+
5644+ model_state = init_fn . ( input , ModelState . empty ( ) )
5645+
5646+ tied_state =
5647+ ModelState . tie ( model_state , [ "dense_0" , "kernel" ] , [ "nonexistent" , "kernel" ] )
5648+
5649+ assert_raise ArgumentError , ~r/ shared parameter.*references non-existent/ , fn ->
5650+ predict_fn . ( tied_state , input )
5651+ end
5652+ end
5653+ end
5654+
55415655 describe "instrumentation" do
55425656 @ describetag :capture_log
55435657
0 commit comments