@@ -10,7 +10,7 @@ kgen(num::Num, raw_outputs::Vector{Symbol}; constants::Vector{Num}=Num[], overwr
1010kgen (num:: Num , gradlist:: Vector{Num} , raw_outputs:: Vector{Symbol} ; constants:: Vector{Num} = Num[], overwrite:: Bool = false , splitting:: Symbol = :default , affine_quadratic:: Bool = true ) = kgen (num, gradlist, raw_outputs, constants, overwrite, splitting, affine_quadratic)
1111function kgen (num:: Num , gradlist:: Vector{Num} , raw_outputs:: Vector{Symbol} , constants:: Vector{Num} , overwrite:: Bool , splitting:: Symbol , affine_quadratic:: Bool )
1212 # Create a hash of the expression and check if the function already exists
13- expr_hash = string (hash (num+ sum (gradlist)), base= 62 )
13+ expr_hash = string (hash (string ( num) * string (gradlist)), base= 62 )
1414 if (overwrite== false ) && (isfile (joinpath (@__DIR__ , " storage" , " f_" * expr_hash* " .jl" )))
1515 try func_name = eval (Meta. parse (" f_" * expr_hash))
1616 return func_name
@@ -102,9 +102,6 @@ function kgen(num::Num, gradlist::Vector{Num}, raw_outputs::Vector{Symbol}, cons
102102 elseif splitting== :high # Formerly default
103103 split_point = 1500
104104 max_size = 2000
105- # elseif splitting==:high # More splitting
106- # split_point = 1000
107- # max_size = 1200
108105 elseif splitting== :max # Extremely small
109106 split_point = 500
110107 max_size = 750
@@ -116,7 +113,7 @@ function kgen(num::Num, gradlist::Vector{Num}, raw_outputs::Vector{Symbol}, cons
116113 sparsity = detect_sparsity (factored, gradlist)
117114
118115 # Decide if the kernel needs to be split
119- if (n_vars[end ] < 31 ) && (n_lines[end ] <= max_size)
116+ if (n_vars[end ] < 31 ) && (( n_lines[end ] <= max_size) || ( findfirst (x -> x > split_point, n_lines) == length (n_lines)) )
120117 # Complexity is fairly low; only a single kernel needed
121118 create_kernel! (expr_hash, 1 , num, get_name .(gradlist), func_outputs, constants, factored, sparsity)
122119 push! (kernel_nums, 1 )
@@ -130,7 +127,7 @@ function kgen(num::Num, gradlist::Vector{Num}, raw_outputs::Vector{Symbol}, cons
130127 while ! complete
131128 # Determine which line to break at
132129 line_ID = findfirst (x -> x > split_point, n_lines)
133- vars_ID = findfirst (x -> x == 31 , n_vars)
130+ vars_ID = findfirst (x -> ( x == 30 ) || (x == 31 ) , n_vars)
134131 if isnothing (vars_ID)
135132 new_ID = line_ID
136133 elseif isnothing (line_ID)
@@ -188,7 +185,7 @@ function kgen(num::Num, gradlist::Vector{Num}, raw_outputs::Vector{Symbol}, cons
188185 n_lines = complexity (factored)
189186 n_vars = var_counts (factored)
190187
191- # If the total number of lines (not including the final line) is below 2000
188+ # If the total number of lines (not including the final line) is below the max size
192189 # and the number of variables is below 32, we can make the final kernel and be done
193190 if (n_vars[end ] < 32 ) && (all (n_lines[1 : end - 1 ] .<= max_size))
194191 create_kernel! (expr_hash, kernel_count, extract (factored), get_name .(gradlist), func_outputs, constants, factored, sparsity)
@@ -328,7 +325,12 @@ function kgen_affine_quadratic(expr_hash::String, num::Num, gradlist::Vector{Num
328325 file = open (joinpath (@__DIR__ , " storage" , " f_" * expr_hash* " .jl" ), " a" )
329326
330327 # Put in the preamble.
331- write (file, preamble_string (expr_hash, [" OUT" ; string .(vars)], 1 , 1 , length (gradlist)))
328+ if isempty (vars)
329+ write (file, preamble_string (expr_hash, [" OUT" ;], 1 , 1 , length (gradlist)))
330+ else
331+ write (file, preamble_string (expr_hash, [" OUT" ; string .(vars)], 1 , 1 , length (gradlist)))
332+ end
333+
332334
333335 # Depending on the format of the expression, compose the kernel differently
334336 if typeof (expr) <: Real
@@ -360,9 +362,9 @@ function kgen_affine_quadratic(expr_hash::String, num::Num, gradlist::Vector{Num
360362 end
361363 end
362364 else # There must be two elements in the dictionary
363- binary_vars = string .(get_name .(keys (key . dict)))
365+ binary_vars = string .(get_name .(keys (expr . dict)))
364366 binary_vars = binary_vars[sort_vars (binary_vars)]
365- write (file, SCMC_quadaff_binary (vars ... , expr. coeff, varlist))
367+ write (file, SCMC_quadaff_binary (binary_vars ... , expr. coeff, varlist))
366368 end
367369
368370 elseif exprtype (expr)== ADD
@@ -394,7 +396,13 @@ function kgen_affine_quadratic(expr_hash::String, num::Num, gradlist::Vector{Num
394396 # EAGO already does this and bypasses the need to calculate relaxations.
395397 # But, for compatibility with McCormick-style relaxations in ParBB,
396398 # it's easier to simply calculate what ParBB is expecting.)
397- write (file, postamble_quadaff (string .(vars), varlist))
399+ if isempty (varlist)
400+ write (file, postamble_quadaff (String[], String[]))
401+ elseif isempty (vars)
402+ write (file, postamble_quadaff (String[], varlist))
403+ else
404+ write (file, postamble_quadaff (string .(vars), varlist))
405+ end
398406 close (file)
399407
400408 # Include this kernel so SCMC knows what it is
@@ -403,7 +411,13 @@ function kgen_affine_quadratic(expr_hash::String, num::Num, gradlist::Vector{Num
403411 # Add onto the file the "main" CPU function that calls the kernel
404412 blocks = Int32 (CUDA. attribute (CUDA. device (), CUDA. DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT))
405413 file = open (joinpath (@__DIR__ , " storage" , " f_" * expr_hash* " .jl" ), " a" )
406- write (file, outro (expr_hash, [1 ], [string .(vars)], [" OUT" ], blocks, get_name .(gradlist)))
414+ if isempty (gradlist)
415+ write (file, outro (expr_hash, [1 ], [String[]], [" OUT" ], blocks, Symbol[]))
416+ elseif isempty (vars)
417+ write (file, outro (expr_hash, [1 ], [String[]], [" OUT" ], blocks, get_name .(gradlist)))
418+ else
419+ write (file, outro (expr_hash, [1 ], [string .(vars)], [" OUT" ], blocks, get_name .(gradlist)))
420+ end
407421 close (file)
408422
409423 # Include the file again to get the final kernel
@@ -684,7 +698,21 @@ function create_kernel!(expr_hash::String, kernel_ID::Int, num::BasicSymbolic{Re
684698
685699 # Now we can pass the equation's RHS and the inputs to call the correct
686700 # writer function
687- write_operation (file, factorized[factorized_ID]. rhs, inputs, string .(gradlist), sparsity[i])
701+ if length (inputs)> 2 && inputs[1 ]== inputs[2 ]
702+ # Special case. We're adding inputs[3] to inputs[2], so we only want
703+ # to pass the sparsity information of inputs[3] (rather than the
704+ # sparsity information of inputs[2] and inputs[3])
705+ corrected_i = findfirst (x-> x== inputs[3 ], varorder)
706+ write_operation (file, factorized[factorized_ID]. rhs, inputs, string .(gradlist), sparsity[corrected_i])
707+ elseif length (inputs)> 2 && inputs[1 ]== inputs[3 ]
708+ # Special case. We're adding inputs[2] to inputs[3], so we only want
709+ # to pass the sparsity information of inputs[2] (rather than the
710+ # sparsity information of inputs[2] and inputs[3])
711+ corrected_i = findfirst (x-> x== inputs[2 ], varorder)
712+ write_operation (file, factorized[factorized_ID]. rhs, inputs, string .(gradlist), sparsity[corrected_i])
713+ else
714+ write_operation (file, factorized[factorized_ID]. rhs, inputs, string .(gradlist), sparsity[i])
715+ end
688716
689717 # Now that we're done with this variable, eliminate this variable
690718 # from the lists of temporary variables' requirements
731759# 7) log(inv(x1)) = -log(x1) [EAGO paper]
732760# 8) CONST1*CONST2*x1 = (CONST1*CONST2)*x1
733761# 9) 1 / (1 + exp(-x)) = Sigmoid(x)
762+ # 10) sin(x) = cos(x - pi/2)
734763#
735764# Forms that aren't relevant yet:
736765# 1) (a^x1)^b = (a^b)^x1 [EAGO paper] (Can't do powers besides integers)
@@ -826,7 +855,7 @@ function perform_substitutions(old_factored::Vector{Equation})
826855 end
827856 end
828857 # Create a factorization of this new expr
829- new_factorization = factor (new_expr)
858+ new_factorization = factor (new_expr, split_div = true )
830859 # Scan through the new factorization to see if we can merge elements
831860 # with the original factored list
832861 done = false
@@ -1191,7 +1220,7 @@ function perform_substitutions(old_factored::Vector{Equation})
11911220 new_expr *= arg
11921221 end
11931222 # Create a factorization of this new expr
1194- new_factorization = factor (new_expr)
1223+ new_factorization = factor (new_expr, split_div = true )
11951224
11961225
11971226 # Scan through the new factorization to see if we can merge elements
@@ -1315,6 +1344,38 @@ function perform_substitutions(old_factored::Vector{Equation})
13151344 end
13161345 end
13171346 end
1347+
1348+ # 10) sin(x) = cos(x - pi/2)
1349+ if exprtype (factored[index0]. rhs)== TERM
1350+ if factored[index0]. rhs. f== sin
1351+ # We found sin(arg). Check if (arg - pi/2) exists,
1352+ # and if so, also check if cos(arg - pi/2) exists.
1353+ scan_flag = true
1354+ index1 = findfirst (x -> isequal (x. rhs, arguments (factored[index0]. rhs)[] - pi / 2 ), factored)
1355+ if ! isnothing (index1)
1356+ index2 = findfirst (x -> isequal (x. rhs, cos (factored[index1]. lhs)), factored)
1357+ if ! isnothing (index2)
1358+ # cos(arg - pi/2) exists already (index2). Remove all reference to index0 and replace with index2
1359+ for i in eachindex (factored)
1360+ @eval $ factored[$ i] = $ factored[$ i]. lhs ~ substitute ($ factored[$ i]. rhs, Dict ($ factored[$ index0]. lhs => $ factored[$ index2]. lhs))
1361+ end
1362+ deleteat! (factored, index0)
1363+ else
1364+ # arg - pi/2 exists already (index1), but not cos(arg - pi/2). Change
1365+ # index0 to be cos of index1.lhs instead of sin of arg
1366+ @eval $ factored[$ index0] = $ factored[$ index0]. lhs ~ cos ($ factored[$ index1]. lhs)
1367+ end
1368+ else
1369+ # (arg - pi/2) doesn't exist, so we need to create it
1370+ newsym = gensym (:aux )
1371+ newsym = Symbol (string (newsym)[3 : 5 ] * string (newsym)[7 : end ])
1372+ newvar = genvar (newsym)
1373+ insert! (factored, index0, Equation (Symbolics. value (newvar), arguments (factored[index0]. rhs)[] - pi / 2 ))
1374+ @eval $ factored[$ index0+ 1 ] = $ factored[$ index0+ 1 ]. lhs ~ cos ($ newvar)
1375+ end
1376+ break
1377+ end
1378+ end
13181379 end
13191380 end
13201381
@@ -1511,6 +1572,10 @@ function write_operation(file::IOStream, RHS::BasicSymbolic{Real}, inputs::Vecto
15111572 write (file, SCMC_sigmoid_kernel (inputs... , gradlist, sparsity))
15121573 elseif RHS. f== sqrt
15131574 write (file, SCMC_float_power_kernel (inputs... , 0.5 , gradlist, sparsity))
1575+ elseif RHS. f== cos
1576+ write (file, SCMC_cos_kernel (inputs... , gradlist, sparsity))
1577+ elseif RHS. f== abs
1578+ write (file, SCMC_abs_kernel (inputs... , gradlist, sparsity))
15141579 else
15151580 close (file)
15161581 error (" Some function was used that we can't handle yet ($RHS )" )
@@ -1845,6 +1910,10 @@ function _complexity(complexity::Vector{Int}, factorized::Vector{Equation}, star
18451910 else
18461911 total_lines += 190
18471912 end
1913+ new_ID = findfirst (x -> isequal (x. lhs, RHS. base), factorized)
1914+ if ! isnothing (new_ID)
1915+ total_lines += _complexity (complexity, factorized, new_ID)
1916+ end
18481917 elseif exprtype (RHS) == TERM
18491918 if RHS. f== exp
18501919 total_lines += 212 # Ranges from 212--310
@@ -1866,8 +1935,24 @@ function _complexity(complexity::Vector{Int}, factorized::Vector{Equation}, star
18661935 end
18671936 elseif RHS. f== sqrt
18681937 total_lines += 190
1938+ new_ID = findfirst (x -> isequal (x. lhs, RHS. arguments[1 ]), factorized)
1939+ if ! isnothing (new_ID)
1940+ total_lines += _complexity (complexity, factorized, new_ID)
1941+ end
1942+ elseif RHS. f== cos || RHS. f== sin
1943+ total_lines += 300
1944+ new_ID = findfirst (x -> isequal (x. lhs, RHS. arguments[1 ]), factorized)
1945+ if ! isnothing (new_ID)
1946+ total_lines += _complexity (complexity, factorized, new_ID)
1947+ end
1948+ elseif RHS. f== abs
1949+ total_lines += 280
1950+ new_ID = findfirst (x -> isequal (x. lhs, RHS. arguments[1 ]), factorized)
1951+ if ! isnothing (new_ID)
1952+ total_lines += _complexity (complexity, factorized, new_ID)
1953+ end
18691954 else
1870- error (" Unknown function" )
1955+ error (" Some function was used that we can't handle yet ( $RHS ) " )
18711956 end
18721957 elseif exprtype (RHS) == SYM
18731958 nothing
0 commit comments