diff --git a/src/conditional_layers/conditional_layer_glow.jl b/src/conditional_layers/conditional_layer_glow.jl index 2c3b895e..5fb6c5cd 100644 --- a/src/conditional_layers/conditional_layer_glow.jl +++ b/src/conditional_layers/conditional_layer_glow.jl @@ -80,6 +80,9 @@ function ConditionalLayerGlow(n_in::Int64, n_cond::Int64, n_hidden::Int64;freeze C = Conv1x1(n_in; freeze=freeze_conv) split_num = Int(round(n_in/2)) + if split_num == 0 + split_num = 1 + end in_split = n_in-split_num out_chan = 2*split_num @@ -95,6 +98,9 @@ function forward(X::AbstractArray{T, N}, C::AbstractArray{T, N}, L::ConditionalL X_ = L.C.forward(X) X1, X2 = tensor_split(X_) + if length(X1) == 0 + X1, X2 = X2, X1 + end Y2 = copy(X2) @@ -115,6 +121,9 @@ end function inverse(Y::AbstractArray{T, N}, C::AbstractArray{T, N}, L::ConditionalLayerGlow; save=false) where {T,N} Y1, Y2 = tensor_split(Y) + if length(Y1) == 0 + Y1, Y2 = Y2, Y1 + end X2 = copy(Y2) logS_T = L.RB.forward(tensor_cat(X2,C)) @@ -138,6 +147,9 @@ function backward(ΔY::AbstractArray{T, N}, Y::AbstractArray{T, N}, C::AbstractA # Backpropagate residual ΔY1, ΔY2 = tensor_split(ΔY) + if length(ΔY1) == 0 + ΔY1, ΔY2 = ΔY2, ΔY1 + end ΔT = copy(ΔY1) ΔS = ΔY1 .* X1 ΔX1 = ΔY1 .* S diff --git a/test/test_layers/test_coupling_layer_irim.jl b/test/test_layers/test_coupling_layer_irim.jl index b2522a3a..36c28126 100644 --- a/test/test_layers/test_coupling_layer_irim.jl +++ b/test/test_layers/test_coupling_layer_irim.jl @@ -105,7 +105,7 @@ dv2 = L.C.v2.data - L01.C.v2.data dv3 = L.C.v3.data - L01.C.v3.data f0, ΔX, Δv1, Δv2, Δv3, ΔW1, ΔW2, ΔW3 = loss(L01, X, Y) -h = 0.1f0 +h = 0.2f0 maxiter = 4 err5 = zeros(Float32, maxiter) err6 = zeros(Float32, maxiter) diff --git a/test/test_networks/test_conditional_glow_network.jl b/test/test_networks/test_conditional_glow_network.jl index 946b39b0..cb13d17b 100644 --- a/test/test_networks/test_conditional_glow_network.jl +++ b/test/test_networks/test_conditional_glow_network.jl @@ -6,6 +6,122 @@ using InvertibleNetworks, LinearAlgebra, Test,Flux, Random device = InvertibleNetworks.CUDA.functional() ? gpu : cpu (device == gpu) && println("Testing on GPU"); +########################################### Test with split_scales = false N = (nx,ny) ######################### +Random.seed!(3); +# Define network +nx = 1; ny = 1; +n_in = 1 +n_cond = 1 +n_hidden = 4 +batchsize = 7 +L = 2 +K = 2 +split_scales = false +N = (nx,ny) + +# Invertibility + +# Network and input +G = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K; split_scales=split_scales,ndims=length(N)) |> device +X = rand(Float32, N..., n_in, batchsize) |> device +Cond = rand(Float32, N..., n_cond, batchsize) |> device + +Y, Cond = G.forward(X,Cond) +X_ = G.inverse(Y,Cond) # saving the cond is important in split scales because of reshapes + +@test isapprox(norm(X - X_)/norm(X), 0f0; atol=1f-5) + +# Test gradients are set and cleared +G.backward(Y, Y, Cond) + +P = get_params(G) +gsum = 0 +for p in P + ~isnothing(p.grad) && (global gsum += 1) +end +@test isequal(gsum, L*K*10+2) + +clear_grad!(G) +gsum = 0 +for p in P + ~isnothing(p.grad) && (global gsum += 1) +end +@test isequal(gsum, 0) + +# Gradient test + +function loss(G, X, Cond) + Y, ZC, logdet = G.forward(X, Cond) + f = -log_likelihood(Y) - logdet + ΔY = -∇log_likelihood(Y) + ΔX, X_ = G.backward(ΔY, Y, ZC) + return f, ΔX, G.CL[1,1].RB.W1.grad, G.CL[1,1].C.v1.grad +end + + +# Gradient test w.r.t. input +Random.seed!(4); +G = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K;split_scales=split_scales,ndims=length(N)) |> device +X = rand(Float32, N..., n_in, batchsize) |> device +Cond = rand(Float32, N..., n_cond, batchsize) |> device +X0 = rand(Float32, N..., n_in, batchsize) |> device +Cond0 = rand(Float32, N..., n_cond, batchsize) |> device + +dX = X - X0 + +f0, ΔX = loss(G, X0, Cond0)[1:2] +h = 0.1f0 +maxiter = 4 +err1 = zeros(Float32, maxiter) +err2 = zeros(Float32, maxiter) + +print("\nGradient test glow: input\n") +for j=1:maxiter + f = loss(G, X0 + h*dX, Cond0)[1] + err1[j] = abs(f - f0) + err2[j] = abs(f - f0 - h*dot(dX, ΔX)) + print(err1[j], "; ", err2[j], "\n") + global h = h/2f0 +end + +@test isapprox(err1[end] / (err1[1]/2^(maxiter-1)), 1f0; atol=1f0) +@test isapprox(err2[end] / (err2[1]/4^(maxiter-1)), 1f0; atol=1f0) + + +# Gradient test w.r.t. parameters +Random.seed!(5); +X = rand(Float32, N..., n_in, batchsize) |> device +G = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K;split_scales=split_scales,ndims=length(N)) |> device +G0 = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K;split_scales=split_scales,ndims=length(N)) |> device +Gini = deepcopy(G0) + +# Test one parameter from residual block and 1x1 conv +dW = G.CL[1,1].RB.W1.data - G0.CL[1,1].RB.W1.data +dv = G.CL[1,1].C.v1.data - G0.CL[1,1].C.v1.data + +f0, ΔX, ΔW, Δv = loss(G0, X, Cond) +h = 0.1f0 +maxiter = 4 +err3 = zeros(Float32, maxiter) +err4 = zeros(Float32, maxiter) + +print("\nGradient test glow: params\n") +for j=1:maxiter + G0.CL[1,1].RB.W1.data = Gini.CL[1,1].RB.W1.data + h*dW + G0.CL[1,1].C.v1.data = Gini.CL[1,1].C.v1.data + h*dv + + f = loss(G0, X, Cond)[1] + err3[j] = abs(f - f0) + err4[j] = abs(f - f0 - h*dot(dW, ΔW) - h*dot(dv, Δv)) + print(err3[j], "; ", err4[j], "\n") + global h = h/2f0 +end + +@test isapprox(err3[end] / (err3[1]/2^(maxiter-1)), 1f0; atol=1f0) +@test isapprox(err4[end] / (err4[1]/4^(maxiter-1)), 1f0; atol=1f0) + +################################################################################################### + # Random seed Random.seed!(3); @@ -51,7 +167,9 @@ end @test isequal(gsum, 0) +# Random seed Random.seed!(3); + # Define network nx = 32; ny = 32; nz = 32 n_in = 2 @@ -106,6 +224,7 @@ end # Gradient test w.r.t. input +Random.seed!(4); G = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K;split_scales=split_scales,ndims=length(N)) |> device X = rand(Float32, N..., n_in, batchsize) |> device Cond = rand(Float32, N..., n_cond, batchsize) |> device @@ -132,8 +251,8 @@ end @test isapprox(err1[end] / (err1[1]/2^(maxiter-1)), 1f0; atol=1f0) @test isapprox(err2[end] / (err2[1]/4^(maxiter-1)), 1f0; atol=1f0) - # Gradient test w.r.t. parameters +Random.seed!(5); X = rand(Float32, N..., n_in, batchsize) |> device G = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K;split_scales=split_scales,ndims=length(N)) |> device G0 = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K;split_scales=split_scales,ndims=length(N)) |> device @@ -149,7 +268,7 @@ maxiter = 4 err3 = zeros(Float32, maxiter) err4 = zeros(Float32, maxiter) -print("\nGradient test glow: input\n") +print("\nGradient test glow: params\n") for j=1:maxiter G0.CL[1,1].RB.W1.data = Gini.CL[1,1].RB.W1.data + h*dW G0.CL[1,1].C.v1.data = Gini.CL[1,1].C.v1.data + h*dv @@ -167,6 +286,8 @@ end ########################################### Test with split_scales = true N = (nx,ny) and summary network ######################### +Random.seed!(3); + # Invertibility sum_net = ResNet(n_cond, 16, 3; norm=nothing) # make sure it doesnt have any weird normalizations @@ -210,6 +331,7 @@ function loss_sum(G, X, Cond) end # Gradient test w.r.t. input +Random.seed!(4); X = rand(Float32, N..., n_in, batchsize) |> device; Cond = rand(Float32, N..., n_cond, batchsize) |> device; X0 = rand(Float32, N..., n_in, batchsize) |> device; @@ -237,6 +359,7 @@ end # Gradient test w.r.t. parameters +Random.seed!(5); X = rand(Float32, N..., n_in, batchsize) |> device flow0 = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K; split_scales=split_scales,ndims=length(N)) |> device G0 = SummarizedNet(flow0, sum_net) |> device @@ -252,7 +375,7 @@ maxiter = 4 err3 = zeros(Float32, maxiter) err4 = zeros(Float32, maxiter) -print("\nGradient test glow: input\n") +print("\nGradient test glow: params\n") for j=1:maxiter G0.cond_net.CL[1,1].RB.W1.data = Gini.cond_net.CL[1,1].RB.W1.data + h*dW G0.cond_net.CL[1,1].C.v1.data = Gini.cond_net.CL[1,1].C.v1.data + h*dv @@ -270,6 +393,8 @@ end N = (nx,ny,nz) ########################################### Test with split_scales = true N = (nx,ny,nz) ######################### +Random.seed!(3); + # Invertibility # Network and input @@ -304,6 +429,7 @@ end # Gradient test w.r.t. input +Random.seed!(4); G = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K;split_scales=split_scales,ndims=length(N)) |> device X = rand(Float32, N..., n_in, batchsize) |> device Cond = rand(Float32, N..., n_cond, batchsize) |> device @@ -332,6 +458,7 @@ end # Gradient test w.r.t. parameters +Random.seed!(5); X = rand(Float32, N..., n_in, batchsize) |> device G = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K;split_scales=split_scales,ndims=length(N)) |> device G0 = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K;split_scales=split_scales,ndims=length(N)) |> device @@ -347,7 +474,7 @@ maxiter = 4 err3 = zeros(Float32, maxiter) err4 = zeros(Float32, maxiter) -print("\nGradient test glow: input\n") +print("\nGradient test glow: params\n") for j=1:maxiter G0.CL[1,1].RB.W1.data = Gini.CL[1,1].RB.W1.data + h*dW G0.CL[1,1].C.v1.data = Gini.CL[1,1].C.v1.data + h*dv @@ -364,6 +491,8 @@ end ########################################### Test with split_scales = true N = (nx,ny,nz) and Summary network ######################### +Random.seed!(3); + # Invertibility sum_net_3d = ResNet(n_cond, 16, 3; ndims=3, norm=nothing) |> device# make sure it doesnt have any weird normalizati8ons @@ -401,6 +530,7 @@ end # Gradient test w.r.t. input +Random.seed!(4); X = rand(Float32, N..., n_in, batchsize) |> device; Cond = rand(Float32, N..., n_cond, batchsize) |> device; X0 = rand(Float32, N..., n_in, batchsize) |> device; @@ -427,6 +557,7 @@ end @test isapprox(err2[end] / (err2[1]/4^(maxiter-1)), 1f0; atol=1f0) # Gradient test w.r.t. parameters +Random.seed!(5); X = rand(Float32, N..., n_in, batchsize) |> device flow0 = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K; split_scales=split_scales,ndims=length(N)) |> device G0 = SummarizedNet(flow0, sum_net_3d) |> device @@ -442,7 +573,7 @@ maxiter = 4 err3 = zeros(Float32, maxiter) err4 = zeros(Float32, maxiter) -print("\nGradient test glow: input\n") +print("\nGradient test glow: params\n") for j=1:maxiter G0.cond_net.CL[1,1].RB.W1.data = Gini.cond_net.CL[1,1].RB.W1.data + h*dW G0.cond_net.CL[1,1].C.v1.data = Gini.cond_net.CL[1,1].C.v1.data + h*dv @@ -456,4 +587,3 @@ end @test isapprox(err3[end] / (err3[1]/2^(maxiter-1)), 1f0; atol=1f0) @test isapprox(err4[end] / (err4[1]/4^(maxiter-1)), 1f0; atol=1f0) - diff --git a/test/test_networks/test_multiscale_conditional_hint_network.jl b/test/test_networks/test_multiscale_conditional_hint_network.jl index 260906ee..d94536aa 100644 --- a/test/test_networks/test_multiscale_conditional_hint_network.jl +++ b/test/test_networks/test_multiscale_conditional_hint_network.jl @@ -76,7 +76,7 @@ function grad_test_X(nx, ny, n_channel, batchsize, logdet, squeeze_type, split_s f0, gX, gY = loss(CH, X0, Y0)[1:3] maxiter = 5 - h = 0.1f0 + h = 5f-2 err1 = zeros(Float32, maxiter) err2 = zeros(Float32, maxiter)