Skip to content

Commit 1f22741

Browse files
committed
Update for Julia v1.12
- Explicitly use `Float64(pi)` in some places in written kernels - Switch to using MultiFloats in `quadrant` to avoid use of `BigFloat` on GPU - Improve sparsity detection for more complex expressions - Move `@register_symbolic` to global scope for `SCMC_sigmoid`
1 parent 5f9c86a commit 1f22741

File tree

5 files changed

+61
-90
lines changed

5 files changed

+61
-90
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
name = "SourceCodeMcCormick"
22
uuid = "a7283dc5-4ecf-47fb-a95b-1412723fc960"
33
authors = ["Robert Gottlieb <Robert.x.gottlieb@uconn.edu>"]
4-
version = "0.5.1"
4+
version = "0.5.2"
55

66
[deps]
77
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
88
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
99
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
1010
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
1111
IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173"
12+
MultiFloats = "bdf0d083-296b-4888-a5b6-7498122e68a5"
1213
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
1314
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1415
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
@@ -20,6 +21,7 @@ CUDA = "5"
2021
DocStringExtensions = "0.8 - 0.9"
2122
Graphs = "1"
2223
IfElse = "0.1.0 - 0.1.1"
24+
MultiFloats = "3.2.3"
2325
PrecompileTools = "~1"
2426
Reexport = "~1"
2527
StaticArrays = "~1"

src/SourceCodeMcCormick.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ using DocStringExtensions
99
using Graphs
1010
using CUDA
1111
using StaticArrays: @MVector
12+
using MultiFloats
1213
import Dates
1314
import SymbolicUtils: BasicSymbolic, exprtype, SYM, TERM, ADD, MUL, POW, DIV
1415

src/kernel_writer/kernel_write.jl

Lines changed: 38 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -109,13 +109,10 @@ function kgen(num::Num, gradlist::Vector{Num}, raw_outputs::Vector{Symbol}, cons
109109
error("Splitting must be one of: {:low, :default, :high, :max}")
110110
end
111111

112-
# Pull out sparsity information in the factorization
113-
sparsity = detect_sparsity(factored, gradlist)
114-
115112
# Decide if the kernel needs to be split
116113
if (n_vars[end] < 31) && ((n_lines[end] <= max_size) || (findfirst(x -> x > split_point, n_lines)==length(n_lines)))
117114
# Complexity is fairly low; only a single kernel needed
118-
create_kernel!(expr_hash, 1, num, get_name.(gradlist), func_outputs, constants, factored, sparsity)
115+
create_kernel!(expr_hash, 1, num, get_name.(gradlist), func_outputs, constants, factored)
119116
push!(kernel_nums, 1)
120117
push!(inputs, string.(indep_vars))
121118
push!(outputs, "OUT")
@@ -169,7 +166,7 @@ function kgen(num::Num, gradlist::Vector{Num}, raw_outputs::Vector{Symbol}, cons
169166
#### Start of alternative to experimental section
170167

171168
# Send the element at `new_ID` to create_kernel!()
172-
create_kernel!(expr_hash, kernel_count, extract(factored, new_ID), get_name.(gradlist), func_outputs, constants, factored, sparsity)
169+
create_kernel!(expr_hash, kernel_count, extract(factored, new_ID), get_name.(gradlist), func_outputs, constants, factored)
173170
push!(kernel_nums, kernel_count)
174171
push!(inputs, string.(get_name.(pull_vars(extract(factored, new_ID)))))
175172
push!(outputs, string(factored[new_ID].lhs))
@@ -188,7 +185,7 @@ function kgen(num::Num, gradlist::Vector{Num}, raw_outputs::Vector{Symbol}, cons
188185
# If the total number of lines (not including the final line) is below the max size
189186
# and the number of variables is below 32, we can make the final kernel and be done
190187
if (n_vars[end] < 32) && (all(n_lines[1:end-1] .<= max_size))
191-
create_kernel!(expr_hash, kernel_count, extract(factored), get_name.(gradlist), func_outputs, constants, factored, sparsity)
188+
create_kernel!(expr_hash, kernel_count, extract(factored), get_name.(gradlist), func_outputs, constants, factored)
192189
push!(kernel_nums, kernel_count)
193190
push!(inputs, string.(get_name.(pull_vars(extract(factored)))))
194191
push!(outputs, "OUT")
@@ -435,8 +432,8 @@ end
435432
# This function takes information about the file name, kernel ID, and
436433
# the expression that a SINGLE kernel is being created for, and creates
437434
# that kernel in the specified file.
438-
create_kernel!(expr_hash::String, kernel_ID::Int, num::Num, gradlist::Vector{Symbol}, func_outputs::Vector{Symbol}, constants::Vector{Num}, orig_factored::Vector{Equation}, orig_sparsity::Vector{Vector{Int}}) = create_kernel!(expr_hash, kernel_ID, num.val, gradlist, func_outputs, constants, orig_factored, orig_sparsity)
439-
function create_kernel!(expr_hash::String, kernel_ID::Int, num::BasicSymbolic{Real}, gradlist::Vector{Symbol}, func_outputs::Vector{Symbol}, constants::Vector{Num}, orig_factored::Vector{Equation}, orig_sparsity::Vector{Vector{Int}})
435+
create_kernel!(expr_hash::String, kernel_ID::Int, num::Num, gradlist::Vector{Symbol}, func_outputs::Vector{Symbol}, constants::Vector{Num}, orig_factored::Vector{Equation}) = create_kernel!(expr_hash, kernel_ID, num.val, gradlist, func_outputs, constants, orig_factored)
436+
function create_kernel!(expr_hash::String, kernel_ID::Int, num::BasicSymbolic{Real}, gradlist::Vector{Symbol}, func_outputs::Vector{Symbol}, constants::Vector{Num}, orig_factored::Vector{Equation})
440437
# This function will create a kernel for `num`, with the name:
441438
# "f_" * expr_hash * "_$n". This name will be pushed to `kernels`,
442439
# and a vector of the required inputs variables will be pushed to
@@ -466,41 +463,7 @@ function create_kernel!(expr_hash::String, kernel_ID::Int, num::BasicSymbolic{Re
466463
# normally, unless it's a temporary variable, in which case we have to refer to the original
467464
# factorization and sparsity.
468465
string_gradlist = string.(gradlist)
469-
sparsity = Vector{Vector{Int}}(undef, length(varorder))
470-
for i in eachindex(varorder)
471-
if varorder[i] in string_gradlist
472-
# Mark sparsity if the variable is already in gradlist
473-
sparsity[i] = [findfirst(==(varorder[i]), string_gradlist)]
474-
else
475-
# Find out what index we're on
476-
idx = findfirst(x -> isequal(string(x.lhs), varorder[i]), factorized)
477-
478-
if isnothing(idx)
479-
sparsity[i] = orig_sparsity[findfirst(x -> isequal(string(x.lhs), varorder[i]), orig_factored)]
480-
else
481-
# Extract all the variables for this index
482-
vars = pull_vars(extract(factorized, idx))
483-
484-
# For each variable in the expanded expression, add in sparsity information
485-
curr_sparsity = Int[]
486-
for var in vars
487-
ID = findfirst(==(string(get_name(var))), string_gradlist)
488-
if isnothing(ID)
489-
# If we didn't find the variable, we need to scan the original factorization,
490-
# and then pull sparsity info from the original sparsity list
491-
ID = findfirst(x -> isequal(string(x.lhs), string(var)), orig_factored)
492-
push!(curr_sparsity, orig_sparsity[ID]...)
493-
else
494-
# If we do find the variable, we can add this variable directly into the sparsity
495-
push!(curr_sparsity, ID)
496-
end
497-
end
498-
499-
# Add a sorted, unique list to the sparsity tracker
500-
sparsity[i] = sort(unique(curr_sparsity))
501-
end
502-
end
503-
end
466+
sparsity = get_sparsity(varorder, string_gradlist, factorized)
504467

505468
# Check if we need temporary variables at all. We don't need
506469
# temporary variables if we only have addition, or if we have
@@ -536,7 +499,7 @@ function create_kernel!(expr_hash::String, kernel_ID::Int, num::BasicSymbolic{Re
536499
maxtemp = 0
537500
# if need_temps #(Skip for now)
538501
for i in eachindex(varorder) # Loop through every participating variable
539-
if (varorder[i] in string.(get_name.(vars)))
502+
if (varorder[i] in string.(vars))
540503
# Skip the variable if it's an input
541504
continue
542505
end
@@ -608,15 +571,15 @@ function create_kernel!(expr_hash::String, kernel_ID::Int, num::BasicSymbolic{Re
608571
file = open(joinpath(@__DIR__, "storage", "f_"*expr_hash*".jl"), "a")
609572

610573
# Put in the preamble.
611-
write(file, preamble_string(expr_hash, ["OUT"; string.(get_name.(vars))], kernel_ID, maxtemp, length(gradlist)))
574+
write(file, preamble_string(expr_hash, ["OUT"; string.(vars)], kernel_ID, maxtemp, length(gradlist)))
612575

613576
# Loop through the topological list to add calculations in order
614577
temp_endlist = []
615578
outvar = ""
616579
name_tracker = copy(varids)
617580
for i in eachindex(varorder) # Order in which variables are calculated
618581
# Skip calculation if the variable is one of the inputs
619-
if (varorder[i] in string.(get_name.(vars)))
582+
if (varorder[i] in string.(vars))
620583
continue
621584
end
622585

@@ -628,20 +591,17 @@ function create_kernel!(expr_hash::String, kernel_ID::Int, num::BasicSymbolic{Re
628591
factorized_ID = findfirst(x -> isequal(string(x.lhs), varorder[i]), factorized)
629592
participants = get_name.(pull_vars(factorized[factorized_ID].rhs))
630593
inputs = []
594+
input_IDs = Int[]
631595
for p in string.(participants)
632596
# Find the corresponding element in varids
633597
varids_ID = findfirst(x -> isequal(x, p), varids)
634-
# Push the name_tracker name to the input list
598+
# Push the name_tracker name to the input list and the ID to the list of input IDs
635599
push!(inputs, name_tracker[varids_ID])
636-
end
637600

638-
# [Deprecating; I'll use temporary variables the whole way and then set
639-
# the output variable at the end for final copying]
640-
# # If this is the final variable, it'll be called "OUT". No need
641-
# # for temp variables
642-
# if i==length(varorder)
643-
# name_tracker[ID] = "OUT"
644-
# else
601+
# Find the corresponding element in varorder
602+
varorder_ID = findfirst(x -> isequal(x, p), varorder)
603+
push!(input_IDs, varorder_ID)
604+
end
645605

646606
# Determine which tempID to use/override. temp_endlist keeps
647607
# track of where variables will be used in the future (stored
@@ -695,21 +655,20 @@ function create_kernel!(expr_hash::String, kernel_ID::Int, num::BasicSymbolic{Re
695655
# Now we can append this temporary variable to the list of inputs
696656
# for the correct operation
697657
inputs = [name_tracker[ID]; inputs]
658+
input_IDs = [i; input_IDs]
698659

699660
# Now we can pass the equation's RHS and the inputs to call the correct
700661
# writer function
701662
if length(inputs)>2 && inputs[1]==inputs[2]
702663
# Special case. We're adding inputs[3] to inputs[2], so we only want
703664
# to pass the sparsity information of inputs[3] (rather than the
704665
# sparsity information of inputs[2] and inputs[3])
705-
corrected_i = findfirst(x->x==inputs[3], varorder)
706-
write_operation(file, factorized[factorized_ID].rhs, inputs, string.(gradlist), sparsity[corrected_i])
666+
write_operation(file, factorized[factorized_ID].rhs, inputs, string.(gradlist), sparsity[input_IDs[3]])
707667
elseif length(inputs)>2 && inputs[1]==inputs[3]
708668
# Special case. We're adding inputs[2] to inputs[3], so we only want
709669
# to pass the sparsity information of inputs[2] (rather than the
710670
# sparsity information of inputs[2] and inputs[3])
711-
corrected_i = findfirst(x->x==inputs[2], varorder)
712-
write_operation(file, factorized[factorized_ID].rhs, inputs, string.(gradlist), sparsity[corrected_i])
671+
write_operation(file, factorized[factorized_ID].rhs, inputs, string.(gradlist), sparsity[input_IDs[2]])
713672
else
714673
write_operation(file, factorized[factorized_ID].rhs, inputs, string.(gradlist), sparsity[i])
715674
end
@@ -765,10 +724,6 @@ end
765724
# 1) (a^x1)^b = (a^b)^x1 [EAGO paper] (Can't do powers besides integers)
766725
function perform_substitutions(old_factored::Vector{Equation})
767726
factored = deepcopy(old_factored)
768-
769-
# Register any terms we want to substitute
770-
@eval @register_symbolic SCMC_sigmoid(x)
771-
772727
scan_flag = true
773728
while scan_flag
774729
scan_flag = false
@@ -1468,23 +1423,28 @@ function pull_mult(factors, index; args=[])
14681423
end
14691424

14701425

1471-
# A helper function to calculate the sparsity of a factorization, given
1472-
# the full gradlist
1473-
function detect_sparsity(factored, gradlist)
1474-
# Idea is to check every element of "factored" and pull out a list of indices
1475-
# in gradlist that the variables depend on.
1476-
sparsity = Vector{Int}[]
1477-
string_gradlist = string.(gradlist)
1478-
1479-
for i in eachindex(factored)
1480-
vars = string.(pull_vars(extract(factored, i)))
1481-
indices = zeros(Int, length(vars))
1482-
for j in eachindex(indices)
1483-
indices[j] = findfirst(==(string(vars[j])), string_gradlist)
1426+
# A helper function to calculate the sparsity of a factorization
1427+
function get_sparsity(varorder, string_gradlist, factored)
1428+
sparsity = [Int[] for _ in 1:length(varorder)]
1429+
function calc_sp(v)
1430+
if !isempty(sparsity[v])
1431+
return sparsity[v]
1432+
end
1433+
if varorder[v] in string_gradlist
1434+
sparsity[v] = [findfirst(==(varorder[v]), string_gradlist)]
1435+
return sparsity[v]
1436+
end
1437+
idx = findfirst(x -> isequal(string(x.lhs), varorder[v]), factored)
1438+
RHS_vars = pull_vars(factored[idx].rhs)
1439+
for var in RHS_vars
1440+
var_idx = findfirst(==(string(get_name(var))), varorder)
1441+
sparsity[v] = [sparsity[v]..., calc_sp(var_idx)...]
14841442
end
1485-
push!(sparsity, indices)
1443+
return sort(unique(sparsity[v]))
1444+
end
1445+
for v in 1:length(varorder)
1446+
sparsity[v] = calc_sp(v)
14861447
end
1487-
14881448
return sparsity
14891449
end
14901450

src/kernel_writer/math_kernels.jl

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1365,6 +1365,7 @@ end
13651365

13661366
# Sigmoid function
13671367
# max threads: 640
1368+
@register_symbolic SCMC_sigmoid(x) # Register as symbolic so that we can use it later
13681369
function SCMC_sigmoid_kernel(OUT::CuDeviceMatrix, x::CuDeviceMatrix)
13691370
idx = threadIdx().x + (blockIdx().x - Int32(1)) * blockDim().x
13701371
stride = blockDim().x * gridDim().x
@@ -4670,21 +4671,21 @@ function SCMC_cos_kernel(OUT::CuDeviceMatrix, x::CuDeviceMatrix)
46704671
kL = Base.ceil(-0.5 - x[idx,3]/(2.0*pi))
46714672
xL1 = x[idx,3] + 2.0*pi*kL
46724673
xU1 = x[idx,4] + 2.0*pi*kL
4673-
if (xL1 < -pi) || (xL1 > pi)
4674+
if (xL1 < -pi) || (xL1 > Float64(pi))
46744675
eps_min = NaN
46754676
eps_max = NaN
46764677
elseif xL1 <= 0.0
46774678
if xU1 <= 0.0
46784679
eps_min = x[idx,3]
46794680
eps_max = x[idx,4]
4680-
elseif xU1 >= pi
4681+
elseif xU1 >= Float64(pi)
46814682
eps_min = pi - 2.0*pi*kL
46824683
eps_max = -2.0*pi*kL
46834684
else
46844685
eps_min = (cos(xL1) <= cos(xU1)) ? x[idx,3] : x[idx,4]
46854686
eps_max = -2.0*pi*kL
46864687
end
4687-
elseif xU1 <= pi
4688+
elseif xU1 <= Float64(pi)
46884689
eps_min = x[idx,4]
46894690
eps_max = x[idx,3]
46904691
elseif xU1 >= 2.0*pi
@@ -5449,9 +5450,16 @@ function cos_newton_or_golden_section(x0::Float64, xL::Float64, xU::Float64, env
54495450
return xk
54505451
end
54515452

5452-
# Directly from IntervalArithmetic.jl
5453+
# Similar to IntervalArithmetic.jl, but not using `rem2pi`
54535454
function quadrant(x::Float64)
5454-
x_mod2pi = rem2pi(x, RoundNearest)
5455+
bigx = MultiFloats.Float64x2(x)
5456+
bigpi = MultiFloats._MF{Float64,2}((3.141592653589793, 1.2246467991473532e-16))
5457+
rem = Float64(floor(bigx/bigpi))
5458+
if iseven(rem)
5459+
x_mod2pi = Float64(bigx - rem*bigpi)
5460+
else
5461+
x_mod2pi = Float64(bigx - (rem+1)*bigpi)
5462+
end
54555463

54565464
x_mod2pi < -(pi/2.0) && return (Int32(2), x_mod2pi)
54575465
x_mod2pi < 0 && return (Int32(3), x_mod2pi)

src/kernel_writer/string_math_kernels.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10291,21 +10291,21 @@ function SCMC_cos_kernel(OUT::String, v1::String, varlist::Vector{String}, spars
1029110291
write(buffer, " kL = Base.ceil(-0.5 - $v1_lo/(2.0*pi))\n")
1029210292
write(buffer, " xL1 = $v1_lo + 2.0*pi*kL\n")
1029310293
write(buffer, " xU1 = $v1_hi + 2.0*pi*kL\n")
10294-
write(buffer, " if (xL1 < -pi) || (xL1 > pi)\n")
10294+
write(buffer, " if (xL1 < -pi) || (xL1 > Float64(pi))\n")
1029510295
write(buffer, " eps_min = NaN\n")
1029610296
write(buffer, " eps_max = NaN\n")
1029710297
write(buffer, " elseif xL1 <= 0.0\n")
1029810298
write(buffer, " if xU1 <= 0.0\n")
1029910299
write(buffer, " eps_min = $v1_lo\n")
1030010300
write(buffer, " eps_max = $v1_hi\n")
10301-
write(buffer, " elseif xU1 >= pi\n")
10301+
write(buffer, " elseif xU1 >= Float64(pi)\n")
1030210302
write(buffer, " eps_min = pi - 2.0*pi*kL\n")
1030310303
write(buffer, " eps_max = -2.0*pi*kL\n")
1030410304
write(buffer, " else\n")
1030510305
write(buffer, " eps_min = (cos(xL1) <= cos(xU1)) ? $v1_lo : $v1_hi\n")
1030610306
write(buffer, " eps_max = -2.0*pi*kL\n")
1030710307
write(buffer, " end\n")
10308-
write(buffer, " elseif xU1 <= pi\n")
10308+
write(buffer, " elseif xU1 <= Float64(pi)\n")
1030910309
write(buffer, " eps_min = $v1_hi\n")
1031010310
write(buffer, " eps_max = $v1_lo\n")
1031110311
write(buffer, " elseif xU1 >= 2.0*pi\n")
@@ -10601,21 +10601,21 @@ function SCMC_cos_kernel(OUT::String, v1::String, varlist::Vector{String}, spars
1060110601
write(buffer, " kL = Base.ceil(-0.5 - $v1_lo/(2.0*pi))\n")
1060210602
write(buffer, " xL1 = $v1_lo + 2.0*pi*kL\n")
1060310603
write(buffer, " xU1 = $v1_hi + 2.0*pi*kL\n")
10604-
write(buffer, " if (xL1 < -pi) || (xL1 > pi)\n")
10604+
write(buffer, " if (xL1 < -pi) || (xL1 > Float64(pi))\n")
1060510605
write(buffer, " eps_min = NaN\n")
1060610606
write(buffer, " eps_max = NaN\n")
1060710607
write(buffer, " elseif xL1 <= 0.0\n")
1060810608
write(buffer, " if xU1 <= 0.0\n")
1060910609
write(buffer, " eps_min = $v1_lo\n")
1061010610
write(buffer, " eps_max = $v1_hi\n")
10611-
write(buffer, " elseif xU1 >= pi\n")
10611+
write(buffer, " elseif xU1 >= Float64(pi)\n")
1061210612
write(buffer, " eps_min = pi - 2.0*pi*kL\n")
1061310613
write(buffer, " eps_max = -2.0*pi*kL\n")
1061410614
write(buffer, " else\n")
1061510615
write(buffer, " eps_min = (cos(xL1) <= cos(xU1)) ? $v1_lo : $v1_hi\n")
1061610616
write(buffer, " eps_max = -2.0*pi*kL\n")
1061710617
write(buffer, " end\n")
10618-
write(buffer, " elseif xU1 <= pi\n")
10618+
write(buffer, " elseif xU1 <= Float64(pi)\n")
1061910619
write(buffer, " eps_min = $v1_hi\n")
1062010620
write(buffer, " eps_max = $v1_lo\n")
1062110621
write(buffer, " elseif xU1 >= 2.0*pi\n")

0 commit comments

Comments
 (0)