Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions src/conditional_layers/conditional_layer_glow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ function ConditionalLayerGlow(n_in::Int64, n_cond::Int64, n_hidden::Int64;freeze
C = Conv1x1(n_in; freeze=freeze_conv)

split_num = Int(round(n_in/2))
if split_num == 0
split_num = 1
end
in_split = n_in-split_num
out_chan = 2*split_num

Expand All @@ -95,6 +98,9 @@ function forward(X::AbstractArray{T, N}, C::AbstractArray{T, N}, L::ConditionalL

X_ = L.C.forward(X)
X1, X2 = tensor_split(X_)
if length(X1) == 0
X1, X2 = X2, X1
end

Y2 = copy(X2)

Expand All @@ -115,6 +121,9 @@ end
function inverse(Y::AbstractArray{T, N}, C::AbstractArray{T, N}, L::ConditionalLayerGlow; save=false) where {T,N}

Y1, Y2 = tensor_split(Y)
if length(Y1) == 0
Y1, Y2 = Y2, Y1
end

X2 = copy(Y2)
logS_T = L.RB.forward(tensor_cat(X2,C))
Expand All @@ -138,6 +147,9 @@ function backward(ΔY::AbstractArray{T, N}, Y::AbstractArray{T, N}, C::AbstractA

# Backpropagate residual
ΔY1, ΔY2 = tensor_split(ΔY)
if length(ΔY1) == 0
ΔY1, ΔY2 = ΔY2, ΔY1
end
ΔT = copy(ΔY1)
ΔS = ΔY1 .* X1
ΔX1 = ΔY1 .* S
Expand Down
2 changes: 1 addition & 1 deletion test/test_layers/test_coupling_layer_irim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ dv2 = L.C.v2.data - L01.C.v2.data
dv3 = L.C.v3.data - L01.C.v3.data

f0, ΔX, Δv1, Δv2, Δv3, ΔW1, ΔW2, ΔW3 = loss(L01, X, Y)
h = 0.1f0
h = 0.2f0
maxiter = 4
err5 = zeros(Float32, maxiter)
err6 = zeros(Float32, maxiter)
Expand Down
142 changes: 136 additions & 6 deletions test/test_networks/test_conditional_glow_network.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,122 @@ using InvertibleNetworks, LinearAlgebra, Test,Flux, Random
device = InvertibleNetworks.CUDA.functional() ? gpu : cpu
(device == gpu) && println("Testing on GPU");

########################################### Test with split_scales = false N = (nx,ny) #########################
Random.seed!(3);
# Define network
nx = 1; ny = 1;
n_in = 1
n_cond = 1
n_hidden = 4
batchsize = 7
L = 2
K = 2
split_scales = false
N = (nx,ny)

# Invertibility

# Network and input
G = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K; split_scales=split_scales,ndims=length(N)) |> device
X = rand(Float32, N..., n_in, batchsize) |> device
Cond = rand(Float32, N..., n_cond, batchsize) |> device

Y, Cond = G.forward(X,Cond)
X_ = G.inverse(Y,Cond) # saving the cond is important in split scales because of reshapes

@test isapprox(norm(X - X_)/norm(X), 0f0; atol=1f-5)

# Test gradients are set and cleared
G.backward(Y, Y, Cond)

P = get_params(G)
gsum = 0
for p in P
~isnothing(p.grad) && (global gsum += 1)
end
@test isequal(gsum, L*K*10+2)

clear_grad!(G)
gsum = 0
for p in P
~isnothing(p.grad) && (global gsum += 1)
end
@test isequal(gsum, 0)

# Gradient test

function loss(G, X, Cond)
Y, ZC, logdet = G.forward(X, Cond)
f = -log_likelihood(Y) - logdet
ΔY = -∇log_likelihood(Y)
ΔX, X_ = G.backward(ΔY, Y, ZC)
return f, ΔX, G.CL[1,1].RB.W1.grad, G.CL[1,1].C.v1.grad
end


# Gradient test w.r.t. input
Random.seed!(4);
G = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K;split_scales=split_scales,ndims=length(N)) |> device
X = rand(Float32, N..., n_in, batchsize) |> device
Cond = rand(Float32, N..., n_cond, batchsize) |> device
X0 = rand(Float32, N..., n_in, batchsize) |> device
Cond0 = rand(Float32, N..., n_cond, batchsize) |> device

dX = X - X0

f0, ΔX = loss(G, X0, Cond0)[1:2]
h = 0.1f0
maxiter = 4
err1 = zeros(Float32, maxiter)
err2 = zeros(Float32, maxiter)

print("\nGradient test glow: input\n")
for j=1:maxiter
f = loss(G, X0 + h*dX, Cond0)[1]
err1[j] = abs(f - f0)
err2[j] = abs(f - f0 - h*dot(dX, ΔX))
print(err1[j], "; ", err2[j], "\n")
global h = h/2f0
end

@test isapprox(err1[end] / (err1[1]/2^(maxiter-1)), 1f0; atol=1f0)
@test isapprox(err2[end] / (err2[1]/4^(maxiter-1)), 1f0; atol=1f0)


# Gradient test w.r.t. parameters
Random.seed!(5);
X = rand(Float32, N..., n_in, batchsize) |> device
G = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K;split_scales=split_scales,ndims=length(N)) |> device
G0 = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K;split_scales=split_scales,ndims=length(N)) |> device
Gini = deepcopy(G0)

# Test one parameter from residual block and 1x1 conv
dW = G.CL[1,1].RB.W1.data - G0.CL[1,1].RB.W1.data
dv = G.CL[1,1].C.v1.data - G0.CL[1,1].C.v1.data

f0, ΔX, ΔW, Δv = loss(G0, X, Cond)
h = 0.1f0
maxiter = 4
err3 = zeros(Float32, maxiter)
err4 = zeros(Float32, maxiter)

print("\nGradient test glow: params\n")
for j=1:maxiter
G0.CL[1,1].RB.W1.data = Gini.CL[1,1].RB.W1.data + h*dW
G0.CL[1,1].C.v1.data = Gini.CL[1,1].C.v1.data + h*dv

f = loss(G0, X, Cond)[1]
err3[j] = abs(f - f0)
err4[j] = abs(f - f0 - h*dot(dW, ΔW) - h*dot(dv, Δv))
print(err3[j], "; ", err4[j], "\n")
global h = h/2f0
end

@test isapprox(err3[end] / (err3[1]/2^(maxiter-1)), 1f0; atol=1f0)
@test isapprox(err4[end] / (err4[1]/4^(maxiter-1)), 1f0; atol=1f0)

###################################################################################################

# Random seed
Random.seed!(3);

Expand Down Expand Up @@ -51,7 +167,9 @@ end
@test isequal(gsum, 0)


# Random seed
Random.seed!(3);

# Define network
nx = 32; ny = 32; nz = 32
n_in = 2
Expand Down Expand Up @@ -106,6 +224,7 @@ end


# Gradient test w.r.t. input
Random.seed!(4);
G = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K;split_scales=split_scales,ndims=length(N)) |> device
X = rand(Float32, N..., n_in, batchsize) |> device
Cond = rand(Float32, N..., n_cond, batchsize) |> device
Expand All @@ -132,8 +251,8 @@ end
@test isapprox(err1[end] / (err1[1]/2^(maxiter-1)), 1f0; atol=1f0)
@test isapprox(err2[end] / (err2[1]/4^(maxiter-1)), 1f0; atol=1f0)


# Gradient test w.r.t. parameters
Random.seed!(5);
X = rand(Float32, N..., n_in, batchsize) |> device
G = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K;split_scales=split_scales,ndims=length(N)) |> device
G0 = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K;split_scales=split_scales,ndims=length(N)) |> device
Expand All @@ -149,7 +268,7 @@ maxiter = 4
err3 = zeros(Float32, maxiter)
err4 = zeros(Float32, maxiter)

print("\nGradient test glow: input\n")
print("\nGradient test glow: params\n")
for j=1:maxiter
G0.CL[1,1].RB.W1.data = Gini.CL[1,1].RB.W1.data + h*dW
G0.CL[1,1].C.v1.data = Gini.CL[1,1].C.v1.data + h*dv
Expand All @@ -167,6 +286,8 @@ end


########################################### Test with split_scales = true N = (nx,ny) and summary network #########################
Random.seed!(3);

# Invertibility
sum_net = ResNet(n_cond, 16, 3; norm=nothing) # make sure it doesnt have any weird normalizations

Expand Down Expand Up @@ -210,6 +331,7 @@ function loss_sum(G, X, Cond)
end

# Gradient test w.r.t. input
Random.seed!(4);
X = rand(Float32, N..., n_in, batchsize) |> device;
Cond = rand(Float32, N..., n_cond, batchsize) |> device;
X0 = rand(Float32, N..., n_in, batchsize) |> device;
Expand Down Expand Up @@ -237,6 +359,7 @@ end


# Gradient test w.r.t. parameters
Random.seed!(5);
X = rand(Float32, N..., n_in, batchsize) |> device
flow0 = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K; split_scales=split_scales,ndims=length(N)) |> device
G0 = SummarizedNet(flow0, sum_net) |> device
Expand All @@ -252,7 +375,7 @@ maxiter = 4
err3 = zeros(Float32, maxiter)
err4 = zeros(Float32, maxiter)

print("\nGradient test glow: input\n")
print("\nGradient test glow: params\n")
for j=1:maxiter
G0.cond_net.CL[1,1].RB.W1.data = Gini.cond_net.CL[1,1].RB.W1.data + h*dW
G0.cond_net.CL[1,1].C.v1.data = Gini.cond_net.CL[1,1].C.v1.data + h*dv
Expand All @@ -270,6 +393,8 @@ end

N = (nx,ny,nz)
########################################### Test with split_scales = true N = (nx,ny,nz) #########################
Random.seed!(3);

# Invertibility

# Network and input
Expand Down Expand Up @@ -304,6 +429,7 @@ end


# Gradient test w.r.t. input
Random.seed!(4);
G = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K;split_scales=split_scales,ndims=length(N)) |> device
X = rand(Float32, N..., n_in, batchsize) |> device
Cond = rand(Float32, N..., n_cond, batchsize) |> device
Expand Down Expand Up @@ -332,6 +458,7 @@ end


# Gradient test w.r.t. parameters
Random.seed!(5);
X = rand(Float32, N..., n_in, batchsize) |> device
G = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K;split_scales=split_scales,ndims=length(N)) |> device
G0 = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K;split_scales=split_scales,ndims=length(N)) |> device
Expand All @@ -347,7 +474,7 @@ maxiter = 4
err3 = zeros(Float32, maxiter)
err4 = zeros(Float32, maxiter)

print("\nGradient test glow: input\n")
print("\nGradient test glow: params\n")
for j=1:maxiter
G0.CL[1,1].RB.W1.data = Gini.CL[1,1].RB.W1.data + h*dW
G0.CL[1,1].C.v1.data = Gini.CL[1,1].C.v1.data + h*dv
Expand All @@ -364,6 +491,8 @@ end


########################################### Test with split_scales = true N = (nx,ny,nz) and Summary network #########################
Random.seed!(3);

# Invertibility
sum_net_3d = ResNet(n_cond, 16, 3; ndims=3, norm=nothing) |> device# make sure it doesnt have any weird normalizati8ons

Expand Down Expand Up @@ -401,6 +530,7 @@ end


# Gradient test w.r.t. input
Random.seed!(4);
X = rand(Float32, N..., n_in, batchsize) |> device;
Cond = rand(Float32, N..., n_cond, batchsize) |> device;
X0 = rand(Float32, N..., n_in, batchsize) |> device;
Expand All @@ -427,6 +557,7 @@ end
@test isapprox(err2[end] / (err2[1]/4^(maxiter-1)), 1f0; atol=1f0)

# Gradient test w.r.t. parameters
Random.seed!(5);
X = rand(Float32, N..., n_in, batchsize) |> device
flow0 = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K; split_scales=split_scales,ndims=length(N)) |> device
G0 = SummarizedNet(flow0, sum_net_3d) |> device
Expand All @@ -442,7 +573,7 @@ maxiter = 4
err3 = zeros(Float32, maxiter)
err4 = zeros(Float32, maxiter)

print("\nGradient test glow: input\n")
print("\nGradient test glow: params\n")
for j=1:maxiter
G0.cond_net.CL[1,1].RB.W1.data = Gini.cond_net.CL[1,1].RB.W1.data + h*dW
G0.cond_net.CL[1,1].C.v1.data = Gini.cond_net.CL[1,1].C.v1.data + h*dv
Expand All @@ -456,4 +587,3 @@ end

@test isapprox(err3[end] / (err3[1]/2^(maxiter-1)), 1f0; atol=1f0)
@test isapprox(err4[end] / (err4[1]/4^(maxiter-1)), 1f0; atol=1f0)

Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ function grad_test_X(nx, ny, n_channel, batchsize, logdet, squeeze_type, split_s
f0, gX, gY = loss(CH, X0, Y0)[1:3]

maxiter = 5
h = 0.1f0
h = 5f-2
err1 = zeros(Float32, maxiter)
err2 = zeros(Float32, maxiter)

Expand Down