diff --git a/Project.toml b/Project.toml index 9a0b3b9b..fc9fb0d6 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "InvertibleNetworks" uuid = "b7115f24-5f92-4794-81e8-23b0ddb121d3" authors = ["Philipp Witte ", "Ali Siahkoohi ", "Mathias Louboutin ", "Gabrio Rizzuti ", "Rafael Orozco ", "Felix J. herrmann "] -version = "2.2.5" +version = "2.2.6" [deps] CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" diff --git a/src/conditional_layers/conditional_layer_glow.jl b/src/conditional_layers/conditional_layer_glow.jl index ab36e064..9be591ae 100644 --- a/src/conditional_layers/conditional_layer_glow.jl +++ b/src/conditional_layers/conditional_layer_glow.jl @@ -75,10 +75,13 @@ end # Constructor from input dimensions function ConditionalLayerGlow(n_in::Int64, n_cond::Int64, n_hidden::Int64;freeze_conv=false, k1=3, k2=1, p1=1, p2=0, s1=1, s2=1, logdet=false, activation::ActivationFunction=SigmoidLayer(), rb_activation::ActivationFunction=RELUlayer(), ndims=2) - - # 1x1 Convolution and residual block for invertible layers C = Conv1x1(n_in; freeze=freeze_conv) - RB = ResidualBlock(Int(n_in/2)+n_cond, n_hidden; n_out=n_in, activation=rb_activation, k1=k1, k2=k2, p1=p1, p2=p2, s1=s1, s2=s2, fan=true, ndims=ndims) + + split_num = Int(round(n_in/2)) + in_chan = n_in-split_num + out_chan = 2*split_num + + RB = ResidualBlock(in_chan+n_cond, n_hidden; n_out=out_chan, activation=rb_activation, k1=k1, k2=k2, p1=p1, p2=p2, s1=s1, s2=s2, fan=true, ndims=ndims) return ConditionalLayerGlow(C, RB, logdet, activation) end @@ -143,7 +146,10 @@ function backward(ΔY::AbstractArray{T, N}, Y::AbstractArray{T, N}, C::AbstractA # Backpropagate RB ΔX2_ΔC = L.RB.backward(tensor_cat(L.activation.backward(ΔS, S), ΔT), (tensor_cat(X2, C))) - ΔX2, ΔC = tensor_split(ΔX2_ΔC; split_index=Int(size(ΔY)[N-1]/2)) + + n_in = size(ΔY)[N-1] + split_num = Int(round(n_in/2)) + ΔX2, ΔC = tensor_split(ΔX2_ΔC; split_index=n_in-split_num) ΔX2 += ΔY2 # Backpropagate 1x1 conv diff --git a/src/networks/invertible_network_conditional_glow.jl b/src/networks/invertible_network_conditional_glow.jl index 7bd57c69..d700a869 100644 --- a/src/networks/invertible_network_conditional_glow.jl +++ b/src/networks/invertible_network_conditional_glow.jl @@ -6,14 +6,14 @@ export NetworkConditionalGlow, NetworkConditionalGlow3D """ - G = NetworkGlow(n_in, n_cond, n_hidden, L, K; k1=3, k2=1, p1=1, p2=0, s1=1, s2=1) + G = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K; split_scales=fals) - G = NetworkGlow3D(n_in, n_cond, n_hidden, L, K; k1=3, k2=1, p1=1, p2=0, s1=1, s2=1) + G = NetworkConditionalGlow3D(n_in, n_cond, n_hidden, L, K; split_scales=false) Create a conditional invertible network based on the Glow architecture. Each flow step in the inner loop - consists of an activation normalization layer, followed by an invertible coupling layer with - 1x1 convolutions and a residual block. The outer loop performs a squeezing operation prior - to the inner loop, and a splitting operation afterwards. + consists of an activation normalization layer, followed by an invertible glow conditional coupling layer with + 1x1 convolutions and a residual block that takes the condition as an input. + The outer loop performs a squeezing operation prior to the inner loop, and a splitting operation afterwards. *Input*: @@ -44,7 +44,7 @@ export NetworkConditionalGlow, NetworkConditionalGlow3D *Output*: - - `G`: invertible Glow network. + - `G`: invertible conditional Glow network. *Usage:* @@ -56,14 +56,13 @@ export NetworkConditionalGlow, NetworkConditionalGlow3D - None in `G` itself - - Trainable parameters in activation normalizations `G.AN[i,j]` and coupling layers `G.C[i,j]`, + - Trainable parameters in activation normalizations `G.AN[i,j]` and coupling layers `G.CL[i,j]`, where `i` and `j` range from `1` to `L` and `K` respectively. - See also: [`ActNorm`](@ref), [`CouplingLayerGlow!`](@ref), [`get_params`](@ref), [`clear_grad!`](@ref) + See also: [`ActNorm`](@ref), [`ConditionalLayerGlow!`](@ref), [`get_params`](@ref), [`clear_grad!`](@ref) """ struct NetworkConditionalGlow <: InvertibleNetwork AN::AbstractArray{ActNorm, 2} - AN_C::ActNorm CL::AbstractArray{ConditionalLayerGlow, 2} Z_dims::Union{Array{Array, 1}, Nothing} L::Int64 @@ -77,7 +76,6 @@ end # Constructor function NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K; freeze_conv=false, split_scales=false, rb_activation::ActivationFunction=ReLUlayer(), k1=3, k2=1, p1=1, p2=0, s1=1, s2=1, ndims=2, squeezer::Squeezer=ShuffleLayer(), activation::ActivationFunction=SigmoidLayer()) AN = Array{ActNorm}(undef, L, K) # activation normalization - AN_C = ActNorm(n_cond; logdet=false) # activation normalization for condition CL = Array{ConditionalLayerGlow}(undef, L, K) # coupling layers w/ 1x1 convolution and residual block if split_scales @@ -98,7 +96,7 @@ function NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K; freeze_conv=false, (i < L && split_scales) && (n_in = Int64(n_in/2)) # split end - return NetworkConditionalGlow(AN, AN_C, CL, Z_dims, L, K, squeezer, split_scales) + return NetworkConditionalGlow(AN, CL, Z_dims, L, K, squeezer, split_scales) end NetworkConditionalGlow3D(args; kw...) = NetworkConditionalGlow(args...; kw..., ndims=3) @@ -108,8 +106,6 @@ function forward(X::AbstractArray{T, N}, C::AbstractArray{T, N}, G::NetworkCondi G.split_scales && (Z_save = array_of_array(X, G.L-1)) orig_shape = size(X) - C = G.AN_C.forward(C) - logdet = 0 for i=1:G.L (G.split_scales) && (X = G.squeezer.forward(X)) @@ -176,6 +172,5 @@ function backward(ΔX::AbstractArray{T, N}, X::AbstractArray{T, N}, C::AbstractA end end - ΔC, C = G.AN_C.backward(ΔC, C) return ΔX, X, ΔC -end +end \ No newline at end of file diff --git a/test/test_networks/test_conditional_glow_network.jl b/test/test_networks/test_conditional_glow_network.jl index 642f76f2..77e3d6df 100644 --- a/test/test_networks/test_conditional_glow_network.jl +++ b/test/test_networks/test_conditional_glow_network.jl @@ -3,415 +3,221 @@ # Date: January 2020 using InvertibleNetworks, LinearAlgebra, Test, Random -using Flux +using Statistics -# Random seed -Random.seed!(3); - -# Define network -nx = 32 -ny = 32 -nz = 32 -n_in = 2 -n_cond = 2 -n_hidden = 4 -batchsize = 2 -L = 2 -K = 2 -split_scales = true -N = (nx,ny) - -########################################### Test with split_scales = true N = (nx,ny) ######################### -# Invertibility - -# Network and input -G = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K;split_scales=split_scales,ndims=length(N)) -X = rand(Float32, N..., n_in, batchsize) -Cond = rand(Float32, N..., n_cond, batchsize) - -Y, Cond = G.forward(X,Cond) -X_ = G.inverse(Y,Cond) # saving the cond is important in split scales because of reshapes - -@test isapprox(norm(X - X_)/norm(X), 0f0; atol=1f-5) - -# Test gradients are set and cleared -G.backward(Y, Y, Cond) - -P = get_params(G) -gsum = 0 -for p in P - ~isnothing(p.grad) && (global gsum += 1) -end -@test isequal(gsum, L*K*10+2) - -clear_grad!(G) -gsum = 0 -for p in P - ~isnothing(p.grad) && (global gsum += 1) -end -@test isequal(gsum, 0) - -################################################################################################### -# Gradient test - -function loss(G, X, Cond) - Y, ZC, logdet = G.forward(X, Cond) - f = -log_likelihood(Y) - logdet - ΔY = -∇log_likelihood(Y) - ΔX, X_ = G.backward(ΔY, Y, ZC) - return f, ΔX, G.CL[1,1].RB.W1.grad, G.CL[1,1].C.v1.grad -end - - -# Gradient test w.r.t. input -G = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K;split_scales=split_scales,ndims=length(N)) -X = rand(Float32, N..., n_in, batchsize) -Cond = rand(Float32, N..., n_cond, batchsize) -X0 = rand(Float32, N..., n_in, batchsize) -Cond0 = rand(Float32, N..., n_cond, batchsize) - -dX = X - X0 - -f0, ΔX = loss(G, X0, Cond0)[1:2] -h = 0.1f0 -maxiter = 4 -err1 = zeros(Float32, maxiter) -err2 = zeros(Float32, maxiter) - -print("\nGradient test glow: input\n") -for j=1:maxiter - f = loss(G, X0 + h*dX, Cond0)[1] - err1[j] = abs(f - f0) - err2[j] = abs(f - f0 - h*dot(dX, ΔX)) - print(err1[j], "; ", err2[j], "\n") - global h = h/2f0 -end - -@test isapprox(err1[end] / (err1[1]/2^(maxiter-1)), 1f0; atol=1f0) -@test isapprox(err2[end] / (err2[1]/4^(maxiter-1)), 1f0; atol=1f0) - - -# Gradient test w.r.t. parameters -X = rand(Float32, N..., n_in, batchsize) -G = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K;split_scales=split_scales,ndims=length(N)) -G0 = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K;split_scales=split_scales,ndims=length(N)) -Gini = deepcopy(G0) - -# Test one parameter from residual block and 1x1 conv -dW = G.CL[1,1].RB.W1.data - G0.CL[1,1].RB.W1.data -dv = G.CL[1,1].C.v1.data - G0.CL[1,1].C.v1.data - -f0, ΔX, ΔW, Δv = loss(G0, X, Cond) -h = 0.1f0 -maxiter = 4 -err3 = zeros(Float32, maxiter) -err4 = zeros(Float32, maxiter) - -print("\nGradient test glow: input\n") -for j=1:maxiter - G0.CL[1,1].RB.W1.data = Gini.CL[1,1].RB.W1.data + h*dW - G0.CL[1,1].C.v1.data = Gini.CL[1,1].C.v1.data + h*dv - - f = loss(G0, X, Cond)[1] - err3[j] = abs(f - f0) - err4[j] = abs(f - f0 - h*dot(dW, ΔW) - h*dot(dv, Δv)) - print(err3[j], "; ", err4[j], "\n") - global h = h/2f0 -end - -@test isapprox(err3[end] / (err3[1]/2^(maxiter-1)), 1f0; atol=1f0) -@test isapprox(err4[end] / (err4[1]/4^(maxiter-1)), 1f0; atol=1f0) - - - -########################################### Test with split_scales = true N = (nx,ny) and summary network ######################### -# Invertibility -sum_net = ResNet(n_cond, 16, 3; norm=nothing) # make sure it doesnt have any weird normalizations - -# Network and input -flow = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K; split_scales=split_scales,ndims=length(N)) -G = SummarizedNet(flow, sum_net) - -X = rand(Float32, N..., n_in, batchsize); -Cond = rand(Float32, N..., n_cond, batchsize); - -Y, ZCond = G.forward(X,Cond) -X_ = G.inverse(Y,ZCond) # saving the cond is important in split scales because of reshapes - -@test isapprox(norm(X - X_)/norm(X), 0f0; atol=1f-5) - -# Test gradients are set and cleared -G.backward(Y, Y, ZCond; Y_save = Cond) - -P = get_params(G) -gsum = 0 -for p in P - ~isnothing(p.grad) && (global gsum += 1) -end -@test isequal(gsum, L*K*10+2+12) # depends on summary net you use - -clear_grad!(G) -gsum = 0 -for p in P - ~isnothing(p.grad) && (global gsum += 1) -end -@test isequal(gsum, 0) +# Random seed +Random.seed!(36); -# Gradient test -function loss_sum(G, X, Cond) +function loss(G, X, Cond;summarized=false) Y, ZC, logdet = G.forward(X, Cond) f = -log_likelihood(Y) - logdet ΔY = -∇log_likelihood(Y) - ΔX, X_ = G.backward(ΔY, Y, ZC; Y_save=Cond) - return f, ΔX, G.cond_net.CL[1,1].RB.W1.grad, G.cond_net.CL[1,1].C.v1.grad -end - -# Gradient test w.r.t. input -X = rand(Float32, N..., n_in, batchsize); -Cond = rand(Float32, N..., n_cond, batchsize); -X0 = rand(Float32, N..., n_in, batchsize); -Cond0 = rand(Float32, N..., n_cond, batchsize); - -dX = X - X0 - -f0, ΔX = loss_sum(G, X0, Cond0)[1:2] -h = 0.1f0 -maxiter = 4 -err1 = zeros(Float32, maxiter) -err2 = zeros(Float32, maxiter) - -print("\nGradient test glow: input\n") -for j=1:maxiter - f = loss_sum(G, X0 + h*dX, Cond0)[1] - err1[j] = abs(f - f0) - err2[j] = abs(f - f0 - h*dot(dX, ΔX)) - print(err1[j], "; ", err2[j], "\n") - global h = h/2f0 + if summarized + ΔX = G.backward(ΔY, Y, ZC; Y_save=Cond)[1] + return f, ΔX, G.cond_net.CL[1,1].RB.W1.grad + else + ΔX = G.backward(ΔY, Y, ZC)[1] + return f, ΔX, G.CL[1,1].RB.W1.grad + end +end + +function gradients_set(G, n_in,n_cond,N; summarized=false) + X = rand(Float32, N..., n_in, batchsize) + Cond = rand(Float32, N..., n_cond, batchsize) + + XZ, CondZ = G.forward(X,Cond) + + # Set gradients + summarized ? G.backward(XZ, XZ, CondZ; Y_save=Cond) : G.backward(XZ, XZ, CondZ) + + P = get_params(G) + gsum = 0 + for p in P + ~isnothing(p.grad) && (gsum += 1) + end + summarized ? (@test isequal(gsum, L*K*10+12)) : (@test isequal(gsum, L*K*10)) + + clear_grad!(G) + gsum = 0 + for p in P + ~isnothing(p.grad) && (gsum += 1) + end + @test isequal(gsum, 0) end -@test isapprox(err1[end] / (err1[1]/2^(maxiter-1)), 1f0; atol=1f0) -@test isapprox(err2[end] / (err2[1]/4^(maxiter-1)), 1f0; atol=1f0) - - -# Gradient test w.r.t. parameters -X = rand(Float32, N..., n_in, batchsize) -flow0 = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K; split_scales=split_scales,ndims=length(N)) -G0 = SummarizedNet(flow0, sum_net) -Gini = deepcopy(G0) - -# Test one parameter from residual block and 1x1 conv -dW = G.cond_net.CL[1,1].RB.W1.data - G0.cond_net.CL[1,1].RB.W1.data -dv = G.cond_net.CL[1,1].C.v1.data - G0.cond_net.CL[1,1].C.v1.data - -f0, ΔX, ΔW, Δv = loss_sum(G0, X, Cond) -h = 0.1f0 -maxiter = 4 -err3 = zeros(Float32, maxiter) -err4 = zeros(Float32, maxiter) - -print("\nGradient test glow: input\n") -for j=1:maxiter - G0.cond_net.CL[1,1].RB.W1.data = Gini.cond_net.CL[1,1].RB.W1.data + h*dW - G0.cond_net.CL[1,1].C.v1.data = Gini.cond_net.CL[1,1].C.v1.data + h*dv - - f = loss_sum(G0, X, Cond)[1] - err3[j] = abs(f - f0) - err4[j] = abs(f - f0 - h*dot(dW, ΔW) - h*dot(dv, Δv)) - print(err3[j], "; ", err4[j], "\n") - global h = h/2f0 -end - -@test isapprox(err3[end] / (err3[1]/2^(maxiter-1)), 1f0; atol=1f0) -@test isapprox(err4[end] / (err4[1]/4^(maxiter-1)), 1f0; atol=1f0) - - -N = (nx,ny,nz) -########################################### Test with split_scales = true N = (nx,ny,nz) ######################### -# Invertibility - -# Network and input -G = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K;split_scales=split_scales,ndims=length(N)) -X = rand(Float32, N..., n_in, batchsize) -Cond = rand(Float32, N..., n_cond, batchsize) - -Y, Cond = G.forward(X,Cond) -X_ = G.inverse(Y,Cond) # saving the cond is important in split scales because of reshapes - -@test isapprox(norm(X - X_)/norm(X), 0f0; atol=1f-5) - -# Test gradients are set and cleared -G.backward(Y, Y, Cond) - -P = get_params(G) -gsum = 0 -for p in P - ~isnothing(p.grad) && (global gsum += 1) -end -@test isequal(gsum, L*K*10+2) - -clear_grad!(G) -gsum = 0 -for p in P - ~isnothing(p.grad) && (global gsum += 1) -end -@test isequal(gsum, 0) - - -# Gradient test - - -# Gradient test w.r.t. input -G = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K;split_scales=split_scales,ndims=length(N)) -X = rand(Float32, N..., n_in, batchsize) -Cond = rand(Float32, N..., n_cond, batchsize) -X0 = rand(Float32, N..., n_in, batchsize) -Cond0 = rand(Float32, N..., n_cond, batchsize) - -dX = X - X0 - -f0, ΔX = loss(G, X0, Cond0)[1:2] -h = 0.1f0 -maxiter = 4 -err1 = zeros(Float32, maxiter) -err2 = zeros(Float32, maxiter) - -print("\nGradient test glow: input\n") -for j=1:maxiter - f = loss(G, X0 + h*dX, Cond0)[1] - err1[j] = abs(f - f0) - err2[j] = abs(f - f0 - h*dot(dX, ΔX)) - print(err1[j], "; ", err2[j], "\n") - global h = h/2f0 -end - -@test isapprox(err1[end] / (err1[1]/2^(maxiter-1)), 1f0; atol=1f0) -@test isapprox(err2[end] / (err2[1]/4^(maxiter-1)), 1f0; atol=1f0) - - -# Gradient test w.r.t. parameters -X = rand(Float32, N..., n_in, batchsize) -G = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K;split_scales=split_scales,ndims=length(N)) -G0 = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K;split_scales=split_scales,ndims=length(N)) -Gini = deepcopy(G0) - -# Test one parameter from residual block and 1x1 conv -dW = G.CL[1,1].RB.W1.data - G0.CL[1,1].RB.W1.data -dv = G.CL[1,1].C.v1.data - G0.CL[1,1].C.v1.data - -f0, ΔX, ΔW, Δv = loss(G0, X, Cond) -h = 0.1f0 -maxiter = 4 -err3 = zeros(Float32, maxiter) -err4 = zeros(Float32, maxiter) - -print("\nGradient test glow: input\n") -for j=1:maxiter - G0.CL[1,1].RB.W1.data = Gini.CL[1,1].RB.W1.data + h*dW - G0.CL[1,1].C.v1.data = Gini.CL[1,1].C.v1.data + h*dv - - f = loss(G0, X, Cond)[1] - err3[j] = abs(f - f0) - err4[j] = abs(f - f0 - h*dot(dW, ΔW) - h*dot(dv, Δv)) - print(err3[j], "; ", err4[j], "\n") - global h = h/2f0 -end - -@test isapprox(err3[end] / (err3[1]/2^(maxiter-1)), 1f0; atol=1f0) -@test isapprox(err4[end] / (err4[1]/4^(maxiter-1)), 1f0; atol=1f0) - - -########################################### Test with split_scales = true N = (nx,ny,nz) and Summary network ######################### -# Invertibility -sum_net_3d = ResNet(n_cond, 16, 3; ndims=3, norm=nothing) # make sure it doesnt have any weird normalizati8ons - -# Network and input -flow = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K; split_scales=split_scales,ndims=length(N)); -G = SummarizedNet(flow, sum_net_3d) - -X = rand(Float32, N..., n_in, batchsize); -Cond = rand(Float32, N..., n_cond, batchsize); - -Y, ZCond = G.forward(X,Cond); -X_ = G.inverse(Y,ZCond); # saving the cond is important in split scales because of reshapes - -@test isapprox(norm(X - X_)/norm(X), 0f0; atol=1f-5) - -# Test gradients are set and cleared -G.backward(Y, Y, ZCond; Y_save=Cond) - -P = get_params(G) -gsum = 0 -for p in P - ~isnothing(p.grad) && (global gsum += 1) -end -@test isequal(gsum, L*K*10+2+12) - -clear_grad!(G) -gsum = 0 -for p in P - ~isnothing(p.grad) && (global gsum += 1) -end -@test isequal(gsum, 0) - - -# Gradient test - - -# Gradient test w.r.t. input -X = rand(Float32, N..., n_in, batchsize); -Cond = rand(Float32, N..., n_cond, batchsize); -X0 = rand(Float32, N..., n_in, batchsize); -Cond0 = rand(Float32, N..., n_cond, batchsize); - -dX = X - X0; - -f0, ΔX = loss_sum(G, X0, Cond0)[1:2]; -h = 0.1f0 -maxiter = 4 -err1 = zeros(Float32, maxiter) -err2 = zeros(Float32, maxiter) - -print("\nGradient test glow: input\n") -for j=1:maxiter - f = loss_sum(G, X0 + h*dX, Cond0)[1] - err1[j] = abs(f - f0) - err2[j] = abs(f - f0 - h*dot(dX, ΔX)) - print(err1[j], "; ", err2[j], "\n") - global h = h/2f0 -end - -@test isapprox(err1[end] / (err1[1]/2^(maxiter-1)), 1f0; atol=1f0) -@test isapprox(err2[end] / (err2[1]/4^(maxiter-1)), 1f0; atol=1f0) - -# Gradient test w.r.t. parameters -X = rand(Float32, N..., n_in, batchsize) -flow0 = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K; split_scales=split_scales,ndims=length(N)) -G0 = SummarizedNet(flow0, sum_net_3d) -Gini = deepcopy(G0) - -# Test one parameter from residual block and 1x1 conv -dW = G.cond_net.CL[1,1].RB.W1.data - G0.cond_net.CL[1,1].RB.W1.data -dv = G.cond_net.CL[1,1].C.v1.data - G0.cond_net.CL[1,1].C.v1.data - -f0, ΔX, ΔW, Δv = loss_sum(G0, X, Cond); -h = 0.1f0 -maxiter = 4 -err3 = zeros(Float32, maxiter) -err4 = zeros(Float32, maxiter) - -print("\nGradient test glow: input\n") -for j=1:maxiter - G0.cond_net.CL[1,1].RB.W1.data = Gini.cond_net.CL[1,1].RB.W1.data + h*dW - G0.cond_net.CL[1,1].C.v1.data = Gini.cond_net.CL[1,1].C.v1.data + h*dv - - f = loss_sum(G0, X, Cond)[1] - err3[j] = abs(f - f0) - err4[j] = abs(f - f0 - h*dot(dW, ΔW) - h*dot(dv, Δv)) - print(err3[j], "; ", err4[j], "\n") - global h = h/2f0 -end - -@test isapprox(err3[end] / (err3[1]/2^(maxiter-1)), 1f0; atol=1f0) -@test isapprox(err4[end] / (err4[1]/4^(maxiter-1)), 1f0; atol=1f0) +# Define network +nx = 16 +ny = 16 +nz = 16 +n_in = 4 +n_cond = 2 +n_hidden = 4 +batchsize = 4 +L = 2 +K = 2 +stol = 1.5f0 +for split_scales in [false,true] + for N in [(16*nx),(nx,ny),(nx,ny,nz)] + println("Test with split_scales = $(split_scales) N = $(N)") + + # Network and inputs + G = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K; split_scales=split_scales, ndims=length(N)) + + X = randn(Float32, N..., n_in, batchsize) + Cond = rand(Float32, N..., n_cond, batchsize) + + # Invertibility + XZ, CondZ = G.forward(X,Cond) + X_ = G.inverse(XZ, CondZ) # saving the cond output is important in split scales because of reshapes + @test isapprox(norm(X - X_)/norm(X), 0f0; atol=1f-5) + + ################################################################################################### + # Test gradients are set and cleared + gradients_set(G, n_in, n_cond,N;) + + ################################################################################################### + # Gradient test w.r.t. input + X0 = randn(Float32, N..., n_in, batchsize) + Cond0 = randn(Float32, N..., n_cond, batchsize) + + dX = X - X0 + + f0, ΔX = loss(G, X0, Cond0)[1:2] + h = 0.1f0 + maxiter = 4 + err1 = zeros(Float32, maxiter) + err2 = zeros(Float32, maxiter) + + print("\nGradient test glow: input\n") + for j=1:maxiter + f = loss(G, X0 + h*dX, Cond0)[1] + err1[j] = abs(f - f0) + err2[j] = abs(f - f0 - h*dot(dX, ΔX)) + print(err1[j], "; ", err2[j], "\n") + h = h/2f0 + end + + rate1 = err1[1:end-1]./err1[2:end] + rate2 = err2[1:end-1]./err2[2:end] + + @test isapprox(mean(rate1), 2f0; atol=stol) + @test isapprox(mean(rate2), 4f0; atol=stol) + + # Gradient test w.r.t. parameters + G = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K; split_scales=split_scales, ndims=length(N)) + G0 = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K; split_scales=split_scales, ndims=length(N)) + Gini = deepcopy(G0) + + # Test one parameter from residual block + dW = G.CL[1,1].RB.W1.data - G0.CL[1,1].RB.W1.data + + f0, ΔX, ΔW = loss(G0, X, Cond) + h = 0.1f0 + maxiter = 4 + err1 = zeros(Float32, maxiter) + err2 = zeros(Float32, maxiter) + + print("\nGradient test glow: parameter\n") + for j=1:maxiter + G0.CL[1,1].RB.W1.data = Gini.CL[1,1].RB.W1.data + h*dW + + f = loss(G0, X, Cond)[1] + err1[j] = abs(f - f0) + err2[j] = abs(f - f0 - h*dot(dW, ΔW)) + print(err1[j], "; ", err2[j], "\n") + h = h/2f0 + end + + rate1 = err1[1:end-1]./err1[2:end] + rate2 = err2[1:end-1]./err2[2:end] + + @test isapprox(mean(rate1),2f0; atol=stol) + @test isapprox(mean(rate2), 4f0; atol=stol) + end +end + +# with summary network +for split_scales in [false,true] + for N in [(16*nx),(nx,ny),(nx,ny,nz)] + println("Test with split_scales = $(split_scales) N = $(N) and summarized=$(true)") + + # Network and inputs + G = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K; split_scales=split_scales, ndims=length(N)) + sum_net = ResNet(n_cond, 16, 3; norm=nothing,ndims=length(N)) # make sure it doesnt have any weird normalizations + G = SummarizedNet(G, sum_net) + + X = randn(Float32, N..., n_in, batchsize) + Cond = randn(Float32, N..., n_cond, batchsize) + + # Invertibility + XZ, CondZ = G.forward(X,Cond) + X_ = G.inverse(XZ, CondZ) # saving the cond output is important in split scales because of reshapes + @test isapprox(norm(X - X_)/norm(X), 0f0; atol=1f-5) + + ################################################################################################### + # Test gradients are set and cleared + gradients_set(G, n_in, n_cond,N; summarized=true) + + ################################################################################################### + # Gradient test w.r.t. input + X0 = randn(Float32, N..., n_in, batchsize) + Cond0 = randn(Float32, N..., n_cond, batchsize) + + dX = X - X0 + + f0, ΔX = loss(G, X0, Cond0; summarized=true)[1:2] + h = 0.1f0 + maxiter = 4 + err1 = zeros(Float32, maxiter) + err2 = zeros(Float32, maxiter) + + print("\nGradient test glow: input\n") + for j=1:maxiter + f = loss(G, X0 + h*dX, Cond0; summarized=true)[1] + err1[j] = abs(f - f0) + err2[j] = abs(f - f0 - h*dot(dX, ΔX)) + print(err1[j], "; ", err2[j], "\n") + h = h/2f0 + end + + rate1 = err1[1:end-1]./err1[2:end] + rate2 = err2[1:end-1]./err2[2:end] + + @test isapprox(mean(rate1),2f0; atol=stol) + @test isapprox(mean(rate2), 4f0; atol=stol) + + # Gradient test w.r.t. parameters + G0 = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K; split_scales=split_scales, ndims=length(N)) + sum_net = ResNet(n_cond, 16, 3; norm=nothing,ndims=length(N)) # make sure it doesnt have any weird normalizations + G0 = SummarizedNet(G0, sum_net) + Gini = deepcopy(G0) + + # Test one parameter from residual block + dW = G.cond_net.CL[1,1].RB.W1.data - G0.cond_net.CL[1,1].RB.W1.data + + f0, ΔX, ΔW = loss(G0, X, Cond; summarized=true) + h = 0.1f0 + maxiter = 4 + err1 = zeros(Float32, maxiter) + err2 = zeros(Float32, maxiter) + + print("\nGradient test glow: parameter\n") + for j=1:maxiter + G0.cond_net.CL[1,1].RB.W1.data = Gini.cond_net.CL[1,1].RB.W1.data + h*dW + + f = loss(G0, X, Cond; summarized=true)[1] + err1[j] = abs(f - f0) + err2[j] = abs(f - f0 - h*dot(dW, ΔW)) + print(err1[j], "; ", err2[j], "\n") + h = h/2f0 + end + + rate1 = err1[1:end-1]./err1[2:end] + rate2 = err2[1:end-1]./err2[2:end] + + @test isapprox(mean(rate1),2f0; atol=stol) + @test isapprox(mean(rate2),4f0; atol=stol) + end +end \ No newline at end of file