@@ -109,13 +109,10 @@ function kgen(num::Num, gradlist::Vector{Num}, raw_outputs::Vector{Symbol}, cons
109109 error (" Splitting must be one of: {:low, :default, :high, :max}" )
110110 end
111111
112- # Pull out sparsity information in the factorization
113- sparsity = detect_sparsity (factored, gradlist)
114-
115112 # Decide if the kernel needs to be split
116113 if (n_vars[end ] < 31 ) && ((n_lines[end ] <= max_size) || (findfirst (x -> x > split_point, n_lines)== length (n_lines)))
117114 # Complexity is fairly low; only a single kernel needed
118- create_kernel! (expr_hash, 1 , num, get_name .(gradlist), func_outputs, constants, factored, sparsity )
115+ create_kernel! (expr_hash, 1 , num, get_name .(gradlist), func_outputs, constants, factored)
119116 push! (kernel_nums, 1 )
120117 push! (inputs, string .(indep_vars))
121118 push! (outputs, " OUT" )
@@ -169,7 +166,7 @@ function kgen(num::Num, gradlist::Vector{Num}, raw_outputs::Vector{Symbol}, cons
169166 # ### Start of alternative to experimental section
170167
171168 # Send the element at `new_ID` to create_kernel!()
172- create_kernel! (expr_hash, kernel_count, extract (factored, new_ID), get_name .(gradlist), func_outputs, constants, factored, sparsity )
169+ create_kernel! (expr_hash, kernel_count, extract (factored, new_ID), get_name .(gradlist), func_outputs, constants, factored)
173170 push! (kernel_nums, kernel_count)
174171 push! (inputs, string .(get_name .(pull_vars (extract (factored, new_ID)))))
175172 push! (outputs, string (factored[new_ID]. lhs))
@@ -188,7 +185,7 @@ function kgen(num::Num, gradlist::Vector{Num}, raw_outputs::Vector{Symbol}, cons
188185 # If the total number of lines (not including the final line) is below the max size
189186 # and the number of variables is below 32, we can make the final kernel and be done
190187 if (n_vars[end ] < 32 ) && (all (n_lines[1 : end - 1 ] .<= max_size))
191- create_kernel! (expr_hash, kernel_count, extract (factored), get_name .(gradlist), func_outputs, constants, factored, sparsity )
188+ create_kernel! (expr_hash, kernel_count, extract (factored), get_name .(gradlist), func_outputs, constants, factored)
192189 push! (kernel_nums, kernel_count)
193190 push! (inputs, string .(get_name .(pull_vars (extract (factored)))))
194191 push! (outputs, " OUT" )
435432# This function takes information about the file name, kernel ID, and
436433# the expression that a SINGLE kernel is being created for, and creates
437434# that kernel in the specified file.
438- create_kernel! (expr_hash:: String , kernel_ID:: Int , num:: Num , gradlist:: Vector{Symbol} , func_outputs:: Vector{Symbol} , constants:: Vector{Num} , orig_factored:: Vector{Equation} , orig_sparsity :: Vector{Vector{Int}} ) = create_kernel! (expr_hash, kernel_ID, num. val, gradlist, func_outputs, constants, orig_factored, orig_sparsity )
439- function create_kernel! (expr_hash:: String , kernel_ID:: Int , num:: BasicSymbolic{Real} , gradlist:: Vector{Symbol} , func_outputs:: Vector{Symbol} , constants:: Vector{Num} , orig_factored:: Vector{Equation} , orig_sparsity :: Vector{Vector{Int}} )
435+ create_kernel! (expr_hash:: String , kernel_ID:: Int , num:: Num , gradlist:: Vector{Symbol} , func_outputs:: Vector{Symbol} , constants:: Vector{Num} , orig_factored:: Vector{Equation} ) = create_kernel! (expr_hash, kernel_ID, num. val, gradlist, func_outputs, constants, orig_factored)
436+ function create_kernel! (expr_hash:: String , kernel_ID:: Int , num:: BasicSymbolic{Real} , gradlist:: Vector{Symbol} , func_outputs:: Vector{Symbol} , constants:: Vector{Num} , orig_factored:: Vector{Equation} )
440437 # This function will create a kernel for `num`, with the name:
441438 # "f_" * expr_hash * "_$n". This name will be pushed to `kernels`,
442439 # and a vector of the required inputs variables will be pushed to
@@ -466,41 +463,7 @@ function create_kernel!(expr_hash::String, kernel_ID::Int, num::BasicSymbolic{Re
466463 # normally, unless it's a temporary variable, in which case we have to refer to the original
467464 # factorization and sparsity.
468465 string_gradlist = string .(gradlist)
469- sparsity = Vector {Vector{Int}} (undef, length (varorder))
470- for i in eachindex (varorder)
471- if varorder[i] in string_gradlist
472- # Mark sparsity if the variable is already in gradlist
473- sparsity[i] = [findfirst (== (varorder[i]), string_gradlist)]
474- else
475- # Find out what index we're on
476- idx = findfirst (x -> isequal (string (x. lhs), varorder[i]), factorized)
477-
478- if isnothing (idx)
479- sparsity[i] = orig_sparsity[findfirst (x -> isequal (string (x. lhs), varorder[i]), orig_factored)]
480- else
481- # Extract all the variables for this index
482- vars = pull_vars (extract (factorized, idx))
483-
484- # For each variable in the expanded expression, add in sparsity information
485- curr_sparsity = Int[]
486- for var in vars
487- ID = findfirst (== (string (get_name (var))), string_gradlist)
488- if isnothing (ID)
489- # If we didn't find the variable, we need to scan the original factorization,
490- # and then pull sparsity info from the original sparsity list
491- ID = findfirst (x -> isequal (string (x. lhs), string (var)), orig_factored)
492- push! (curr_sparsity, orig_sparsity[ID]. .. )
493- else
494- # If we do find the variable, we can add this variable directly into the sparsity
495- push! (curr_sparsity, ID)
496- end
497- end
498-
499- # Add a sorted, unique list to the sparsity tracker
500- sparsity[i] = sort (unique (curr_sparsity))
501- end
502- end
503- end
466+ sparsity = get_sparsity (varorder, string_gradlist, factorized)
504467
505468 # Check if we need temporary variables at all. We don't need
506469 # temporary variables if we only have addition, or if we have
@@ -536,7 +499,7 @@ function create_kernel!(expr_hash::String, kernel_ID::Int, num::BasicSymbolic{Re
536499 maxtemp = 0
537500 # if need_temps #(Skip for now)
538501 for i in eachindex (varorder) # Loop through every participating variable
539- if (varorder[i] in string .(get_name .( vars) ))
502+ if (varorder[i] in string .(vars))
540503 # Skip the variable if it's an input
541504 continue
542505 end
@@ -608,15 +571,15 @@ function create_kernel!(expr_hash::String, kernel_ID::Int, num::BasicSymbolic{Re
608571 file = open (joinpath (@__DIR__ , " storage" , " f_" * expr_hash* " .jl" ), " a" )
609572
610573 # Put in the preamble.
611- write (file, preamble_string (expr_hash, [" OUT" ; string .(get_name .( vars) )], kernel_ID, maxtemp, length (gradlist)))
574+ write (file, preamble_string (expr_hash, [" OUT" ; string .(vars)], kernel_ID, maxtemp, length (gradlist)))
612575
613576 # Loop through the topological list to add calculations in order
614577 temp_endlist = []
615578 outvar = " "
616579 name_tracker = copy (varids)
617580 for i in eachindex (varorder) # Order in which variables are calculated
618581 # Skip calculation if the variable is one of the inputs
619- if (varorder[i] in string .(get_name .( vars) ))
582+ if (varorder[i] in string .(vars))
620583 continue
621584 end
622585
@@ -628,20 +591,17 @@ function create_kernel!(expr_hash::String, kernel_ID::Int, num::BasicSymbolic{Re
628591 factorized_ID = findfirst (x -> isequal (string (x. lhs), varorder[i]), factorized)
629592 participants = get_name .(pull_vars (factorized[factorized_ID]. rhs))
630593 inputs = []
594+ input_IDs = Int[]
631595 for p in string .(participants)
632596 # Find the corresponding element in varids
633597 varids_ID = findfirst (x -> isequal (x, p), varids)
634- # Push the name_tracker name to the input list
598+ # Push the name_tracker name to the input list and the ID to the list of input IDs
635599 push! (inputs, name_tracker[varids_ID])
636- end
637600
638- # [Deprecating; I'll use temporary variables the whole way and then set
639- # the output variable at the end for final copying]
640- # # If this is the final variable, it'll be called "OUT". No need
641- # # for temp variables
642- # if i==length(varorder)
643- # name_tracker[ID] = "OUT"
644- # else
601+ # Find the corresponding element in varorder
602+ varorder_ID = findfirst (x -> isequal (x, p), varorder)
603+ push! (input_IDs, varorder_ID)
604+ end
645605
646606 # Determine which tempID to use/override. temp_endlist keeps
647607 # track of where variables will be used in the future (stored
@@ -695,21 +655,20 @@ function create_kernel!(expr_hash::String, kernel_ID::Int, num::BasicSymbolic{Re
695655 # Now we can append this temporary variable to the list of inputs
696656 # for the correct operation
697657 inputs = [name_tracker[ID]; inputs]
658+ input_IDs = [i; input_IDs]
698659
699660 # Now we can pass the equation's RHS and the inputs to call the correct
700661 # writer function
701662 if length (inputs)> 2 && inputs[1 ]== inputs[2 ]
702663 # Special case. We're adding inputs[3] to inputs[2], so we only want
703664 # to pass the sparsity information of inputs[3] (rather than the
704665 # sparsity information of inputs[2] and inputs[3])
705- corrected_i = findfirst (x-> x== inputs[3 ], varorder)
706- write_operation (file, factorized[factorized_ID]. rhs, inputs, string .(gradlist), sparsity[corrected_i])
666+ write_operation (file, factorized[factorized_ID]. rhs, inputs, string .(gradlist), sparsity[input_IDs[3 ]])
707667 elseif length (inputs)> 2 && inputs[1 ]== inputs[3 ]
708668 # Special case. We're adding inputs[2] to inputs[3], so we only want
709669 # to pass the sparsity information of inputs[2] (rather than the
710670 # sparsity information of inputs[2] and inputs[3])
711- corrected_i = findfirst (x-> x== inputs[2 ], varorder)
712- write_operation (file, factorized[factorized_ID]. rhs, inputs, string .(gradlist), sparsity[corrected_i])
671+ write_operation (file, factorized[factorized_ID]. rhs, inputs, string .(gradlist), sparsity[input_IDs[2 ]])
713672 else
714673 write_operation (file, factorized[factorized_ID]. rhs, inputs, string .(gradlist), sparsity[i])
715674 end
765724# 1) (a^x1)^b = (a^b)^x1 [EAGO paper] (Can't do powers besides integers)
766725function perform_substitutions (old_factored:: Vector{Equation} )
767726 factored = deepcopy (old_factored)
768-
769- # Register any terms we want to substitute
770- @eval @register_symbolic SCMC_sigmoid (x)
771-
772727 scan_flag = true
773728 while scan_flag
774729 scan_flag = false
@@ -1468,23 +1423,28 @@ function pull_mult(factors, index; args=[])
14681423end
14691424
14701425
1471- # A helper function to calculate the sparsity of a factorization, given
1472- # the full gradlist
1473- function detect_sparsity (factored, gradlist)
1474- # Idea is to check every element of "factored" and pull out a list of indices
1475- # in gradlist that the variables depend on.
1476- sparsity = Vector{Int}[]
1477- string_gradlist = string .(gradlist)
1478-
1479- for i in eachindex (factored)
1480- vars = string .(pull_vars (extract (factored, i)))
1481- indices = zeros (Int, length (vars))
1482- for j in eachindex (indices)
1483- indices[j] = findfirst (== (string (vars[j])), string_gradlist)
1426+ # A helper function to calculate the sparsity of a factorization
1427+ function get_sparsity (varorder, string_gradlist, factored)
1428+ sparsity = [Int[] for _ in 1 : length (varorder)]
1429+ function calc_sp (v)
1430+ if ! isempty (sparsity[v])
1431+ return sparsity[v]
1432+ end
1433+ if varorder[v] in string_gradlist
1434+ sparsity[v] = [findfirst (== (varorder[v]), string_gradlist)]
1435+ return sparsity[v]
1436+ end
1437+ idx = findfirst (x -> isequal (string (x. lhs), varorder[v]), factored)
1438+ RHS_vars = pull_vars (factored[idx]. rhs)
1439+ for var in RHS_vars
1440+ var_idx = findfirst (== (string (get_name (var))), varorder)
1441+ sparsity[v] = [sparsity[v]. .. , calc_sp (var_idx)... ]
14841442 end
1485- push! (sparsity, indices)
1443+ return sort (unique (sparsity[v]))
1444+ end
1445+ for v in 1 : length (varorder)
1446+ sparsity[v] = calc_sp (v)
14861447 end
1487-
14881448 return sparsity
14891449end
14901450
0 commit comments