Make some functions more AD friendly#91
Make some functions more AD friendly#91DhairyaLGandhi wants to merge 2 commits intoJuliaMolSim:masterfrom
Conversation
|
Seems okay to me. Is there a Zygote issue discussing why changing broadcast to map is required here? It might be worth referencing that in a code comment otherwise this could get changed back in future. Beyond this PR we could think about adding a Zygote test if we want to make sure we don't break AD compat. |
|
Zygote hasn't changed here, what is required is overloading |
|
I like the idea of adding tests here |
|
Definitely agree about not depending on Zygote and the same implementation with/without AD. I'm just wondering why that overload is required at all, sounds like something that could be tracked/improved in Zygote? |
|
Lgtm modulo adding a test. Would it make sense to make this part of AtomsBaseTesting to test also in downstream codes? |
|
It could be an optional extra or emit a warning in AtomsBaseTesting, but I don't think we should make Zygote compat a required part of the interface. Tests to check that the systems in AtomsBase are Zygote-compatible would be useful to avoid regression though, I guess taking on Zygote as a test dependency is fine. |
|
I was working on some tests to add to this and am running into a missing adjoint issue... box = [[1, 0, 0], [0, 1, 0], [0, 0, 1]]u"m"
bcs = [Periodic(), Periodic(), DirichletZero()]
elements = [:C, :C]
atoms = [Atom(elements[i], positions[i]) for i in 1:2]
# distance between first two particles
function dist(sys::AbstractSystem)
sepvec = diff(position(sys))[1]
sqrt(dot(sepvec, sepvec))
end
gradient(0) do x
positions = [[0, 0, 0], [x, 0.5, 0.5]]u"m"
atoms = [Atom(elements[i], positions[i]) for i in 1:2]
flexible = FlexibleSystem(atoms, box, bcs)
dist(flexible)
endAnd I get a super long stacktrace that starts with: ERROR: Need an adjoint for constructor StaticArrays.SVector{3, Quantity{Float64, 𝐋, Unitful.FreeUnits{(m,), 𝐋, nothing}}}. Gradient is of type StaticArrays.SVector{3, Float64}I found a whole chain of discussions across several PR's on various packages (1 -> 2 -> 3 -> 4 -> 5 -> 6), and I don't follow enough of the nitty-gritty details to know for sure if that last one will fix this or not when merged (it also seems like it might depend on the Julia version? I was doing this on 1.9.2), but hopefully @DhairyaLGandhi can lend some insight? |
|
I get a different issue, related to mutation with using AtomsBase, Zygote, Unitful, LinearAlgebra
box = [[1, 0, 0], [0, 1, 0], [0, 0, 1]]u"m"
bcs = [Periodic(), Periodic(), DirichletZero()]
elements = [:C, :C]
function dist(sys::AbstractSystem)
sepvec = diff(position(sys))[1]
sqrt(dot(sepvec, sepvec))
end
gradient(0) do x
positions = [[0, 0, 0], [x, 0.5, 0.5]]u"m"
atoms = [Atom(elements[i], positions[i]) for i in 1:2]
flexible = FlexibleSystem(atoms, box, bcs)
dist(flexible)
end |
|
So far, I've looked at |
|
I have always found it hard to get units to play well with AD, and don't use them when taking gradients in my own code. There is https://github.com/SBuercklin/UnitfulChainRules.jl which may be useful. |
|
Update on this PR? |
Zygote has a property called
literal_indexed_iteratewhich types with some iteration can implement to allow for cleaner accumulation of gradients when working with AD. However, this adds a dependency on Zygote, which might be costly for a base package.Package extensions also cannot be used since it would basically overwrite methods causing an amount of piracy. It is also disallowed as of Julia 1.10. This therefore is a simple way to still benefit from AD-able code gen while not having to introduce (any) complexity.