forked from JuliaSIMD/LoopVectorization.jl
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathForwardDiffNNlibExt.jl
More file actions
42 lines (38 loc) · 926 Bytes
/
Copy pathForwardDiffNNlibExt.jl
File metadata and controls
42 lines (38 loc) · 926 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
module ForwardDiffNNlibExt
import ForwardDiff
using LoopVectorization, VectorizationBase, SLEEFPirates, ForwardDiff, NNlib
@generated function NNlib.relu(
x::ForwardDiff.Dual{T,<:LoopVectorization.AbstractSIMD,N}
) where {T,S,N}
quote
$(Expr(:meta, :inline))
v = x.value
z = zero(v)
cmp = v < z
r = ifelse(cmp, z, v)
p = x.partials
ForwardDiff.Dual{T}(
r,
ForwardDiff.Partials(Base.Cartesian.@ntuple $N n -> ifelse(cmp, z, p[n]))
)
end
end
@generated function NNlib.leakyrelu(
x::ForwardDiff.Dual{T,<:LoopVectorization.AbstractSIMD,N},
a = 0.01
) where {T,S,N}
quote
$(Expr(:meta, :inline))
v = x.value
z = zero(v)
α = convert(typeof(v), a)
cmp = v < z
r = ifelse(cmp, α * v, v)
p = x.partials
ForwardDiff.Dual{T}(
r,
ForwardDiff.Partials(Base.Cartesian.@ntuple $N n -> ifelse(cmp, α * p[n], p[n]))
)
end
end
end