From 1600e4b87f56425d3898ec33ee3eb6b2ac6c7666 Mon Sep 17 00:00:00 2001 From: Ali Siahkoohi Date: Thu, 17 Jun 2021 10:52:45 -0400 Subject: [PATCH 1/3] add flag for computing p(x|y) in conditional hint --- src/conditional_layers/conditional_layer_hint.jl | 8 ++++---- .../invertible_network_conditional_hint.jl | 14 +++++++------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/conditional_layers/conditional_layer_hint.jl b/src/conditional_layers/conditional_layer_hint.jl index da982777..44ed0eb1 100644 --- a/src/conditional_layers/conditional_layer_hint.jl +++ b/src/conditional_layers/conditional_layer_hint.jl @@ -82,7 +82,7 @@ end # 3D Constructor from input dimensions ConditionalLayerHINT3D(args...; kw...) = ConditionalLayerHINT(args...; kw..., ndims=3) -function forward(X::AbstractArray{T, N}, Y::AbstractArray{T, N}, CH::ConditionalLayerHINT; logdet=nothing) where {T, N} +function forward(X::AbstractArray{T, N}, Y::AbstractArray{T, N}, CH::ConditionalLayerHINT; logdet=nothing, x_lane=false) where {T, N} isnothing(logdet) ? logdet = (CH.logdet && ~CH.is_reversed) : logdet = logdet # Y-lane @@ -96,10 +96,10 @@ function forward(X::AbstractArray{T, N}, Y::AbstractArray{T, N}, CH::Conditional # X-lane: conditional layer logdet ? (Zx, logdet3) = CH.CL_YX.forward(Yp, X)[2:3] : Zx = CH.CL_YX.forward(Yp, X)[2] - logdet ? (return Zx, Zy, logdet1 + logdet2 + logdet3) : (return Zx, Zy) + logdet ? (return Zx, Zy, logdet1 + !x_lane*logdet2 + logdet3) : (return Zx, Zy) end -function inverse(Zx::AbstractArray{T, N}, Zy::AbstractArray{T, N}, CH::ConditionalLayerHINT; logdet=nothing) where {T, N} +function inverse(Zx::AbstractArray{T, N}, Zy::AbstractArray{T, N}, CH::ConditionalLayerHINT; logdet=nothing, x_lane=false) where {T, N} isnothing(logdet) ? logdet = (CH.logdet && CH.is_reversed) : logdet = logdet # Y-lane @@ -114,7 +114,7 @@ function inverse(Zx::AbstractArray{T, N}, Zy::AbstractArray{T, N}, CH::Condition logdet ? (Xp, logdet3) = CH.CL_X.inverse(X; logdet=true) : Xp = CH.CL_X.inverse(X; logdet=false) ~isnothing(CH.C_X) ? (X = CH.C_X.inverse(Xp)) : (X = copy(Xp)) - logdet ? (return X, Y, logdet1 + logdet2 + logdet3) : (return X, Y) + logdet ? (return X, Y, !x_lane*logdet1 + logdet2 + logdet3) : (return X, Y) end function backward(ΔZx::AbstractArray{T, N}, ΔZy::AbstractArray{T, N}, Zx::AbstractArray{T, N}, Zy::AbstractArray{T, N}, CH::ConditionalLayerHINT; logdet=nothing, set_grad::Bool=true) where {T, N} diff --git a/src/networks/invertible_network_conditional_hint.jl b/src/networks/invertible_network_conditional_hint.jl index 6574b845..7fcd6c7d 100644 --- a/src/networks/invertible_network_conditional_hint.jl +++ b/src/networks/invertible_network_conditional_hint.jl @@ -78,7 +78,7 @@ end NetworkConditionalHINT3D(args...;kw...) = NetworkConditionalHINT(args...; kw..., ndims=3) # Forward pass and compute logdet -function forward(X::AbstractArray{T, N}, Y::AbstractArray{T, N}, CH::NetworkConditionalHINT; logdet=nothing) where {T, N} +function forward(X::AbstractArray{T, N}, Y::AbstractArray{T, N}, CH::NetworkConditionalHINT; logdet=nothing, x_lane=false) where {T, N} isnothing(logdet) ? logdet = (CH.logdet && ~CH.is_reversed) : logdet = logdet depth = length(CH.CL) @@ -86,23 +86,23 @@ function forward(X::AbstractArray{T, N}, Y::AbstractArray{T, N}, CH::NetworkCond for j=1:depth logdet ? (X_, logdet1) = CH.AN_X[j].forward(X) : X_ = CH.AN_X[j].forward(X) logdet ? (Y_, logdet2) = CH.AN_Y[j].forward(Y) : Y_ = CH.AN_Y[j].forward(Y) - logdet ? (X, Y, logdet3) = CH.CL[j].forward(X_, Y_) : (X, Y) = CH.CL[j].forward(X_, Y_) - logdet && (logdet_ += (logdet1 + logdet2 + logdet3)) + logdet ? (X, Y, logdet3) = CH.CL[j].forward(X_, Y_; x_lane=x_lane) : (X, Y) = CH.CL[j].forward(X_, Y_) + logdet && (logdet_ += (logdet1 + !x_lane*logdet2 + logdet3)) end logdet ? (return X, Y, logdet_) : (return X, Y) end # Inverse pass and compute gradients -function inverse(Zx::AbstractArray{T, N}, Zy::AbstractArray{T, N}, CH::NetworkConditionalHINT; logdet=nothing) where {T, N} +function inverse(Zx::AbstractArray{T, N}, Zy::AbstractArray{T, N}, CH::NetworkConditionalHINT; logdet=nothing, x_lane=false) where {T, N} isnothing(logdet) ? logdet = (CH.logdet && CH.is_reversed) : logdet = logdet depth = length(CH.CL) logdet_ = 0 for j=depth:-1:1 - logdet ? (Zx_, Zy_, logdet1) = CH.CL[j].inverse(Zx, Zy; logdet=true) : (Zx_, Zy_) = CH.CL[j].inverse(Zx, Zy; logdet=false) + logdet ? (Zx_, Zy_, logdet1) = CH.CL[j].inverse(Zx, Zy; logdet=true, x_lane=x_lane) : (Zx_, Zy_) = CH.CL[j].inverse(Zx, Zy; logdet=false) logdet ? (Zy, logdet2) = CH.AN_Y[j].inverse(Zy_; logdet=true) : Zy = CH.AN_Y[j].inverse(Zy_; logdet=false) logdet ? (Zx, logdet3) = CH.AN_X[j].inverse(Zx_; logdet=true) : Zx = CH.AN_X[j].inverse(Zx_; logdet=false) - logdet && (logdet_ += (logdet1 + logdet2 + logdet3)) + logdet && (logdet_ += (logdet1 + !x_lane*logdet2 + logdet3)) end logdet ? (return Zx, Zy, logdet_) : (return Zx, Zy) end @@ -233,4 +233,4 @@ function tag_as_reversed!(CH::NetworkConditionalHINT, tag::Bool) tag_as_reversed!(CH.CL[j], tag) end return CH -end \ No newline at end of file +end From d8fc9899a1a9e7a70151f879d68e6c8e549fc8bb Mon Sep 17 00:00:00 2001 From: Ali Siahkoohi Date: Wed, 14 Jul 2021 13:07:55 -0400 Subject: [PATCH 2/3] adding x_lane to gradient calculations --- .../conditional_layer_hint.jl | 16 ++++++----- src/layers/invertible_layer_actnorm.jl | 8 +++--- src/layers/invertible_layer_basic.jl | 12 ++++---- src/layers/invertible_layer_hint.jl | 28 +++++++++---------- .../invertible_network_conditional_hint.jl | 18 ++++++------ src/utils/parameter.jl | 10 ++++++- 6 files changed, 52 insertions(+), 40 deletions(-) diff --git a/src/conditional_layers/conditional_layer_hint.jl b/src/conditional_layers/conditional_layer_hint.jl index 44ed0eb1..67013af0 100644 --- a/src/conditional_layers/conditional_layer_hint.jl +++ b/src/conditional_layers/conditional_layer_hint.jl @@ -82,7 +82,7 @@ end # 3D Constructor from input dimensions ConditionalLayerHINT3D(args...; kw...) = ConditionalLayerHINT(args...; kw..., ndims=3) -function forward(X::AbstractArray{T, N}, Y::AbstractArray{T, N}, CH::ConditionalLayerHINT; logdet=nothing, x_lane=false) where {T, N} +function forward(X::AbstractArray{T, N}, Y::AbstractArray{T, N}, CH::ConditionalLayerHINT; logdet=nothing, x_lane::Bool=false) where {T, N} isnothing(logdet) ? logdet = (CH.logdet && ~CH.is_reversed) : logdet = logdet # Y-lane @@ -99,7 +99,7 @@ function forward(X::AbstractArray{T, N}, Y::AbstractArray{T, N}, CH::Conditional logdet ? (return Zx, Zy, logdet1 + !x_lane*logdet2 + logdet3) : (return Zx, Zy) end -function inverse(Zx::AbstractArray{T, N}, Zy::AbstractArray{T, N}, CH::ConditionalLayerHINT; logdet=nothing, x_lane=false) where {T, N} +function inverse(Zx::AbstractArray{T, N}, Zy::AbstractArray{T, N}, CH::ConditionalLayerHINT; logdet=nothing, x_lane::Bool=false) where {T, N} isnothing(logdet) ? logdet = (CH.logdet && CH.is_reversed) : logdet = logdet # Y-lane @@ -117,15 +117,16 @@ function inverse(Zx::AbstractArray{T, N}, Zy::AbstractArray{T, N}, CH::Condition logdet ? (return X, Y, !x_lane*logdet1 + logdet2 + logdet3) : (return X, Y) end -function backward(ΔZx::AbstractArray{T, N}, ΔZy::AbstractArray{T, N}, Zx::AbstractArray{T, N}, Zy::AbstractArray{T, N}, CH::ConditionalLayerHINT; logdet=nothing, set_grad::Bool=true) where {T, N} +function backward(ΔZx::AbstractArray{T, N}, ΔZy::AbstractArray{T, N}, Zx::AbstractArray{T, N}, Zy::AbstractArray{T, N}, + CH::ConditionalLayerHINT; logdet=nothing, set_grad::Bool=true, x_lane::Bool=false) where {T, N} isnothing(logdet) ? logdet = (CH.logdet && ~CH.is_reversed) : logdet = logdet # Y-lane if set_grad - ΔYp, Yp = CH.CL_Y.backward(ΔZy, Zy) + ΔYp, Yp = CH.CL_Y.backward(ΔZy, Zy; x_lane=x_lane) else if logdet - ΔYp, Δθ_CLY, Yp, ∇logdet_CLY = CH.CL_Y.backward(ΔZy, Zy; set_grad=set_grad) + ΔYp, Δθ_CLY, Yp, ∇logdet_CLY = CH.CL_Y.backward(ΔZy, Zy; set_grad=set_grad, x_lane=x_lane) else ΔYp, Δθ_CLY, Yp = CH.CL_Y.backward(ΔZy, Zy; set_grad=set_grad) end @@ -185,7 +186,8 @@ function backward(ΔZx::AbstractArray{T, N}, ΔZy::AbstractArray{T, N}, Zx::Abst end end -function backward_inv(ΔX::AbstractArray{T, N}, ΔY::AbstractArray{T, N}, X::AbstractArray{T, N}, Y::AbstractArray{T, N}, CH::ConditionalLayerHINT) where {T, N} +function backward_inv(ΔX::AbstractArray{T, N}, ΔY::AbstractArray{T, N}, X::AbstractArray{T, N}, + Y::AbstractArray{T, N}, CH::ConditionalLayerHINT; x_lane::Bool=false) where {T, N} # 1x1 Convolutions if isnothing(CH.C_X) || isnothing(CH.C_Y) @@ -204,7 +206,7 @@ function backward_inv(ΔX::AbstractArray{T, N}, ΔY::AbstractArray{T, N}, X::Abs ΔYp += ΔYp_ # Y-lane - ΔZy, Zy = backward_inv(ΔYp, Yp, CH.CL_Y) + ΔZy, Zy = backward_inv(ΔYp, Yp, CH.CL_Y; x_lane=x_lane) return ΔZx, ΔZy, Zx, Zy end diff --git a/src/layers/invertible_layer_actnorm.jl b/src/layers/invertible_layer_actnorm.jl index 3c3a23f5..5d525b6a 100644 --- a/src/layers/invertible_layer_actnorm.jl +++ b/src/layers/invertible_layer_actnorm.jl @@ -97,7 +97,7 @@ function inverse(Y::AbstractArray{T, N}, AN::ActNorm; logdet=nothing) where {T, end # 2-3D Backward pass: Input (ΔY, Y), Output (ΔY, Y) -function backward(ΔY::AbstractArray{T, N}, Y::AbstractArray{T, N}, AN::ActNorm; set_grad::Bool = true) where {T, N} +function backward(ΔY::AbstractArray{T, N}, Y::AbstractArray{T, N}, AN::ActNorm; set_grad::Bool = true, x_lane::Bool=false) where {T, N} inds = [i!=(N-1) ? 1 : (:) for i=1:N] dims = collect(1:N-1); dims[end] +=1 nn = size(ΔY)[1:N-2] @@ -106,7 +106,7 @@ function backward(ΔY::AbstractArray{T, N}, Y::AbstractArray{T, N}, AN::ActNorm; ΔX = ΔY .* reshape(AN.s.data, inds...) Δs = sum(ΔY .* X, dims=dims)[inds...] if AN.logdet - set_grad ? (Δs -= logdet_backward(nn..., AN.s)) : (Δs_ = logdet_backward(nn..., AN.s)) + set_grad ? (Δs -= !x_lane*logdet_backward(nn..., AN.s)) : (Δs_ = !x_lane*logdet_backward(nn..., AN.s)) end Δb = sum(ΔY, dims=dims)[inds...] if set_grad @@ -124,7 +124,7 @@ end ## Reverse-layer functions # 2-3D Backward pass (inverse): Input (ΔX, X), Output (ΔX, X) -function backward_inv(ΔX::AbstractArray{T, N}, X::AbstractArray{T, N}, AN::ActNorm; set_grad::Bool = true) where {T, N} +function backward_inv(ΔX::AbstractArray{T, N}, X::AbstractArray{T, N}, AN::ActNorm; set_grad::Bool = true, x_lane::Bool=false) where {T, N} inds = [i!=(N-1) ? 1 : (:) for i=1:N] dims = collect(1:N-1); dims[end] +=1 nn = size(ΔX)[1:N-2] @@ -133,7 +133,7 @@ function backward_inv(ΔX::AbstractArray{T, N}, X::AbstractArray{T, N}, AN::ActN ΔY = ΔX ./ reshape(AN.s.data, inds...) Δs = -sum(ΔX .* X ./ reshape(AN.s.data, inds...), dims=dims)[inds...] if AN.logdet - set_grad ? (Δs += logdet_backward(nn..., AN.s)) : (∇logdet = -logdet_backward(nn..., AN.s)) + set_grad ? (Δs += !x_lane*logdet_backward(nn..., AN.s)) : (∇logdet = !x_lane*(-logdet_backward(nn..., AN.s))) end Δb = -sum(ΔX ./ reshape(AN.s.data, inds...), dims=dims)[inds...] if set_grad diff --git a/src/layers/invertible_layer_basic.jl b/src/layers/invertible_layer_basic.jl index cdbc9d1d..4c12caa0 100644 --- a/src/layers/invertible_layer_basic.jl +++ b/src/layers/invertible_layer_basic.jl @@ -119,7 +119,7 @@ function inverse(Y1::AbstractArray{T, N}, Y2::AbstractArray{T, N}, L::CouplingLa end # 2D/3D Backward pass: Input (ΔY, Y), Output (ΔX, X) -function backward(ΔY1::AbstractArray{T, N}, ΔY2::AbstractArray{T, N}, Y1::AbstractArray{T, N}, Y2::AbstractArray{T, N}, L::CouplingLayerBasic; set_grad::Bool=true) where {T, N} +function backward(ΔY1::AbstractArray{T, N}, ΔY2::AbstractArray{T, N}, Y1::AbstractArray{T, N}, Y2::AbstractArray{T, N}, L::CouplingLayerBasic; set_grad::Bool=true, x_lane::Bool=false) where {T, N} # Recompute forward state X1, X2, S = inverse(Y1, Y2, L; save=true, logdet=false) @@ -128,7 +128,7 @@ function backward(ΔY1::AbstractArray{T, N}, ΔY2::AbstractArray{T, N}, Y1::Abst ΔT = copy(ΔY2) ΔS = ΔY2 .* X2 if L.logdet - set_grad && (ΔS -= coupling_logdet_backward(S)) + set_grad && (ΔS -= !x_lane*coupling_logdet_backward(S)) end ΔX2 = ΔY2 .* S if set_grad @@ -136,7 +136,7 @@ function backward(ΔY1::AbstractArray{T, N}, ΔY2::AbstractArray{T, N}, Y1::Abst else ΔX1, Δθ = L.RB.backward(tensor_cat(L.activation.backward(ΔS, S), ΔT), X1; set_grad=set_grad) if L.logdet - _, ∇logdet = L.RB.backward(tensor_cat(L.activation.backward(coupling_logdet_backward(S), S), 0 .*ΔT), X1; set_grad=set_grad) + _, ∇logdet = L.RB.backward(tensor_cat(L.activation.backward(!x_lane*coupling_logdet_backward(S), S), 0f0.*ΔT), X1; set_grad=set_grad) end ΔX1 += ΔY1 end @@ -149,7 +149,7 @@ function backward(ΔY1::AbstractArray{T, N}, ΔY2::AbstractArray{T, N}, Y1::Abst end # 2D/3D Reverse backward pass: Input (ΔX, X), Output (ΔY, Y) -function backward_inv(ΔX1::AbstractArray{T, N}, ΔX2::AbstractArray{T, N}, X1::AbstractArray{T, N}, X2::AbstractArray{T, N}, L::CouplingLayerBasic; set_grad::Bool=true) where {T, N} +function backward_inv(ΔX1::AbstractArray{T, N}, ΔX2::AbstractArray{T, N}, X1::AbstractArray{T, N}, X2::AbstractArray{T, N}, L::CouplingLayerBasic; set_grad::Bool=true, x_lane::Bool=false) where {T, N} # Recompute inverse state Y1, Y2, S = forward(X1, X2, L; save=true, logdet=false) @@ -158,7 +158,7 @@ function backward_inv(ΔX1::AbstractArray{T, N}, ΔX2::AbstractArray{T, N}, X1:: ΔT = -ΔX2 ./ S ΔS = X2 .* ΔT if L.logdet == true - set_grad ? (ΔS += coupling_logdet_backward(S)) : (∇logdet = -coupling_logdet_backward(S)) + set_grad ? (ΔS += !x_lane*coupling_logdet_backward(S)) : (∇logdet = !x_lane*(-coupling_logdet_backward(S))) end if set_grad ΔY1 = L.RB.backward(tensor_cat(L.activation.backward(ΔS, S), ΔT), Y1) + ΔX1 @@ -223,4 +223,4 @@ get_params(L::CouplingLayerBasic) = get_params(L.RB) function tag_as_reversed!(L::CouplingLayerBasic, tag::Bool) L.is_reversed = tag return L -end \ No newline at end of file +end diff --git a/src/layers/invertible_layer_hint.jl b/src/layers/invertible_layer_hint.jl index a604c336..86992cd3 100644 --- a/src/layers/invertible_layer_hint.jl +++ b/src/layers/invertible_layer_hint.jl @@ -201,7 +201,7 @@ function inverse(Y::AbstractArray{T, N} , H::CouplingLayerHINT; scale=1, permute end # Input are two tensors ΔY, Y -function backward(ΔY::AbstractArray{T, N}, Y::AbstractArray{T, N}, H::CouplingLayerHINT; scale=1, permute=nothing, set_grad::Bool=true) where {T, N} +function backward(ΔY::AbstractArray{T, N}, Y::AbstractArray{T, N}, H::CouplingLayerHINT; scale=1, permute=nothing, set_grad::Bool=true, x_lane::Bool=false) where {T, N} isnothing(permute) ? permute = H.permute : permute = permute # Initializing output parameter array @@ -231,16 +231,16 @@ function backward(ΔY::AbstractArray{T, N}, Y::AbstractArray{T, N}, H::CouplingL # HINT coupling if recursive if set_grad - ΔXa, Xa = backward(ΔYa, Ya, H; scale=scale+1, permute="none") - ΔXa_temp, ΔXb_temp, X_temp = H.CL[scale].backward(ΔXa.*0, ΔYb, Xa, Yb)[[1,2,4]] - ΔXb, Xb = backward(ΔXb_temp, X_temp, H; scale=scale+1, permute="none") + ΔXa, Xa = backward(ΔYa, Ya, H; scale=scale+1, permute="none", x_lane=x_lane) + ΔXa_temp, ΔXb_temp, X_temp = H.CL[scale].backward(ΔXa.*0f0, ΔYb, Xa, Yb; x_lane=x_lane)[[1,2,4]] + ΔXb, Xb = backward(ΔXb_temp, X_temp, H; scale=scale+1, permute="none", x_lane=x_lane) else if H.logdet ΔXa, Δθa, Xa, ∇logdet_a = backward(ΔYa, Ya, H; scale=scale+1, permute="none", set_grad=set_grad) ΔXa_temp, ΔXb_temp, Δθ_scale, _, X_temp, ∇logdet_scale = H.CL[scale].backward(ΔXa.*0, ΔYb, Xa, Yb; set_grad=set_grad) ΔXb, Δθb, Xb, ∇logdet_b = backward(ΔXb_temp, X_temp, H; scale=scale+1, permute="none", set_grad=set_grad) - ∇logdet[1:5] .= ∇logdet_scale - ∇logdet[6:5+length(∇logdet_a)] .= ∇logdet_a+∇logdet_b + ∇logdet[1:5] .= !x_lane*∇logdet_scale + ∇logdet[6:5+length(∇logdet_a)] .= !x_lane*∇logdet_a+!x_lane*∇logdet_b else ΔXa, Δθa, Xa = backward(ΔYa, Ya, H; scale=scale+1, permute="none", set_grad=set_grad) ΔXa_temp, ΔXb_temp, Δθ_scale, _, X_temp = H.CL[scale].backward(ΔXa.*0, ΔYb, Xa, Yb; set_grad=set_grad) @@ -254,11 +254,11 @@ function backward(ΔY::AbstractArray{T, N}, Y::AbstractArray{T, N}, H::CouplingL Xa = copy(Ya) ΔXa = copy(ΔYa) if set_grad - ΔXa_, ΔXb, Xb = H.CL[scale].backward(ΔYa.*0, ΔYb, Ya, Yb)[[1,2,4]] + ΔXa_, ΔXb, Xb = H.CL[scale].backward(ΔYa.*0f0, ΔYb, Ya, Yb; x_lane=x_lane)[[1,2,4]] else if H.logdet - ΔXa_, ΔXb, Δθ_scale, _, Xb, ∇logdet_scale = H.CL[scale].backward(ΔYa.*0, ΔYb, Ya, Yb; set_grad=set_grad) - ∇logdet[1:5] .= ∇logdet_scale + ΔXa_, ΔXb, Δθ_scale, _, Xb, ∇logdet_scale = H.CL[scale].backward(ΔYa.*0f0, ΔYb, Ya, Yb; set_grad=set_grad) + ∇logdet[1:5] .= !x_lane*∇logdet_scale else ΔXa_, ΔXb, Δθ_scale, _, Xb = H.CL[scale].backward(ΔYa.*0, ΔYb, Ya, Yb; set_grad=set_grad) end @@ -299,7 +299,7 @@ function backward(ΔY::AbstractArray{T, N}, Y::AbstractArray{T, N}, H::CouplingL end # Input are two tensors ΔX, X -function backward_inv(ΔX::AbstractArray{T, N}, X::AbstractArray{T, N}, H::CouplingLayerHINT; scale=1, permute=nothing) where {T, N} +function backward_inv(ΔX::AbstractArray{T, N}, X::AbstractArray{T, N}, H::CouplingLayerHINT; scale=1, permute=nothing, x_lane::Bool=false) where {T, N} isnothing(permute) ? permute = H.permute : permute = permute # Permutation @@ -315,13 +315,13 @@ function backward_inv(ΔX::AbstractArray{T, N}, X::AbstractArray{T, N}, H::Coupl # Coupling layer backprop if recursive - ΔY_temp, Y_temp = backward_inv(ΔXb, Xb, H; scale=scale+1, permute="none") - ΔYa_temp, ΔYb, Yb = backward_inv(0 .*ΔXa, ΔY_temp, Xa, Y_temp, H.CL[scale])[[1,2,4]] - ΔYa, Ya = backward_inv(ΔXa+ΔYa_temp, Xa, H; scale=scale+1, permute="none") + ΔY_temp, Y_temp = backward_inv(ΔXb, Xb, H; scale=scale+1, permute="none", x_lane=x_lane) + ΔYa_temp, ΔYb, Yb = backward_inv(0f0.*ΔXa, ΔY_temp, Xa, Y_temp, H.CL[scale]; x_lane=x_lane)[[1,2,4]] + ΔYa, Ya = backward_inv(ΔXa+ΔYa_temp, Xa, H; scale=scale+1, permute="none", x_lane=x_lane) else ΔYa = copy(ΔXa) Ya = copy(Xa) - ΔYa_temp, ΔYb, Yb = backward_inv(0 .*ΔYa, ΔXb, Xa, Xb, H.CL[scale])[[1,2,4]] + ΔYa_temp, ΔYb, Yb = backward_inv(0f0.*ΔYa, ΔXb, Xa, Xb, H.CL[scale]; x_lane=x_lane)[[1,2,4]] ΔYa += ΔYa_temp end ΔY = tensor_cat(ΔYa, ΔYb) diff --git a/src/networks/invertible_network_conditional_hint.jl b/src/networks/invertible_network_conditional_hint.jl index 7fcd6c7d..da5e3d62 100644 --- a/src/networks/invertible_network_conditional_hint.jl +++ b/src/networks/invertible_network_conditional_hint.jl @@ -108,7 +108,8 @@ function inverse(Zx::AbstractArray{T, N}, Zy::AbstractArray{T, N}, CH::NetworkCo end # Backward pass and compute gradients -function backward(ΔZx::AbstractArray{T, N}, ΔZy::AbstractArray{T, N}, Zx::AbstractArray{T, N}, Zy::AbstractArray{T, N}, CH::NetworkConditionalHINT; set_grad::Bool=true) where {T, N} +function backward(ΔZx::AbstractArray{T, N}, ΔZy::AbstractArray{T, N}, Zx::AbstractArray{T, N}, + Zy::AbstractArray{T, N}, CH::NetworkConditionalHINT; set_grad::Bool=true, x_lane::Bool=false) where {T, N} depth = length(CH.CL) if ~set_grad Δθ = Array{Parameter, 1}(undef, 0) @@ -116,14 +117,14 @@ function backward(ΔZx::AbstractArray{T, N}, ΔZy::AbstractArray{T, N}, Zx::Abst end for j=depth:-1:1 if set_grad - ΔZx_, ΔZy_, Zx_, Zy_ = CH.CL[j].backward(ΔZx, ΔZy, Zx, Zy) + ΔZx_, ΔZy_, Zx_, Zy_ = CH.CL[j].backward(ΔZx, ΔZy, Zx, Zy; x_lane=x_lane) ΔZx, Zx = CH.AN_X[j].backward(ΔZx_, Zx_) - ΔZy, Zy = CH.AN_Y[j].backward(ΔZy_, Zy_) + ΔZy, Zy = CH.AN_Y[j].backward(ΔZy_, Zy_; x_lane=x_lane) else if CH.logdet - ΔZx_, ΔZy_, Δθcl, Zx_, Zy_, ∇logdetcl = CH.CL[j].backward(ΔZx, ΔZy, Zx, Zy; set_grad=set_grad) + ΔZx_, ΔZy_, Δθcl, Zx_, Zy_, ∇logdetcl = CH.CL[j].backward(ΔZx, ΔZy, Zx, Zy; set_grad=set_grad, x_lane=x_lane) ΔZx, Δθx, Zx, ∇logdetx = CH.AN_X[j].backward(ΔZx_, Zx_; set_grad=set_grad) - ΔZy, Δθy, Zy, ∇logdety = CH.AN_Y[j].backward(ΔZy_, Zy_; set_grad=set_grad) + ΔZy, Δθy, Zy, ∇logdety = CH.AN_Y[j].backward(ΔZy_, Zy_; set_grad=set_grad, x_lane=x_lane) ∇logdet = cat(∇logdetx, ∇logdety, ∇logdetcl, ∇logdet; dims=1) else ΔZx_, ΔZy_, Δθcl, Zx_, Zy_ = CH.CL[j].backward(ΔZx, ΔZy, Zx, Zy; set_grad=set_grad) @@ -141,12 +142,13 @@ function backward(ΔZx::AbstractArray{T, N}, ΔZy::AbstractArray{T, N}, Zx::Abst end # Backward reverse pass and compute gradients -function backward_inv(ΔX::AbstractArray{T, N}, ΔY::AbstractArray{T, N}, X::AbstractArray{T, N}, Y::AbstractArray{T, N}, CH::NetworkConditionalHINT) where {T, N} +function backward_inv(ΔX::AbstractArray{T, N}, ΔY::AbstractArray{T, N}, X::AbstractArray{T, N}, + Y::AbstractArray{T, N}, CH::NetworkConditionalHINT; x_lane::Bool=false) where {T, N} depth = length(CH.CL) for j=1:depth ΔX_, X_ = backward_inv(ΔX, X, CH.AN_X[j]) - ΔY_, Y_ = backward_inv(ΔY, Y, CH.AN_Y[j]) - ΔX, ΔY, X, Y = backward_inv(ΔX_, ΔY_, X_, Y_, CH.CL[j]) + ΔY_, Y_ = backward_inv(ΔY, Y, CH.AN_Y[j]; x_lane=x_lane) + ΔX, ΔY, X, Y = backward_inv(ΔX_, ΔY_, X_, Y_, CH.CL[j]; x_lane=x_lane) end return ΔX, ΔY, X, Y end diff --git a/src/utils/parameter.jl b/src/utils/parameter.jl index 9ac134f1..de068f67 100644 --- a/src/utils/parameter.jl +++ b/src/utils/parameter.jl @@ -134,6 +134,14 @@ function /(p1::T, p2::Parameter) where {T<:Real} return Parameter(p1/p2.data) end +function *(p1::Parameter, p2::Bool) + return Parameter(p1.data*p2) +end + +function *(p1::Bool, p2::Parameter) + return p2*p1 +end + # Shape manipulation par2vec(x::Parameter) = vec(x.data), size(x.data) @@ -157,4 +165,4 @@ function vec2par(x::AbstractArray{T, 1}, s::Array{Any, 1}) where T idx_i += prod(s[i]) end return xpar -end \ No newline at end of file +end From d9e4ea05526dcfc1f4c153ad031e4e37bb2b87c1 Mon Sep 17 00:00:00 2001 From: Ali Siahkoohi Date: Mon, 14 Mar 2022 15:17:04 -0400 Subject: [PATCH 3/3] adding tests for x_lane --- src/layers/invertible_layer_basic.jl | 6 ++--- test/test_layers/test_actnorm.jl | 24 ++++++++++++++++++- test/test_layers/test_coupling_layer_basic.jl | 19 ++++++++++++++- 3 files changed, 44 insertions(+), 5 deletions(-) diff --git a/src/layers/invertible_layer_basic.jl b/src/layers/invertible_layer_basic.jl index 4c12caa0..3709e622 100644 --- a/src/layers/invertible_layer_basic.jl +++ b/src/layers/invertible_layer_basic.jl @@ -128,7 +128,7 @@ function backward(ΔY1::AbstractArray{T, N}, ΔY2::AbstractArray{T, N}, Y1::Abst ΔT = copy(ΔY2) ΔS = ΔY2 .* X2 if L.logdet - set_grad && (ΔS -= !x_lane*coupling_logdet_backward(S)) + set_grad && (ΔS -= !x_lane * coupling_logdet_backward(S)) end ΔX2 = ΔY2 .* S if set_grad @@ -136,7 +136,7 @@ function backward(ΔY1::AbstractArray{T, N}, ΔY2::AbstractArray{T, N}, Y1::Abst else ΔX1, Δθ = L.RB.backward(tensor_cat(L.activation.backward(ΔS, S), ΔT), X1; set_grad=set_grad) if L.logdet - _, ∇logdet = L.RB.backward(tensor_cat(L.activation.backward(!x_lane*coupling_logdet_backward(S), S), 0f0.*ΔT), X1; set_grad=set_grad) + _, ∇logdet = L.RB.backward(tensor_cat(L.activation.backward(!x_lane * coupling_logdet_backward(S), S), 0f0.*ΔT), X1; set_grad=set_grad) end ΔX1 += ΔY1 end @@ -195,7 +195,7 @@ function jacobian(ΔX1::AbstractArray{T, N}, ΔX2::AbstractArray{T, N}, Δθ::Ab # Gauss-Newton approximation of logdet terms JΔθ = tensor_split(L.RB.jacobian(zeros(Float32, size(ΔX1)), Δθ, X1)[1])[1] GNΔθ = -L.RB.adjointJacobian(tensor_cat(L.activation.backward(JΔθ, S), zeros(Float32, size(S))), X1)[2] - + save ? (return ΔX1, ΔY2, X1, Y2, coupling_logdet_forward(S), GNΔθ, S) : (return ΔX1, ΔY2, X1, Y2, coupling_logdet_forward(S), GNΔθ) else save ? (return ΔX1, ΔY2, X1, Y2, S) : (return ΔX1, ΔY2, X1, Y2) diff --git a/test/test_layers/test_actnorm.jl b/test/test_layers/test_actnorm.jl index 852549df..6453610e 100644 --- a/test/test_layers/test_actnorm.jl +++ b/test/test_layers/test_actnorm.jl @@ -146,6 +146,28 @@ end @test isapprox(err2[end] / (err2[1]/4^(maxiter-1)), 1f0; atol=1f1) +# Gradient test when x_lane = true +AN_x = ActNorm(nc; logdet=true) +X = randn(Float32, nx, ny, nc, batchsize) +X0 = randn(Float32, nx, ny, nc, batchsize) + +# Forward pass +Y = AN.forward(X)[1] + +# Forward pass +Y_, lgdet = AN.forward(X) + +# Residual and function value +ΔY = Y_ - Y +f = .5f0/batchsize*norm(ΔY)^2 +AN.logdet == true && (f -= lgdet) + +# Back propagation +ΔX, X_ = AN.backward(ΔY./batchsize, Y_, x_lane = true) + +grads = get_grads(AN) +@test isapprox(grads[1].data, zeros(size(grads[1]))) + # Gradient test for parameters AN0 = ActNorm(nc; logdet=true); AN0.forward(randn(Float32, nx, ny, nc, batchsize)) AN_ini = deepcopy(AN0) @@ -274,4 +296,4 @@ dY_ = randn(Float32, size(dY)) logdet ? ((dX_, dθ_, _, _) = AN.adjointJacobian(dY_, Y)) : ((dX_, dθ_, _) = AN.adjointJacobian(dY_, Y)) a = dot(dY, dY_) b = dot(dX, dX_)+dot(dθ, dθ_) -@test isapprox(a, b; rtol=1f-3) \ No newline at end of file +@test isapprox(a, b; rtol=1f-3) diff --git a/test/test_layers/test_coupling_layer_basic.jl b/test/test_layers/test_coupling_layer_basic.jl index 2b691765..0804d018 100644 --- a/test/test_layers/test_coupling_layer_basic.jl +++ b/test/test_layers/test_coupling_layer_basic.jl @@ -129,6 +129,23 @@ end @test isapprox(err2[end] / (err2[1]/4^(maxiter-1)), 1f0; atol=1f1) +# Gradient test w.r.t. input X0 with x_lane = true +# Invertible layers +RB0 = ResidualBlock(n_in, n_hidden; fan=true) +L = CouplingLayerBasic(RB; logdet=true) +L01 = CouplingLayerBasic(RB; logdet=true) +L02 = CouplingLayerBasic(RB0; logdet=true) + +Ya_, Yb_, logdet = L.forward(Xa, Xb) +f = mse(tensor_cat(Ya_, Yb_), tensor_cat(Ya, Yb)) - logdet +ΔY = ∇mse(tensor_cat(Ya_, Yb_), tensor_cat(Ya, Yb)) +ΔYa, ΔYb = tensor_split(ΔY) +∇logdet = L.backward(ΔYa, ΔYb, Ya_, Yb_; set_grad = false, x_lane = true)[end] + +for param in ∇logdet + @test isapprox(param.data, zeros(size(param))) +end + # Gradient test w.r.t. weights of residual block Ya, Yb = L.forward(Xa, Xb)[1:2] Lini = deepcopy(L02) @@ -281,4 +298,4 @@ dY1_ = randn(Float32, size(dY1)); dY2_ = randn(Float32, size(dY2)) dX1_, dX2_, dθ_ = L.adjointJacobian(dY1_, dY2_, Y1, Y2) a = dot(dY1, dY1_)+dot(dY2, dY2_) b = dot(dX1, dX1_)+dot(dX2, dX2_)+dot(dθ, dθ_) -@test isapprox(a, b; rtol=1f-3) \ No newline at end of file +@test isapprox(a, b; rtol=1f-3)