Skip to content

Commit 20c3ac3

Browse files
committed
Fix add-to sparsity
- Correct the case where information is added in-place to use the sparsity information of the component being added, rather than the sparsity information for the final sum - Simplify the case where a constant is added in-place
1 parent 1f0fef1 commit 20c3ac3

File tree

2 files changed

+76
-49
lines changed

2 files changed

+76
-49
lines changed

src/kernel_writer/kernel_write.jl

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -698,7 +698,21 @@ function create_kernel!(expr_hash::String, kernel_ID::Int, num::BasicSymbolic{Re
698698

699699
# Now we can pass the equation's RHS and the inputs to call the correct
700700
# writer function
701-
write_operation(file, factorized[factorized_ID].rhs, inputs, string.(gradlist), sparsity[i])
701+
if length(inputs)>2 && inputs[1]==inputs[2]
702+
# Special case. We're adding inputs[3] to inputs[2], so we only want
703+
# to pass the sparsity information of inputs[3] (rather than the
704+
# sparsity information of inputs[2] and inputs[3])
705+
corrected_i = findfirst(x->x==inputs[3], varorder)
706+
write_operation(file, factorized[factorized_ID].rhs, inputs, string.(gradlist), sparsity[corrected_i])
707+
elseif length(inputs)>2 && inputs[1]==inputs[3]
708+
# Special case. We're adding inputs[2] to inputs[3], so we only want
709+
# to pass the sparsity information of inputs[2] (rather than the
710+
# sparsity information of inputs[2] and inputs[3])
711+
corrected_i = findfirst(x->x==inputs[2], varorder)
712+
write_operation(file, factorized[factorized_ID].rhs, inputs, string.(gradlist), sparsity[corrected_i])
713+
else
714+
write_operation(file, factorized[factorized_ID].rhs, inputs, string.(gradlist), sparsity[i])
715+
end
702716

703717
# Now that we're done with this variable, eliminate this variable
704718
# from the lists of temporary variables' requirements

src/kernel_writer/string_math_kernels.jl

Lines changed: 61 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -91,60 +91,73 @@ function SCMC_cadd_kernel(OUT::String, v1::String, CONST::Real, varlist::Vector{
9191
buffer = Base.IOBuffer()
9292

9393
# Write all the lines to the buffer
94-
write(buffer, " ############################\n")
95-
write(buffer, " ## Addition of a constant ##\n")
96-
write(buffer, " ############################\n")
97-
write(buffer, "\n")
98-
write(buffer, " # Reset the column counter\n")
99-
write(buffer, " col = Int32(1)\n")
100-
write(buffer, "\n")
101-
write(buffer, " # Begin rule\n")
102-
write(buffer, " $OUT_lo $eq $v1_lo + $CONST\n")
103-
write(buffer, " $OUT_hi $eq $v1_hi + $CONST\n")
104-
write(buffer, " $OUT_cv $eq $v1_cv + $CONST\n")
105-
write(buffer, " $OUT_cc $eq $v1_cc + $CONST\n")
106-
write(buffer, " while col <= colmax\n")
107-
if sparsity_case == 1
108-
write(buffer, " if $sparsity_string\n")
109-
write(buffer, " $(OUT_cvgrad)col] = $v1_cvgrad\n")
110-
write(buffer, " $(OUT_ccgrad)col] = $v1_ccgrad\n")
111-
write(buffer, " else\n")
94+
if OUT != v1
95+
write(buffer, " ############################\n")
96+
write(buffer, " ## Addition of a constant ##\n")
97+
write(buffer, " ############################\n")
98+
write(buffer, "\n")
99+
write(buffer, " # Reset the column counter\n")
100+
write(buffer, " col = Int32(1)\n")
101+
write(buffer, "\n")
102+
write(buffer, " # Begin rule\n")
103+
write(buffer, " $OUT_lo $eq $v1_lo + $CONST\n")
104+
write(buffer, " $OUT_hi $eq $v1_hi + $CONST\n")
105+
write(buffer, " $OUT_cv $eq $v1_cv + $CONST\n")
106+
write(buffer, " $OUT_cc $eq $v1_cc + $CONST\n")
107+
write(buffer, " while col <= colmax\n")
108+
if sparsity_case == 1
109+
write(buffer, " if $sparsity_string\n")
110+
write(buffer, " $(OUT_cvgrad)col] = $v1_cvgrad\n")
111+
write(buffer, " $(OUT_ccgrad)col] = $v1_ccgrad\n")
112+
write(buffer, " else\n")
113+
write(buffer, " $(OUT_cvgrad)col] = 0.0\n")
114+
write(buffer, " $(OUT_ccgrad)col] = 0.0\n")
115+
write(buffer, " end\n")
116+
elseif sparsity_case == 2
117+
write(buffer, " if $antisparsity_string\n")
118+
write(buffer, " $(OUT_cvgrad)col] = 0.0\n")
119+
write(buffer, " $(OUT_ccgrad)col] = 0.0\n")
120+
write(buffer, " else\n")
121+
write(buffer, " $(OUT_cvgrad)col] = $v1_cvgrad\n")
122+
write(buffer, " $(OUT_ccgrad)col] = $v1_ccgrad\n")
123+
write(buffer, " end\n")
124+
else
125+
write(buffer, " $(OUT_cvgrad)col] = $v1_cvgrad\n")
126+
write(buffer, " $(OUT_ccgrad)col] = $v1_ccgrad\n")
127+
end
128+
write(buffer, " col += Int32(1)\n")
129+
write(buffer, " end\n")
130+
write(buffer, "\n")
131+
write(buffer, " # Cut\n")
132+
write(buffer, " if $OUT_cv < $OUT_lo\n")
133+
write(buffer, " $OUT_cv = $OUT_lo\n")
134+
write(buffer, " col = Int32(1)\n")
135+
write(buffer, " while col <= colmax\n")
112136
write(buffer, " $(OUT_cvgrad)col] = 0.0\n")
113-
write(buffer, " $(OUT_ccgrad)col] = 0.0\n")
137+
write(buffer, " col += Int32(1)\n")
114138
write(buffer, " end\n")
115-
elseif sparsity_case == 2
116-
write(buffer, " if $antisparsity_string\n")
117-
write(buffer, " $(OUT_cvgrad)col] = 0.0\n")
139+
write(buffer, " end\n")
140+
write(buffer, " if $OUT_cc > $OUT_hi\n")
141+
write(buffer, " $OUT_cc = $OUT_hi\n")
142+
write(buffer, " col = Int32(1)\n")
143+
write(buffer, " while col <= colmax\n")
118144
write(buffer, " $(OUT_ccgrad)col] = 0.0\n")
119-
write(buffer, " else\n")
120-
write(buffer, " $(OUT_cvgrad)col] = $v1_cvgrad\n")
121-
write(buffer, " $(OUT_ccgrad)col] = $v1_ccgrad\n")
145+
write(buffer, " col += Int32(1)\n")
122146
write(buffer, " end\n")
147+
write(buffer, " end\n")
148+
write(buffer, "\n")
123149
else
124-
write(buffer, " $(OUT_cvgrad)col] = $v1_cvgrad\n")
125-
write(buffer, " $(OUT_ccgrad)col] = $v1_ccgrad\n")
150+
write(buffer, " ############################\n")
151+
write(buffer, " ## Addition of a constant ##\n")
152+
write(buffer, " ############################\n")
153+
write(buffer, "\n")
154+
write(buffer, " # Begin rule\n")
155+
write(buffer, " $OUT_lo += $CONST\n")
156+
write(buffer, " $OUT_hi += $CONST\n")
157+
write(buffer, " $OUT_cv += $CONST\n")
158+
write(buffer, " $OUT_cc += $CONST\n")
159+
write(buffer, "\n")
126160
end
127-
write(buffer, " col += Int32(1)\n")
128-
write(buffer, " end\n")
129-
write(buffer, "\n")
130-
write(buffer, " # Cut\n")
131-
write(buffer, " if $OUT_cv < $OUT_lo\n")
132-
write(buffer, " $OUT_cv = $OUT_lo\n")
133-
write(buffer, " col = Int32(1)\n")
134-
write(buffer, " while col <= colmax\n")
135-
write(buffer, " $(OUT_cvgrad)col] = 0.0\n")
136-
write(buffer, " col += Int32(1)\n")
137-
write(buffer, " end\n")
138-
write(buffer, " end\n")
139-
write(buffer, " if $OUT_cc > $OUT_hi\n")
140-
write(buffer, " $OUT_cc = $OUT_hi\n")
141-
write(buffer, " col = Int32(1)\n")
142-
write(buffer, " while col <= colmax\n")
143-
write(buffer, " $(OUT_ccgrad)col] = 0.0\n")
144-
write(buffer, " col += Int32(1)\n")
145-
write(buffer, " end\n")
146-
write(buffer, " end\n")
147-
write(buffer, "\n")
148161
return String(take!(buffer))
149162
end
150163

0 commit comments

Comments
 (0)