Skip to content

Commit 5f9c86a

Browse files
authored
Add support for sin/cos and abs (#16)
* Support for sin/cos * Update kernel_write.jl * Add `abs` * Fix add-to sparsity - Correct the case where information is added in-place to use the sparsity information of the component being added, rather than the sparsity information for the final sum - Simplify the case where a constant is added in-place
1 parent 4dfbc2b commit 5f9c86a

File tree

5 files changed

+2630
-917
lines changed

5 files changed

+2630
-917
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "SourceCodeMcCormick"
22
uuid = "a7283dc5-4ecf-47fb-a95b-1412723fc960"
33
authors = ["Robert Gottlieb <Robert.x.gottlieb@uconn.edu>"]
4-
version = "0.5.0"
4+
version = "0.5.1"
55

66
[deps]
77
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"

src/kernel_writer/kernel_write.jl

Lines changed: 101 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ kgen(num::Num, raw_outputs::Vector{Symbol}; constants::Vector{Num}=Num[], overwr
1010
kgen(num::Num, gradlist::Vector{Num}, raw_outputs::Vector{Symbol}; constants::Vector{Num}=Num[], overwrite::Bool=false, splitting::Symbol=:default, affine_quadratic::Bool=true) = kgen(num, gradlist, raw_outputs, constants, overwrite, splitting, affine_quadratic)
1111
function kgen(num::Num, gradlist::Vector{Num}, raw_outputs::Vector{Symbol}, constants::Vector{Num}, overwrite::Bool, splitting::Symbol, affine_quadratic::Bool)
1212
# Create a hash of the expression and check if the function already exists
13-
expr_hash = string(hash(num+sum(gradlist)), base=62)
13+
expr_hash = string(hash(string(num)*string(gradlist)), base=62)
1414
if (overwrite==false) && (isfile(joinpath(@__DIR__, "storage", "f_"*expr_hash*".jl")))
1515
try func_name = eval(Meta.parse("f_"*expr_hash))
1616
return func_name
@@ -102,9 +102,6 @@ function kgen(num::Num, gradlist::Vector{Num}, raw_outputs::Vector{Symbol}, cons
102102
elseif splitting==:high # Formerly default
103103
split_point = 1500
104104
max_size = 2000
105-
# elseif splitting==:high # More splitting
106-
# split_point = 1000
107-
# max_size = 1200
108105
elseif splitting==:max # Extremely small
109106
split_point = 500
110107
max_size = 750
@@ -116,7 +113,7 @@ function kgen(num::Num, gradlist::Vector{Num}, raw_outputs::Vector{Symbol}, cons
116113
sparsity = detect_sparsity(factored, gradlist)
117114

118115
# Decide if the kernel needs to be split
119-
if (n_vars[end] < 31) && (n_lines[end] <= max_size)
116+
if (n_vars[end] < 31) && ((n_lines[end] <= max_size) || (findfirst(x -> x > split_point, n_lines)==length(n_lines)))
120117
# Complexity is fairly low; only a single kernel needed
121118
create_kernel!(expr_hash, 1, num, get_name.(gradlist), func_outputs, constants, factored, sparsity)
122119
push!(kernel_nums, 1)
@@ -130,7 +127,7 @@ function kgen(num::Num, gradlist::Vector{Num}, raw_outputs::Vector{Symbol}, cons
130127
while !complete
131128
# Determine which line to break at
132129
line_ID = findfirst(x -> x > split_point, n_lines)
133-
vars_ID = findfirst(x -> x == 31, n_vars)
130+
vars_ID = findfirst(x -> (x == 30) || (x == 31), n_vars)
134131
if isnothing(vars_ID)
135132
new_ID = line_ID
136133
elseif isnothing(line_ID)
@@ -188,7 +185,7 @@ function kgen(num::Num, gradlist::Vector{Num}, raw_outputs::Vector{Symbol}, cons
188185
n_lines = complexity(factored)
189186
n_vars = var_counts(factored)
190187

191-
# If the total number of lines (not including the final line) is below 2000
188+
# If the total number of lines (not including the final line) is below the max size
192189
# and the number of variables is below 32, we can make the final kernel and be done
193190
if (n_vars[end] < 32) && (all(n_lines[1:end-1] .<= max_size))
194191
create_kernel!(expr_hash, kernel_count, extract(factored), get_name.(gradlist), func_outputs, constants, factored, sparsity)
@@ -328,7 +325,12 @@ function kgen_affine_quadratic(expr_hash::String, num::Num, gradlist::Vector{Num
328325
file = open(joinpath(@__DIR__, "storage", "f_"*expr_hash*".jl"), "a")
329326

330327
# Put in the preamble.
331-
write(file, preamble_string(expr_hash, ["OUT"; string.(vars)], 1, 1, length(gradlist)))
328+
if isempty(vars)
329+
write(file, preamble_string(expr_hash, ["OUT";], 1, 1, length(gradlist)))
330+
else
331+
write(file, preamble_string(expr_hash, ["OUT"; string.(vars)], 1, 1, length(gradlist)))
332+
end
333+
332334

333335
# Depending on the format of the expression, compose the kernel differently
334336
if typeof(expr) <: Real
@@ -360,9 +362,9 @@ function kgen_affine_quadratic(expr_hash::String, num::Num, gradlist::Vector{Num
360362
end
361363
end
362364
else # There must be two elements in the dictionary
363-
binary_vars = string.(get_name.(keys(key.dict)))
365+
binary_vars = string.(get_name.(keys(expr.dict)))
364366
binary_vars = binary_vars[sort_vars(binary_vars)]
365-
write(file, SCMC_quadaff_binary(vars..., expr.coeff, varlist))
367+
write(file, SCMC_quadaff_binary(binary_vars..., expr.coeff, varlist))
366368
end
367369

368370
elseif exprtype(expr)==ADD
@@ -394,7 +396,13 @@ function kgen_affine_quadratic(expr_hash::String, num::Num, gradlist::Vector{Num
394396
# EAGO already does this and bypasses the need to calculate relaxations.
395397
# But, for compatibility with McCormick-style relaxations in ParBB,
396398
# it's easier to simply calculate what ParBB is expecting.)
397-
write(file, postamble_quadaff(string.(vars), varlist))
399+
if isempty(varlist)
400+
write(file, postamble_quadaff(String[], String[]))
401+
elseif isempty(vars)
402+
write(file, postamble_quadaff(String[], varlist))
403+
else
404+
write(file, postamble_quadaff(string.(vars), varlist))
405+
end
398406
close(file)
399407

400408
# Include this kernel so SCMC knows what it is
@@ -403,7 +411,13 @@ function kgen_affine_quadratic(expr_hash::String, num::Num, gradlist::Vector{Num
403411
# Add onto the file the "main" CPU function that calls the kernel
404412
blocks = Int32(CUDA.attribute(CUDA.device(), CUDA.DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT))
405413
file = open(joinpath(@__DIR__, "storage", "f_"*expr_hash*".jl"), "a")
406-
write(file, outro(expr_hash, [1], [string.(vars)], ["OUT"], blocks, get_name.(gradlist)))
414+
if isempty(gradlist)
415+
write(file, outro(expr_hash, [1], [String[]], ["OUT"], blocks, Symbol[]))
416+
elseif isempty(vars)
417+
write(file, outro(expr_hash, [1], [String[]], ["OUT"], blocks, get_name.(gradlist)))
418+
else
419+
write(file, outro(expr_hash, [1], [string.(vars)], ["OUT"], blocks, get_name.(gradlist)))
420+
end
407421
close(file)
408422

409423
# Include the file again to get the final kernel
@@ -684,7 +698,21 @@ function create_kernel!(expr_hash::String, kernel_ID::Int, num::BasicSymbolic{Re
684698

685699
# Now we can pass the equation's RHS and the inputs to call the correct
686700
# writer function
687-
write_operation(file, factorized[factorized_ID].rhs, inputs, string.(gradlist), sparsity[i])
701+
if length(inputs)>2 && inputs[1]==inputs[2]
702+
# Special case. We're adding inputs[3] to inputs[2], so we only want
703+
# to pass the sparsity information of inputs[3] (rather than the
704+
# sparsity information of inputs[2] and inputs[3])
705+
corrected_i = findfirst(x->x==inputs[3], varorder)
706+
write_operation(file, factorized[factorized_ID].rhs, inputs, string.(gradlist), sparsity[corrected_i])
707+
elseif length(inputs)>2 && inputs[1]==inputs[3]
708+
# Special case. We're adding inputs[2] to inputs[3], so we only want
709+
# to pass the sparsity information of inputs[2] (rather than the
710+
# sparsity information of inputs[2] and inputs[3])
711+
corrected_i = findfirst(x->x==inputs[2], varorder)
712+
write_operation(file, factorized[factorized_ID].rhs, inputs, string.(gradlist), sparsity[corrected_i])
713+
else
714+
write_operation(file, factorized[factorized_ID].rhs, inputs, string.(gradlist), sparsity[i])
715+
end
688716

689717
# Now that we're done with this variable, eliminate this variable
690718
# from the lists of temporary variables' requirements
@@ -731,6 +759,7 @@ end
731759
# 7) log(inv(x1)) = -log(x1) [EAGO paper]
732760
# 8) CONST1*CONST2*x1 = (CONST1*CONST2)*x1
733761
# 9) 1 / (1 + exp(-x)) = Sigmoid(x)
762+
# 10) sin(x) = cos(x - pi/2)
734763
#
735764
# Forms that aren't relevant yet:
736765
# 1) (a^x1)^b = (a^b)^x1 [EAGO paper] (Can't do powers besides integers)
@@ -826,7 +855,7 @@ function perform_substitutions(old_factored::Vector{Equation})
826855
end
827856
end
828857
# Create a factorization of this new expr
829-
new_factorization = factor(new_expr)
858+
new_factorization = factor(new_expr, split_div=true)
830859
# Scan through the new factorization to see if we can merge elements
831860
# with the original factored list
832861
done = false
@@ -1191,7 +1220,7 @@ function perform_substitutions(old_factored::Vector{Equation})
11911220
new_expr *= arg
11921221
end
11931222
# Create a factorization of this new expr
1194-
new_factorization = factor(new_expr)
1223+
new_factorization = factor(new_expr, split_div=true)
11951224

11961225

11971226
# Scan through the new factorization to see if we can merge elements
@@ -1315,6 +1344,38 @@ function perform_substitutions(old_factored::Vector{Equation})
13151344
end
13161345
end
13171346
end
1347+
1348+
# 10) sin(x) = cos(x - pi/2)
1349+
if exprtype(factored[index0].rhs)==TERM
1350+
if factored[index0].rhs.f==sin
1351+
# We found sin(arg). Check if (arg - pi/2) exists,
1352+
# and if so, also check if cos(arg - pi/2) exists.
1353+
scan_flag = true
1354+
index1 = findfirst(x -> isequal(x.rhs, arguments(factored[index0].rhs)[] - pi/2), factored)
1355+
if !isnothing(index1)
1356+
index2 = findfirst(x -> isequal(x.rhs, cos(factored[index1].lhs)), factored)
1357+
if !isnothing(index2)
1358+
# cos(arg - pi/2) exists already (index2). Remove all reference to index0 and replace with index2
1359+
for i in eachindex(factored)
1360+
@eval $factored[$i] = $factored[$i].lhs ~ substitute($factored[$i].rhs, Dict($factored[$index0].lhs => $factored[$index2].lhs))
1361+
end
1362+
deleteat!(factored, index0)
1363+
else
1364+
# arg - pi/2 exists already (index1), but not cos(arg - pi/2). Change
1365+
# index0 to be cos of index1.lhs instead of sin of arg
1366+
@eval $factored[$index0] = $factored[$index0].lhs ~ cos($factored[$index1].lhs)
1367+
end
1368+
else
1369+
# (arg - pi/2) doesn't exist, so we need to create it
1370+
newsym = gensym(:aux)
1371+
newsym = Symbol(string(newsym)[3:5] * string(newsym)[7:end])
1372+
newvar = genvar(newsym)
1373+
insert!(factored, index0, Equation(Symbolics.value(newvar), arguments(factored[index0].rhs)[] - pi/2))
1374+
@eval $factored[$index0+1] = $factored[$index0+1].lhs ~ cos($newvar)
1375+
end
1376+
break
1377+
end
1378+
end
13181379
end
13191380
end
13201381

@@ -1511,6 +1572,10 @@ function write_operation(file::IOStream, RHS::BasicSymbolic{Real}, inputs::Vecto
15111572
write(file, SCMC_sigmoid_kernel(inputs..., gradlist, sparsity))
15121573
elseif RHS.f==sqrt
15131574
write(file, SCMC_float_power_kernel(inputs..., 0.5, gradlist, sparsity))
1575+
elseif RHS.f==cos
1576+
write(file, SCMC_cos_kernel(inputs..., gradlist, sparsity))
1577+
elseif RHS.f==abs
1578+
write(file, SCMC_abs_kernel(inputs..., gradlist, sparsity))
15141579
else
15151580
close(file)
15161581
error("Some function was used that we can't handle yet ($RHS)")
@@ -1845,6 +1910,10 @@ function _complexity(complexity::Vector{Int}, factorized::Vector{Equation}, star
18451910
else
18461911
total_lines += 190
18471912
end
1913+
new_ID = findfirst(x -> isequal(x.lhs, RHS.base), factorized)
1914+
if !isnothing(new_ID)
1915+
total_lines += _complexity(complexity, factorized, new_ID)
1916+
end
18481917
elseif exprtype(RHS) == TERM
18491918
if RHS.f==exp
18501919
total_lines += 212 # Ranges from 212--310
@@ -1866,8 +1935,24 @@ function _complexity(complexity::Vector{Int}, factorized::Vector{Equation}, star
18661935
end
18671936
elseif RHS.f==sqrt
18681937
total_lines += 190
1938+
new_ID = findfirst(x -> isequal(x.lhs, RHS.arguments[1]), factorized)
1939+
if !isnothing(new_ID)
1940+
total_lines += _complexity(complexity, factorized, new_ID)
1941+
end
1942+
elseif RHS.f==cos || RHS.f==sin
1943+
total_lines += 300
1944+
new_ID = findfirst(x -> isequal(x.lhs, RHS.arguments[1]), factorized)
1945+
if !isnothing(new_ID)
1946+
total_lines += _complexity(complexity, factorized, new_ID)
1947+
end
1948+
elseif RHS.f==abs
1949+
total_lines += 280
1950+
new_ID = findfirst(x -> isequal(x.lhs, RHS.arguments[1]), factorized)
1951+
if !isnothing(new_ID)
1952+
total_lines += _complexity(complexity, factorized, new_ID)
1953+
end
18691954
else
1870-
error("Unknown function")
1955+
error("Some function was used that we can't handle yet ($RHS)")
18711956
end
18721957
elseif exprtype(RHS) == SYM
18731958
nothing

0 commit comments

Comments
 (0)