From 30a98b8d045af126736e41b683e51b4ebe291ea2 Mon Sep 17 00:00:00 2001 From: Sam Buercklin Date: Wed, 1 Dec 2021 09:37:09 -0500 Subject: [PATCH 1/8] added basic rrules, ProjectTo for Quantity --- Project.toml | 1 + src/Unitful.jl | 3 +++ src/chainrules.jl | 34 ++++++++++++++++++++++++++++++++++ 3 files changed, 38 insertions(+) create mode 100644 src/chainrules.jl diff --git a/Project.toml b/Project.toml index d66697299..dcc2d0125 100644 --- a/Project.toml +++ b/Project.toml @@ -3,6 +3,7 @@ uuid = "1986cc42-f94f-5a68-af5c-568840ba703d" version = "1.11.0" [deps] +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/src/Unitful.jl b/src/Unitful.jl index dac829346..bf642c76b 100644 --- a/src/Unitful.jl +++ b/src/Unitful.jl @@ -27,6 +27,8 @@ import Random import ConstructionBase: constructorof +import ChainRulesCore: rrule, NoTangent, ProjectTo + export logunit, unit, absoluteunit, dimension, uconvert, ustrip, upreferred export @dimension, @derived_dimension, @refunit, @unit, @affineunit, @u_str export Quantity, DimensionlessQuantity, NoUnits, NoDims @@ -69,5 +71,6 @@ include("logarithm.jl") include("complex.jl") include("pkgdefaults.jl") include("dates.jl") +include("chainrules.jl") end diff --git a/src/chainrules.jl b/src/chainrules.jl new file mode 100644 index 000000000..b4b5a4316 --- /dev/null +++ b/src/chainrules.jl @@ -0,0 +1,34 @@ +function rrule(UT::Type{Quantity{T,D,U}}, x::Number) where {T,D,U} + unitful_x = Quantity{T,D,U}(x) + projector_x = ProjectTo(x) + uq_pullback(Δx) = (NoTangent(), projector_x(Δx) * oneunit(UT)) + return unitful_x, uq_pullback +end + +function ProjectTo(x::Quantity) + project_val = ProjectTo(x.val) # Project the literal number + return ProjectTo{typeof(x)}(; project_val = project_val) +end + +function (projector::ProjectTo{<:Quantity})(x::Number) + new_val = projector.project_val(ustrip(x)) + return new_val*x +end + +# Project Unitful Quantities onto numerical types by projecting the value and carrying units +(project::ProjectTo{<:Real})(dx::Quantity) = project(ustrip(dx))*unit(dx) +(project::ProjectTo{<:Complex})(dx::Quantity) = project(ustrip(dx))*unit(dx) + +function rrule(::typeof(*), x::Quantity, y::Units, z::Units...) + Ω = *(x, y, z...) + project_x = ProjectTo(x) + function times_pb(Δ) + δ = project_x(Δ) + units = (y, z...) + return (NoTangent(), *(δ, y, z...), ntuple(_ -> NoTangent(), length(units))...) + end + return Ω, times_pb +end + +rrule(::typeof(/), x::Number, y::Units) = rrule(*, x, inv(y)) +rrule(::typeof(/), x::Units, y::Number) = rrule(*, x, inv(y)) From 16281fe58e1c6a349df9633875b1ac6083543d1e Mon Sep 17 00:00:00 2001 From: Sam Buercklin Date: Wed, 1 Dec 2021 10:08:06 -0500 Subject: [PATCH 2/8] Fixed multiplying by value in Quantity projector, should be just units --- src/chainrules.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/chainrules.jl b/src/chainrules.jl index b4b5a4316..5f4fce11e 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -12,7 +12,7 @@ end function (projector::ProjectTo{<:Quantity})(x::Number) new_val = projector.project_val(ustrip(x)) - return new_val*x + return new_val*unit(x) end # Project Unitful Quantities onto numerical types by projecting the value and carrying units From 3c487492fee3153ad62c546597206b40caa090de Mon Sep 17 00:00:00 2001 From: Sam Buercklin Date: Mon, 6 Dec 2021 09:54:09 -0500 Subject: [PATCH 3/8] ProjectTo maps to projecting onto the inner val, cleaner NoTangents in * pullback --- src/chainrules.jl | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/src/chainrules.jl b/src/chainrules.jl index 5f4fce11e..996e52806 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -5,27 +5,22 @@ function rrule(UT::Type{Quantity{T,D,U}}, x::Number) where {T,D,U} return unitful_x, uq_pullback end -function ProjectTo(x::Quantity) - project_val = ProjectTo(x.val) # Project the literal number - return ProjectTo{typeof(x)}(; project_val = project_val) -end - function (projector::ProjectTo{<:Quantity})(x::Number) new_val = projector.project_val(ustrip(x)) return new_val*unit(x) end # Project Unitful Quantities onto numerical types by projecting the value and carrying units +ProjectTo(x::Quantity) = ProjectTo(x.val) + (project::ProjectTo{<:Real})(dx::Quantity) = project(ustrip(dx))*unit(dx) (project::ProjectTo{<:Complex})(dx::Quantity) = project(ustrip(dx))*unit(dx) function rrule(::typeof(*), x::Quantity, y::Units, z::Units...) Ω = *(x, y, z...) - project_x = ProjectTo(x) function times_pb(Δ) - δ = project_x(Δ) - units = (y, z...) - return (NoTangent(), *(δ, y, z...), ntuple(_ -> NoTangent(), length(units))...) + nots = ntuple(_ -> NoTangent(), 1 + length(z)) + return (NoTangent(), *(ProjectTo(x)(Δ), y, z...), nots...) end return Ω, times_pb end From 6f47bdf53598e4a899e385509f3cdc4c73fb1241 Mon Sep 17 00:00:00 2001 From: Sam Buercklin Date: Sat, 15 Jan 2022 11:20:31 -0500 Subject: [PATCH 4/8] added rrule tests --- test/chainrules.jl | 52 ++++++++++++++++++++++++++++++++++++++++++++++ test/runtests.jl | 4 ++++ 2 files changed, 56 insertions(+) create mode 100644 test/chainrules.jl diff --git a/test/chainrules.jl b/test/chainrules.jl new file mode 100644 index 000000000..8314966af --- /dev/null +++ b/test/chainrules.jl @@ -0,0 +1,52 @@ +using ChainRulesCore: rrule, ProjectTo, NoTangent + +@testset "ProjectTo" begin + real_test(proj, val) = proj(val) == real(val) + complex_test(proj, val) = proj(val) == val + uval = 8.0*u"W" + p_uval = ProjectTo(uval) + cuval = (1.0+im)*u"kg" + p_cuval = ProjectTo(cuval) + + p_real = ProjectTo(1.0) + p_complex = ProjectTo(1.0+im) + + δval = 6.0*u"m" + δcval = (2.0+3.0im)*u"L" + + # Test projection onto real unitful quantities + for δ in (δval, δcval, 1.0, 1.0+im) + @test real_test(p_uval, δ) + end + + # Test projection onto complex unitful quantities + for δ in (δval, δcval, 1.0, 1.0+im) + @test complex_test(p_cuval, δ) + end + + # Projecting Unitful quantities onto real values + @test p_real(δval) == δval + @test p_real(δcval) == real(δcval) + + # Projecting Unitful quantities onto complex values + @test p_complex(δval) == δval + @test p_complex(δcval) == δcval +end + +@testset "rrules" begin + @testset "Quantity rrule" begin + UT = typeof(1.0*u"W") + x = 5.0 + Ω, pb = rrule(UT, x) + @test Ω == 5.0 * u"W" + @test pb(3.0) == (NoTangent(), 3.0 * u"W") + end + @testset "* rrule" begin + x = 5.0*u"W" + y = u"m" + z = u"L" + Ω, pb = rrule(*, x, y, z) + @test Ω == x*y*z + @test pb(3.0) == (NoTangent(), 3.0*y*z, NoTangent(), NoTangent()) + end +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index b64810124..53803642c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2047,6 +2047,10 @@ end """ end +@testset "ChainRules" begin + include("./chainrules.jl") +end + # Test precompiled Unitful extension modules load_path = mktempdir() load_cache_path = mktempdir() From c98ebdc527d4060bded567958bbafe2363ccab3e Mon Sep 17 00:00:00 2001 From: Samuel Buercklin Date: Tue, 15 Mar 2022 09:45:04 -0400 Subject: [PATCH 5/8] Update test/chainrules.jl MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit add endline Co-authored-by: Mosè Giordano --- test/chainrules.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/chainrules.jl b/test/chainrules.jl index 8314966af..03e39dddc 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -49,4 +49,4 @@ end @test Ω == x*y*z @test pb(3.0) == (NoTangent(), 3.0*y*z, NoTangent(), NoTangent()) end -end \ No newline at end of file +end From 3a5e0a3aa38c10c3d982a96692b66d3bfc34fa5c Mon Sep 17 00:00:00 2001 From: Sam Buercklin Date: Tue, 15 Mar 2022 09:50:27 -0400 Subject: [PATCH 6/8] add CRC compat bound --- Project.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/Project.toml b/Project.toml index dcc2d0125..1d56f170b 100644 --- a/Project.toml +++ b/Project.toml @@ -10,6 +10,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" [compat] +ChainRulesCore = "1" ConstructionBase = "1" julia = "1" From 5e24deb0b6fbf0da915a3a642507c5fdf63624a9 Mon Sep 17 00:00:00 2001 From: Sam Buercklin Date: Tue, 15 Mar 2022 09:51:23 -0400 Subject: [PATCH 7/8] bump to 1.11.1 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 1d56f170b..c8da697a3 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Unitful" uuid = "1986cc42-f94f-5a68-af5c-568840ba703d" -version = "1.11.0" +version = "1.11.1" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" From 5f11a686f1a087b2057dc14f048db87c6155613d Mon Sep 17 00:00:00 2001 From: Sam Buercklin Date: Tue, 15 Mar 2022 16:38:48 -0400 Subject: [PATCH 8/8] bump minor version to 1.12 instead of patch --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index c8da697a3..7b769f1d3 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Unitful" uuid = "1986cc42-f94f-5a68-af5c-568840ba703d" -version = "1.11.1" +version = "1.12.0" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"