Skip to content

Commit 0562126

Browse files
authored
Correct implementation of weight tying (#625)
1 parent dfec027 commit 0562126

5 files changed

Lines changed: 315 additions & 16 deletions

File tree

lib/axon/compiler.ex

Lines changed: 40 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -916,23 +916,11 @@ defmodule Axon.Compiler do
916916
else
917917
# Parameters are just accessed in the layer sub-map of the nested
918918
# parameter map, so we just need to extract them and then apply
919-
# freezing and dtype policy
919+
# freezing and dtype policy. Parameters may be SharedParameter
920+
# structs for tied weights, which are resolved to their source.
920921
parameter_inputs =
921922
Enum.map(layer_params, fn %{name: v, frozen: frz} ->
922-
param = params[name][v]
923-
924-
cond do
925-
param != nil ->
926-
safe_policy_cast(maybe_freeze(param, frz), policy, :compute)
927-
928-
true ->
929-
raise ArgumentError,
930-
"parameter #{inspect(v)} for layer: #{inspect(name)} in" <>
931-
" was not present in the given parameter map, this can" <>
932-
" happen if you are using parameters intended for another" <>
933-
" model or did not initialize portions of your model with" <>
934-
" Axon.init/3"
935-
end
923+
resolve_parameter!(params, name, v, frz, policy)
936924
end)
937925

938926
# Reorder the inputs according to the original input ordering
@@ -1188,5 +1176,42 @@ defmodule Axon.Compiler do
11881176
defp propagating_none?(%Axon.None{__propagate__: true}), do: true
11891177
defp propagating_none?(_), do: false
11901178

1179+
defp resolve_parameter!(params, layer_name, param_name, freeze?, policy) do
1180+
# Special case where this is a SharedParameter at the layer level, so we
1181+
# need to resolve that before forwarding. Otherwise this falls through and
1182+
# is handled at the next step
1183+
layer_params =
1184+
with %Axon.ModelState.SharedParameter{path: path} <- params[layer_name] do
1185+
get_in(params, path)
1186+
end
1187+
1188+
parameter =
1189+
case layer_params[param_name] do
1190+
nil ->
1191+
raise ArgumentError,
1192+
"parameter #{inspect(param_name)} for layer: #{inspect(layer_name)}" <>
1193+
" was not present in the given parameter map, this can" <>
1194+
" happen if you are using parameters intended for another" <>
1195+
" model or did not initialize portions of your model with" <>
1196+
" Axon.init/3"
1197+
1198+
%Axon.ModelState.SharedParameter{path: path, transform: transform} ->
1199+
tensor =
1200+
with nil <- get_in(params, path) do
1201+
raise ArgumentError,
1202+
"shared parameter for #{inspect(param_name)} in layer:" <>
1203+
" #{inspect(layer_name)}, references non-existent parameter" <>
1204+
" #{inspect(path)}"
1205+
end
1206+
1207+
if transform, do: transform.(tensor), else: tensor
1208+
1209+
parameter ->
1210+
parameter
1211+
end
1212+
1213+
safe_policy_cast(maybe_freeze(parameter, freeze?), policy, :compute)
1214+
end
1215+
11911216
defp us_to_ms(time), do: Float.round(time / 1000, 1)
11921217
end

lib/axon/model_state.ex

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,54 @@ defmodule Axon.ModelState do
191191
}
192192
end
193193

194+
@doc """
195+
Ties a parameter to another parameter, enabling weight sharing.
196+
197+
The destination parameter will reference the source parameter's tensor,
198+
optionally applying a transformation. Both `destination` and `source`
199+
are access paths (lists of strings) into the model state data.
200+
201+
## Options
202+
203+
* `:transform` - a function to transform the source tensor before
204+
use at the destination. For example, `&Nx.transpose/1` for tying
205+
an embedding layer to an output projection.
206+
207+
## Examples
208+
209+
# Tie output projection to embedding weights (transposed)
210+
model_state = Axon.ModelState.tie(
211+
model_state,
212+
["output", "kernel"],
213+
["embed", "kernel"],
214+
transform: &Nx.transpose/1
215+
)
216+
217+
"""
218+
def tie(model_state, destination, source, opts \\ []) do
219+
update_in(model_state, [Access.key!(:data)], fn data ->
220+
shared = Axon.ModelState.SharedParameter.new(source, opts)
221+
[key | rest] = Enum.reverse(destination)
222+
223+
shared =
224+
Enum.reduce(rest, %{key => shared}, fn next, acc ->
225+
%{next => acc}
226+
end)
227+
228+
deep_merge(data, shared)
229+
end)
230+
end
231+
232+
defp deep_merge(left, right) do
233+
Map.merge(left, right, fn
234+
_key, left_val, right_val when is_map(left_val) and is_map(right_val) ->
235+
deep_merge(left_val, right_val)
236+
237+
_key, _left_val, right_val ->
238+
right_val
239+
end)
240+
end
241+
194242
defp transform_to_parameters(%Nx.Tensor{}), do: nil
195243

196244
defp transform_to_parameters(map) when is_map(map) do
@@ -249,6 +297,9 @@ defmodule Axon.ModelState do
249297
defp tree_get(data, access) when is_list(access) do
250298
Enum.reduce(access, %{}, fn key, acc ->
251299
case data do
300+
%{^key => %Axon.ModelState.SharedParameter{}} ->
301+
acc
302+
252303
%{^key => val} ->
253304
Map.put(acc, key, val)
254305

@@ -261,9 +312,13 @@ defmodule Axon.ModelState do
261312
defp tree_get(data, access) when is_map(access) do
262313
Enum.reduce(access, %{}, fn {key, value}, acc ->
263314
case data do
315+
%{^key => %Axon.ModelState.SharedParameter{}} ->
316+
# Skip shared parameters - they reference another parameter
317+
acc
318+
264319
%{^key => val} ->
265320
tree = tree_get(val, value)
266-
Map.put(acc, key, tree)
321+
if map_size(tree) == 0, do: acc, else: Map.put(acc, key, tree)
267322

268323
%{} ->
269324
acc
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
defmodule Axon.ModelState.SharedParameter do
2+
# Represents a tied or shared parameter for layers whose
3+
# weights are connected but don't necessarily perform the
4+
# same operation. This implements the Nx.Container behavior
5+
# and contains an access path to the parameter that holds the
6+
# original weight.
7+
8+
@moduledoc false
9+
10+
@derive {Nx.Container, containers: [], keep: [:path, :transform]}
11+
defstruct [:path, :transform]
12+
13+
def new(path, opts \\ []) do
14+
%__MODULE__{
15+
path: path,
16+
transform: Keyword.get(opts, :transform)
17+
}
18+
end
19+
end

test/axon/compiler_test.exs

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5538,6 +5538,120 @@ defmodule CompilerTest do
55385538
end
55395539
end
55405540

5541+
describe "weight tying" do
5542+
test "tied parameter uses source parameter value" do
5543+
# Both dense layers have same input/output size so kernels are compatible
5544+
model =
5545+
Axon.input("input", shape: {nil, 4})
5546+
|> Axon.dense(4, name: "dense_0", use_bias: false)
5547+
|> Axon.dense(4, name: "dense_1", use_bias: false)
5548+
5549+
{init_fn, predict_fn} = Axon.build(model)
5550+
input = Nx.tensor([[1.0, 2.0, 3.0, 4.0]])
5551+
5552+
model_state = init_fn.(input, ModelState.empty())
5553+
5554+
# Set dense_0 kernel to identity matrix so we can trace the computation
5555+
identity = Nx.eye(4)
5556+
model_state = put_in(model_state.data["dense_0"]["kernel"], identity)
5557+
model_state = put_in(model_state.data["dense_1"]["kernel"], Nx.broadcast(0.0, {4, 4}))
5558+
5559+
# Without tying: input -> identity -> zeros = zeros
5560+
output_untied = predict_fn.(model_state, input)
5561+
assert_equal(output_untied, Nx.tensor([[0.0, 0.0, 0.0, 0.0]]))
5562+
5563+
# With tying: input -> identity -> identity = input
5564+
tied_state =
5565+
ModelState.tie(model_state, ["dense_1", "kernel"], ["dense_0", "kernel"])
5566+
5567+
output_tied = predict_fn.(tied_state, input)
5568+
assert_equal(output_tied, input)
5569+
end
5570+
5571+
test "tied parameter with transform applies transformation" do
5572+
model =
5573+
Axon.input("input", shape: {nil, 2})
5574+
|> Axon.dense(4, name: "dense_0", use_bias: false)
5575+
|> Axon.dense(2, name: "dense_1", use_bias: false)
5576+
5577+
{init_fn, predict_fn} = Axon.build(model)
5578+
input = Nx.tensor([[1.0, 2.0]])
5579+
5580+
model_state = init_fn.(input, ModelState.empty())
5581+
5582+
# Set a known kernel value
5583+
kernel = Nx.tensor([[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0]])
5584+
model_state = put_in(model_state.data["dense_0"]["kernel"], kernel)
5585+
5586+
# Tie with transpose: dense_1 uses kernel^T which is {4, 2}
5587+
tied_state =
5588+
ModelState.tie(
5589+
model_state,
5590+
["dense_1", "kernel"],
5591+
["dense_0", "kernel"],
5592+
transform: &Nx.transpose/1
5593+
)
5594+
5595+
# input {1,2} @ kernel {2,4} = {1,4}, then @ kernel^T {4,2} = {1,2}
5596+
# [[1,2]] @ [[1,0,0,0],[0,1,0,0]] = [[1,2,0,0]]
5597+
# [[1,2,0,0]] @ [[1,0],[0,1],[0,0],[0,0]] = [[1,2]]
5598+
output = predict_fn.(tied_state, input)
5599+
assert_equal(output, input)
5600+
end
5601+
5602+
test "modifying source parameter affects tied layers" do
5603+
model =
5604+
Axon.input("input", shape: {nil, 2})
5605+
|> Axon.dense(2, name: "dense_0", use_bias: false)
5606+
|> Axon.dense(2, name: "dense_1", use_bias: false)
5607+
5608+
{init_fn, predict_fn} = Axon.build(model)
5609+
input = Nx.tensor([[1.0, 0.0]])
5610+
5611+
model_state = init_fn.(input, ModelState.empty())
5612+
5613+
tied_state =
5614+
ModelState.tie(model_state, ["dense_1", "kernel"], ["dense_0", "kernel"])
5615+
5616+
# Set source kernel to a specific value
5617+
kernel_v1 = Nx.tensor([[1.0, 0.0], [0.0, 1.0]])
5618+
tied_state = put_in(tied_state.data["dense_0"]["kernel"], kernel_v1)
5619+
output_v1 = predict_fn.(tied_state, input)
5620+
5621+
# Change source kernel - tied layer should see the change
5622+
kernel_v2 = Nx.tensor([[2.0, 0.0], [0.0, 2.0]])
5623+
tied_state = put_in(tied_state.data["dense_0"]["kernel"], kernel_v2)
5624+
output_v2 = predict_fn.(tied_state, input)
5625+
5626+
# Outputs should differ because the shared kernel changed
5627+
refute Nx.all(Nx.equal(output_v1, output_v2)) |> Nx.to_number() == 1
5628+
5629+
# Verify expected values: input @ kernel @ kernel
5630+
# v1: [1,0] @ I @ I = [1,0]
5631+
# v2: [1,0] @ 2I @ 2I = [4,0]
5632+
assert_equal(output_v1, Nx.tensor([[1.0, 0.0]]))
5633+
assert_equal(output_v2, Nx.tensor([[4.0, 0.0]]))
5634+
end
5635+
5636+
test "raises on non-existent shared parameter source" do
5637+
model =
5638+
Axon.input("input", shape: {nil, 2})
5639+
|> Axon.dense(4, name: "dense_0")
5640+
5641+
{init_fn, predict_fn} = Axon.build(model)
5642+
input = Nx.tensor([[1.0, 2.0]])
5643+
5644+
model_state = init_fn.(input, ModelState.empty())
5645+
5646+
tied_state =
5647+
ModelState.tie(model_state, ["dense_0", "kernel"], ["nonexistent", "kernel"])
5648+
5649+
assert_raise ArgumentError, ~r/shared parameter.*references non-existent/, fn ->
5650+
predict_fn.(tied_state, input)
5651+
end
5652+
end
5653+
end
5654+
55415655
describe "instrumentation" do
55425656
@describetag :capture_log
55435657

test/axon/model_state_test.exs

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
defmodule Axon.ModelStateTest do
2+
use ExUnit.Case, async: true
3+
4+
alias Axon.ModelState
5+
6+
describe "tie/4" do
7+
test "creates shared parameter at destination path" do
8+
model =
9+
Axon.input("input", shape: {nil, 2})
10+
|> Axon.dense(4, name: "dense_0")
11+
|> Axon.dense(4, name: "dense_1")
12+
13+
{init_fn, _} = Axon.build(model)
14+
model_state = init_fn.(Nx.template({1, 2}, :f32), ModelState.empty())
15+
16+
tied = ModelState.tie(model_state, ["dense_1", "kernel"], ["dense_0", "kernel"])
17+
18+
assert %Axon.ModelState.SharedParameter{path: ["dense_0", "kernel"], transform: nil} =
19+
tied.data["dense_1"]["kernel"]
20+
end
21+
22+
test "stores transform function" do
23+
model =
24+
Axon.input("input", shape: {nil, 2})
25+
|> Axon.dense(4, name: "dense_0")
26+
|> Axon.dense(4, name: "dense_1")
27+
28+
{init_fn, _} = Axon.build(model)
29+
model_state = init_fn.(Nx.template({1, 2}, :f32), ModelState.empty())
30+
31+
tied =
32+
ModelState.tie(
33+
model_state,
34+
["dense_1", "kernel"],
35+
["dense_0", "kernel"],
36+
transform: &Nx.transpose/1
37+
)
38+
39+
assert %Axon.ModelState.SharedParameter{transform: transform} =
40+
tied.data["dense_1"]["kernel"]
41+
42+
assert is_function(transform, 1)
43+
end
44+
end
45+
46+
describe "trainable_parameters/1 with tied weights" do
47+
test "excludes tied parameters" do
48+
model =
49+
Axon.input("input", shape: {nil, 2})
50+
|> Axon.dense(4, name: "dense_0")
51+
|> Axon.dense(4, name: "dense_1")
52+
53+
{init_fn, _} = Axon.build(model)
54+
model_state = init_fn.(Nx.template({1, 2}, :f32), ModelState.empty())
55+
56+
# Before tying, both layers have kernel in trainable params
57+
trainable_before = ModelState.trainable_parameters(model_state)
58+
assert Map.has_key?(trainable_before["dense_0"], "kernel")
59+
assert Map.has_key?(trainable_before["dense_1"], "kernel")
60+
61+
# After tying, dense_1 kernel should be excluded
62+
tied = ModelState.tie(model_state, ["dense_1", "kernel"], ["dense_0", "kernel"])
63+
trainable_after = ModelState.trainable_parameters(tied)
64+
65+
assert Map.has_key?(trainable_after["dense_0"], "kernel")
66+
refute Map.has_key?(trainable_after["dense_1"], "kernel")
67+
assert Map.has_key?(trainable_after["dense_1"], "bias")
68+
end
69+
70+
test "excludes layer when all parameters are tied" do
71+
model =
72+
Axon.input("input", shape: {nil, 2})
73+
|> Axon.dense(4, name: "dense_0", use_bias: false)
74+
|> Axon.dense(4, name: "dense_1", use_bias: false)
75+
76+
{init_fn, _} = Axon.build(model)
77+
model_state = init_fn.(Nx.template({1, 2}, :f32), ModelState.empty())
78+
79+
tied = ModelState.tie(model_state, ["dense_1", "kernel"], ["dense_0", "kernel"])
80+
trainable = ModelState.trainable_parameters(tied)
81+
82+
assert Map.has_key?(trainable, "dense_0")
83+
refute Map.has_key?(trainable, "dense_1")
84+
end
85+
end
86+
end

0 commit comments

Comments
 (0)