-
Notifications
You must be signed in to change notification settings - Fork 24
Change irim block - add invertible UNET #57
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
35cf1e4
2ad8772
2b517de
4fbd203
8a91549
238e99b
6311ced
0a3a73e
9fb0572
338060e
34a67dd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -58,19 +58,27 @@ or | |
| See also: [`Conv1x1`](@ref), [`ResidualBlock!`](@ref), [`get_params`](@ref), [`clear_grad!`](@ref) | ||
| """ | ||
| struct CouplingLayerIRIM <: NeuralNetLayer | ||
| C::Conv1x1 | ||
| RB::Union{ResidualBlock, FluxBlock} | ||
| C::AbstractArray{Conv1x1, 1} | ||
| RB::AbstractArray{ResidualBlock, 1} | ||
| end | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why are we removing |
||
|
|
||
| @Flux.functor CouplingLayerIRIM | ||
|
|
||
| # 2D Constructor from input dimensions | ||
| function CouplingLayerIRIM(n_in::Int64, n_hidden::Int64; | ||
| function CouplingLayerIRIM(n_in::Int64, n_hiddens::Array{Int64,1}, ds::Array{Int64,1}; | ||
| k1=4, k2=3, p1=0, p2=1, s1=4, s2=1, ndims=2) | ||
|
|
||
| # 1x1 Convolution and residual block for invertible layer | ||
| C = Conv1x1(n_in) | ||
| RB = ResidualBlock(n_in÷2, n_hidden; k1=k1, k2=k2, p1=p1, p2=p2, s1=s1, s2=s2, ndims=ndims) | ||
| if length(n_hiddens) != length(ds) | ||
| throw("Number of downsampling factors in ds must be the same defined hidden channels in n_hidden") | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| end | ||
|
|
||
| num_downsamp = length(n_hiddens) | ||
| C = Array{Conv1x1}(undef, num_downsamp) | ||
| RB = Array{ResidualBlock}(undef, num_downsamp) | ||
|
|
||
| for j=1:num_downsamp | ||
| C[j] = Conv1x1(n_in) | ||
| RB[j] = ResidualBlock(n_in÷2, n_hiddens[j]; d=ds[j], k1=k1, k2=k2, p1=p1, p2=p2, s1=s1, s2=s2, fan=false, ndims=ndims) | ||
| end | ||
|
|
||
| return CouplingLayerIRIM(C, RB) | ||
| end | ||
|
|
@@ -79,92 +87,120 @@ CouplingLayerIRIM3D(args...;kw...) = CouplingLayerIRIM(args...; kw..., ndims=3) | |
|
|
||
| # 2D Forward pass: Input X, Output Y | ||
| function forward(X::AbstractArray{T, N}, L::CouplingLayerIRIM) where {T, N} | ||
| X_ = L.C.forward(X) | ||
| X1_, X2_ = tensor_split(X_) | ||
|
|
||
| Y1_ = X1_ | ||
| Y2_ = X2_ + L.RB.forward(Y1_) | ||
| # Init tensors to avoid reallocation | ||
| Y_ = similar(X) | ||
|
|
||
| num_downsamp = length(L.C) | ||
| for j=1:num_downsamp | ||
| X_ = L.C[j].forward(X) | ||
| X1_, X2_ = tensor_split(X_) | ||
|
|
||
| Y_ = tensor_cat(Y1_, Y2_) | ||
| Y = L.C.inverse(Y_) | ||
| Y1_ = X1_ | ||
| Y2_ = X2_ + L.RB[j].forward(Y1_) | ||
|
|
||
| tensor_cat!(Y_, Y1_, Y2_) | ||
| X = L.C[j].inverse(Y_) | ||
| end | ||
|
|
||
| return Y | ||
| return X | ||
| end | ||
|
|
||
| # 2D Inverse pass: Input Y, Output X | ||
| function inverse(Y::AbstractArray{T, N}, L::CouplingLayerIRIM; save=false) where {T, N} | ||
| Y_ = L.C.forward(Y) | ||
| Y1_, Y2_ = tensor_split(Y_) | ||
|
|
||
| X1_ = Y1_ | ||
| X2_ = Y2_ - L.RB.forward(Y1_) | ||
| # Init tensors to avoid reallocation | ||
| X_ = similar(Y) | ||
|
|
||
| num_downsamp = length(L.C) | ||
| for j=num_downsamp:-1:1 | ||
| Y_ = L.C[j].forward(Y) | ||
| Y1_, Y2_ = tensor_split(Y_) | ||
|
|
||
| X_ = tensor_cat(X1_, X2_) | ||
| X = L.C.inverse(X_) | ||
| X1_ = Y1_ | ||
| X2_ = Y2_ - L.RB[j].forward(Y1_) | ||
|
|
||
| if save == false | ||
| return X | ||
| else | ||
| return X, X_, Y1_ | ||
| tensor_cat!(X_, X1_, X2_) | ||
| Y = L.C[j].inverse(X_) | ||
| end | ||
|
|
||
| return Y | ||
| end | ||
|
|
||
| # 2D Backward pass: Input (ΔY, Y), Output (ΔX, X) | ||
| function backward(ΔY::AbstractArray{T, N}, Y::AbstractArray{T, N}, L::CouplingLayerIRIM; set_grad::Bool=true) where {T, N} | ||
|
|
||
| # Recompute forward state | ||
| k = Int(L.C.k/2) | ||
| X, X_, Y1_ = inverse(Y, L; save=true) | ||
|
|
||
| # Backpropagate residual | ||
| if set_grad | ||
| ΔY_ = L.C.forward((ΔY, Y))[1] | ||
| else | ||
| ΔY_, Δθ_C1 = L.C.forward((ΔY, Y); set_grad=set_grad)[1:2] | ||
| end | ||
| ΔYl_, ΔYr_ = tensor_split(ΔY_) | ||
| if set_grad | ||
| ΔY1_ = L.RB.backward(ΔYr_, Y1_) + ΔYl_ | ||
| else | ||
| ΔY1_, Δθ_RB = L.RB.backward(ΔYr_, Y1_; set_grad=set_grad) | ||
| ΔY1_ = ΔY1_ + ΔYl_ | ||
| end | ||
|
|
||
| ΔX_ = tensor_cat(ΔY1_, ΔYr_) | ||
| if set_grad | ||
| ΔX = L.C.inverse((ΔX_, X_))[1] | ||
| else | ||
| ΔX, Δθ_C2 = L.C.inverse((ΔX_, X_); set_grad=set_grad)[1:2] | ||
| # Initialize layer parameters | ||
| !set_grad && (p1 = Array{Parameter, 1}(undef, 0)) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why do we need to initialize? |
||
| !set_grad && (p2 = Array{Parameter, 1}(undef, 0)) | ||
|
|
||
| # Init tensors to avoid reallocation | ||
| ΔY_ = similar(ΔY) | ||
| Y_ = similar(Y) | ||
|
|
||
| num_downsamp = length(L.C) | ||
| for j=num_downsamp:-1:1 | ||
| if set_grad | ||
| ΔY_, Y_ = L.C[j].forward((ΔY, Y)) | ||
| else | ||
| ΔY_, Δθ_C1, Y_ = L.C[j].forward((ΔY, Y); set_grad=set_grad) | ||
| end | ||
|
|
||
| ΔYl_, ΔYr_ = tensor_split(ΔY_) | ||
| Y1_, Y2_ = tensor_split(Y_) | ||
|
|
||
| if set_grad | ||
| ΔYl_ .= L.RB[j].backward(ΔYr_, Y1_) + ΔYl_ | ||
| else | ||
| ΔY_RB, Δθ_RB = L.RB[j].backward(ΔYr_, Y1_; set_grad=set_grad) | ||
| ΔYl_ .= ΔY_RB + ΔYl_ | ||
| end | ||
|
|
||
| Y2_ .-= L.RB[j].forward(Y1_) | ||
|
|
||
| tensor_cat!(ΔY_, ΔYl_, ΔYr_) | ||
| tensor_cat!(Y_, Y1_, Y2_) | ||
|
|
||
| if set_grad | ||
| ΔY, Y = L.C[j].inverse((ΔY_, Y_)) | ||
| else | ||
| ΔY, Δθ_C2, Y = L.C[j].inverse((ΔY_, Y_); set_grad=set_grad) | ||
| p1 = cat(p1, Δθ_C1+Δθ_C2; dims=1) | ||
| p2 = cat(p2, Δθ_RB; dims=1) | ||
| end | ||
| end | ||
|
|
||
| set_grad ? (return ΔX, X) : (return ΔX, cat(Δθ_C1+Δθ_C2, Δθ_RB; dims=1), X) | ||
| set_grad ? (return ΔY, Y) : (ΔY, cat(p1, p2; dims=1), Y) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would be good to keep the naming convention the same, i.e., this should return |
||
| end | ||
|
|
||
| ## Jacobian utilities | ||
|
|
||
| # 2D | ||
| function jacobian(ΔX::AbstractArray{T, N}, Δθ::Array{Parameter, 1}, X::AbstractArray{T, N}, L::CouplingLayerIRIM) where {T, N} | ||
|
|
||
| # Get dimensions | ||
| k = Int(L.C.k/2) | ||
|
|
||
| ΔX_, X_ = L.C.jacobian(ΔX, Δθ[1:3], X) | ||
| X1_, X2_ = tensor_split(X_) | ||
| ΔX1_, ΔX2_ = tensor_split(ΔX_) | ||
|
|
||
| ΔY1_, Y1__ = L.RB.jacobian(ΔX1_, Δθ[4:end], X1_) | ||
| Y2_ = X2_ + Y1__ | ||
| ΔY2_ = ΔX2_ + ΔY1_ | ||
|
|
||
| Y_ = tensor_cat(X1_, Y2_) | ||
| ΔY_ = tensor_cat(ΔX1_, ΔY2_) | ||
| ΔY, Y = L.C.jacobianInverse(ΔY_, Δθ[1:3], Y_) | ||
| num_downsamp = length(L.C) | ||
| num_rb = 5 | ||
| num_1x1c = 3 | ||
| for j=1:num_downsamp | ||
| idx_conv = (j-1)*num_1x1c+1:j*num_1x1c | ||
| idx_rb = (j-1)*num_rb+1+num_downsamp*num_1x1c:(j)*num_rb+num_downsamp*num_1x1c | ||
|
|
||
| ΔX_, X_ = L.C[j].jacobian(ΔX, Δθ[idx_conv], X) | ||
| X1_, X2_ = tensor_split(X_) | ||
| ΔX1_, ΔX2_ = tensor_split(ΔX_) | ||
|
|
||
| ΔY1_, Y1__ = L.RB[j].jacobian(ΔX1_, Δθ[idx_rb], X1_) | ||
| Y2_ = X2_ + Y1__ | ||
| ΔY2_ = ΔX2_ + ΔY1_ | ||
|
|
||
| Y_ = tensor_cat(X1_, Y2_) | ||
| ΔY_ = tensor_cat(ΔX1_, ΔY2_) | ||
| ΔX, X = L.C[j].jacobianInverse(ΔY_, Δθ[idx_conv], Y_) | ||
| end | ||
|
|
||
| return ΔY, Y | ||
|
|
||
| return ΔX, X | ||
| end | ||
|
|
||
| # 2D/3D | ||
| function adjointJacobian(ΔY::AbstractArray{T, N}, Y::AbstractArray{T, N}, L::CouplingLayerIRIM) where {T, N} | ||
| return backward(ΔY, Y, L; set_grad=false) | ||
| end | ||
| end | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -75,10 +75,18 @@ end | |
| # Constructors | ||
|
|
||
| # Constructor | ||
| function ResidualBlock(n_in, n_hidden; k1=3, k2=3, p1=1, p2=1, s1=1, s2=1, fan=false, ndims=2) | ||
| function ResidualBlock(n_in, n_hidden; d=nothing, k1=3, k2=3, p1=1, p2=1, s1=1, s2=1, fan=false, ndims=2) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not clear what |
||
|
|
||
| # Check if downsampling factor d is defined | ||
| if !isnothing(d) | ||
| k1 = d | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why |
||
| s1 = d | ||
| p1 = 0 | ||
| end | ||
|
|
||
| k1 = Tuple(k1 for i=1:ndims) | ||
| k2 = Tuple(k2 for i=1:ndims) | ||
|
|
||
| # Initialize weights | ||
| W1 = Parameter(glorot_uniform(k1..., n_in, n_hidden)) | ||
| W2 = Parameter(glorot_uniform(k2..., n_hidden, n_hidden)) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -66,7 +66,7 @@ end | |
| @Flux.functor NetworkLoop | ||
|
|
||
| # 2D Constructor | ||
| function NetworkLoop(n_in, n_hidden, maxiter, Ψ; k1=4, k2=3, p1=0, p2=1, s1=4, s2=1, type="additive", ndims=2) | ||
| function NetworkLoop(n_in, n_hidden, maxiter, Ψ; n_hiddens=nothing, ds=nothing, k1=4, k2=3, p1=0, p2=1, s1=4, s2=1, type="additive", ndims=2) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar variable names |
||
|
|
||
| if type == "additive" | ||
| L = Array{CouplingLayerIRIM}(undef, maxiter) | ||
|
|
@@ -77,7 +77,7 @@ function NetworkLoop(n_in, n_hidden, maxiter, Ψ; k1=4, k2=3, p1=0, p2=1, s1=4, | |
| AN = Array{ActNorm}(undef, maxiter) | ||
| for j=1:maxiter | ||
| if type == "additive" | ||
| L[j] = CouplingLayerIRIM(n_in, n_hidden; k1=k1, k2=k2, p1=p1, p2=p2, s1=s1, s2=s2, ndims=ndims) | ||
| L[j] = CouplingLayerIRIM(n_in, n_hidden; n_hiddens=n_hiddens, ds=ds, k1=k1, k2=k2, p1=p1, p2=p2, s1=s1, s2=s2, ndims=ndims) | ||
| elseif type == "HINT" | ||
| L[j] = CouplingLayerHINT(n_in, n_hidden; logdet=false, permute="both", k1=k1, k2=k2, p1=p1, p2=p2, | ||
| s1=s1, s2=s2, ndims=ndims) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this implementing https://github.com/pputzky/invertible_rim/blob/master/irim/core/invertible_unet.py?