Skip to content

Commit 1f0fef1

Browse files
committed
Add abs
1 parent c363eb8 commit 1f0fef1

File tree

3 files changed

+638
-4
lines changed

3 files changed

+638
-4
lines changed

src/kernel_writer/kernel_write.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1560,6 +1560,8 @@ function write_operation(file::IOStream, RHS::BasicSymbolic{Real}, inputs::Vecto
15601560
write(file, SCMC_float_power_kernel(inputs..., 0.5, gradlist, sparsity))
15611561
elseif RHS.f==cos
15621562
write(file, SCMC_cos_kernel(inputs..., gradlist, sparsity))
1563+
elseif RHS.f==abs
1564+
write(file, SCMC_abs_kernel(inputs..., gradlist, sparsity))
15631565
else
15641566
close(file)
15651567
error("Some function was used that we can't handle yet ($RHS)")
@@ -1929,8 +1931,14 @@ function _complexity(complexity::Vector{Int}, factorized::Vector{Equation}, star
19291931
if !isnothing(new_ID)
19301932
total_lines += _complexity(complexity, factorized, new_ID)
19311933
end
1934+
elseif RHS.f==abs
1935+
total_lines += 280
1936+
new_ID = findfirst(x -> isequal(x.lhs, RHS.arguments[1]), factorized)
1937+
if !isnothing(new_ID)
1938+
total_lines += _complexity(complexity, factorized, new_ID)
1939+
end
19321940
else
1933-
error("Unknown function")
1941+
error("Some function was used that we can't handle yet ($RHS)")
19341942
end
19351943
elseif exprtype(RHS) == SYM
19361944
nothing

src/kernel_writer/math_kernels.jl

Lines changed: 157 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1150,6 +1150,162 @@ function SCMC_inv_kernel(OUT::CuDeviceMatrix, x::CuDeviceMatrix)
11501150
return nothing
11511151
end
11521152

1153+
1154+
# Absolute value
1155+
# max threads: ???
1156+
function SCMC_abs_kernel(OUT::CuDeviceMatrix, x::CuDeviceMatrix)
1157+
idx = threadIdx().x + (blockIdx().x - Int32(1)) * blockDim().x
1158+
stride = blockDim().x * gridDim().x
1159+
colmax = Int32((size(OUT,2)-4)/2)
1160+
1161+
while idx <= Int32(size(OUT,1))
1162+
# Reset the column counter
1163+
col = Int32(1)
1164+
1165+
# Get interval extension
1166+
if x[idx,4] >= 0.0 && x[idx,3] <= 0.0
1167+
OUT[idx,3] = 0.0
1168+
else
1169+
OUT[idx,3] = min(abs(x[idx,3]), abs(x[idx,4]))
1170+
end
1171+
OUT[idx,4] = max(abs(x[idx,3]), abs(x[idx,4]))
1172+
1173+
# Calculate eps_min and eps_max
1174+
if x[idx,3] >= 0.0
1175+
eps_min = x[idx,3]
1176+
elseif x[idx,4] <= 0.0
1177+
eps_min = x[idx,4]
1178+
else
1179+
eps_min = 0.0
1180+
end
1181+
if abs(x[idx,4]) >= abs(x[idx,3])
1182+
eps_max = x[idx,4]
1183+
else
1184+
eps_max = x[idx,3]
1185+
end
1186+
1187+
# Get midcv and midcc by finding the middle values of (cv, cc, eps_min), (cv, cc, eps_max)
1188+
midcv, cv_id, midcc, cc_id = midvals(x[idx,1], x[idx,2], eps_min, eps_max)
1189+
1190+
# Get derivative values
1191+
if x[idx,4] - x[idx,3] == 0.0
1192+
OUT[idx,2] = abs(midcc)
1193+
if midcc > 0.0
1194+
dcc = 1.0
1195+
elseif midcc < 0.0
1196+
dcc = -1.0
1197+
else
1198+
dcc = 0.0
1199+
end
1200+
else
1201+
OUT[idx,2] = (abs(x[idx,3])*(x[idx,4] - midcc) + abs(x[idx,4])*(midcc - x[idx,3]))/(x[idx,4]-x[idx,3])
1202+
dcc = (abs(x[idx,4]) - abs(x[idx,3]))/(x[idx,4]-x[idx,3])
1203+
end
1204+
OUT[idx,1] = abs(midcv)
1205+
if midcv > 0.0
1206+
dcv = 1.0
1207+
elseif midcc < 0.0
1208+
dcv = -1.0
1209+
else
1210+
dcv = 0.0
1211+
end
1212+
1213+
# Calculate subgradients
1214+
if cv_id==1
1215+
if cc_id==1
1216+
col = Int32(1)
1217+
while col <= colmax
1218+
OUT[idx,end-2*colmax+col] = x[idx,end-1*colmax+col]*dcv
1219+
OUT[idx,end-1*colmax+col] = x[idx,end-1*colmax+col]*dcc
1220+
col += Int32(1)
1221+
end
1222+
elseif cc_id==2
1223+
col = Int32(1)
1224+
while col <= colmax
1225+
OUT[idx,end-2*colmax+col] = x[idx,end-1*colmax+col]*dcv
1226+
OUT[idx,end-1*colmax+col] = x[idx,end-2*colmax+col]*dcc
1227+
col += Int32(1)
1228+
end
1229+
else
1230+
col = Int32(1)
1231+
while col <= colmax
1232+
OUT[idx,end-2*colmax+col] = x[idx,end-1*colmax+col]*dcv
1233+
OUT[idx,end-1*colmax+col] = 0.0
1234+
col += Int32(1)
1235+
end
1236+
end
1237+
elseif cv_id==2
1238+
if cc_id==1
1239+
col = Int32(1)
1240+
while col <= colmax
1241+
OUT[idx,end-2*colmax+col] = x[idx,end-2*colmax+col]*dcv
1242+
OUT[idx,end-1*colmax+col] = x[idx,end-1*colmax+col]*dcc
1243+
col += Int32(1)
1244+
end
1245+
elseif cc_id==2
1246+
col = Int32(1)
1247+
while col <= colmax
1248+
OUT[idx,end-2*colmax+col] = x[idx,end-2*colmax+col]*dcv
1249+
OUT[idx,end-1*colmax+col] = x[idx,end-2*colmax+col]*dcc
1250+
col += Int32(1)
1251+
end
1252+
else
1253+
col = Int32(1)
1254+
while col <= colmax
1255+
OUT[idx,end-2*colmax+col] = x[idx,end-2*colmax+col]*dcv
1256+
OUT[idx,end-1*colmax+col] = 0.0
1257+
col += Int32(1)
1258+
end
1259+
end
1260+
else
1261+
if cc_id==1
1262+
col = Int32(1)
1263+
while col <= colmax
1264+
OUT[idx,end-2*colmax+col] = 0.0
1265+
OUT[idx,end-1*colmax+col] = x[idx,end-1*colmax+col]*dcc
1266+
col += Int32(1)
1267+
end
1268+
elseif cc_id==2
1269+
col = Int32(1)
1270+
while col <= colmax
1271+
OUT[idx,end-2*colmax+col] = 0.0
1272+
OUT[idx,end-1*colmax+col] = x[idx,end-2*colmax+col]*dcc
1273+
col += Int32(1)
1274+
end
1275+
else
1276+
col = Int32(1)
1277+
while col <= colmax
1278+
OUT[idx,end-2*colmax+col] = 0.0
1279+
OUT[idx,end-1*colmax+col] = 0.0
1280+
col += Int32(1)
1281+
end
1282+
end
1283+
end
1284+
1285+
# Perform the cut operation
1286+
if OUT[idx,1] < OUT[idx,3]
1287+
OUT[idx,1] = OUT[idx,3]
1288+
col = Int32(1)
1289+
while col <= colmax
1290+
OUT[idx,end-2*colmax+col] = 0.0
1291+
col += Int32(1)
1292+
end
1293+
end
1294+
if OUT[idx,2] > OUT[idx,4]
1295+
OUT[idx,2] = OUT[idx,4]
1296+
col = Int32(1)
1297+
while col <= colmax
1298+
OUT[idx,end-1*colmax+col] = 0.0
1299+
col += Int32(1)
1300+
end
1301+
end
1302+
1303+
idx += stride
1304+
end
1305+
return nothing
1306+
end
1307+
1308+
11531309
# Multiplication by a constant
11541310
# max threads: 640
11551311
function SCMC_cmul_kernel(OUT::CuDeviceMatrix, CONST::Real, x::CuDeviceMatrix)
@@ -4463,8 +4619,7 @@ end
44634619

44644620
# Cosine (argument should be in radians)
44654621
# NOTE: Sine can be cos(x - pi/2)
4466-
function SCMC_cos_kernel(OUT, x)
4467-
# function SCMC_cos_kernel(OUT::CuDeviceMatrix, x::CuDeviceMatrix)
4622+
function SCMC_cos_kernel(OUT::CuDeviceMatrix, x::CuDeviceMatrix)
44684623
idx = threadIdx().x + (blockIdx().x - Int32(1)) * blockDim().x
44694624
stride = blockDim().x * gridDim().x
44704625
colmax = Int32((size(OUT,2)-4)/2)

0 commit comments

Comments
 (0)