From 35cf1e4eead8ca554529f90a89638d67bdc702f2 Mon Sep 17 00:00:00 2001 From: rafaelorozco Date: Tue, 1 Mar 2022 08:54:02 -0500 Subject: [PATCH 1/9] take away view from adjoint 1x1conv --- src/layers/invertible_layer_conv1x1.jl | 11 ++++-- src/networks/invertible_network_irim.jl | 47 +++++++++++++++++++++++-- 2 files changed, 53 insertions(+), 5 deletions(-) diff --git a/src/layers/invertible_layer_conv1x1.jl b/src/layers/invertible_layer_conv1x1.jl index a67b4b76..6422912e 100644 --- a/src/layers/invertible_layer_conv1x1.jl +++ b/src/layers/invertible_layer_conv1x1.jl @@ -153,14 +153,19 @@ function conv1x1_grad_v(X::AbstractArray{T, N}, ΔY::AbstractArray{T, N}, for i=1:k # ∂V1 mul!(tmp, ∂V1[i, :, :], M1) - @views adjoint ? adjoint!(∂V1[i, :, :], tmp) : copyto!(∂V1[i, :, :], tmp) + #@views adjoint ? adjoint!(∂V1[i, :, :], tmp) : copyto!(∂V1[i, :, :], tmp) + adjoint ? adjoint!(∂V1[i, :, :], tmp) : copyto!(∂V1[i, :, :], tmp) + # ∂V2 v2 = ∂V2[i, :, :] broadcast!(+, tmp, v2, 4 * V1 * v2 * V3 - 2 * (V1 * v2 + v2 * V3)) - @views adjoint ? adjoint!(∂V2[i, :, :], tmp) : copyto!(∂V2[i, :, :], tmp) + #@views adjoint ? adjoint!(∂V2[i, :, :], tmp) : copyto!(∂V2[i, :, :], tmp) + adjoint ? adjoint!(∂V2[i, :, :], tmp) : copyto!(∂V2[i, :, :], tmp) + # ∂V3 mul!(tmp, M3, ∂V3[i, :, :]) - @views adjoint ? adjoint!(∂V3[i, :, :], tmp) : copyto!(∂V3[i, :, :], tmp) + #@views adjoint ? adjoint!(∂V3[i, :, :], tmp) : copyto!(∂V3[i, :, :], tmp) + adjoint ? adjoint!(∂V3[i, :, :], tmp) : copyto!(∂V3[i, :, :], tmp) end prod_res = cuzeros(X, size(∂V1, 1), prod(size(X)[1:N-2]), n_in) diff --git a/src/networks/invertible_network_irim.jl b/src/networks/invertible_network_irim.jl index 1effe9b7..d0c115d5 100644 --- a/src/networks/invertible_network_irim.jl +++ b/src/networks/invertible_network_irim.jl @@ -92,7 +92,7 @@ end NetworkLoop3D(args...; kw...) = NetworkLoop(args...; kw..., ndims=3) # 2D Forward loop: Input (η, s), Output (η, s) -function forward(η::AbstractArray{T, N}, s::AbstractArray{T, N}, d::AbstractArray, J, UL::NetworkLoop) where {T, N} +function forward(η::AbstractArray{T, N}, s::AbstractArray{T, N}, d::Union{AbstractArray, Nothing}, J, UL::NetworkLoop; g=nothing) where {T, N} # Dimensions n_in = size(s, N-1) + 1 @@ -101,8 +101,10 @@ function forward(η::AbstractArray{T, N}, s::AbstractArray{T, N}, d::AbstractArr maxiter = length(UL.L) N0 = cuzeros(η, nn..., n_in-2, batchsize) + isnothing(g) && (g = 1) + for j=1:maxiter - g = J'*(J*reshape(UL.Ψ(η), :, batchsize) - reshape(d, :, batchsize)) + isnothing(g) && (g = J'*(J*reshape(UL.Ψ(η), :, batchsize) - reshape(d, :, batchsize))) g = reshape(g, nn..., 1, batchsize) gn = UL.AN[j].forward(g) # normalize s_ = s + tensor_cat(gn, N0) @@ -176,6 +178,47 @@ function backward(Δη::AbstractArray{T, N}, Δs::AbstractArray{T, N}, set_grad ? (return Δη, Δs, η, s) : (Δη, Δs, Δθ, η, s) end +# 2D Backward loop: Input (Δη, Δs, η, s), Output (Δη, Δs, η, s) +function backward(Δη::AbstractArray{T, N}, Δs::AbstractArray{T, N}, + η::AbstractArray{T, N}, s::AbstractArray{T, N}, J, UL::NetworkLoop; g=nothing, set_grad::Bool=true) where {T, N} + + # Dimensions + n_in = size(s, N-1) + 1 + batchsize = size(s)[end] + nn = size(s)[1:N-2] + maxiter = length(UL.L) + + N0 = cuzeros(Δη, nn..., n_in-2, batchsize) + typeof(Δs) == T && (Δs = 0 .* s) # make Δs zero tensor + + # Initialize net parameters + set_grad && (Δθ = Array{Parameter, 1}(undef, 0)) + + j = 1 + #for j = maxiter:-1:1 + if set_grad + Δηs_, ηs_ = UL.L[j].backward(tensor_cat(Δη, Δs), tensor_cat(η, s)) + else + Δηs_, Δθ_L, ηs_ = UL.L[j].backward(tensor_cat(Δη, Δs), tensor_cat(η, s); set_grad=set_grad) + push!(Δθ, Δθ_L) + end + + # Inverse pass + η, s_ = tensor_split(ηs_; split_index=1) + #g = J'*(J*reshape(UL.Ψ(η), :, batchsize) - reshape(d, :, batchsize)) + g = reshape(g, nn..., 1, batchsize) + gn = UL.AN[j].forward(g) # normalize + s = s_ - tensor_cat(gn, N0) + + # Gradients + #Δs2, Δs = tensor_split(Δηs_; split_index=1) + #Δgn = tensor_split(Δs; split_index=1)[1] + #Δg = UL.AN[j].backward(Δgn, gn)[1] + #Δη = reshape(J'*J*reshape(Δg, :, batchsize), nn..., 1, batchsize) + Δs2 +#end + set_grad ? (return Δη, Δs, η, s) : (Δη, Δs, Δθ, η, s) +end + ## Jacobian-related utils jacobian(::AbstractArray{T, 5}, ::AbstractArray{T, 5}, d::AbstractArray, J, UL::NetworkLoop) where T = throw(ArgumentError("Jacobian for NetworkLoop not yet implemented")) From 2ad877262b0b05d41c4e18588570142b3c2d1b64 Mon Sep 17 00:00:00 2001 From: rafaelorozco Date: Tue, 1 Mar 2022 11:19:36 -0500 Subject: [PATCH 2/9] first commit with different dilations --- src/layers/invertible_layer_irim.jl | 111 +++++++++++++++--------- src/layers/layer_residual_block.jl | 11 ++- src/networks/invertible_network_irim.jl | 4 +- 3 files changed, 79 insertions(+), 47 deletions(-) diff --git a/src/layers/invertible_layer_irim.jl b/src/layers/invertible_layer_irim.jl index 4373b9f9..ac03fdad 100644 --- a/src/layers/invertible_layer_irim.jl +++ b/src/layers/invertible_layer_irim.jl @@ -58,19 +58,25 @@ or See also: [`Conv1x1`](@ref), [`ResidualBlock!`](@ref), [`get_params`](@ref), [`clear_grad!`](@ref) """ struct CouplingLayerIRIM <: NeuralNetLayer - C::Conv1x1 - RB::Union{ResidualBlock, FluxBlock} + C::AbstractArray{Conv1x1, 1} + RB::AbstractArray{ResidualBlock, 1} end @Flux.functor CouplingLayerIRIM # 2D Constructor from input dimensions -function CouplingLayerIRIM(n_in::Int64, n_hidden::Int64; +function CouplingLayerIRIM(n_in::Int64, n_hidden::Int64; ds=nothing, k1=4, k2=3, p1=0, p2=1, s1=4, s2=1, ndims=2) - # 1x1 Convolution and residual block for invertible layer - C = Conv1x1(n_in) - RB = ResidualBlock(n_in÷2, n_hidden; k1=k1, k2=k2, p1=p1, p2=p2, s1=s1, s2=s2, ndims=ndims) + + num_downsamp = length(ds) + C = Array{Conv1x1}(undef, num_downsamp) + RB = Array{ResidualBlock}(undef, num_downsamp) + + for j=1:num_downsamp + C[j] = Conv1x1(n_in) + RB[j] = ResidualBlock(n_in÷2, n_hidden; d=ds[j], k1=k1, k2=k2, p1=p1, p2=p2, s1=s1, s2=s2, fan=false, ndims=ndims) + end return CouplingLayerIRIM(C, RB) end @@ -79,28 +85,37 @@ CouplingLayerIRIM3D(args...;kw...) = CouplingLayerIRIM(args...; kw..., ndims=3) # 2D Forward pass: Input X, Output Y function forward(X::AbstractArray{T, N}, L::CouplingLayerIRIM) where {T, N} - X_ = L.C.forward(X) - X1_, X2_ = tensor_split(X_) - Y1_ = X1_ - Y2_ = X2_ + L.RB.forward(Y1_) + num_downsamp = length(L.C) + for j=1:num_downsamp + println("here at $(j)") + X_ = L.C[j].forward(X) + X1_, X2_ = tensor_split(X_) - Y_ = tensor_cat(Y1_, Y2_) - Y = L.C.inverse(Y_) + Y1_ = X1_ + Y2_ = X2_ + L.RB[j].forward(Y1_) + + Y_ = tensor_cat(Y1_, Y2_) + X = L.C[j].inverse(Y_) + end - return Y + return X end # 2D Inverse pass: Input Y, Output X function inverse(Y::AbstractArray{T, N}, L::CouplingLayerIRIM; save=false) where {T, N} - Y_ = L.C.forward(Y) - Y1_, Y2_ = tensor_split(Y_) + + num_downsamp = length(L.C) + for j=1:num_downsamp + Y_ = L.C[j].forward(Y) + Y1_, Y2_ = tensor_split(Y_) - X1_ = Y1_ - X2_ = Y2_ - L.RB.forward(Y1_) + X1_ = Y1_ + X2_ = Y2_ - L.RB[j].forward(Y1_) - X_ = tensor_cat(X1_, X2_) - X = L.C.inverse(X_) + X_ = tensor_cat(X1_, X2_) + X = L.C[j].inverse(X_) + end if save == false return X @@ -113,31 +128,32 @@ end function backward(ΔY::AbstractArray{T, N}, Y::AbstractArray{T, N}, L::CouplingLayerIRIM; set_grad::Bool=true) where {T, N} # Recompute forward state - k = Int(L.C.k/2) - X, X_, Y1_ = inverse(Y, L; save=true) + #X, X_, Y1_ = inverse(Y, L; save=true) - # Backpropagate residual - if set_grad - ΔY_ = L.C.forward((ΔY, Y))[1] - else - ΔY_, Δθ_C1 = L.C.forward((ΔY, Y); set_grad=set_grad)[1:2] - end - ΔYl_, ΔYr_ = tensor_split(ΔY_) - if set_grad - ΔY1_ = L.RB.backward(ΔYr_, Y1_) + ΔYl_ - else - ΔY1_, Δθ_RB = L.RB.backward(ΔYr_, Y1_; set_grad=set_grad) - ΔY1_ = ΔY1_ + ΔYl_ - end - - ΔX_ = tensor_cat(ΔY1_, ΔYr_) - if set_grad - ΔX = L.C.inverse((ΔX_, X_))[1] - else - ΔX, Δθ_C2 = L.C.inverse((ΔX_, X_); set_grad=set_grad)[1:2] + num_downsamp = length(L.C) + for j=1:num_downsamp + + # Recompute forward state + Y_ = L.C[j].forward(Y) + Y1_, Y2_ = tensor_split(Y_) + + X1_ = Y1_ + X2_ = Y2_ - L.RB[j].forward(Y1_) + X_ = tensor_cat(X1_, X2_) + X = L.C[j].inverse(X_) + + + # Backpropagate residual + ΔY_, Y_ = L.C[j].forward((ΔY, Y)) + ΔYl_, ΔYr_ = tensor_split(ΔY_) + + ΔY1_ = L.RB[j].backward(ΔYr_, Y1_) + ΔYl_ + ΔX_ = tensor_cat(ΔY1_, ΔYr_) + + ΔY, Y = L.C[j].inverse((ΔX_, X_)) end - set_grad ? (return ΔX, X) : (return ΔX, cat(Δθ_C1+Δθ_C2, Δθ_RB; dims=1), X) + return ΔY, Y end ## Jacobian utilities @@ -180,7 +196,16 @@ end # Get parameters function get_params(L::CouplingLayerIRIM) - p1 = get_params(L.C) - p2 = get_params(L.RB) + maxiter = length(L.C) + + p1 = get_params(L.C[1]) + p2 = get_params(L.RB[1]) + if maxiter > 1 + for j=2:maxiter + p1 = cat(p1, get_params(L.C[j]); dims=1) + p2 = cat(p2, get_params(L.RB[j]); dims=1) + end + end + return cat(p1, p2; dims=1) end \ No newline at end of file diff --git a/src/layers/layer_residual_block.jl b/src/layers/layer_residual_block.jl index d7cce4cb..72752499 100644 --- a/src/layers/layer_residual_block.jl +++ b/src/layers/layer_residual_block.jl @@ -75,14 +75,21 @@ end # Constructors # Constructor -function ResidualBlock(n_in, n_hidden; k1=3, k2=3, p1=1, p2=1, s1=1, s2=1, fan=false, ndims=2) +function ResidualBlock(n_in, n_hidden; d=nothing, k1=3, k2=3, p1=1, p2=1, s1=1, s2=1, fan=false, ndims=2) + + if !isnothing(d) + k1 = d + s1 = d + p1 = 0 + end + k1 = Tuple(k1 for i=1:ndims) k2 = Tuple(k2 for i=1:ndims) # Initialize weights W1 = Parameter(glorot_uniform(k1..., n_in, n_hidden)) W2 = Parameter(glorot_uniform(k2..., n_hidden, n_hidden)) - W3 = Parameter(glorot_uniform(k1..., 2*n_in, n_hidden)) + W3 = Parameter(glorot_uniform(k1..., 2*n_in, n_hidden)) # will be a transpose so output is 2*in b1 = Parameter(zeros(Float32, n_hidden)) b2 = Parameter(zeros(Float32, n_hidden)) diff --git a/src/networks/invertible_network_irim.jl b/src/networks/invertible_network_irim.jl index d0c115d5..1ca2f81e 100644 --- a/src/networks/invertible_network_irim.jl +++ b/src/networks/invertible_network_irim.jl @@ -66,7 +66,7 @@ end @Flux.functor NetworkLoop # 2D Constructor -function NetworkLoop(n_in, n_hidden, maxiter, Ψ; k1=4, k2=3, p1=0, p2=1, s1=4, s2=1, type="additive", ndims=2) +function NetworkLoop(n_in, n_hidden, maxiter, Ψ; ds=nothing, k1=4, k2=3, p1=0, p2=1, s1=4, s2=1, type="additive", ndims=2) if type == "additive" L = Array{CouplingLayerIRIM}(undef, maxiter) @@ -77,7 +77,7 @@ function NetworkLoop(n_in, n_hidden, maxiter, Ψ; k1=4, k2=3, p1=0, p2=1, s1=4, AN = Array{ActNorm}(undef, maxiter) for j=1:maxiter if type == "additive" - L[j] = CouplingLayerIRIM(n_in, n_hidden; k1=k1, k2=k2, p1=p1, p2=p2, s1=s1, s2=s2, ndims=ndims) + L[j] = CouplingLayerIRIM(n_in, n_hidden; ds=ds, k1=k1, k2=k2, p1=p1, p2=p2, s1=s1, s2=s2, ndims=ndims) elseif type == "HINT" L[j] = CouplingLayerHINT(n_in, n_hidden; logdet=false, permute="both", k1=k1, k2=k2, p1=p1, p2=p2, s1=s1, s2=s2, ndims=ndims) From 2b517def0cb5388f57a71b9cc3862f83d6238c4b Mon Sep 17 00:00:00 2001 From: rafaelorozco Date: Tue, 1 Mar 2022 18:31:27 -0500 Subject: [PATCH 3/9] fix loop index --- src/layers/invertible_layer_irim.jl | 18 +++++++++++------- src/networks/invertible_network_irim.jl | 4 ++-- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/src/layers/invertible_layer_irim.jl b/src/layers/invertible_layer_irim.jl index ac03fdad..fc00025a 100644 --- a/src/layers/invertible_layer_irim.jl +++ b/src/layers/invertible_layer_irim.jl @@ -65,7 +65,7 @@ end @Flux.functor CouplingLayerIRIM # 2D Constructor from input dimensions -function CouplingLayerIRIM(n_in::Int64, n_hidden::Int64; ds=nothing, +function CouplingLayerIRIM(n_in::Int64;n_hiddens=nothing, ds=nothing, k1=4, k2=3, p1=0, p2=1, s1=4, s2=1, ndims=2) @@ -75,7 +75,7 @@ function CouplingLayerIRIM(n_in::Int64, n_hidden::Int64; ds=nothing, for j=1:num_downsamp C[j] = Conv1x1(n_in) - RB[j] = ResidualBlock(n_in÷2, n_hidden; d=ds[j], k1=k1, k2=k2, p1=p1, p2=p2, s1=s1, s2=s2, fan=false, ndims=ndims) + RB[j] = ResidualBlock(n_in÷2, n_hiddens[j]; d=ds[j], k1=k1, k2=k2, p1=p1, p2=p2, s1=s1, s2=s2, fan=false, ndims=ndims) end return CouplingLayerIRIM(C, RB) @@ -88,7 +88,6 @@ function forward(X::AbstractArray{T, N}, L::CouplingLayerIRIM) where {T, N} num_downsamp = length(L.C) for j=1:num_downsamp - println("here at $(j)") X_ = L.C[j].forward(X) X1_, X2_ = tensor_split(X_) @@ -106,7 +105,7 @@ end function inverse(Y::AbstractArray{T, N}, L::CouplingLayerIRIM; save=false) where {T, N} num_downsamp = length(L.C) - for j=1:num_downsamp + for j=num_downsamp:-1:1 Y_ = L.C[j].forward(Y) Y1_, Y2_ = tensor_split(Y_) @@ -131,7 +130,7 @@ function backward(ΔY::AbstractArray{T, N}, Y::AbstractArray{T, N}, L::CouplingL #X, X_, Y1_ = inverse(Y, L; save=true) num_downsamp = length(L.C) - for j=1:num_downsamp + for j=num_downsamp:-1:1 # Recompute forward state Y_ = L.C[j].forward(Y) @@ -190,8 +189,13 @@ end # Clear gradients function clear_grad!(L::CouplingLayerIRIM) - clear_grad!(L.C) - clear_grad!(L.RB) + + maxiter = length(L.C) + + for j=1:maxiter + clear_grad!(L.C[j]) + clear_grad!(L.RB[j]) + end end # Get parameters diff --git a/src/networks/invertible_network_irim.jl b/src/networks/invertible_network_irim.jl index 1ca2f81e..8538b1d1 100644 --- a/src/networks/invertible_network_irim.jl +++ b/src/networks/invertible_network_irim.jl @@ -66,7 +66,7 @@ end @Flux.functor NetworkLoop # 2D Constructor -function NetworkLoop(n_in, n_hidden, maxiter, Ψ; ds=nothing, k1=4, k2=3, p1=0, p2=1, s1=4, s2=1, type="additive", ndims=2) +function NetworkLoop(n_in, maxiter, Ψ; n_hiddens=nothing, ds=nothing, k1=4, k2=3, p1=0, p2=1, s1=4, s2=1, type="additive", ndims=2) if type == "additive" L = Array{CouplingLayerIRIM}(undef, maxiter) @@ -77,7 +77,7 @@ function NetworkLoop(n_in, n_hidden, maxiter, Ψ; ds=nothing, k1=4, k2=3, p1=0, AN = Array{ActNorm}(undef, maxiter) for j=1:maxiter if type == "additive" - L[j] = CouplingLayerIRIM(n_in, n_hidden; ds=ds, k1=k1, k2=k2, p1=p1, p2=p2, s1=s1, s2=s2, ndims=ndims) + L[j] = CouplingLayerIRIM(n_in; n_hiddens=n_hiddens, ds=ds, k1=k1, k2=k2, p1=p1, p2=p2, s1=s1, s2=s2, ndims=ndims) elseif type == "HINT" L[j] = CouplingLayerHINT(n_in, n_hidden; logdet=false, permute="both", k1=k1, k2=k2, p1=p1, p2=p2, s1=s1, s2=s2, ndims=ndims) From 4fbd20310873fc537dcfe44d354cf78c99dbc276 Mon Sep 17 00:00:00 2001 From: rafaelorozco Date: Mon, 25 Apr 2022 10:53:23 -0400 Subject: [PATCH 4/9] clean up for PR --- src/layers/invertible_layer_conv1x1.jl | 5 -- src/layers/invertible_layer_irim.jl | 93 +++++++++++++------------ src/layers/layer_residual_block.jl | 5 +- src/networks/invertible_network_irim.jl | 51 ++------------ 4 files changed, 54 insertions(+), 100 deletions(-) diff --git a/src/layers/invertible_layer_conv1x1.jl b/src/layers/invertible_layer_conv1x1.jl index 6422912e..335ea803 100644 --- a/src/layers/invertible_layer_conv1x1.jl +++ b/src/layers/invertible_layer_conv1x1.jl @@ -153,18 +153,15 @@ function conv1x1_grad_v(X::AbstractArray{T, N}, ΔY::AbstractArray{T, N}, for i=1:k # ∂V1 mul!(tmp, ∂V1[i, :, :], M1) - #@views adjoint ? adjoint!(∂V1[i, :, :], tmp) : copyto!(∂V1[i, :, :], tmp) adjoint ? adjoint!(∂V1[i, :, :], tmp) : copyto!(∂V1[i, :, :], tmp) # ∂V2 v2 = ∂V2[i, :, :] broadcast!(+, tmp, v2, 4 * V1 * v2 * V3 - 2 * (V1 * v2 + v2 * V3)) - #@views adjoint ? adjoint!(∂V2[i, :, :], tmp) : copyto!(∂V2[i, :, :], tmp) adjoint ? adjoint!(∂V2[i, :, :], tmp) : copyto!(∂V2[i, :, :], tmp) # ∂V3 mul!(tmp, M3, ∂V3[i, :, :]) - #@views adjoint ? adjoint!(∂V3[i, :, :], tmp) : copyto!(∂V3[i, :, :], tmp) adjoint ? adjoint!(∂V3[i, :, :], tmp) : copyto!(∂V3[i, :, :], tmp) end @@ -255,9 +252,7 @@ function inverse(Y_tuple::Tuple, C::Conv1x1; set_grad::Bool=true) set_grad ? (return ΔX, X) : (return ΔX, Δθ, X) end - ## Jacobian-related functions - function jacobian(ΔX::AbstractArray{T, N}, Δθ::Array{Parameter, 1}, X::AbstractArray{T, N}, C::Conv1x1) where {T, N} Y = cuzeros(X, size(X)...) ΔY = cuzeros(ΔX, size(ΔX)...) diff --git a/src/layers/invertible_layer_irim.jl b/src/layers/invertible_layer_irim.jl index fc00025a..14f2a5ab 100644 --- a/src/layers/invertible_layer_irim.jl +++ b/src/layers/invertible_layer_irim.jl @@ -65,11 +65,23 @@ end @Flux.functor CouplingLayerIRIM # 2D Constructor from input dimensions -function CouplingLayerIRIM(n_in::Int64;n_hiddens=nothing, ds=nothing, +function CouplingLayerIRIM(n_in::Int64, n_hidden::Int64; n_hiddens=nothing, ds=nothing, k1=4, k2=3, p1=0, p2=1, s1=4, s2=1, ndims=2) + # Check if unet structure defined + if isnothing(n_hiddens) + num_downsamp = 1 + n_hiddens = [n_hidden] + ds = [4] + else + # Use user defined hidden channels and downsampling factors in ds + num_downsamp = length(n_hiddens) + end + + if num_downsamp != length(ds) + throw("Number of downsampling factors in ds must be the same defined hidden channels in n_hidden") + end - num_downsamp = length(ds) C = Array{Conv1x1}(undef, num_downsamp) RB = Array{ResidualBlock}(undef, num_downsamp) @@ -126,30 +138,20 @@ end # 2D Backward pass: Input (ΔY, Y), Output (ΔX, X) function backward(ΔY::AbstractArray{T, N}, Y::AbstractArray{T, N}, L::CouplingLayerIRIM; set_grad::Bool=true) where {T, N} - # Recompute forward state - #X, X_, Y1_ = inverse(Y, L; save=true) - num_downsamp = length(L.C) - for j=num_downsamp:-1:1 - - # Recompute forward state - Y_ = L.C[j].forward(Y) - Y1_, Y2_ = tensor_split(Y_) - - X1_ = Y1_ - X2_ = Y2_ - L.RB[j].forward(Y1_) - X_ = tensor_cat(X1_, X2_) - X = L.C[j].inverse(X_) - - - # Backpropagate residual + for j=num_downsamp:-1:1ß ΔY_, Y_ = L.C[j].forward((ΔY, Y)) + ΔYl_, ΔYr_ = tensor_split(ΔY_) - - ΔY1_ = L.RB[j].backward(ΔYr_, Y1_) + ΔYl_ - ΔX_ = tensor_cat(ΔY1_, ΔYr_) + Y1_, Y2_ = tensor_split(Y_) + + ΔYl_ .= L.RB[j].backwards(ΔYr_, Y1_) + ΔYl_ + Y2_ .-= L.RB[j].forward(Y1_) - ΔY, Y = L.C[j].inverse((ΔX_, X_)) + ΔY_ = tensor_cat(ΔYl_, ΔYr_) + Y_ = tensor_cat(Y1_, Y2_) + + ΔY, Y = L.C[j].inverse((ΔY_,Y_)) end return ΔY, Y @@ -160,20 +162,20 @@ end # 2D function jacobian(ΔX::AbstractArray{T, N}, Δθ::Array{Parameter, 1}, X::AbstractArray{T, N}, L::CouplingLayerIRIM) where {T, N} - # Get dimensions - k = Int(L.C.k/2) - - ΔX_, X_ = L.C.jacobian(ΔX, Δθ[1:3], X) - X1_, X2_ = tensor_split(X_) - ΔX1_, ΔX2_ = tensor_split(ΔX_) + num_downsamp = length(L.C) + for j=num_downsamp:-1:1 + ΔX_, X_ = L.C[j].jacobian(ΔX, Δθ[1:3], X) + X1_, X2_ = tensor_split(X_) + ΔX1_, ΔX2_ = tensor_split(ΔX_) - ΔY1_, Y1__ = L.RB.jacobian(ΔX1_, Δθ[4:end], X1_) - Y2_ = X2_ + Y1__ - ΔY2_ = ΔX2_ + ΔY1_ - - Y_ = tensor_cat(X1_, Y2_) - ΔY_ = tensor_cat(ΔX1_, ΔY2_) - ΔY, Y = L.C.jacobianInverse(ΔY_, Δθ[1:3], Y_) + ΔY1_, Y1__ = L.RB[j].jacobian(ΔX1_, Δθ[4:end], X1_) + Y2_ = X2_ + Y1__ + ΔY2_ = ΔX2_ + ΔY1_ + + Y_ = tensor_cat(X1_, Y2_) + ΔY_ = tensor_cat(ΔX1_, ΔY2_) + ΔY, Y = L.C[j].jacobianInverse(ΔY_, Δθ[1:3], Y_) + end return ΔY, Y @@ -190,9 +192,9 @@ end # Clear gradients function clear_grad!(L::CouplingLayerIRIM) - maxiter = length(L.C) + num_downsamp = length(L.C) - for j=1:maxiter + for j=1:num_downsamp clear_grad!(L.C[j]) clear_grad!(L.RB[j]) end @@ -200,15 +202,14 @@ end # Get parameters function get_params(L::CouplingLayerIRIM) - maxiter = length(L.C) - - p1 = get_params(L.C[1]) - p2 = get_params(L.RB[1]) - if maxiter > 1 - for j=2:maxiter - p1 = cat(p1, get_params(L.C[j]); dims=1) - p2 = cat(p2, get_params(L.RB[j]); dims=1) - end + num_downsamp = length(L.C) + + p1 = Array{Parameter, 1}(undef, 0) + p2 = Array{Parameter, 1}(undef, 0) + + for j=1:num_downsamp + p1 = cat(p1, get_params(L.C[j]); dims=1) + p2 = cat(p2, get_params(L.RB[j]); dims=1) end return cat(p1, p2; dims=1) diff --git a/src/layers/layer_residual_block.jl b/src/layers/layer_residual_block.jl index 72752499..757feb6d 100644 --- a/src/layers/layer_residual_block.jl +++ b/src/layers/layer_residual_block.jl @@ -77,19 +77,20 @@ end # Constructor function ResidualBlock(n_in, n_hidden; d=nothing, k1=3, k2=3, p1=1, p2=1, s1=1, s2=1, fan=false, ndims=2) + # Check if downsampling factor d is defined if !isnothing(d) k1 = d s1 = d p1 = 0 end - k1 = Tuple(k1 for i=1:ndims) k2 = Tuple(k2 for i=1:ndims) + # Initialize weights W1 = Parameter(glorot_uniform(k1..., n_in, n_hidden)) W2 = Parameter(glorot_uniform(k2..., n_hidden, n_hidden)) - W3 = Parameter(glorot_uniform(k1..., 2*n_in, n_hidden)) # will be a transpose so output is 2*in + W3 = Parameter(glorot_uniform(k1..., 2*n_in, n_hidden)) b1 = Parameter(zeros(Float32, n_hidden)) b2 = Parameter(zeros(Float32, n_hidden)) diff --git a/src/networks/invertible_network_irim.jl b/src/networks/invertible_network_irim.jl index 8538b1d1..1323aa8f 100644 --- a/src/networks/invertible_network_irim.jl +++ b/src/networks/invertible_network_irim.jl @@ -66,7 +66,7 @@ end @Flux.functor NetworkLoop # 2D Constructor -function NetworkLoop(n_in, maxiter, Ψ; n_hiddens=nothing, ds=nothing, k1=4, k2=3, p1=0, p2=1, s1=4, s2=1, type="additive", ndims=2) +function NetworkLoop(n_in, n_hidden, maxiter, Ψ; n_hiddens=nothing, ds=nothing, k1=4, k2=3, p1=0, p2=1, s1=4, s2=1, type="additive", ndims=2) if type == "additive" L = Array{CouplingLayerIRIM}(undef, maxiter) @@ -77,7 +77,7 @@ function NetworkLoop(n_in, maxiter, Ψ; n_hiddens=nothing, ds=nothing, k1=4, k2= AN = Array{ActNorm}(undef, maxiter) for j=1:maxiter if type == "additive" - L[j] = CouplingLayerIRIM(n_in; n_hiddens=n_hiddens, ds=ds, k1=k1, k2=k2, p1=p1, p2=p2, s1=s1, s2=s2, ndims=ndims) + L[j] = CouplingLayerIRIM(n_in, n_hidden; n_hiddens=n_hiddens, ds=ds, k1=k1, k2=k2, p1=p1, p2=p2, s1=s1, s2=s2, ndims=ndims) elseif type == "HINT" L[j] = CouplingLayerHINT(n_in, n_hidden; logdet=false, permute="both", k1=k1, k2=k2, p1=p1, p2=p2, s1=s1, s2=s2, ndims=ndims) @@ -92,7 +92,7 @@ end NetworkLoop3D(args...; kw...) = NetworkLoop(args...; kw..., ndims=3) # 2D Forward loop: Input (η, s), Output (η, s) -function forward(η::AbstractArray{T, N}, s::AbstractArray{T, N}, d::Union{AbstractArray, Nothing}, J, UL::NetworkLoop; g=nothing) where {T, N} +function forward(η::AbstractArray{T, N}, s::AbstractArray{T, N}, d::AbstractArray, J, UL::NetworkLoop) where {T, N} # Dimensions n_in = size(s, N-1) + 1 @@ -101,10 +101,8 @@ function forward(η::AbstractArray{T, N}, s::AbstractArray{T, N}, d::Union{Abstr maxiter = length(UL.L) N0 = cuzeros(η, nn..., n_in-2, batchsize) - isnothing(g) && (g = 1) - for j=1:maxiter - isnothing(g) && (g = J'*(J*reshape(UL.Ψ(η), :, batchsize) - reshape(d, :, batchsize))) + g = J'*(J*reshape(UL.Ψ(η), :, batchsize) - reshape(d, :, batchsize)) g = reshape(g, nn..., 1, batchsize) gn = UL.AN[j].forward(g) # normalize s_ = s + tensor_cat(gn, N0) @@ -178,47 +176,6 @@ function backward(Δη::AbstractArray{T, N}, Δs::AbstractArray{T, N}, set_grad ? (return Δη, Δs, η, s) : (Δη, Δs, Δθ, η, s) end -# 2D Backward loop: Input (Δη, Δs, η, s), Output (Δη, Δs, η, s) -function backward(Δη::AbstractArray{T, N}, Δs::AbstractArray{T, N}, - η::AbstractArray{T, N}, s::AbstractArray{T, N}, J, UL::NetworkLoop; g=nothing, set_grad::Bool=true) where {T, N} - - # Dimensions - n_in = size(s, N-1) + 1 - batchsize = size(s)[end] - nn = size(s)[1:N-2] - maxiter = length(UL.L) - - N0 = cuzeros(Δη, nn..., n_in-2, batchsize) - typeof(Δs) == T && (Δs = 0 .* s) # make Δs zero tensor - - # Initialize net parameters - set_grad && (Δθ = Array{Parameter, 1}(undef, 0)) - - j = 1 - #for j = maxiter:-1:1 - if set_grad - Δηs_, ηs_ = UL.L[j].backward(tensor_cat(Δη, Δs), tensor_cat(η, s)) - else - Δηs_, Δθ_L, ηs_ = UL.L[j].backward(tensor_cat(Δη, Δs), tensor_cat(η, s); set_grad=set_grad) - push!(Δθ, Δθ_L) - end - - # Inverse pass - η, s_ = tensor_split(ηs_; split_index=1) - #g = J'*(J*reshape(UL.Ψ(η), :, batchsize) - reshape(d, :, batchsize)) - g = reshape(g, nn..., 1, batchsize) - gn = UL.AN[j].forward(g) # normalize - s = s_ - tensor_cat(gn, N0) - - # Gradients - #Δs2, Δs = tensor_split(Δηs_; split_index=1) - #Δgn = tensor_split(Δs; split_index=1)[1] - #Δg = UL.AN[j].backward(Δgn, gn)[1] - #Δη = reshape(J'*J*reshape(Δg, :, batchsize), nn..., 1, batchsize) + Δs2 -#end - set_grad ? (return Δη, Δs, η, s) : (Δη, Δs, Δθ, η, s) -end - ## Jacobian-related utils jacobian(::AbstractArray{T, 5}, ::AbstractArray{T, 5}, d::AbstractArray, J, UL::NetworkLoop) where T = throw(ArgumentError("Jacobian for NetworkLoop not yet implemented")) From 8a91549c1641e5819784b7eabb758b39f648d2a6 Mon Sep 17 00:00:00 2001 From: rafaelorozco Date: Mon, 25 Apr 2022 14:13:01 -0400 Subject: [PATCH 5/9] fix conv without @views --- src/layers/invertible_layer_conv1x1.jl | 6 +- src/layers/invertible_layer_irim.jl | 77 ++++++++++++-------- test/test_layers/test_coupling_layer_irim.jl | 70 +++++++++--------- 3 files changed, 86 insertions(+), 67 deletions(-) diff --git a/src/layers/invertible_layer_conv1x1.jl b/src/layers/invertible_layer_conv1x1.jl index 335ea803..df12693c 100644 --- a/src/layers/invertible_layer_conv1x1.jl +++ b/src/layers/invertible_layer_conv1x1.jl @@ -153,16 +153,16 @@ function conv1x1_grad_v(X::AbstractArray{T, N}, ΔY::AbstractArray{T, N}, for i=1:k # ∂V1 mul!(tmp, ∂V1[i, :, :], M1) - adjoint ? adjoint!(∂V1[i, :, :], tmp) : copyto!(∂V1[i, :, :], tmp) + adjoint ? ∂V1[i, :, :] = tmp' : ∂V1[i, :, :] = tmp # ∂V2 v2 = ∂V2[i, :, :] broadcast!(+, tmp, v2, 4 * V1 * v2 * V3 - 2 * (V1 * v2 + v2 * V3)) - adjoint ? adjoint!(∂V2[i, :, :], tmp) : copyto!(∂V2[i, :, :], tmp) + adjoint ? ∂V2[i, :, :] = tmp' : ∂V2[i, :, :] = tmp # ∂V3 mul!(tmp, M3, ∂V3[i, :, :]) - adjoint ? adjoint!(∂V3[i, :, :], tmp) : copyto!(∂V3[i, :, :], tmp) + adjoint ? ∂V3[i, :, :] = tmp' : ∂V3[i, :, :] = tmp end prod_res = cuzeros(X, size(∂V1, 1), prod(size(X)[1:N-2]), n_in) diff --git a/src/layers/invertible_layer_irim.jl b/src/layers/invertible_layer_irim.jl index 14f2a5ab..96962011 100644 --- a/src/layers/invertible_layer_irim.jl +++ b/src/layers/invertible_layer_irim.jl @@ -65,23 +65,13 @@ end @Flux.functor CouplingLayerIRIM # 2D Constructor from input dimensions -function CouplingLayerIRIM(n_in::Int64, n_hidden::Int64; n_hiddens=nothing, ds=nothing, +function CouplingLayerIRIM(n_in::Int64, n_hiddens::Array{Int64,1}, ds::Array{Int64,1}; k1=4, k2=3, p1=0, p2=1, s1=4, s2=1, ndims=2) - - # Check if unet structure defined - if isnothing(n_hiddens) - num_downsamp = 1 - n_hiddens = [n_hidden] - ds = [4] - else - # Use user defined hidden channels and downsampling factors in ds - num_downsamp = length(n_hiddens) - end - - if num_downsamp != length(ds) + if length(n_hiddens) != length(ds) throw("Number of downsampling factors in ds must be the same defined hidden channels in n_hidden") end + num_downsamp = length(n_hiddens) C = Array{Conv1x1}(undef, num_downsamp) RB = Array{ResidualBlock}(undef, num_downsamp) @@ -125,36 +115,55 @@ function inverse(Y::AbstractArray{T, N}, L::CouplingLayerIRIM; save=false) where X2_ = Y2_ - L.RB[j].forward(Y1_) X_ = tensor_cat(X1_, X2_) - X = L.C[j].inverse(X_) + Y = L.C[j].inverse(X_) end - if save == false - return X - else - return X, X_, Y1_ - end + return Y end # 2D Backward pass: Input (ΔY, Y), Output (ΔX, X) function backward(ΔY::AbstractArray{T, N}, Y::AbstractArray{T, N}, L::CouplingLayerIRIM; set_grad::Bool=true) where {T, N} + # Initialize layer parameters + #!set_grad && (Δθ = Array{Parameter, 1}(undef, 0)) + !set_grad && (p1 = Array{Parameter, 1}(undef, 0)) + !set_grad && (p2 = Array{Parameter, 1}(undef, 0)) + num_downsamp = length(L.C) - for j=num_downsamp:-1:1ß - ΔY_, Y_ = L.C[j].forward((ΔY, Y)) + for j=num_downsamp:-1:1 + if set_grad + ΔY_, Y_ = L.C[j].forward((ΔY, Y)) + else + ΔY_, Δθ_C1, Y_ = L.C[j].forward((ΔY, Y); set_grad=set_grad) + end ΔYl_, ΔYr_ = tensor_split(ΔY_) Y1_, Y2_ = tensor_split(Y_) - ΔYl_ .= L.RB[j].backwards(ΔYr_, Y1_) + ΔYl_ + if set_grad + ΔYl_ .= L.RB[j].backward(ΔYr_, Y1_) + ΔYl_ + else + ΔY_RB, Δθ_RB = L.RB[j].backward(ΔYr_, Y1_; set_grad=set_grad) + ΔYl_ .= ΔY_RB + ΔYl_ + end + Y2_ .-= L.RB[j].forward(Y1_) ΔY_ = tensor_cat(ΔYl_, ΔYr_) Y_ = tensor_cat(Y1_, Y2_) - - ΔY, Y = L.C[j].inverse((ΔY_,Y_)) + + if set_grad + ΔY, Y = L.C[j].inverse((ΔY_, Y_)) + else + ΔY, Δθ_C2, Y = L.C[j].inverse((ΔY_, Y_); set_grad=set_grad) + #append!(Δθ, cat(Δθ_C1+Δθ_C2, Δθ_RB; dims=1)) + #push!(p1, cat(Δθ_C1+Δθ_C2, Δθ_RB; dims=1)) + p1 = cat(p1, Δθ_C1+Δθ_C2; dims=1) + p2 = cat(p2, Δθ_RB; dims=1) + end end - return ΔY, Y + set_grad ? (return ΔY, Y) : (ΔY, cat(p1, p2; dims=1), Y) end ## Jacobian utilities @@ -163,22 +172,28 @@ end function jacobian(ΔX::AbstractArray{T, N}, Δθ::Array{Parameter, 1}, X::AbstractArray{T, N}, L::CouplingLayerIRIM) where {T, N} num_downsamp = length(L.C) - for j=num_downsamp:-1:1 - ΔX_, X_ = L.C[j].jacobian(ΔX, Δθ[1:3], X) + num_rb = 5 + num_1x1c = 3 + for j=1:num_downsamp + idx_conv = (j-1)*num_1x1c+1:j*num_1x1c + idx_rb = (j-1)*num_rb+1+num_downsamp*num_1x1c:(j)*num_rb+num_downsamp*num_1x1c + println("\nat j") + println("\nidx_conv = $(idx_conv)") + println("\nidx_rb = $(idx_rb)") + ΔX_, X_ = L.C[j].jacobian(ΔX, Δθ[idx_conv], X) X1_, X2_ = tensor_split(X_) ΔX1_, ΔX2_ = tensor_split(ΔX_) - ΔY1_, Y1__ = L.RB[j].jacobian(ΔX1_, Δθ[4:end], X1_) + ΔY1_, Y1__ = L.RB[j].jacobian(ΔX1_, Δθ[idx_rb], X1_) Y2_ = X2_ + Y1__ ΔY2_ = ΔX2_ + ΔY1_ Y_ = tensor_cat(X1_, Y2_) ΔY_ = tensor_cat(ΔX1_, ΔY2_) - ΔY, Y = L.C[j].jacobianInverse(ΔY_, Δθ[1:3], Y_) + ΔX, X = L.C[j].jacobianInverse(ΔY_, Δθ[idx_conv], Y_) end - return ΔY, Y - + return ΔX, X end # 2D/3D diff --git a/test/test_layers/test_coupling_layer_irim.jl b/test/test_layers/test_coupling_layer_irim.jl index b2522a3a..cc6bce5a 100644 --- a/test/test_layers/test_coupling_layer_irim.jl +++ b/test/test_layers/test_coupling_layer_irim.jl @@ -11,7 +11,8 @@ Random.seed!(1); nx = 28 ny = 28 n_in = 8 -n_hidden = 8 +n_hiddens = [4,8,4] +ds = [1,4,1] batchsize = 2 # Input images @@ -20,9 +21,9 @@ X0 = randn(Float32, nx, ny, n_in, batchsize) dX = X - X0 # Invertible layers -L = CouplingLayerIRIM(n_in, n_hidden) -L01 = CouplingLayerIRIM(n_in, n_hidden) -L02 = CouplingLayerIRIM(n_in, n_hidden) +L = CouplingLayerIRIM(n_in, n_hiddens, ds) +L01 = CouplingLayerIRIM(n_in, n_hiddens, ds) +L02 = CouplingLayerIRIM(n_in, n_hiddens, ds) ################################################################################################### # Test invertibility @@ -37,7 +38,8 @@ X_ = L.forward(L.inverse(X)) # Gradient tests # Loss Function -function loss(L, X, Y) +function loss(L, X, Y; ind=1) + println("ind $(ind)") Y_ = L.forward(X) ΔY = Y_ - Y @@ -45,7 +47,7 @@ function loss(L, X, Y) ΔX = L.backward(ΔY, Y_)[1] # Pass back gradients w.r.t. input X and from the residual block and 1x1 conv. layer - return f, ΔX, L.C.v1.grad, L.C.v2.grad, L.C.v3.grad, L.RB.W1.grad, L.RB.W2.grad, L.RB.W3.grad + return f, ΔX, L.C[ind].v1.grad, L.C[ind].v2.grad, L.C[ind].v3.grad, L.RB[ind].W1.grad, L.RB[ind].W2.grad, L.RB[ind].W3.grad end # Gradient test w.r.t. input X0 @@ -70,13 +72,14 @@ end # Gradient test w.r.t. weights of residual block +ind = 2 Y = L.forward(X) Lini = deepcopy(L02) -dW1 = L.RB.W1.data - L02.RB.W1.data -dW2 = L.RB.W2.data - L02.RB.W2.data -dW3 = L.RB.W3.data - L02.RB.W3.data +dW1 = L.RB[ind].W1.data - L02.RB[ind].W1.data +dW2 = L.RB[ind].W2.data - L02.RB[ind].W2.data +dW3 = L.RB[ind].W3.data - L02.RB[ind].W3.data -f0, ΔX, Δv1, Δv2, Δv3, ΔW1, ΔW2, ΔW3 = loss(L02, X, Y) +f0, ΔX, Δv1, Δv2, Δv3, ΔW1, ΔW2, ΔW3 = loss(L02, X, Y;ind=ind) h = 0.1f0 maxiter = 5 err3 = zeros(Float32, maxiter) @@ -84,9 +87,9 @@ err4 = zeros(Float32, maxiter) print("\nGradient test invertible layer\n") for j=1:maxiter - L02.RB.W1.data = Lini.RB.W1.data + h*dW1 - L02.RB.W2.data = Lini.RB.W2.data + h*dW2 - L02.RB.W3.data = Lini.RB.W3.data + h*dW3 + L02.RB[ind].W1.data = Lini.RB[ind].W1.data + h*dW1 + L02.RB[ind].W2.data = Lini.RB[ind].W2.data + h*dW2 + L02.RB[ind].W3.data = Lini.RB[ind].W3.data + h*dW3 f = loss(L02, X, Y)[1] err3[j] = abs(f - f0) err4[j] = abs(f - f0 - h*dot(dW1, ΔW1) - h*dot(dW2, ΔW2) - h*dot(dW3, ΔW3)) @@ -94,15 +97,15 @@ for j=1:maxiter global h = h/2f0 end -@test isapprox(err3[end] / (err3[1]/2^(maxiter-1)), 1f0; atol=1f1) -@test isapprox(err4[end] / (err4[1]/4^(maxiter-1)), 1f0; atol=1f1) +@test isapprox(err3[end] / (err3[1]/2^(maxiter-1)), 1f0; atol=1f0) +@test isapprox(err4[end] / (err4[1]/4^(maxiter-1)), 1f0; atol=1f0) # Gradient test w.r.t. 1x1 conv weights Y = L.forward(X) Lini = deepcopy(L01) -dv1 = L.C.v1.data - L01.C.v1.data -dv2 = L.C.v2.data - L01.C.v2.data -dv3 = L.C.v3.data - L01.C.v3.data +dv1 = L.C[1].v1.data - L01.C[1].v1.data +dv2 = L.C[1].v2.data - L01.C[1].v2.data +dv3 = L.C[1].v3.data - L01.C[1].v3.data f0, ΔX, Δv1, Δv2, Δv3, ΔW1, ΔW2, ΔW3 = loss(L01, X, Y) h = 0.1f0 @@ -112,9 +115,9 @@ err6 = zeros(Float32, maxiter) print("\nGradient test invertible layer\n") for j=1:maxiter - L01.C.v1.data = Lini.C.v1.data + h*dv1 - L01.C.v2.data = Lini.C.v2.data + h*dv2 - L01.C.v3.data = Lini.C.v3.data + h*dv3 + L01.C[1].v1.data = Lini.C[1].v1.data + h*dv1 + L01.C[1].v2.data = Lini.C[1].v2.data + h*dv2 + L01.C[1].v3.data = Lini.C[1].v3.data + h*dv3 f = loss(L01, X, Y)[1] err5[j] = abs(f - f0) err6[j] = abs(f - f0 - h*dot(dv1, Δv1) - h*dot(dv2, Δv2) - h*dot(dv3, Δv3)) @@ -122,8 +125,8 @@ for j=1:maxiter global h = h/2f0 end -@test isapprox(err5[end] / (err5[1]/2^(maxiter-1)), 1f0; atol=1f1) -@test isapprox(err6[end] / (err6[1]/4^(maxiter-1)), 1f0; atol=1f1) +@test isapprox(err5[end] / (err5[1]/2^(maxiter-1)), 1f0; atol=1f0) +@test isapprox(err6[end] / (err6[1]/4^(maxiter-1)), 1f0; atol=1f0) ################################################################################################### # Jacobian-related tests @@ -131,9 +134,9 @@ end # Gradient test # Initialization -L = CouplingLayerIRIM(n_in, n_hidden) +L = CouplingLayerIRIM(n_in, n_hiddens, ds) θ = deepcopy(get_params(L)) -L0 = CouplingLayerIRIM(n_in, n_hidden) +L0 = CouplingLayerIRIM(n_in, n_hiddens, ds) θ0 = deepcopy(get_params(L0)) X = randn(Float32, nx, ny, n_in, batchsize) @@ -163,11 +166,12 @@ end @test isapprox(err8[end] / (err8[1]/4^(maxiter-1)), 1f0; atol=1f1) # Adjoint test - -set_params!(L, θ) -dY, Y = L.jacobian(dX, dθ, X) -dY_ = randn(Float32, size(dY)) -dX_, dθ_ = L.adjointJacobian(dY_, Y) -a = dot(dY, dY_) -b = dot(dX, dX_)+dot(dθ, dθ_) -@test isapprox(a, b; rtol=1f-3) \ No newline at end of file +# NOT IMPLEMENTED YET! + +# set_params!(L, θ) +# dY, Y = L.jacobian(dX, dθ, X) +# dY_ = randn(Float32, size(dY)) +# dX_, dθ_ = L.adjointJacobian(dY_, Y) +# a = dot(dY, dY_) +# b = dot(dX, dX_)+dot(dθ, dθ_) +# @test isapprox(a, b; rtol=1f-3) \ No newline at end of file From 0a3a73e906a7514b15f8addb50de60a7eb14c521 Mon Sep 17 00:00:00 2001 From: rafaelorozco Date: Mon, 25 Apr 2022 16:05:03 -0400 Subject: [PATCH 6/9] decided that copyto does the same thing 1x1conv --- src/layers/invertible_layer_conv1x1.jl | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/src/layers/invertible_layer_conv1x1.jl b/src/layers/invertible_layer_conv1x1.jl index 5b0bc2c2..1f4d8f02 100644 --- a/src/layers/invertible_layer_conv1x1.jl +++ b/src/layers/invertible_layer_conv1x1.jl @@ -139,20 +139,6 @@ function conv1x1_grad_v(X::AbstractArray{T, N}, ΔY::AbstractArray{T, N}, M3 = (I - 2 * (V1 + V2) + 4*V1*V2) tmp = cuzeros(X, k, k) for i=1:k -# <<<<<<< HEAD -# # ∂V1 -# mul!(tmp, ∂V1[i, :, :], M1) -# adjoint ? ∂V1[i, :, :] = tmp' : ∂V1[i, :, :] = tmp - -# # ∂V2 -# v2 = ∂V2[i, :, :] -# broadcast!(+, tmp, v2, 4 * V1 * v2 * V3 - 2 * (V1 * v2 + v2 * V3)) -# adjoint ? ∂V2[i, :, :] = tmp' : ∂V2[i, :, :] = tmp - -# # ∂V3 -# mul!(tmp, M3, ∂V3[i, :, :]) -# adjoint ? ∂V3[i, :, :] = tmp' : ∂V3[i, :, :] = tmp -# ======= # dV1 mul!(tmp, dV1[i, :, :], M1) @views adjoint ? copyto!(dV1[i, :, :], tmp') : copyto!(dV1[i, :, :], tmp) @@ -163,7 +149,6 @@ function conv1x1_grad_v(X::AbstractArray{T, N}, ΔY::AbstractArray{T, N}, # dV3 mul!(tmp, M3, dV3[i, :, :]) @views adjoint ? copyto!(dV3[i, :, :], tmp') : copyto!(dV3[i, :, :], tmp) -#>>>>>>> master end prod_res = cuzeros(X, size(dV1, 1)) From 9fb057262134815dff2d27f909fd49cc7d02aeb1 Mon Sep 17 00:00:00 2001 From: rafaelorozco Date: Mon, 25 Apr 2022 18:31:50 -0400 Subject: [PATCH 7/9] add unet - basically one loop of irim --- src/InvertibleNetworks.jl | 1 + src/layers/invertible_layer_irim.jl | 25 +++-- src/networks/invertible_network_unet.jl | 133 ++++++++++++++++++++++++ test/runtests.jl | 1 + test/test_networks/test_unet.jl | 94 +++++++++++++++++ 5 files changed, 245 insertions(+), 9 deletions(-) create mode 100644 src/networks/invertible_network_unet.jl create mode 100644 test/test_networks/test_unet.jl diff --git a/src/InvertibleNetworks.jl b/src/InvertibleNetworks.jl index ca94d3ee..08dacae5 100644 --- a/src/InvertibleNetworks.jl +++ b/src/InvertibleNetworks.jl @@ -64,6 +64,7 @@ include("layers/invertible_layer_hint.jl") # Invertible network architectures include("networks/invertible_network_hint_multiscale.jl") include("networks/invertible_network_irim.jl") # i-RIM: Putzky and Welling (2019) +include("networks/invertible_network_unet.jl") # single loop i-RIM: Putzky and Welling (2019) include("networks/invertible_network_glow.jl") # Glow: Dinh et al. (2017), Kingma and Dhariwal (2018) include("networks/invertible_network_hyperbolic.jl") # Hyperbolic: Lensink et al. (2019) diff --git a/src/layers/invertible_layer_irim.jl b/src/layers/invertible_layer_irim.jl index 9994a5bc..5909a2f1 100644 --- a/src/layers/invertible_layer_irim.jl +++ b/src/layers/invertible_layer_irim.jl @@ -88,6 +88,9 @@ CouplingLayerIRIM3D(args...;kw...) = CouplingLayerIRIM(args...; kw..., ndims=3) # 2D Forward pass: Input X, Output Y function forward(X::AbstractArray{T, N}, L::CouplingLayerIRIM) where {T, N} + # Init tensors to avoid reallocation + Y_ = similar(X) + num_downsamp = length(L.C) for j=1:num_downsamp X_ = L.C[j].forward(X) @@ -96,7 +99,7 @@ function forward(X::AbstractArray{T, N}, L::CouplingLayerIRIM) where {T, N} Y1_ = X1_ Y2_ = X2_ + L.RB[j].forward(Y1_) - Y_ = tensor_cat(Y1_, Y2_) + tensor_cat!(Y_, Y1_, Y2_) X = L.C[j].inverse(Y_) end @@ -105,7 +108,10 @@ end # 2D Inverse pass: Input Y, Output X function inverse(Y::AbstractArray{T, N}, L::CouplingLayerIRIM; save=false) where {T, N} - + + # Init tensors to avoid reallocation + X_ = similar(Y) + num_downsamp = length(L.C) for j=num_downsamp:-1:1 Y_ = L.C[j].forward(Y) @@ -114,7 +120,7 @@ function inverse(Y::AbstractArray{T, N}, L::CouplingLayerIRIM; save=false) where X1_ = Y1_ X2_ = Y2_ - L.RB[j].forward(Y1_) - X_ = tensor_cat(X1_, X2_) + tensor_cat!(X_, X1_, X2_) Y = L.C[j].inverse(X_) end @@ -128,6 +134,10 @@ function backward(ΔY::AbstractArray{T, N}, Y::AbstractArray{T, N}, L::CouplingL !set_grad && (p1 = Array{Parameter, 1}(undef, 0)) !set_grad && (p2 = Array{Parameter, 1}(undef, 0)) + # Init tensors to avoid reallocation + ΔY_ = similar(ΔY) + Y_ = similar(Y) + num_downsamp = length(L.C) for j=num_downsamp:-1:1 if set_grad @@ -148,8 +158,8 @@ function backward(ΔY::AbstractArray{T, N}, Y::AbstractArray{T, N}, L::CouplingL Y2_ .-= L.RB[j].forward(Y1_) - ΔY_ = tensor_cat(ΔYl_, ΔYr_) - Y_ = tensor_cat(Y1_, Y2_) + tensor_cat!(ΔY_, ΔYl_, ΔYr_) + tensor_cat!(Y_, Y1_, Y2_) if set_grad ΔY, Y = L.C[j].inverse((ΔY_, Y_)) @@ -167,16 +177,13 @@ end # 2D function jacobian(ΔX::AbstractArray{T, N}, Δθ::Array{Parameter, 1}, X::AbstractArray{T, N}, L::CouplingLayerIRIM) where {T, N} - num_downsamp = length(L.C) num_rb = 5 num_1x1c = 3 for j=1:num_downsamp idx_conv = (j-1)*num_1x1c+1:j*num_1x1c idx_rb = (j-1)*num_rb+1+num_downsamp*num_1x1c:(j)*num_rb+num_downsamp*num_1x1c - println("\nat j") - println("\nidx_conv = $(idx_conv)") - println("\nidx_rb = $(idx_rb)") + ΔX_, X_ = L.C[j].jacobian(ΔX, Δθ[idx_conv], X) X1_, X2_ = tensor_split(X_) ΔX1_, ΔX2_ = tensor_split(ΔX_) diff --git a/src/networks/invertible_network_unet.jl b/src/networks/invertible_network_unet.jl new file mode 100644 index 00000000..cb063512 --- /dev/null +++ b/src/networks/invertible_network_unet.jl @@ -0,0 +1,133 @@ +# Invertible network layer from Putzky and Welling (2019): https://arxiv.org/abs/1911.10914 +# Author: Philipp Witte, pwitte3@gatech.edu +# Date: January 2020 + +export NetworkUNET, NetworkUNET3D + +""" + L = NetworkUNET(n_in, n_hidden, maxiter, Ψ; k1=4, k2=3, p1=0, p2=1, s1=4, s2=1, ndims=2) (2D) + + L = NetworkUNET3D(n_in, n_hidden, maxiter, Ψ; k1=4, k2=3, p1=0, p2=1, s1=4, s2=1) (3D) + + Create an invertibel recurrent inference machine (i-RIM) consisting of an unrooled loop + for a given number of iterations. + + *Input*: + + - 'n_in': number of input channels + + - `n_hidden`: number of hidden units in residual blocks + + - `maxiter`: number unrolled loop iterations + + - `Ψ`: link function + + - `k1`, `k2`: stencil sizes for convolutions in the residual blocks. The first convolution + uses a stencil of size and stride `k1`, thereby downsampling the input. The second + convolutions uses a stencil of size `k2`. The last layer uses a stencil of size and stride `k1`, + but performs the transpose operation of the first convolution, thus upsampling the output to + the original input size. + + - `p1`, `p2`: padding for the first and third convolution (`p1`) and the second convolution (`p2`) in + residual block + + - `s1`, `s2`: stride for the first and third convolution (`s1`) and the second convolution (`s2`) in + residual block + + - `ndims` : number of dimensions + + *Output*: + + - `L`: invertible i-RIM network. + + *Usage:* + + - Forward mode: `η_out, s_out = L.forward(η_in, s_in, d, A)` + + - Inverse mode: `η_in, s_in = L.inverse(η_out, s_out, d, A)` + + - Backward mode: `Δη_in, Δs_in, η_in, s_in = L.backward(Δη_out, Δs_out, η_out, s_out, d, A)` + + *Trainable parameters:* + + - None in `L` itself + + - Trainable parameters in the invertible coupling layers `L.L[i]`, and actnorm layers + `L.AN[i]`, where `i` ranges from `1` to the number of loop iterations. + + See also: [`CouplingLayerIRIM`](@ref), [`ResidualBlock`](@ref), [`get_params`](@ref), [`clear_grad!`](@ref) +""" +struct NetworkUNET <: InvertibleNetwork + L::CouplingLayerIRIM + AN::ActNorm + n_mem::Int64 +end + +@Flux.functor NetworkUNET + +# 2D Constructor +function NetworkUNET(n_in::Int64, n_hiddens::Array{Int64,1}, ds::Array{Int64,1}; k1=4, k2=3, p1=0, p2=1, s1=4, s2=1, ndims=2) + + L = CouplingLayerIRIM(n_in, n_hiddens, ds; k1=k1, k2=k2, p1=p1, p2=p2, s1=s1, s2=s2, ndims=ndims) + AN = ActNorm(1) # Only for 1 channel gradient + n_mem = n_in - 1 + return NetworkUNET(L, AN, n_mem) +end + +# 3D Constructor +NetworkUNET3D(args...; kw...) = NetworkUNET(args...; kw..., ndims=3) + +# 2D Forward loop: Input (η), Output (η) +function forward(η::AbstractArray{T, N}, g::AbstractArray{T, N}, UL::NetworkUNET) where {T, N} + + # Dimensions + batchsize = size(η)[end] + nn = size(η)[1:N-2] + inds_c = [i!=(N-1) ? Colon() : 1 for i=1:N] + + # Forward pass + s = cuzeros(η, nn..., UL.n_mem, batchsize) + gn = UL.AN.forward(g) # normalize + s[inds_c...] = gn # gradient in first channel + + ηs = UL.L.forward(tensor_cat(η, s)) + η, s = tensor_split(ηs; split_index=1) + + return η, s +end + +# 2D Inverse loop: Input (η), Output (η) +function inverse(η::AbstractArray{T, N}, s::AbstractArray{T, N}, g::AbstractArray{T, N}, UL::NetworkUNET) where {T, N} + + # Inverse pass + ηs_ = UL.L.inverse(tensor_cat(η, s)) + η, s_ = tensor_split(ηs_; split_index=1) + + return η +end + +# 2D Backward loop: Input (Δη, Δs, η, s), Output (Δη, Δs, η, s) +function backward(Δη::AbstractArray{T, N}, + η::AbstractArray{T, N}, s::AbstractArray{T, N}, g::AbstractArray{T, N}, UL::NetworkUNET; set_grad::Bool=true) where {T, N} + + Δs = 0 .* s # make Δs zero tensor + + # Backwards pass + Δηs_, ηs_ = UL.L.backward(tensor_cat(Δη, Δs), tensor_cat(η, s)) + + η, s_ = tensor_split(ηs_; split_index=1) + Δη, Δs = tensor_split(Δηs_; split_index=1) + + gn = UL.AN.forward(g) # normalize + Δgn = tensor_split(Δs; split_index=1)[1] + Δg = UL.AN.backward(Δgn, gn)[1] + + return Δη, η +end + +## Jacobian-related utils +jacobian(::AbstractArray{T, 5}, ::AbstractArray{T, 5}, UL::NetworkUNET) where T = throw(ArgumentError("Jacobian for NetworkUNET not yet implemented")) + +adjointJacobian(Δη::AbstractArray{T, N}, η::AbstractArray{T, N}, s::AbstractArray{T, N}, UL::NetworkUNET; + set_grad::Bool=true) where {T, N} = throw(ArgumentError("Jacobian for NetworkUNET not yet implemented")) + diff --git a/test/runtests.jl b/test/runtests.jl index 0ad86b83..4742345f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -35,6 +35,7 @@ layers = ["test_layers/test_residual_block.jl", "test_layers/test_layer_affine.jl"] networks = ["test_networks/test_unrolled_loop.jl", + "test_networks/test_unet.jl", "test_networks/test_generator.jl", "test_networks/test_glow.jl", "test_networks/test_hyperbolic_network.jl", diff --git a/test/test_networks/test_unet.jl b/test/test_networks/test_unet.jl new file mode 100644 index 00000000..a3ea99e2 --- /dev/null +++ b/test/test_networks/test_unet.jl @@ -0,0 +1,94 @@ +using InvertibleNetworks, LinearAlgebra, Test, Random + +Random.seed!(11) + +# Input +nx = 16 +ny = 16 +nz = 16 +n_in = 4 +n_hiddens = [4,8,4] +ds = [1,4,1] +batchsize = 2 + +# Unrolled loop +L = NetworkUNET(n_in, n_hiddens, ds;ndims=3) + +# Initializations +η = 10*randn(Float32, nx, ny, nz, 1, batchsize) +g = 10*randn(Float32, nx, ny, nz, 1, batchsize) + +################################################################################################### + +# Test invertibility +η_, s_ = L.forward(η, g) +ηInv = L.inverse(η_, s_, g) +@test isapprox(norm(ηInv - η)/norm(η), 0f0, atol=1e-5) + +# Test invertibility +η_, s_ = L.forward(η, g) +ηInv = L.backward(0f0.*η_, η_, s_, g)[2] +@test isapprox(norm(ηInv - η)/norm(η), 0f0, atol=1e-5) + +################################################################################################### + +# Initializations +η = randn(Float32, nx, ny, nz, 1, batchsize) +η0 = randn(Float32, nx, ny, nz, 1, batchsize) +Δη = η - η0 + +# Observed data +η_, s_ = L.forward(η, g) # only need η + +function loss(L, η0, g, η) + η_, s_ = L.forward(η0, g) # reshape + Δη = η_ - η + f = .5f0*norm(Δη)^2 + Δη_ = L.backward(Δη, η_, s_, g)[1] + return f, Δη_, L.L.C[1].v1.grad, L.L.RB[1].W1.grad +end + +# Gradient test for input +f0, gη = loss(L, η0, g, η_)[1:2] +h = 0.1f0 +maxiter = 6 +err1 = zeros(Float32, maxiter) +err2 = zeros(Float32, maxiter) + +print("\nGradient test loop unrolling\n") +for j=1:maxiter + f = loss(L, η0 + h*Δη, g, η_)[1] + err1[j] = abs(f - f0) + err2[j] = abs(f - f0 - h*dot(Δη, gη)) + print(err1[j], "; ", err2[j], "\n") + global h = h/2f0 +end + +@test isapprox(err1[end] / (err1[1]/2^(maxiter-1)), 1f0; atol=1f0) +@test isapprox(err2[end] / (err2[1]/4^(maxiter-1)), 1f0; atol=1f0) + + +# Gradient test for weights +L0 = NetworkUNET3D(n_in, n_hiddens, ds; ) +L_ini = deepcopy(L0) +dv = L.L.C[1].v1.data - L0.L.C[1].v1.data # just test for 2 parameters +dW = L.L.RB[1].W1.data - L0.L.RB[1].W1.data +f0, gη, gv, gW = loss(L0, η, g, η_) +h = 0.05f0 +maxiter = 5 +err3 = zeros(Float32, maxiter) +err4 = zeros(Float32, maxiter) + +print("\nGradient test loop unrolling\n") +for j=1:maxiter + L0.L.C[1].v1.data = L_ini.L.C[1].v1.data + h*dv + L0.L.RB[1].W1.data = L_ini.L.RB[1].W1.data + h*dW + f = loss(L0, η, g, η_)[1] + err3[j] = abs(f - f0) + err4[j] = abs(f - f0 - h*dot(dv, gv) - h*dot(dW, gW)) + print(err3[j], "; ", err4[j], "\n") + global h = h/2f0 +end + +@test isapprox(err3[end] / (err3[1]/2^(maxiter-1)), 1f0; atol=1f0) +@test isapprox(err4[end] / (err4[1]/4^(maxiter-1)), 1f0; atol=1f0) From 338060efbc18f798c7b40fe99be326b7e8ecf86b Mon Sep 17 00:00:00 2001 From: Rafael Orozco Date: Wed, 12 Oct 2022 14:41:31 -0400 Subject: [PATCH 8/9] Update invertible_network_unet.jl add very simple invertible unet. No gradient as irim, just an input for clear comparison with traditional unets. --- src/networks/invertible_network_unet.jl | 85 +++++++++++++++---------- 1 file changed, 51 insertions(+), 34 deletions(-) diff --git a/src/networks/invertible_network_unet.jl b/src/networks/invertible_network_unet.jl index cb063512..3e68abe2 100644 --- a/src/networks/invertible_network_unet.jl +++ b/src/networks/invertible_network_unet.jl @@ -61,68 +61,86 @@ struct NetworkUNET <: InvertibleNetwork L::CouplingLayerIRIM AN::ActNorm n_mem::Int64 + n_grad::Int64 + early_squeeze end @Flux.functor NetworkUNET # 2D Constructor -function NetworkUNET(n_in::Int64, n_hiddens::Array{Int64,1}, ds::Array{Int64,1}; k1=4, k2=3, p1=0, p2=1, s1=4, s2=1, ndims=2) +function NetworkUNET(n_in::Int64, n_hiddens::Array{Int64,1}, ds::Array{Int64,1}; early_squeeze=false, n_grad=1, k1=4, k2=3, p1=0, p2=1, s1=4, s2=1, ndims=2) + n_mem = n_in + if early_squeeze + n_in = 4*n_in + end L = CouplingLayerIRIM(n_in, n_hiddens, ds; k1=k1, k2=k2, p1=p1, p2=p2, s1=s1, s2=s2, ndims=ndims) - AN = ActNorm(1) # Only for 1 channel gradient - n_mem = n_in - 1 - return NetworkUNET(L, AN, n_mem) + AN = ActNorm(n_grad) # Only for 1 channel gradient # try turning off logdet + + return NetworkUNET(L, AN, n_mem, n_grad, early_squeeze) end # 3D Constructor NetworkUNET3D(args...; kw...) = NetworkUNET(args...; kw..., ndims=3) # 2D Forward loop: Input (η), Output (η) -function forward(η::AbstractArray{T, N}, g::AbstractArray{T, N}, UL::NetworkUNET) where {T, N} +function forward(g::AbstractArray{T, N}, UL::NetworkUNET) where {T, N} # Dimensions - batchsize = size(η)[end] - nn = size(η)[1:N-2] - inds_c = [i!=(N-1) ? Colon() : 1 for i=1:N] - + batchsize = size(g)[end] + nn = size(g)[1:N-2] + # Forward pass - s = cuzeros(η, nn..., UL.n_mem, batchsize) - gn = UL.AN.forward(g) # normalize - s[inds_c...] = gn # gradient in first channel + gs = cuzeros(g, nn..., UL.n_mem, batchsize) + gn = UL.AN.forward(g)[1] # normalize + + gs[:,:,1:UL.n_grad,:] = gn # gradient in first channel - ηs = UL.L.forward(tensor_cat(η, s)) - η, s = tensor_split(ηs; split_index=1) + if UL.early_squeeze + gs = squeeze(gs; pattern="checkerboard") + end + gs = UL.L.forward(gs) - return η, s + if UL.early_squeeze + gs = unsqueeze(gs; pattern="checkerboard") + end + + return gs end # 2D Inverse loop: Input (η), Output (η) -function inverse(η::AbstractArray{T, N}, s::AbstractArray{T, N}, g::AbstractArray{T, N}, UL::NetworkUNET) where {T, N} +function inverse(y::AbstractArray{T, N}, UL::NetworkUNET) where {T, N} - # Inverse pass - ηs_ = UL.L.inverse(tensor_cat(η, s)) - η, s_ = tensor_split(ηs_; split_index=1) - - return η + UL.early_squeeze && (y = squeeze(y; pattern="checkerboard")) + x = UL.L.inverse(y) + UL.early_squeeze && (x = unsqueeze(x; pattern="checkerboard")) + + x, _ = tensor_split(x; split_index=UL.n_grad) + + x = UL.AN.inverse(x)[1] # normalize + return x end # 2D Backward loop: Input (Δη, Δs, η, s), Output (Δη, Δs, η, s) -function backward(Δη::AbstractArray{T, N}, - η::AbstractArray{T, N}, s::AbstractArray{T, N}, g::AbstractArray{T, N}, UL::NetworkUNET; set_grad::Bool=true) where {T, N} +function backward(Δy::AbstractArray{T, N}, + y::AbstractArray{T, N}, UL::NetworkUNET; set_grad::Bool=true) where {T, N} - Δs = 0 .* s # make Δs zero tensor + if UL.early_squeeze + Δy = squeeze(Δy; pattern="checkerboard") + y = squeeze(y; pattern="checkerboard") + end + Δx, x = UL.L.backward(Δy, y) + if UL.early_squeeze + Δx = unsqueeze(Δx; pattern="checkerboard") + x = unsqueeze(x; pattern="checkerboard") + end - # Backwards pass - Δηs_, ηs_ = UL.L.backward(tensor_cat(Δη, Δs), tensor_cat(η, s)) + x, _ = tensor_split(x; split_index=UL.n_grad) + Δx, _ = tensor_split(Δx; split_index=UL.n_grad) - η, s_ = tensor_split(ηs_; split_index=1) - Δη, Δs = tensor_split(Δηs_; split_index=1) + Δx, x = UL.AN.backward(Δx, x) - gn = UL.AN.forward(g) # normalize - Δgn = tensor_split(Δs; split_index=1)[1] - Δg = UL.AN.backward(Δgn, gn)[1] - - return Δη, η + return Δx, x end ## Jacobian-related utils @@ -130,4 +148,3 @@ jacobian(::AbstractArray{T, 5}, ::AbstractArray{T, 5}, UL::NetworkUNET) where T adjointJacobian(Δη::AbstractArray{T, N}, η::AbstractArray{T, N}, s::AbstractArray{T, N}, UL::NetworkUNET; set_grad::Bool=true) where {T, N} = throw(ArgumentError("Jacobian for NetworkUNET not yet implemented")) - From 34a67dd2b7222604ef9b31b7ecf9e35b92f3dd44 Mon Sep 17 00:00:00 2001 From: Rafael Orozco Date: Wed, 12 Oct 2022 15:06:24 -0400 Subject: [PATCH 9/9] Update invertible_network_unet.jl take away indexing --- src/networks/invertible_network_unet.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/networks/invertible_network_unet.jl b/src/networks/invertible_network_unet.jl index 3e68abe2..64822e68 100644 --- a/src/networks/invertible_network_unet.jl +++ b/src/networks/invertible_network_unet.jl @@ -92,7 +92,7 @@ function forward(g::AbstractArray{T, N}, UL::NetworkUNET) where {T, N} # Forward pass gs = cuzeros(g, nn..., UL.n_mem, batchsize) - gn = UL.AN.forward(g)[1] # normalize + gn,_ = UL.AN.forward(g) # normalize gs[:,:,1:UL.n_grad,:] = gn # gradient in first channel @@ -117,7 +117,7 @@ function inverse(y::AbstractArray{T, N}, UL::NetworkUNET) where {T, N} x, _ = tensor_split(x; split_index=UL.n_grad) - x = UL.AN.inverse(x)[1] # normalize + x,_ = UL.AN.inverse(x) # normalize return x end