diff --git a/Project.toml b/Project.toml
index 9a0b3b9b..fc9fb0d6 100644
--- a/Project.toml
+++ b/Project.toml
@@ -1,7 +1,7 @@
name = "InvertibleNetworks"
uuid = "b7115f24-5f92-4794-81e8-23b0ddb121d3"
authors = ["Philipp Witte
", "Ali Siahkoohi ", "Mathias Louboutin ", "Gabrio Rizzuti ", "Rafael Orozco ", "Felix J. herrmann "]
-version = "2.2.5"
+version = "2.2.6"
[deps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
diff --git a/src/conditional_layers/conditional_layer_glow.jl b/src/conditional_layers/conditional_layer_glow.jl
index ab36e064..9be591ae 100644
--- a/src/conditional_layers/conditional_layer_glow.jl
+++ b/src/conditional_layers/conditional_layer_glow.jl
@@ -75,10 +75,13 @@ end
# Constructor from input dimensions
function ConditionalLayerGlow(n_in::Int64, n_cond::Int64, n_hidden::Int64;freeze_conv=false, k1=3, k2=1, p1=1, p2=0, s1=1, s2=1, logdet=false, activation::ActivationFunction=SigmoidLayer(), rb_activation::ActivationFunction=RELUlayer(), ndims=2)
-
- # 1x1 Convolution and residual block for invertible layers
C = Conv1x1(n_in; freeze=freeze_conv)
- RB = ResidualBlock(Int(n_in/2)+n_cond, n_hidden; n_out=n_in, activation=rb_activation, k1=k1, k2=k2, p1=p1, p2=p2, s1=s1, s2=s2, fan=true, ndims=ndims)
+
+ split_num = Int(round(n_in/2))
+ in_chan = n_in-split_num
+ out_chan = 2*split_num
+
+ RB = ResidualBlock(in_chan+n_cond, n_hidden; n_out=out_chan, activation=rb_activation, k1=k1, k2=k2, p1=p1, p2=p2, s1=s1, s2=s2, fan=true, ndims=ndims)
return ConditionalLayerGlow(C, RB, logdet, activation)
end
@@ -143,7 +146,10 @@ function backward(ΔY::AbstractArray{T, N}, Y::AbstractArray{T, N}, C::AbstractA
# Backpropagate RB
ΔX2_ΔC = L.RB.backward(tensor_cat(L.activation.backward(ΔS, S), ΔT), (tensor_cat(X2, C)))
- ΔX2, ΔC = tensor_split(ΔX2_ΔC; split_index=Int(size(ΔY)[N-1]/2))
+
+ n_in = size(ΔY)[N-1]
+ split_num = Int(round(n_in/2))
+ ΔX2, ΔC = tensor_split(ΔX2_ΔC; split_index=n_in-split_num)
ΔX2 += ΔY2
# Backpropagate 1x1 conv
diff --git a/src/networks/invertible_network_conditional_glow.jl b/src/networks/invertible_network_conditional_glow.jl
index 7bd57c69..d700a869 100644
--- a/src/networks/invertible_network_conditional_glow.jl
+++ b/src/networks/invertible_network_conditional_glow.jl
@@ -6,14 +6,14 @@
export NetworkConditionalGlow, NetworkConditionalGlow3D
"""
- G = NetworkGlow(n_in, n_cond, n_hidden, L, K; k1=3, k2=1, p1=1, p2=0, s1=1, s2=1)
+ G = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K; split_scales=fals)
- G = NetworkGlow3D(n_in, n_cond, n_hidden, L, K; k1=3, k2=1, p1=1, p2=0, s1=1, s2=1)
+ G = NetworkConditionalGlow3D(n_in, n_cond, n_hidden, L, K; split_scales=false)
Create a conditional invertible network based on the Glow architecture. Each flow step in the inner loop
- consists of an activation normalization layer, followed by an invertible coupling layer with
- 1x1 convolutions and a residual block. The outer loop performs a squeezing operation prior
- to the inner loop, and a splitting operation afterwards.
+ consists of an activation normalization layer, followed by an invertible glow conditional coupling layer with
+ 1x1 convolutions and a residual block that takes the condition as an input.
+ The outer loop performs a squeezing operation prior to the inner loop, and a splitting operation afterwards.
*Input*:
@@ -44,7 +44,7 @@ export NetworkConditionalGlow, NetworkConditionalGlow3D
*Output*:
- - `G`: invertible Glow network.
+ - `G`: invertible conditional Glow network.
*Usage:*
@@ -56,14 +56,13 @@ export NetworkConditionalGlow, NetworkConditionalGlow3D
- None in `G` itself
- - Trainable parameters in activation normalizations `G.AN[i,j]` and coupling layers `G.C[i,j]`,
+ - Trainable parameters in activation normalizations `G.AN[i,j]` and coupling layers `G.CL[i,j]`,
where `i` and `j` range from `1` to `L` and `K` respectively.
- See also: [`ActNorm`](@ref), [`CouplingLayerGlow!`](@ref), [`get_params`](@ref), [`clear_grad!`](@ref)
+ See also: [`ActNorm`](@ref), [`ConditionalLayerGlow!`](@ref), [`get_params`](@ref), [`clear_grad!`](@ref)
"""
struct NetworkConditionalGlow <: InvertibleNetwork
AN::AbstractArray{ActNorm, 2}
- AN_C::ActNorm
CL::AbstractArray{ConditionalLayerGlow, 2}
Z_dims::Union{Array{Array, 1}, Nothing}
L::Int64
@@ -77,7 +76,6 @@ end
# Constructor
function NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K; freeze_conv=false, split_scales=false, rb_activation::ActivationFunction=ReLUlayer(), k1=3, k2=1, p1=1, p2=0, s1=1, s2=1, ndims=2, squeezer::Squeezer=ShuffleLayer(), activation::ActivationFunction=SigmoidLayer())
AN = Array{ActNorm}(undef, L, K) # activation normalization
- AN_C = ActNorm(n_cond; logdet=false) # activation normalization for condition
CL = Array{ConditionalLayerGlow}(undef, L, K) # coupling layers w/ 1x1 convolution and residual block
if split_scales
@@ -98,7 +96,7 @@ function NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K; freeze_conv=false,
(i < L && split_scales) && (n_in = Int64(n_in/2)) # split
end
- return NetworkConditionalGlow(AN, AN_C, CL, Z_dims, L, K, squeezer, split_scales)
+ return NetworkConditionalGlow(AN, CL, Z_dims, L, K, squeezer, split_scales)
end
NetworkConditionalGlow3D(args; kw...) = NetworkConditionalGlow(args...; kw..., ndims=3)
@@ -108,8 +106,6 @@ function forward(X::AbstractArray{T, N}, C::AbstractArray{T, N}, G::NetworkCondi
G.split_scales && (Z_save = array_of_array(X, G.L-1))
orig_shape = size(X)
- C = G.AN_C.forward(C)
-
logdet = 0
for i=1:G.L
(G.split_scales) && (X = G.squeezer.forward(X))
@@ -176,6 +172,5 @@ function backward(ΔX::AbstractArray{T, N}, X::AbstractArray{T, N}, C::AbstractA
end
end
- ΔC, C = G.AN_C.backward(ΔC, C)
return ΔX, X, ΔC
-end
+end
\ No newline at end of file
diff --git a/test/test_networks/test_conditional_glow_network.jl b/test/test_networks/test_conditional_glow_network.jl
index 642f76f2..77e3d6df 100644
--- a/test/test_networks/test_conditional_glow_network.jl
+++ b/test/test_networks/test_conditional_glow_network.jl
@@ -3,415 +3,221 @@
# Date: January 2020
using InvertibleNetworks, LinearAlgebra, Test, Random
-using Flux
+using Statistics
-# Random seed
-Random.seed!(3);
-
-# Define network
-nx = 32
-ny = 32
-nz = 32
-n_in = 2
-n_cond = 2
-n_hidden = 4
-batchsize = 2
-L = 2
-K = 2
-split_scales = true
-N = (nx,ny)
-
-########################################### Test with split_scales = true N = (nx,ny) #########################
-# Invertibility
-
-# Network and input
-G = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K;split_scales=split_scales,ndims=length(N))
-X = rand(Float32, N..., n_in, batchsize)
-Cond = rand(Float32, N..., n_cond, batchsize)
-
-Y, Cond = G.forward(X,Cond)
-X_ = G.inverse(Y,Cond) # saving the cond is important in split scales because of reshapes
-
-@test isapprox(norm(X - X_)/norm(X), 0f0; atol=1f-5)
-
-# Test gradients are set and cleared
-G.backward(Y, Y, Cond)
-
-P = get_params(G)
-gsum = 0
-for p in P
- ~isnothing(p.grad) && (global gsum += 1)
-end
-@test isequal(gsum, L*K*10+2)
-
-clear_grad!(G)
-gsum = 0
-for p in P
- ~isnothing(p.grad) && (global gsum += 1)
-end
-@test isequal(gsum, 0)
-
-###################################################################################################
-# Gradient test
-
-function loss(G, X, Cond)
- Y, ZC, logdet = G.forward(X, Cond)
- f = -log_likelihood(Y) - logdet
- ΔY = -∇log_likelihood(Y)
- ΔX, X_ = G.backward(ΔY, Y, ZC)
- return f, ΔX, G.CL[1,1].RB.W1.grad, G.CL[1,1].C.v1.grad
-end
-
-
-# Gradient test w.r.t. input
-G = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K;split_scales=split_scales,ndims=length(N))
-X = rand(Float32, N..., n_in, batchsize)
-Cond = rand(Float32, N..., n_cond, batchsize)
-X0 = rand(Float32, N..., n_in, batchsize)
-Cond0 = rand(Float32, N..., n_cond, batchsize)
-
-dX = X - X0
-
-f0, ΔX = loss(G, X0, Cond0)[1:2]
-h = 0.1f0
-maxiter = 4
-err1 = zeros(Float32, maxiter)
-err2 = zeros(Float32, maxiter)
-
-print("\nGradient test glow: input\n")
-for j=1:maxiter
- f = loss(G, X0 + h*dX, Cond0)[1]
- err1[j] = abs(f - f0)
- err2[j] = abs(f - f0 - h*dot(dX, ΔX))
- print(err1[j], "; ", err2[j], "\n")
- global h = h/2f0
-end
-
-@test isapprox(err1[end] / (err1[1]/2^(maxiter-1)), 1f0; atol=1f0)
-@test isapprox(err2[end] / (err2[1]/4^(maxiter-1)), 1f0; atol=1f0)
-
-
-# Gradient test w.r.t. parameters
-X = rand(Float32, N..., n_in, batchsize)
-G = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K;split_scales=split_scales,ndims=length(N))
-G0 = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K;split_scales=split_scales,ndims=length(N))
-Gini = deepcopy(G0)
-
-# Test one parameter from residual block and 1x1 conv
-dW = G.CL[1,1].RB.W1.data - G0.CL[1,1].RB.W1.data
-dv = G.CL[1,1].C.v1.data - G0.CL[1,1].C.v1.data
-
-f0, ΔX, ΔW, Δv = loss(G0, X, Cond)
-h = 0.1f0
-maxiter = 4
-err3 = zeros(Float32, maxiter)
-err4 = zeros(Float32, maxiter)
-
-print("\nGradient test glow: input\n")
-for j=1:maxiter
- G0.CL[1,1].RB.W1.data = Gini.CL[1,1].RB.W1.data + h*dW
- G0.CL[1,1].C.v1.data = Gini.CL[1,1].C.v1.data + h*dv
-
- f = loss(G0, X, Cond)[1]
- err3[j] = abs(f - f0)
- err4[j] = abs(f - f0 - h*dot(dW, ΔW) - h*dot(dv, Δv))
- print(err3[j], "; ", err4[j], "\n")
- global h = h/2f0
-end
-
-@test isapprox(err3[end] / (err3[1]/2^(maxiter-1)), 1f0; atol=1f0)
-@test isapprox(err4[end] / (err4[1]/4^(maxiter-1)), 1f0; atol=1f0)
-
-
-
-########################################### Test with split_scales = true N = (nx,ny) and summary network #########################
-# Invertibility
-sum_net = ResNet(n_cond, 16, 3; norm=nothing) # make sure it doesnt have any weird normalizations
-
-# Network and input
-flow = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K; split_scales=split_scales,ndims=length(N))
-G = SummarizedNet(flow, sum_net)
-
-X = rand(Float32, N..., n_in, batchsize);
-Cond = rand(Float32, N..., n_cond, batchsize);
-
-Y, ZCond = G.forward(X,Cond)
-X_ = G.inverse(Y,ZCond) # saving the cond is important in split scales because of reshapes
-
-@test isapprox(norm(X - X_)/norm(X), 0f0; atol=1f-5)
-
-# Test gradients are set and cleared
-G.backward(Y, Y, ZCond; Y_save = Cond)
-
-P = get_params(G)
-gsum = 0
-for p in P
- ~isnothing(p.grad) && (global gsum += 1)
-end
-@test isequal(gsum, L*K*10+2+12) # depends on summary net you use
-
-clear_grad!(G)
-gsum = 0
-for p in P
- ~isnothing(p.grad) && (global gsum += 1)
-end
-@test isequal(gsum, 0)
+# Random seed
+Random.seed!(36);
-# Gradient test
-function loss_sum(G, X, Cond)
+function loss(G, X, Cond;summarized=false)
Y, ZC, logdet = G.forward(X, Cond)
f = -log_likelihood(Y) - logdet
ΔY = -∇log_likelihood(Y)
- ΔX, X_ = G.backward(ΔY, Y, ZC; Y_save=Cond)
- return f, ΔX, G.cond_net.CL[1,1].RB.W1.grad, G.cond_net.CL[1,1].C.v1.grad
-end
-
-# Gradient test w.r.t. input
-X = rand(Float32, N..., n_in, batchsize);
-Cond = rand(Float32, N..., n_cond, batchsize);
-X0 = rand(Float32, N..., n_in, batchsize);
-Cond0 = rand(Float32, N..., n_cond, batchsize);
-
-dX = X - X0
-
-f0, ΔX = loss_sum(G, X0, Cond0)[1:2]
-h = 0.1f0
-maxiter = 4
-err1 = zeros(Float32, maxiter)
-err2 = zeros(Float32, maxiter)
-
-print("\nGradient test glow: input\n")
-for j=1:maxiter
- f = loss_sum(G, X0 + h*dX, Cond0)[1]
- err1[j] = abs(f - f0)
- err2[j] = abs(f - f0 - h*dot(dX, ΔX))
- print(err1[j], "; ", err2[j], "\n")
- global h = h/2f0
+ if summarized
+ ΔX = G.backward(ΔY, Y, ZC; Y_save=Cond)[1]
+ return f, ΔX, G.cond_net.CL[1,1].RB.W1.grad
+ else
+ ΔX = G.backward(ΔY, Y, ZC)[1]
+ return f, ΔX, G.CL[1,1].RB.W1.grad
+ end
+end
+
+function gradients_set(G, n_in,n_cond,N; summarized=false)
+ X = rand(Float32, N..., n_in, batchsize)
+ Cond = rand(Float32, N..., n_cond, batchsize)
+
+ XZ, CondZ = G.forward(X,Cond)
+
+ # Set gradients
+ summarized ? G.backward(XZ, XZ, CondZ; Y_save=Cond) : G.backward(XZ, XZ, CondZ)
+
+ P = get_params(G)
+ gsum = 0
+ for p in P
+ ~isnothing(p.grad) && (gsum += 1)
+ end
+ summarized ? (@test isequal(gsum, L*K*10+12)) : (@test isequal(gsum, L*K*10))
+
+ clear_grad!(G)
+ gsum = 0
+ for p in P
+ ~isnothing(p.grad) && (gsum += 1)
+ end
+ @test isequal(gsum, 0)
end
-@test isapprox(err1[end] / (err1[1]/2^(maxiter-1)), 1f0; atol=1f0)
-@test isapprox(err2[end] / (err2[1]/4^(maxiter-1)), 1f0; atol=1f0)
-
-
-# Gradient test w.r.t. parameters
-X = rand(Float32, N..., n_in, batchsize)
-flow0 = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K; split_scales=split_scales,ndims=length(N))
-G0 = SummarizedNet(flow0, sum_net)
-Gini = deepcopy(G0)
-
-# Test one parameter from residual block and 1x1 conv
-dW = G.cond_net.CL[1,1].RB.W1.data - G0.cond_net.CL[1,1].RB.W1.data
-dv = G.cond_net.CL[1,1].C.v1.data - G0.cond_net.CL[1,1].C.v1.data
-
-f0, ΔX, ΔW, Δv = loss_sum(G0, X, Cond)
-h = 0.1f0
-maxiter = 4
-err3 = zeros(Float32, maxiter)
-err4 = zeros(Float32, maxiter)
-
-print("\nGradient test glow: input\n")
-for j=1:maxiter
- G0.cond_net.CL[1,1].RB.W1.data = Gini.cond_net.CL[1,1].RB.W1.data + h*dW
- G0.cond_net.CL[1,1].C.v1.data = Gini.cond_net.CL[1,1].C.v1.data + h*dv
-
- f = loss_sum(G0, X, Cond)[1]
- err3[j] = abs(f - f0)
- err4[j] = abs(f - f0 - h*dot(dW, ΔW) - h*dot(dv, Δv))
- print(err3[j], "; ", err4[j], "\n")
- global h = h/2f0
-end
-
-@test isapprox(err3[end] / (err3[1]/2^(maxiter-1)), 1f0; atol=1f0)
-@test isapprox(err4[end] / (err4[1]/4^(maxiter-1)), 1f0; atol=1f0)
-
-
-N = (nx,ny,nz)
-########################################### Test with split_scales = true N = (nx,ny,nz) #########################
-# Invertibility
-
-# Network and input
-G = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K;split_scales=split_scales,ndims=length(N))
-X = rand(Float32, N..., n_in, batchsize)
-Cond = rand(Float32, N..., n_cond, batchsize)
-
-Y, Cond = G.forward(X,Cond)
-X_ = G.inverse(Y,Cond) # saving the cond is important in split scales because of reshapes
-
-@test isapprox(norm(X - X_)/norm(X), 0f0; atol=1f-5)
-
-# Test gradients are set and cleared
-G.backward(Y, Y, Cond)
-
-P = get_params(G)
-gsum = 0
-for p in P
- ~isnothing(p.grad) && (global gsum += 1)
-end
-@test isequal(gsum, L*K*10+2)
-
-clear_grad!(G)
-gsum = 0
-for p in P
- ~isnothing(p.grad) && (global gsum += 1)
-end
-@test isequal(gsum, 0)
-
-
-# Gradient test
-
-
-# Gradient test w.r.t. input
-G = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K;split_scales=split_scales,ndims=length(N))
-X = rand(Float32, N..., n_in, batchsize)
-Cond = rand(Float32, N..., n_cond, batchsize)
-X0 = rand(Float32, N..., n_in, batchsize)
-Cond0 = rand(Float32, N..., n_cond, batchsize)
-
-dX = X - X0
-
-f0, ΔX = loss(G, X0, Cond0)[1:2]
-h = 0.1f0
-maxiter = 4
-err1 = zeros(Float32, maxiter)
-err2 = zeros(Float32, maxiter)
-
-print("\nGradient test glow: input\n")
-for j=1:maxiter
- f = loss(G, X0 + h*dX, Cond0)[1]
- err1[j] = abs(f - f0)
- err2[j] = abs(f - f0 - h*dot(dX, ΔX))
- print(err1[j], "; ", err2[j], "\n")
- global h = h/2f0
-end
-
-@test isapprox(err1[end] / (err1[1]/2^(maxiter-1)), 1f0; atol=1f0)
-@test isapprox(err2[end] / (err2[1]/4^(maxiter-1)), 1f0; atol=1f0)
-
-
-# Gradient test w.r.t. parameters
-X = rand(Float32, N..., n_in, batchsize)
-G = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K;split_scales=split_scales,ndims=length(N))
-G0 = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K;split_scales=split_scales,ndims=length(N))
-Gini = deepcopy(G0)
-
-# Test one parameter from residual block and 1x1 conv
-dW = G.CL[1,1].RB.W1.data - G0.CL[1,1].RB.W1.data
-dv = G.CL[1,1].C.v1.data - G0.CL[1,1].C.v1.data
-
-f0, ΔX, ΔW, Δv = loss(G0, X, Cond)
-h = 0.1f0
-maxiter = 4
-err3 = zeros(Float32, maxiter)
-err4 = zeros(Float32, maxiter)
-
-print("\nGradient test glow: input\n")
-for j=1:maxiter
- G0.CL[1,1].RB.W1.data = Gini.CL[1,1].RB.W1.data + h*dW
- G0.CL[1,1].C.v1.data = Gini.CL[1,1].C.v1.data + h*dv
-
- f = loss(G0, X, Cond)[1]
- err3[j] = abs(f - f0)
- err4[j] = abs(f - f0 - h*dot(dW, ΔW) - h*dot(dv, Δv))
- print(err3[j], "; ", err4[j], "\n")
- global h = h/2f0
-end
-
-@test isapprox(err3[end] / (err3[1]/2^(maxiter-1)), 1f0; atol=1f0)
-@test isapprox(err4[end] / (err4[1]/4^(maxiter-1)), 1f0; atol=1f0)
-
-
-########################################### Test with split_scales = true N = (nx,ny,nz) and Summary network #########################
-# Invertibility
-sum_net_3d = ResNet(n_cond, 16, 3; ndims=3, norm=nothing) # make sure it doesnt have any weird normalizati8ons
-
-# Network and input
-flow = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K; split_scales=split_scales,ndims=length(N));
-G = SummarizedNet(flow, sum_net_3d)
-
-X = rand(Float32, N..., n_in, batchsize);
-Cond = rand(Float32, N..., n_cond, batchsize);
-
-Y, ZCond = G.forward(X,Cond);
-X_ = G.inverse(Y,ZCond); # saving the cond is important in split scales because of reshapes
-
-@test isapprox(norm(X - X_)/norm(X), 0f0; atol=1f-5)
-
-# Test gradients are set and cleared
-G.backward(Y, Y, ZCond; Y_save=Cond)
-
-P = get_params(G)
-gsum = 0
-for p in P
- ~isnothing(p.grad) && (global gsum += 1)
-end
-@test isequal(gsum, L*K*10+2+12)
-
-clear_grad!(G)
-gsum = 0
-for p in P
- ~isnothing(p.grad) && (global gsum += 1)
-end
-@test isequal(gsum, 0)
-
-
-# Gradient test
-
-
-# Gradient test w.r.t. input
-X = rand(Float32, N..., n_in, batchsize);
-Cond = rand(Float32, N..., n_cond, batchsize);
-X0 = rand(Float32, N..., n_in, batchsize);
-Cond0 = rand(Float32, N..., n_cond, batchsize);
-
-dX = X - X0;
-
-f0, ΔX = loss_sum(G, X0, Cond0)[1:2];
-h = 0.1f0
-maxiter = 4
-err1 = zeros(Float32, maxiter)
-err2 = zeros(Float32, maxiter)
-
-print("\nGradient test glow: input\n")
-for j=1:maxiter
- f = loss_sum(G, X0 + h*dX, Cond0)[1]
- err1[j] = abs(f - f0)
- err2[j] = abs(f - f0 - h*dot(dX, ΔX))
- print(err1[j], "; ", err2[j], "\n")
- global h = h/2f0
-end
-
-@test isapprox(err1[end] / (err1[1]/2^(maxiter-1)), 1f0; atol=1f0)
-@test isapprox(err2[end] / (err2[1]/4^(maxiter-1)), 1f0; atol=1f0)
-
-# Gradient test w.r.t. parameters
-X = rand(Float32, N..., n_in, batchsize)
-flow0 = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K; split_scales=split_scales,ndims=length(N))
-G0 = SummarizedNet(flow0, sum_net_3d)
-Gini = deepcopy(G0)
-
-# Test one parameter from residual block and 1x1 conv
-dW = G.cond_net.CL[1,1].RB.W1.data - G0.cond_net.CL[1,1].RB.W1.data
-dv = G.cond_net.CL[1,1].C.v1.data - G0.cond_net.CL[1,1].C.v1.data
-
-f0, ΔX, ΔW, Δv = loss_sum(G0, X, Cond);
-h = 0.1f0
-maxiter = 4
-err3 = zeros(Float32, maxiter)
-err4 = zeros(Float32, maxiter)
-
-print("\nGradient test glow: input\n")
-for j=1:maxiter
- G0.cond_net.CL[1,1].RB.W1.data = Gini.cond_net.CL[1,1].RB.W1.data + h*dW
- G0.cond_net.CL[1,1].C.v1.data = Gini.cond_net.CL[1,1].C.v1.data + h*dv
-
- f = loss_sum(G0, X, Cond)[1]
- err3[j] = abs(f - f0)
- err4[j] = abs(f - f0 - h*dot(dW, ΔW) - h*dot(dv, Δv))
- print(err3[j], "; ", err4[j], "\n")
- global h = h/2f0
-end
-
-@test isapprox(err3[end] / (err3[1]/2^(maxiter-1)), 1f0; atol=1f0)
-@test isapprox(err4[end] / (err4[1]/4^(maxiter-1)), 1f0; atol=1f0)
+# Define network
+nx = 16
+ny = 16
+nz = 16
+n_in = 4
+n_cond = 2
+n_hidden = 4
+batchsize = 4
+L = 2
+K = 2
+stol = 1.5f0
+for split_scales in [false,true]
+ for N in [(16*nx),(nx,ny),(nx,ny,nz)]
+ println("Test with split_scales = $(split_scales) N = $(N)")
+
+ # Network and inputs
+ G = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K; split_scales=split_scales, ndims=length(N))
+
+ X = randn(Float32, N..., n_in, batchsize)
+ Cond = rand(Float32, N..., n_cond, batchsize)
+
+ # Invertibility
+ XZ, CondZ = G.forward(X,Cond)
+ X_ = G.inverse(XZ, CondZ) # saving the cond output is important in split scales because of reshapes
+ @test isapprox(norm(X - X_)/norm(X), 0f0; atol=1f-5)
+
+ ###################################################################################################
+ # Test gradients are set and cleared
+ gradients_set(G, n_in, n_cond,N;)
+
+ ###################################################################################################
+ # Gradient test w.r.t. input
+ X0 = randn(Float32, N..., n_in, batchsize)
+ Cond0 = randn(Float32, N..., n_cond, batchsize)
+
+ dX = X - X0
+
+ f0, ΔX = loss(G, X0, Cond0)[1:2]
+ h = 0.1f0
+ maxiter = 4
+ err1 = zeros(Float32, maxiter)
+ err2 = zeros(Float32, maxiter)
+
+ print("\nGradient test glow: input\n")
+ for j=1:maxiter
+ f = loss(G, X0 + h*dX, Cond0)[1]
+ err1[j] = abs(f - f0)
+ err2[j] = abs(f - f0 - h*dot(dX, ΔX))
+ print(err1[j], "; ", err2[j], "\n")
+ h = h/2f0
+ end
+
+ rate1 = err1[1:end-1]./err1[2:end]
+ rate2 = err2[1:end-1]./err2[2:end]
+
+ @test isapprox(mean(rate1), 2f0; atol=stol)
+ @test isapprox(mean(rate2), 4f0; atol=stol)
+
+ # Gradient test w.r.t. parameters
+ G = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K; split_scales=split_scales, ndims=length(N))
+ G0 = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K; split_scales=split_scales, ndims=length(N))
+ Gini = deepcopy(G0)
+
+ # Test one parameter from residual block
+ dW = G.CL[1,1].RB.W1.data - G0.CL[1,1].RB.W1.data
+
+ f0, ΔX, ΔW = loss(G0, X, Cond)
+ h = 0.1f0
+ maxiter = 4
+ err1 = zeros(Float32, maxiter)
+ err2 = zeros(Float32, maxiter)
+
+ print("\nGradient test glow: parameter\n")
+ for j=1:maxiter
+ G0.CL[1,1].RB.W1.data = Gini.CL[1,1].RB.W1.data + h*dW
+
+ f = loss(G0, X, Cond)[1]
+ err1[j] = abs(f - f0)
+ err2[j] = abs(f - f0 - h*dot(dW, ΔW))
+ print(err1[j], "; ", err2[j], "\n")
+ h = h/2f0
+ end
+
+ rate1 = err1[1:end-1]./err1[2:end]
+ rate2 = err2[1:end-1]./err2[2:end]
+
+ @test isapprox(mean(rate1),2f0; atol=stol)
+ @test isapprox(mean(rate2), 4f0; atol=stol)
+ end
+end
+
+# with summary network
+for split_scales in [false,true]
+ for N in [(16*nx),(nx,ny),(nx,ny,nz)]
+ println("Test with split_scales = $(split_scales) N = $(N) and summarized=$(true)")
+
+ # Network and inputs
+ G = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K; split_scales=split_scales, ndims=length(N))
+ sum_net = ResNet(n_cond, 16, 3; norm=nothing,ndims=length(N)) # make sure it doesnt have any weird normalizations
+ G = SummarizedNet(G, sum_net)
+
+ X = randn(Float32, N..., n_in, batchsize)
+ Cond = randn(Float32, N..., n_cond, batchsize)
+
+ # Invertibility
+ XZ, CondZ = G.forward(X,Cond)
+ X_ = G.inverse(XZ, CondZ) # saving the cond output is important in split scales because of reshapes
+ @test isapprox(norm(X - X_)/norm(X), 0f0; atol=1f-5)
+
+ ###################################################################################################
+ # Test gradients are set and cleared
+ gradients_set(G, n_in, n_cond,N; summarized=true)
+
+ ###################################################################################################
+ # Gradient test w.r.t. input
+ X0 = randn(Float32, N..., n_in, batchsize)
+ Cond0 = randn(Float32, N..., n_cond, batchsize)
+
+ dX = X - X0
+
+ f0, ΔX = loss(G, X0, Cond0; summarized=true)[1:2]
+ h = 0.1f0
+ maxiter = 4
+ err1 = zeros(Float32, maxiter)
+ err2 = zeros(Float32, maxiter)
+
+ print("\nGradient test glow: input\n")
+ for j=1:maxiter
+ f = loss(G, X0 + h*dX, Cond0; summarized=true)[1]
+ err1[j] = abs(f - f0)
+ err2[j] = abs(f - f0 - h*dot(dX, ΔX))
+ print(err1[j], "; ", err2[j], "\n")
+ h = h/2f0
+ end
+
+ rate1 = err1[1:end-1]./err1[2:end]
+ rate2 = err2[1:end-1]./err2[2:end]
+
+ @test isapprox(mean(rate1),2f0; atol=stol)
+ @test isapprox(mean(rate2), 4f0; atol=stol)
+
+ # Gradient test w.r.t. parameters
+ G0 = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K; split_scales=split_scales, ndims=length(N))
+ sum_net = ResNet(n_cond, 16, 3; norm=nothing,ndims=length(N)) # make sure it doesnt have any weird normalizations
+ G0 = SummarizedNet(G0, sum_net)
+ Gini = deepcopy(G0)
+
+ # Test one parameter from residual block
+ dW = G.cond_net.CL[1,1].RB.W1.data - G0.cond_net.CL[1,1].RB.W1.data
+
+ f0, ΔX, ΔW = loss(G0, X, Cond; summarized=true)
+ h = 0.1f0
+ maxiter = 4
+ err1 = zeros(Float32, maxiter)
+ err2 = zeros(Float32, maxiter)
+
+ print("\nGradient test glow: parameter\n")
+ for j=1:maxiter
+ G0.cond_net.CL[1,1].RB.W1.data = Gini.cond_net.CL[1,1].RB.W1.data + h*dW
+
+ f = loss(G0, X, Cond; summarized=true)[1]
+ err1[j] = abs(f - f0)
+ err2[j] = abs(f - f0 - h*dot(dW, ΔW))
+ print(err1[j], "; ", err2[j], "\n")
+ h = h/2f0
+ end
+
+ rate1 = err1[1:end-1]./err1[2:end]
+ rate2 = err2[1:end-1]./err2[2:end]
+
+ @test isapprox(mean(rate1),2f0; atol=stol)
+ @test isapprox(mean(rate2),4f0; atol=stol)
+ end
+end
\ No newline at end of file