Skip to content

Commit 3a52102

Browse files
committed
add reshapeInput function wrapper
1 parent 383d978 commit 3a52102

File tree

5 files changed

+221
-0
lines changed

5 files changed

+221
-0
lines changed

docs/src/calculus.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,4 +32,5 @@ Precompose
3232
PrecomposeDiagonal
3333
Tilt
3434
Translate
35+
ReshapeInput
3536
```

src/ProximalOperators.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ include("calculus/precomposeDiagonal.jl")
9191
include("calculus/regularize.jl")
9292
include("calculus/separableSum.jl")
9393
include("calculus/slicedSeparableSum.jl")
94+
include("calculus/reshapeInput.jl")
9495
include("calculus/sqrDistL2.jl")
9596
include("calculus/tilt.jl")
9697
include("calculus/translate.jl")

src/calculus/reshapeInput.jl

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# wrap a function to reshape the input
2+
3+
export ReshapeInput
4+
5+
"""
6+
ReshapeInput(f, expected_shape)
7+
8+
Wrap a function to reshape the input.
9+
It is useful when the function `f` expects a specific shape of the input, but you want to pass it a different shape.
10+
11+
```julia
12+
julia> f = ReshapeInput(IndballRank(5), (10, 10))
13+
ReshapeInput(IndBallRank{Int64}(5), (10, 10))
14+
15+
julia> f(rand(100))
16+
Inf
17+
```
18+
"""
19+
struct ReshapeInput{F, S}
20+
f::F
21+
expected_shape::S
22+
end
23+
24+
function (f::ReshapeInput)(x)
25+
if size(x) != f.expected_shape
26+
x = reshape(x, f.expected_shape)
27+
end
28+
return f.f(x)
29+
end
30+
31+
function prox!(y, f::ReshapeInput, x, gamma)
32+
if size(x) != f.expected_shape
33+
x = reshape(x, f.expected_shape)
34+
end
35+
if size(y) != f.expected_shape
36+
y = reshape(y, f.expected_shape)
37+
end
38+
return prox!(y, f.f, x, gamma)
39+
end
40+
41+
function gradient!(y, f::ReshapeInput, x)
42+
if size(x) != f.expected_shape
43+
x = reshape(x, f.expected_shape)
44+
end
45+
if size(y) != f.expected_shape
46+
y = reshape(y, f.expected_shape)
47+
end
48+
return gradient!(y, f.f, x)
49+
end
50+
51+
function prox_naive(f::ReshapeInput, x, gamma)
52+
if size(x) != f.expected_shape
53+
x = reshape(x, f.expected_shape)
54+
end
55+
return prox_naive(f.f, x, gamma)
56+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ end
156156
include("test_separableSum.jl")
157157
include("test_slicedSeparableSum.jl")
158158
include("test_sum.jl")
159+
include("test_reshapeInput.jl")
159160
end
160161

161162
@testset "Equivalences" begin

test/test_reshapeInput.jl

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
using LinearAlgebra
2+
using ProximalOperators
3+
using Test
4+
5+
# Define a simple test function that we can use with ReshapeInput
6+
struct SimpleTestFunc end
7+
8+
# Make it callable - returns squared norm
9+
# This function requires 2D input (matrix), and will error for vectors or higher-dimensional arrays
10+
function (::SimpleTestFunc)(x)
11+
if ndims(x) != 2
12+
throw(DimensionMismatch("SimpleTestFunc requires 2D input (matrix), got $(ndims(x))D array"))
13+
end
14+
return sum(abs2, x)
15+
end
16+
17+
# Define a prox! method for SimpleTestFunc
18+
function ProximalOperators.prox!(y, f::SimpleTestFunc, x, gamma)
19+
if ndims(x) != 2
20+
throw(DimensionMismatch("SimpleTestFunc requires 2D input (matrix), got $(ndims(x))D array"))
21+
end
22+
# Simple soft-thresholding prox: prox(||·||^2) = x / (1 + 2*gamma)
23+
y .= x ./ (1 + 2 * gamma)
24+
return sum(abs2, y)
25+
end
26+
27+
# Define a gradient! method for SimpleTestFunc
28+
function ProximalOperators.gradient!(y, f::SimpleTestFunc, x)
29+
if ndims(x) != 2
30+
throw(DimensionMismatch("SimpleTestFunc requires 2D input (matrix), got $(ndims(x))D array"))
31+
end
32+
# Gradient of squared norm: 2*x
33+
y .= 2 .* x
34+
return sum(abs2, y)
35+
end
36+
37+
38+
39+
@testset "ReshapeInput Tests" begin
40+
41+
@testset "Basic Function Call with Correct Shape" begin
42+
# Create a ReshapeInput wrapper
43+
f = ReshapeInput(SimpleTestFunc(), (2, 2))
44+
45+
# Create input with correct shape
46+
x = reshape(1.0:4.0, 2, 2)
47+
result = f(x)
48+
49+
# Should return squared norm of all elements: 1 + 4 + 9 + 16 = 30
50+
expected = sum(abs2, x)
51+
@test result expected
52+
end
53+
54+
@testset "Function Call with Shape Reshaping" begin
55+
# Create a ReshapeInput wrapper expecting (2, 2)
56+
f = ReshapeInput(SimpleTestFunc(), (2, 2))
57+
58+
# Create input as a vector (different shape)
59+
x = vec(reshape(1.0:4.0, 2, 2)) # [1, 2, 3, 4]
60+
result = f(x)
61+
62+
# Should reshape to (2, 2) internally and compute squared norm
63+
x_reshaped = reshape(x, 2, 2)
64+
expected = sum(abs2, x_reshaped)
65+
@test result expected
66+
end
67+
68+
@testset "Function Call with Multiple Reshaping" begin
69+
# Create a ReshapeInput wrapper expecting (3, 4)
70+
f = ReshapeInput(SimpleTestFunc(), (3, 4))
71+
72+
# Create input as a vector of 12 elements
73+
x = collect(1.0:12.0)
74+
result = f(x)
75+
76+
# Should reshape to (3, 4) and compute squared norm
77+
x_reshaped = reshape(x, 3, 4)
78+
expected = sum(abs2, x_reshaped)
79+
@test result expected
80+
end
81+
82+
@testset "prox! with Correct Shape" begin
83+
# Create a ReshapeInput wrapper
84+
f = ReshapeInput(SimpleTestFunc(), (2, 2))
85+
86+
# Create input and output with correct shape
87+
x = reshape(1.0:4.0, 2, 2)
88+
y = zeros(2, 2)
89+
gamma = 0.5
90+
91+
result = prox!(y, f, x, gamma)
92+
93+
# prox of squared norm with soft-thresholding
94+
expected_y = x ./ (1 + 2 * gamma)
95+
expected_result = sum(abs2, expected_y)
96+
97+
@test y expected_y
98+
@test result expected_result
99+
end
100+
101+
@testset "prox! with Shape Reshaping" begin
102+
# Create a ReshapeInput wrapper expecting (2, 2)
103+
f = ReshapeInput(SimpleTestFunc(), (2, 2))
104+
105+
# Create input and output as vectors
106+
x = collect(1.0:4.0)
107+
y = zeros(4)
108+
gamma = 0.5
109+
110+
result = prox!(y, f, x, gamma)
111+
112+
# Should internally reshape to (2, 2)
113+
x_reshaped = reshape(x, 2, 2)
114+
expected_y_reshaped = x_reshaped ./ (1 + 2 * gamma)
115+
expected_result = sum(abs2, expected_y_reshaped)
116+
117+
# y should contain the reshaped result flattened back
118+
y_expected = vec(expected_y_reshaped)
119+
@test y y_expected
120+
@test result expected_result
121+
end
122+
123+
@testset "gradient! with Correct Shape" begin
124+
# Create a ReshapeInput wrapper
125+
f = ReshapeInput(SimpleTestFunc(), (2, 2))
126+
127+
# Create input and output with correct shape
128+
x = reshape(1.0:4.0, 2, 2)
129+
y = zeros(2, 2)
130+
131+
result = gradient!(y, f, x)
132+
133+
# Gradient of squared norm: 2*x
134+
expected_y = 2 .* x
135+
expected_result = sum(abs2, expected_y)
136+
137+
@test y expected_y
138+
@test result expected_result
139+
end
140+
141+
@testset "gradient! with Shape Reshaping" begin
142+
# Create a ReshapeInput wrapper expecting (2, 2)
143+
f = ReshapeInput(SimpleTestFunc(), (2, 2))
144+
145+
# Create input and output as vectors
146+
x = collect(1.0:4.0)
147+
y = zeros(4)
148+
149+
result = gradient!(y, f, x)
150+
151+
# Should internally reshape to (2, 2)
152+
x_reshaped = reshape(x, 2, 2)
153+
expected_y_reshaped = 2 .* x_reshaped
154+
expected_result = sum(abs2, expected_y_reshaped)
155+
156+
# y should contain the reshaped result flattened back
157+
y_expected = vec(expected_y_reshaped)
158+
@test y y_expected
159+
@test result expected_result
160+
end
161+
162+
end

0 commit comments

Comments
 (0)