diff --git a/src/rules/discrete_transition/categoricals.jl b/src/rules/discrete_transition/categoricals.jl index 0c837d990..822da5bd8 100644 --- a/src/rules/discrete_transition/categoricals.jl +++ b/src/rules/discrete_transition/categoricals.jl @@ -159,7 +159,7 @@ Compute the message for one of the Categorical interfaces of the `DiscreteTransi function discrete_transition_structured_message_rule(message_names, messages, marginals_names, marginals, q_a) e_log_a = mean(BroadcastFunction(clamplog), q_a) e_log_a = discrete_transition_process_marginals(e_log_a, marginals_names, marginals) - msg = clamp.(exp.(e_log_a), tiny, huge) + msg = clamp.(softmax!(e_log_a), tiny, huge) msg = discrete_transition_process_messages(msg, message_names, messages, sum_out_dimensions) msg = reshape(msg, :) normalize!(msg, 1) diff --git a/src/rules/discrete_transition/marginals.jl b/src/rules/discrete_transition/marginals.jl index 2d8306a8a..fa6b91cb5 100644 --- a/src/rules/discrete_transition/marginals.jl +++ b/src/rules/discrete_transition/marginals.jl @@ -13,11 +13,13 @@ function marginalrule( ::Any, ::Any ) where {marginal_symbol, message_names, N} - return Contingency(outer_product(probvec.(messages)) .* exp.(mean(BroadcastFunction(clamplog), first(marginals)))) + result = outer_product(probvec.(messages)) .* softmax!(mean(BroadcastFunction(clamplog), first(marginals))) + normalize!(result, 1) + return Contingency(result, Val(false)) end -nonparametric_distribution(v::Vector{<:Real}) = Categorical(normalize!(v, 1)) -nonparametric_distribution(v::AbstractArray{<:Real, N} where {N}) = Contingency(v) +nonparametric_distribution(v::Vector{<:Real}) = Categorical(normalize!(v, 1); check_args = false) +nonparametric_distribution(v::AbstractArray{<:Real, N} where {N}) = Contingency(normalize!(v, 1), Val(false)) # Generic implementation """ @@ -39,15 +41,16 @@ function discrete_transition_marginal_rule( e_log_a = mean(BroadcastFunction(clamplog), q_a) e_log_a = discrete_transition_process_marginals(e_log_a, marginals_names, marginals) - marginal = clamp.(exp.(e_log_a), tiny, huge) + marginal = clamp.(softmax!(e_log_a), tiny, huge) marginal = discrete_transition_process_messages(marginal, message_names, messages, multiply_dimensions!) dims = Tuple(findall(size(marginal) .== 1)) marginal = dropdims(marginal, dims = dims) + normalize!(marginal, 1) return marginal end discrete_transition_marginal_rule_contingency(message_names::NTuple{N, Symbol}, messages::NTuple{N, Union{<:Message{<:DiscreteNonParametric}, <:Message{<:Bernoulli}}}, marginals_names::NTuple{M, Symbol}, marginals, q_a) where {N, M} = Contingency( - discrete_transition_marginal_rule(message_names, messages, marginals_names, marginals, q_a) + discrete_transition_marginal_rule(message_names, messages, marginals_names, marginals, q_a), Val(false) ) function marginalrule( diff --git a/src/rules/discrete_transition/predefined/a.jl b/src/rules/discrete_transition/predefined/a.jl index 4e82c20ed..057433d0c 100644 --- a/src/rules/discrete_transition/predefined/a.jl +++ b/src/rules/discrete_transition/predefined/a.jl @@ -12,6 +12,7 @@ end @rule DiscreteTransition(:a, Marginalisation) (q_out_in::Contingency, q_T1::PointMass{<:AbstractVector{T}}, meta::Any) where {T} = begin out_in = components(q_out_in) T1 = probvec(q_T1) - @tullio result[a, b, c] := out_in[a, b] * T1[c] - return DirichletCollection(result .+ 1) + result = ones(T, size(out_in)..., length(T1)) + result[:, :, findfirst(isone, T1)] .+= out_in + return DirichletCollection(result) end diff --git a/src/rules/discrete_transition/predefined/belief_propagation.jl b/src/rules/discrete_transition/predefined/belief_propagation.jl index a264486a3..1cec68927 100644 --- a/src/rules/discrete_transition/predefined/belief_propagation.jl +++ b/src/rules/discrete_transition/predefined/belief_propagation.jl @@ -2,45 +2,52 @@ using Tullio # # --------------- Rules for 2 interfaces (PointMass q_a) --------------- @rule DiscreteTransition(:out, Marginalisation) (m_in::DiscreteNonParametric, q_a::PointMass{<:AbstractArray{T, 2}}, meta::Any) where {T} = begin - eloga = mean(q_a) + N = eltype(probvec(m_in)) + eloga = clamp.(mean(q_a), tiny(N), one(N)) out = eloga * probvec(m_in) + return Categorical(normalize!(out, 1); check_args = false) end @rule DiscreteTransition(:in, Marginalisation) (m_out::DiscreteNonParametric, q_a::PointMass{<:AbstractArray{T, 2}}, meta::Any) where {T} = begin - eloga = mean(q_a) + N = eltype(probvec(m_out)) + eloga = clamp.(mean(q_a), tiny(N), one(N)) out = eloga' * probvec(m_out) return Categorical(normalize!(out, 1); check_args = false) end # --------------- Rules for 2 interfaces (DirichletCollection q_a) --------------- @rule DiscreteTransition(:out, Marginalisation) (m_in::DiscreteNonParametric, q_a::DirichletCollection, meta::Any) = begin - eloga = exp.(mean(Base.Broadcast.BroadcastFunction(clamplog), q_a)) + N = eltype(probvec(m_in)) + eloga = softmax!(mean(Base.Broadcast.BroadcastFunction(clamplog), q_a)) out = eloga * probvec(m_in) return Categorical(normalize!(out, 1); check_args = false) end @rule DiscreteTransition(:in, Marginalisation) (m_out::DiscreteNonParametric, q_a::DirichletCollection, meta::Any) = begin - eloga = exp.(mean(Base.Broadcast.BroadcastFunction(clamplog), q_a)) + eloga = softmax!(mean(Base.Broadcast.BroadcastFunction(clamplog), q_a)) out = eloga' * probvec(m_out) return Categorical(normalize!(out, 1); check_args = false) end # --------------- Rules for 3 interfaces (PointMass q_a) --------------- @rule DiscreteTransition(:out, Marginalisation) (m_in::DiscreteNonParametric, m_T1::DiscreteNonParametric, q_a::PointMass{<:AbstractArray{T, 3}}, meta::Any) where {T} = begin - eloga = mean(q_a) + N = eltype(probvec(m_in)) + eloga = clamp.(mean(q_a), tiny(N), one(N)) @tullio out[i] := eloga[i, a, b] * probvec(m_in)[a] * probvec(m_T1)[b] return Categorical(normalize!(out, 1); check_args = false) end @rule DiscreteTransition(:in, Marginalisation) (m_out::DiscreteNonParametric, m_T1::DiscreteNonParametric, q_a::PointMass{<:AbstractArray{T, 3}}, meta::Any) where {T} = begin - eloga = mean(q_a) + N = eltype(probvec(m_out)) + eloga = clamp.(mean(q_a), tiny(N), one(N)) @tullio out[i] := eloga[a, i, b] * probvec(m_out)[a] * probvec(m_T1)[b] return Categorical(normalize!(out, 1); check_args = false) end @rule DiscreteTransition(:T1, Marginalisation) (m_out::DiscreteNonParametric, m_in::DiscreteNonParametric, q_a::PointMass{<:AbstractArray{T, 3}}, meta::Any) where {T} = begin - eloga = mean(q_a) + N = eltype(probvec(m_out)) + eloga = clamp.(mean(q_a), tiny(N), one(N)) @tullio out[i] := eloga[a, b, i] * probvec(m_out)[a] * probvec(m_in)[b] return Categorical(normalize!(out, 1); check_args = false) end @@ -49,34 +56,34 @@ end @rule DiscreteTransition(:out, Marginalisation) (m_in::DiscreteNonParametric, q_a::DirichletCollection, q_T1::PointMass{<:AbstractArray{T, 3}}, meta::Any) where {T} = begin eloga = mean(Base.Broadcast.BroadcastFunction(clamplog), q_a) @tullio intermediate[i, a] := eloga[i, a, b] * probvec(q_T1)[b] - out .= exp.(intermediate) - result = out * probvec(m_in) + softmax!(intermediate) + result = intermediate * probvec(m_in) return Categorical(normalize!(result, 1); check_args = false) end @rule DiscreteTransition(:in, Marginalisation) (m_out::DiscreteNonParametric, q_a::DirichletCollection, q_T1::PointMass{<:AbstractArray{T, 3}}, meta::Any) where {T} = begin eloga = mean(Base.Broadcast.BroadcastFunction(clamplog), q_a) @tullio intermediate[a, i] := eloga[a, i, b] * probvec(q_T1)[b] - out .= exp.(intermediate) - result = out' * probvec(m_out) + softmax!(intermediate) + result = intermediate' * probvec(m_out) return Categorical(normalize!(result, 1); check_args = false) end # --------------- Rules for 3 interfaces (DirichletCollection q_a) --------------- @rule DiscreteTransition(:out, Marginalisation) (m_in::DiscreteNonParametric, m_T1::DiscreteNonParametric, q_a::DirichletCollection, meta::Any) = begin - eloga = exp.(mean(Base.Broadcast.BroadcastFunction(clamplog), q_a)) + eloga = softmax!(mean(Base.Broadcast.BroadcastFunction(clamplog), q_a)) @tullio out[i] := eloga[i, a, b] * probvec(m_in)[a] * probvec(m_T1)[b] return Categorical(normalize!(out, 1); check_args = false) end @rule DiscreteTransition(:in, Marginalisation) (m_out::DiscreteNonParametric, m_T1::DiscreteNonParametric, q_a::DirichletCollection, meta::Any) = begin - eloga = exp.(mean(Base.Broadcast.BroadcastFunction(clamplog), q_a)) + eloga = softmax!(mean(Base.Broadcast.BroadcastFunction(clamplog), q_a)) @tullio out[i] := eloga[a, i, b] * probvec(m_out)[a] * probvec(m_T1)[b] return Categorical(normalize!(out, 1); check_args = false) end @rule DiscreteTransition(:T1, Marginalisation) (m_out::DiscreteNonParametric, m_in::DiscreteNonParametric, q_a::DirichletCollection, meta::Any) = begin - eloga = exp.(mean(Base.Broadcast.BroadcastFunction(clamplog), q_a)) + eloga = softmax!(mean(Base.Broadcast.BroadcastFunction(clamplog), q_a)) @tullio out[i] := eloga[a, b, i] * probvec(m_out)[a] * probvec(m_in)[b] return Categorical(normalize!(out, 1); check_args = false) end @@ -85,7 +92,8 @@ end @rule DiscreteTransition(:out, Marginalisation) ( m_in::DiscreteNonParametric, m_T1::DiscreteNonParametric, m_T2::DiscreteNonParametric, q_a::PointMass{<:AbstractArray{T, 4}}, meta::Any ) where {T} = begin - eloga = mean(q_a) + N = eltype(probvec(m_in)) + eloga = clamp.(mean(q_a), tiny(N), one(N)) @tullio out[i] := eloga[i, a, b, c] * probvec(m_in)[a] * probvec(m_T1)[b] * probvec(m_T2)[c] return Categorical(normalize!(out, 1); check_args = false) end @@ -93,7 +101,8 @@ end @rule DiscreteTransition(:in, Marginalisation) ( m_out::DiscreteNonParametric, m_T1::DiscreteNonParametric, m_T2::DiscreteNonParametric, q_a::PointMass{<:AbstractArray{T, 4}}, meta::Any ) where {T} = begin - eloga = mean(q_a) + N = eltype(probvec(m_out)) + eloga = clamp.(mean(q_a), tiny(N), one(N)) @tullio out[i] := eloga[a, i, b, c] * probvec(m_out)[a] * probvec(m_T1)[b] * probvec(m_T2)[c] return Categorical(normalize!(out, 1); check_args = false) end @@ -101,7 +110,8 @@ end @rule DiscreteTransition(:T1, Marginalisation) ( m_out::DiscreteNonParametric, m_in::DiscreteNonParametric, m_T2::DiscreteNonParametric, q_a::PointMass{<:AbstractArray{T, 4}}, meta::Any ) where {T} = begin - eloga = mean(q_a) + N = eltype(probvec(m_out)) + eloga = clamp.(mean(q_a), tiny(N), one(N)) @tullio out[i] := eloga[a, b, i, c] * probvec(m_out)[a] * probvec(m_in)[b] * probvec(m_T2)[c] return Categorical(normalize!(out, 1); check_args = false) end @@ -109,32 +119,33 @@ end @rule DiscreteTransition(:T2, Marginalisation) ( m_out::DiscreteNonParametric, m_in::DiscreteNonParametric, m_T1::DiscreteNonParametric, q_a::PointMass{<:AbstractArray{T, 4}}, meta::Any ) where {T} = begin - eloga = mean(q_a) + N = eltype(probvec(m_out)) + eloga = clamp.(mean(q_a), tiny(N), one(N)) @tullio out[i] := eloga[a, b, c, i] * probvec(m_out)[a] * probvec(m_in)[b] * probvec(m_T1)[c] return Categorical(normalize!(out, 1); check_args = false) end # --------------- Rules for 4 interfaces (DirichletCollection q_a) --------------- @rule DiscreteTransition(:out, Marginalisation) (m_in::DiscreteNonParametric, m_T1::DiscreteNonParametric, m_T2::DiscreteNonParametric, q_a::DirichletCollection, meta::Any) = begin - eloga = exp.(mean(Base.Broadcast.BroadcastFunction(clamplog), q_a)) + eloga = softmax!(mean(Base.Broadcast.BroadcastFunction(clamplog), q_a)) @tullio out[i] := eloga[i, a, b, c] * probvec(m_in)[a] * probvec(m_T1)[b] * probvec(m_T2)[c] return Categorical(normalize!(out, 1); check_args = false) end @rule DiscreteTransition(:in, Marginalisation) (m_out::DiscreteNonParametric, m_T1::DiscreteNonParametric, m_T2::DiscreteNonParametric, q_a::DirichletCollection, meta::Any) = begin - eloga = exp.(mean(Base.Broadcast.BroadcastFunction(clamplog), q_a)) + eloga = softmax!(mean(Base.Broadcast.BroadcastFunction(clamplog), q_a)) @tullio out[i] := eloga[a, i, b, c] * probvec(m_out)[a] * probvec(m_T1)[b] * probvec(m_T2)[c] return Categorical(normalize!(out, 1); check_args = false) end @rule DiscreteTransition(:T1, Marginalisation) (m_out::DiscreteNonParametric, m_in::DiscreteNonParametric, m_T2::DiscreteNonParametric, q_a::DirichletCollection, meta::Any) = begin - eloga = exp.(mean(Base.Broadcast.BroadcastFunction(clamplog), q_a)) + eloga = softmax!(mean(Base.Broadcast.BroadcastFunction(clamplog), q_a)) @tullio out[i] := eloga[a, b, i, c] * probvec(m_out)[a] * probvec(m_in)[b] * probvec(m_T2)[c] return Categorical(normalize!(out, 1); check_args = false) end @rule DiscreteTransition(:T2, Marginalisation) (m_out::DiscreteNonParametric, m_in::DiscreteNonParametric, m_T1::DiscreteNonParametric, q_a::DirichletCollection, meta::Any) = begin - eloga = exp.(mean(Base.Broadcast.BroadcastFunction(clamplog), q_a)) + eloga = softmax!(mean(Base.Broadcast.BroadcastFunction(clamplog), q_a)) @tullio out[i] := eloga[a, b, c, i] * probvec(m_out)[a] * probvec(m_in)[b] * probvec(m_T1)[c] return Categorical(normalize!(out, 1); check_args = false) end @@ -143,7 +154,8 @@ end @rule DiscreteTransition(:out, Marginalisation) ( m_in::DiscreteNonParametric, m_T1::DiscreteNonParametric, m_T2::DiscreteNonParametric, m_T3::DiscreteNonParametric, q_a::PointMass{<:AbstractArray{T, 5}}, meta::Any ) where {T} = begin - eloga = mean(q_a) + N = eltype(probvec(m_in)) + eloga = clamp.(mean(q_a), tiny(N), one(N)) @tullio out[i] := eloga[i, a, b, c, d] * probvec(m_in)[a] * probvec(m_T1)[b] * probvec(m_T2)[c] * probvec(m_T3)[d] return Categorical(normalize!(out, 1); check_args = false) end @@ -151,7 +163,8 @@ end @rule DiscreteTransition(:in, Marginalisation) ( m_out::DiscreteNonParametric, m_T1::DiscreteNonParametric, m_T2::DiscreteNonParametric, m_T3::DiscreteNonParametric, q_a::PointMass{<:AbstractArray{T, 5}}, meta::Any ) where {T} = begin - eloga = mean(q_a) + N = eltype(probvec(m_out)) + eloga = clamp.(mean(q_a), tiny(N), one(N)) @tullio out[i] := eloga[a, i, b, c, d] * probvec(m_out)[a] * probvec(m_T1)[b] * probvec(m_T2)[c] * probvec(m_T3)[d] return Categorical(normalize!(out, 1); check_args = false) end @@ -159,7 +172,8 @@ end @rule DiscreteTransition(:T1, Marginalisation) ( m_out::DiscreteNonParametric, m_in::DiscreteNonParametric, m_T2::DiscreteNonParametric, m_T3::DiscreteNonParametric, q_a::PointMass{<:AbstractArray{T, 5}}, meta::Any ) where {T} = begin - eloga = mean(q_a) + N = eltype(probvec(m_out)) + eloga = clamp.(mean(q_a), tiny(N), one(N)) @tullio out[i] := eloga[a, b, i, c, d] * probvec(m_out)[a] * probvec(m_in)[b] * probvec(m_T2)[c] * probvec(m_T3)[d] return Categorical(normalize!(out, 1); check_args = false) end @@ -167,7 +181,8 @@ end @rule DiscreteTransition(:T2, Marginalisation) ( m_out::DiscreteNonParametric, m_in::DiscreteNonParametric, m_T1::DiscreteNonParametric, m_T2::DiscreteNonParametric, q_a::PointMass{<:AbstractArray{T, 5}}, meta::Any ) where {T} = begin - eloga = mean(q_a) + N = eltype(probvec(m_out)) + eloga = clamp.(mean(q_a), tiny(N), one(N)) @tullio out[i] := eloga[a, b, c, i, d] * probvec(m_out)[a] * probvec(m_in)[b] * probvec(m_T1)[c] * probvec(m_T2)[d] return Categorical(normalize!(out, 1); check_args = false) end @@ -175,7 +190,8 @@ end @rule DiscreteTransition(:T3, Marginalisation) ( m_out::DiscreteNonParametric, m_in::DiscreteNonParametric, m_T1::DiscreteNonParametric, m_T2::DiscreteNonParametric, q_a::PointMass{<:AbstractArray{T, 5}}, meta::Any ) where {T} = begin - eloga = mean(q_a) + N = eltype(probvec(m_out)) + eloga = clamp.(mean(q_a), tiny(N), one(N)) @tullio out[i] := eloga[a, b, c, d, i] * probvec(m_out)[a] * probvec(m_in)[b] * probvec(m_T1)[c] * probvec(m_T2)[d] return Categorical(normalize!(out, 1); check_args = false) end @@ -184,7 +200,7 @@ end @rule DiscreteTransition(:out, Marginalisation) ( m_in::DiscreteNonParametric, m_T1::DiscreteNonParametric, m_T2::DiscreteNonParametric, m_T3::DiscreteNonParametric, q_a::DirichletCollection, meta::Any ) = begin - eloga = exp.(mean(Base.Broadcast.BroadcastFunction(clamplog), q_a)) + eloga = softmax!(mean(Base.Broadcast.BroadcastFunction(clamplog), q_a)) @tullio out[i] := eloga[i, a, b, c, d] * probvec(m_in)[a] * probvec(m_T1)[b] * probvec(m_T2)[c] * probvec(m_T3)[d] return Categorical(normalize!(out, 1); check_args = false) end @@ -192,7 +208,7 @@ end @rule DiscreteTransition(:in, Marginalisation) ( m_out::DiscreteNonParametric, m_T1::DiscreteNonParametric, m_T2::DiscreteNonParametric, m_T3::DiscreteNonParametric, q_a::DirichletCollection, meta::Any ) = begin - eloga = exp.(mean(Base.Broadcast.BroadcastFunction(clamplog), q_a)) + eloga = softmax!(mean(Base.Broadcast.BroadcastFunction(clamplog), q_a)) @tullio out[i] := eloga[a, i, b, c, d] * probvec(m_out)[a] * probvec(m_T1)[b] * probvec(m_T2)[c] * probvec(m_T3)[d] return Categorical(normalize!(out, 1); check_args = false) end @@ -200,7 +216,7 @@ end @rule DiscreteTransition(:T1, Marginalisation) ( m_out::DiscreteNonParametric, m_in::DiscreteNonParametric, m_T2::DiscreteNonParametric, m_T3::DiscreteNonParametric, q_a::DirichletCollection, meta::Any ) = begin - eloga = exp.(mean(Base.Broadcast.BroadcastFunction(clamplog), q_a)) + eloga = softmax!(mean(Base.Broadcast.BroadcastFunction(clamplog), q_a)) @tullio out[i] := eloga[a, b, i, c, d] * probvec(m_out)[a] * probvec(m_in)[b] * probvec(m_T2)[c] * probvec(m_T3)[d] return Categorical(normalize!(out, 1); check_args = false) end @@ -208,7 +224,7 @@ end @rule DiscreteTransition(:T2, Marginalisation) ( m_out::DiscreteNonParametric, m_in::DiscreteNonParametric, m_T1::DiscreteNonParametric, m_T2::DiscreteNonParametric, q_a::DirichletCollection, meta::Any ) = begin - eloga = exp.(mean(Base.Broadcast.BroadcastFunction(clamplog), q_a)) + eloga = softmax!(mean(Base.Broadcast.BroadcastFunction(clamplog), q_a)) @tullio out[i] := eloga[a, b, c, i, d] * probvec(m_out)[a] * probvec(m_in)[b] * probvec(m_T1)[c] * probvec(m_T2)[d] return Categorical(normalize!(out, 1); check_args = false) end @@ -216,7 +232,7 @@ end @rule DiscreteTransition(:T3, Marginalisation) ( m_out::DiscreteNonParametric, m_in::DiscreteNonParametric, m_T1::DiscreteNonParametric, m_T2::DiscreteNonParametric, q_a::DirichletCollection, meta::Any ) = begin - eloga = exp.(mean(Base.Broadcast.BroadcastFunction(clamplog), q_a)) + eloga = softmax!(mean(Base.Broadcast.BroadcastFunction(clamplog), q_a), dims = 1) @tullio out[i] := eloga[a, b, c, d, i] * probvec(m_out)[a] * probvec(m_in)[b] * probvec(m_T1)[c] * probvec(m_T2)[d] return Categorical(normalize!(out, 1); check_args = false) end diff --git a/src/rules/discrete_transition/predefined/marginals.jl b/src/rules/discrete_transition/predefined/marginals.jl index 710bef6db..e0394012b 100644 --- a/src/rules/discrete_transition/predefined/marginals.jl +++ b/src/rules/discrete_transition/predefined/marginals.jl @@ -1,21 +1,24 @@ using Tullio @marginalrule DiscreteTransition(:out_in) (m_out::DiscreteNonParametric, m_in::DiscreteNonParametric, q_a::DirichletCollection, meta::Any) = begin - eloga = exp.(mean(Base.Broadcast.BroadcastFunction(clamplog), q_a)) + eloga = softmax!(mean(Base.Broadcast.BroadcastFunction(clamplog), q_a)) @tullio result[a, b] := eloga[a, b] * probvec(m_out)[a] * probvec(m_in)[b] - return Contingency(result) + normalize!(result, 1) + return Contingency(result, Val(false)) end @marginalrule DiscreteTransition(:out_in) (m_out::DiscreteNonParametric, m_in::DiscreteNonParametric, q_a::DirichletCollection, q_T1::PointMass, meta::Any) = begin eloga = mean(Base.Broadcast.BroadcastFunction(clamplog), q_a) - @tullio result[a, b] := eloga[a, b, i] * probvec(q_T1)[i] - result = exp.(result) + result = eloga[:, :, findfirst(isone, probvec(q_T1))] + softmax!(result) @tullio result[a, b] = result[a, b] * probvec(m_out)[a] * probvec(m_in)[b] - return Contingency(result) + normalize!(result, 1) + return Contingency(result, Val(false)) end @marginalrule DiscreteTransition(:out_in_T1) (m_out::Categorical, m_in::Categorical, m_T1::Categorical, q_a::DirichletCollection, meta::Any) = begin - eloga = exp.(mean(Base.Broadcast.BroadcastFunction(clamplog), q_a)) + eloga = softmax!(mean(Base.Broadcast.BroadcastFunction(clamplog), q_a)) @tullio result[a, b, c] := eloga[a, b, c] * probvec(m_out)[a] * probvec(m_in)[b] * probvec(m_T1)[c] - return Contingency(result) + normalize!(result, 1) + return Contingency(result, Val(false)) end diff --git a/src/rules/discrete_transition/predefined/structured_vmp.jl b/src/rules/discrete_transition/predefined/structured_vmp.jl index fb037d5c7..e1c632609 100644 --- a/src/rules/discrete_transition/predefined/structured_vmp.jl +++ b/src/rules/discrete_transition/predefined/structured_vmp.jl @@ -6,21 +6,23 @@ using Tullio out = eloga' * probvec(q_out) out .= exp.(out) return Categorical(normalize!(out, 1); check_args = false) + softmax!(out) + return Categorical(out; check_args = false) end # --------------- Rules for 2 interfaces (q_out PointMass, q_a DirichletCollection) --------------- @rule DiscreteTransition(:in, Marginalisation) (q_out::PointMass{<:AbstractVector}, q_a::DirichletCollection, meta::Any) = begin eloga = mean(Base.Broadcast.BroadcastFunction(clamplog), q_a) out = eloga' * probvec(q_out) - out .= exp.(out) - return Categorical(normalize!(out, 1); check_args = false) + softmax!(out) + return Categorical(out; check_args = false) end # --------------- Rules for 3 interfaces (q_out PointMass, q_a PointMass) --------------- @rule DiscreteTransition(:in, Marginalisation) (q_out::PointMass{<:AbstractVector}, q_a::PointMass{<:AbstractArray{T, 3}}, m_T1::DiscreteNonParametric, meta::Any) where {T} = begin eloga = mean(Base.Broadcast.BroadcastFunction(clamplog), q_a) @tullio out[i, j] := eloga[a, i, j] * probvec(q_out)[a] - out .= exp.(out) + softmax!(out) msg = out * probvec(m_T1) return Categorical(normalize!(msg, 1); check_args = false) end @@ -28,7 +30,7 @@ end @rule DiscreteTransition(:T1, Marginalisation) (q_out::PointMass{<:AbstractVector}, m_in::DiscreteNonParametric, q_a::PointMass{<:AbstractArray{T, 3}}, meta::Any) where {T} = begin eloga = mean(Base.Broadcast.BroadcastFunction(clamplog), q_a) @tullio out[i, j] := eloga[a, i, j] * probvec(q_out)[a] - out .= exp.(out) + softmax!(out) msg = out' * probvec(m_in) return Categorical(normalize!(msg, 1); check_args = false) end @@ -37,7 +39,7 @@ end @rule DiscreteTransition(:in, Marginalisation) (q_out::PointMass{<:AbstractVector}, q_a::DirichletCollection, m_T1::DiscreteNonParametric, meta::Any) = begin eloga = mean(Base.Broadcast.BroadcastFunction(clamplog), q_a) @tullio out[i, j] := eloga[a, i, j] * probvec(q_out)[a] - out .= exp.(out) + softmax!(out) msg = out * probvec(m_T1) return Categorical(normalize!(msg, 1); check_args = false) end @@ -45,7 +47,7 @@ end @rule DiscreteTransition(:T1, Marginalisation) (q_out::PointMass{<:AbstractVector}, m_in::DiscreteNonParametric, q_a::DirichletCollection, meta::Any) = begin eloga = mean(Base.Broadcast.BroadcastFunction(clamplog), q_a) @tullio out[i, j] := eloga[a, i, j] * probvec(q_out)[a] - out .= exp.(out) + softmax!(out) msg = out' * probvec(m_in) return Categorical(normalize!(msg, 1); check_args = false) end @@ -56,7 +58,7 @@ end ) where {T} = begin eloga = mean(Base.Broadcast.BroadcastFunction(clamplog), q_a) @tullio out[i, j, k] := eloga[a, i, j, k] * probvec(q_out)[a] - out .= exp.(out) + softmax!(out) @tullio msg[i] := out[i, j, k] * probvec(m_T1)[j] * probvec(m_T2)[k] return Categorical(normalize!(msg, 1); check_args = false) end @@ -66,7 +68,7 @@ end ) where {T} = begin eloga = mean(Base.Broadcast.BroadcastFunction(clamplog), q_a) @tullio out[i, j, k] := eloga[a, i, j, k] * probvec(q_out)[a] - out .= exp.(out) + softmax!(out) @tullio msg[j] := out[i, j, k] * probvec(m_in)[i] * probvec(m_T2)[k] return Categorical(normalize!(msg, 1); check_args = false) end @@ -76,7 +78,7 @@ end ) where {T} = begin eloga = mean(Base.Broadcast.BroadcastFunction(clamplog), q_a) @tullio out[i, j, k] := eloga[a, i, j, k] * probvec(q_out)[a] - out .= exp.(out) + softmax!(out) @tullio msg[k] := out[i, j, k] * probvec(m_in)[i] * probvec(m_T1)[j] return Categorical(normalize!(msg, 1); check_args = false) end @@ -86,7 +88,7 @@ end begin eloga = mean(Base.Broadcast.BroadcastFunction(clamplog), q_a) @tullio out[i, j, k] := eloga[a, i, j, k] * probvec(q_out)[a] - out .= exp.(out) + softmax!(out) @tullio msg[i] := out[i, j, k] * probvec(m_T1)[j] * probvec(m_T2)[k] return Categorical(normalize!(msg, 1); check_args = false) end @@ -95,7 +97,7 @@ end begin eloga = mean(Base.Broadcast.BroadcastFunction(clamplog), q_a) @tullio out[i, j, k] := eloga[a, i, j, k] * probvec(q_out)[a] - out .= exp.(out) + softmax!(out) @tullio msg[j] := out[i, j, k] * probvec(m_in)[i] * probvec(m_T2)[k] return Categorical(normalize!(msg, 1); check_args = false) end @@ -104,7 +106,7 @@ end begin eloga = mean(Base.Broadcast.BroadcastFunction(clamplog), q_a) @tullio out[i, j, k] := eloga[a, i, j, k] * probvec(q_out)[a] - out .= exp.(out) + softmax!(out) @tullio msg[k] := out[i, j, k] * probvec(m_in)[i] * probvec(m_T1)[j] return Categorical(normalize!(msg, 1); check_args = false) end @@ -115,7 +117,7 @@ end ) where {T} = begin eloga = mean(Base.Broadcast.BroadcastFunction(clamplog), q_a) @tullio out[i, j, k, l] := eloga[a, i, j, k, l] * probvec(q_out)[a] - out .= exp.(out) + softmax!(out) @tullio msg[i] := out[i, j, k, l] * probvec(m_T1)[j] * probvec(m_T2)[k] * probvec(m_T3)[l] return Categorical(normalize!(msg, 1); check_args = false) end @@ -125,7 +127,7 @@ end ) where {T} = begin eloga = mean(Base.Broadcast.BroadcastFunction(clamplog), q_a) @tullio out[i, j, k, l] := eloga[a, i, j, k, l] * probvec(q_out)[a] - out .= exp.(out) + softmax!(out) @tullio msg[j] := out[i, j, k, l] * probvec(m_in)[i] * probvec(m_T2)[k] * probvec(m_T3)[l] return Categorical(normalize!(msg, 1); check_args = false) end @@ -135,7 +137,7 @@ end ) where {T} = begin eloga = mean(Base.Broadcast.BroadcastFunction(clamplog), q_a) @tullio out[i, j, k, l] := eloga[a, i, j, k, l] * probvec(q_out)[a] - out .= exp.(out) + softmax!(out) @tullio msg[k] := out[i, j, k, l] * probvec(m_in)[i] * probvec(m_T1)[j] * probvec(m_T2)[l] return Categorical(normalize!(msg, 1); check_args = false) end @@ -145,7 +147,7 @@ end ) where {T} = begin eloga = mean(Base.Broadcast.BroadcastFunction(clamplog), q_a) @tullio out[i, j, k, l] := eloga[a, i, j, k, l] * probvec(q_out)[a] - out .= exp.(out) + softmax!(out) @tullio msg[l] := out[i, j, k, l] * probvec(m_in)[i] * probvec(m_T1)[j] * probvec(m_T2)[k] return Categorical(normalize!(msg, 1); check_args = false) end @@ -156,7 +158,7 @@ end ) = begin eloga = mean(Base.Broadcast.BroadcastFunction(clamplog), q_a) @tullio out[i, j, k, l] := eloga[a, i, j, k, l] * probvec(q_out)[a] - out .= exp.(out) + softmax!(out) @tullio msg[i] := out[i, j, k, l] * probvec(m_T1)[j] * probvec(m_T2)[k] * probvec(m_T3)[l] return Categorical(normalize!(msg, 1); check_args = false) end @@ -166,7 +168,7 @@ end ) = begin eloga = mean(Base.Broadcast.BroadcastFunction(clamplog), q_a) @tullio out[i, j, k, l] := eloga[a, i, j, k, l] * probvec(q_out)[a] - out .= exp.(out) + softmax!(out) @tullio msg[j] := out[i, j, k, l] * probvec(m_in)[i] * probvec(m_T2)[k] * probvec(m_T3)[l] return Categorical(normalize!(msg, 1); check_args = false) end @@ -176,7 +178,7 @@ end ) = begin eloga = mean(Base.Broadcast.BroadcastFunction(clamplog), q_a) @tullio out[i, j, k, l] := eloga[a, i, j, k, l] * probvec(q_out)[a] - out .= exp.(out) + softmax!(out) @tullio msg[k] := out[i, j, k, l] * probvec(m_in)[i] * probvec(m_T1)[j] * probvec(m_T2)[l] return Categorical(normalize!(msg, 1); check_args = false) end @@ -186,7 +188,7 @@ end ) = begin eloga = mean(Base.Broadcast.BroadcastFunction(clamplog), q_a) @tullio out[i, j, k, l] := eloga[a, i, j, k, l] * probvec(q_out)[a] - out .= exp.(out) + softmax!(out) @tullio msg[l] := out[i, j, k, l] * probvec(m_in)[i] * probvec(m_T1)[j] * probvec(m_T2)[k] return Categorical(normalize!(msg, 1); check_args = false) end diff --git a/test/rules/discrete_transition/in_tests.jl b/test/rules/discrete_transition/in_tests.jl index 3eb6f11c3..27015a68a 100644 --- a/test/rules/discrete_transition/in_tests.jl +++ b/test/rules/discrete_transition/in_tests.jl @@ -2,10 +2,23 @@ @testitem "rules:DiscreteTransition:in:Variational Bayes: (q_out::Any, q_a::DirichletCollection)" begin using ReactiveMP, BayesBase, Random, ExponentialFamily, Distributions + import ReactiveMP: @test_rules + @test_rules [check_type_promotion = false] DiscreteTransition(:in, Marginalisation) [ + ( + input = (q_out = PointMass([0.1, 0.4, 0.5]), q_a = DirichletCollection([0.2 0.1 0.7; 0.4 0.3 0.3; 0.1 0.6 0.3])), + output = Categorical([0.03245589526827472, 0.5950912160314408, 0.37245288870028453]) + ), + (input = (q_out = PointMass([0.0, 1.0, 0.0]), q_a = DirichletCollection(diageye(3) .+ tiny)), output = Categorical([0.0, 1.0, 0.0])) + ] +end + +@testitem "rules:DiscreteTransition:in:Variational Bayes: (q_out::Any, q_a::PointMass)" begin + using ReactiveMP, BayesBase, Random, ExponentialFamily, Distributions + import ReactiveMP: @test_rules @test_rules [check_type_promotion = false] DiscreteTransition(:in, Marginalisation) [( - input = (q_out = PointMass([0.1, 0.4, 0.5]), q_a = DirichletCollection([0.2 0.1 0.7; 0.4 0.3 0.3; 0.1 0.6 0.3])), - output = Categorical([0.03245589526827472, 0.5950912160314408, 0.37245288870028453]) + input = (q_out = PointMass([0.1, 0.4, 0.5]), q_a = PointMass([0.2 0.4 0.1; 0.1 0.3 0.6; 0.7 0.3 0.3])), + output = Categorical([0.29943853278212923, 0.32603993277541166, 0.3745215344424593]) )] end @@ -29,10 +42,13 @@ end using ReactiveMP, BayesBase, Random, ExponentialFamily, Distributions import ReactiveMP: @test_rules - @test_rules [check_type_promotion = false] DiscreteTransition(:in, Marginalisation) [( - input = (m_out = Categorical([0.1, 0.4, 0.5]), q_a = PointMass([0.2 0.1 0.7; 0.4 0.3 0.3; 0.1 0.6 0.3])), - output = Categorical([0.23000000000000004, 0.43, 0.33999999999999997]) - )] + @test_rules [check_type_promotion = false] DiscreteTransition(:in, Marginalisation) [ + ( + input = (m_out = Categorical([0.1, 0.4, 0.5]), q_a = PointMass([0.2 0.4 0.1; 0.1 0.3 0.6; 0.7 0.3 0.3])), + output = Categorical([0.36607142857142855, 0.27678571428571425, 0.35714285714285715]) + ), + (input = (m_out = Categorical([1.0, 0.0, 0.0]), q_a = PointMass([0.0 0.0 0.0; 0.5 0.5 0.5; 0.5 0.5 0.5])), output = Categorical([1 / 3, 1 / 3, 1 / 3])) + ] end @testitem "rules:DiscreteTransition:in:Belief Propagation: (m_out::Categorical, q_a::DirichletCollection, m_t1::Categorical)" begin diff --git a/test/rules/discrete_transition/out_tests.jl b/test/rules/discrete_transition/out_tests.jl index 74113f3da..5b5ba9370 100644 --- a/test/rules/discrete_transition/out_tests.jl +++ b/test/rules/discrete_transition/out_tests.jl @@ -5,12 +5,12 @@ import ReactiveMP: @test_rules @test_rules [check_type_promotion = false] DiscreteTransition(:out, Marginalisation) [ ( - input = (q_in = PointMass([0.1, 0.4, 0.5]), q_a = PointMass([0.2 0.1 0.7; 0.4 0.3 0.3; 0.1 0.6 0.3])), - output = Categorical([0.29943853278212923, 0.32603993277541166, 0.3745215344424593]) + input = (q_in = PointMass([0.1, 0.4, 0.5]), q_a = PointMass([0.2 0.4 0.1; 0.1 0.3 0.6; 0.7 0.3 0.3])), + output = Categorical([0.2089059609775621, 0.425547421552122, 0.3655466174703161]) ), ( - input = (q_in = PointMass([0.2, 0.5, 0.3]), q_a = PointMass([0.1 0.8 0.1; 0.6 0.3 0.1; 0.2 0.4 0.4])), - output = Categorical([0.3218092957943072, 0.2819969875260698, 0.39619371667962294]) + input = (q_in = PointMass([0.2, 0.5, 0.3]), q_a = PointMass([0.1 0.6 0.2; 0.8 0.3 0.4; 0.1 0.1 0.4])), + output = Categorical([0.35434347876072936, 0.4675590233555693, 0.17809749788370124]) ), (input = (q_in = PointMass([1.0, 0.0, 0.0]), q_a = PointMass([1.0 0.0 0.0; 0.0 1.0 0.0; 0.0 0.0 1.0])), output = Categorical([1.0, 0.0, 0.0])) ] @@ -21,8 +21,8 @@ end import ReactiveMP: @test_rules @test_rules [check_type_promotion = false] DiscreteTransition(:out, Marginalisation) [( - input = (q_in = Categorical([0.1, 0.4, 0.5]), q_a = PointMass([0.2 0.1 0.7; 0.4 0.3 0.3; 0.1 0.6 0.3])), - output = Categorical([0.2994385327821292, 0.3260399327754116, 0.37452153444245917]) + input = (q_in = Categorical([0.1, 0.4, 0.5]), q_a = PointMass([0.2 0.4 0.1; 0.1 0.3 0.6; 0.7 0.3 0.3])), + output = Categorical([0.2089059609775621, 0.425547421552122, 0.3655466174703161]) )] end @@ -44,14 +44,8 @@ end import ReactiveMP: @test_rules @test_rules [check_type_promotion = false] DiscreteTransition(:out, Marginalisation) [ - ( - input = (m_in = Categorical([0.1, 0.4, 0.5]), q_a = PointMass([0.2 0.1 0.7; 0.4 0.3 0.3; 0.1 0.6 0.3])), - output = Categorical([0.3660714285714285, 0.27678571428571425, 0.35714285714285715]) - ), - ( - input = (m_in = Categorical([0.2, 0.5, 0.3]), q_a = PointMass([0.1 0.8 0.1; 0.6 0.3 0.1; 0.2 0.4 0.4])), - output = Categorical([0.40540540540540543, 0.2702702702702703, 0.32432432432432434]) - ) + (input = (m_in = Categorical([0.1, 0.4, 0.5]), q_a = PointMass([0.2 0.4 0.1; 0.1 0.3 0.6; 0.7 0.3 0.3])), output = Categorical([0.23, 0.43, 0.34])), + (input = (m_in = Categorical([0.2, 0.5, 0.3]), q_a = PointMass([0.1 0.6 0.2; 0.8 0.3 0.4; 0.1 0.1 0.4])), output = Categorical([0.38, 0.43, 0.19])) ] end @@ -77,9 +71,9 @@ end import ReactiveMP: @test_rules @test_rules [check_type_promotion = false] DiscreteTransition(:out, Marginalisation) [ (input = (m_in = PointMass([0.0, 1.0, 0.0]), q_a = PointMass([0.2 0.1 0.7; 0.4 0.3 0.3; 0.1 0.6 0.3])), output = Categorical([0.1, 0.3, 0.6])), - (input = (m_in = PointMass([1.0, 0.0, 0.0]), q_a = PointMass([0.1 0.8 0.1; 0.6 0.3 0.1; 0.2 0.4 0.4])), output = Categorical([0.1 / 0.9, 0.6 / 0.9, 0.2 / 0.9])), + (input = (m_in = PointMass([1.0, 0.0, 0.0]), q_a = PointMass([0.1 0.6 0.2; 0.8 0.3 0.4; 0.1 0.1 0.4])), output = Categorical([0.1, 0.8, 0.1])), (input = (m_in = PointMass([0, 1, 0]), q_a = PointMass([0.2 0.1 0.7; 0.4 0.3 0.3; 0.1 0.6 0.3])), output = Categorical([0.1, 0.3, 0.6])), - (input = (m_in = PointMass([1, 0, 0]), q_a = PointMass([0.1 0.8 0.1; 0.6 0.3 0.1; 0.2 0.4 0.4])), output = Categorical([0.1 / 0.9, 0.6 / 0.9, 0.2 / 0.9])) + (input = (m_in = PointMass([1, 0, 0]), q_a = PointMass([0.1 0.6 0.2; 0.8 0.3 0.4; 0.1 0.1 0.4])), output = Categorical([0.1, 0.8, 0.1])) ] end diff --git a/test/rules/discrete_transition/t_tests.jl b/test/rules/discrete_transition/t_tests.jl index e7005359d..7e6ec70e2 100644 --- a/test/rules/discrete_transition/t_tests.jl +++ b/test/rules/discrete_transition/t_tests.jl @@ -10,7 +10,7 @@ m_in = Categorical([0.2, 0.5, 0.3]), q_a = PointMass([1.0 6.0 32.0; 2.0 2.0 9.0; 5.0 5.0 6.0;;; 9.0 5.0 6.0; 4.0 10.0 6.0; 10.0 6.0 32.0;;; 6.0 1.0 8.0; 2.0 10.0 7.0; 1.0 3.0 8.0]) ), - output = Categorical([0.28971962616822433, 0.4392523364485981, 0.2710280373831776]) + output = Categorical([1 / 3, 1 / 3, 1 / 3]) ), ( input = ( @@ -285,7 +285,7 @@ end # Test T1 interface with 3 interfaces (BP with DirichletCollection q_a) using ReactiveMP, BayesBase, Random, ExponentialFamily, Distributions - import ReactiveMP: @test_rules + import ReactiveMP: @test_rules, normalize! @testset "Belief Propagation: T1 with 3 interfaces (DirichletCollection q_a)" begin @test_rules [check_type_promotion = true, extra_float_types = [Float64, Float32]] DiscreteTransition(:T1, Marginalisation) [ ( @@ -324,11 +324,16 @@ end input = ( m_out = Categorical([0.3, 0.4, 0.3]), m_in = Categorical([0.2, 0.5, 0.3]), - q_a = PointMass([ - 1.0 2.0 3.0; 4.0 5.0 6.0; 7.0 8.0 9.0;;; - 10.0 11.0 12.0; 13.0 14.0 15.0; 16.0 17.0 18.0;;; - 19.0 20.0 21.0; 22.0 23.0 24.0; 25.0 26.0 27.0 - ]) + q_a = PointMass( + normalize!( + [ + 1.0 2.0 3.0; 4.0 5.0 6.0; 7.0 8.0 9.0;;; + 10.0 11.0 12.0; 13.0 14.0 15.0; 16.0 17.0 18.0;;; + 19.0 20.0 21.0; 22.0 23.0 24.0; 25.0 26.0 27.0 + ], + 1 + ) + ) ), output = Categorical([0.12056737588652484, 0.3333333333333333, 0.546099290780142]) ), @@ -336,11 +341,16 @@ end input = ( m_out = Categorical([0.6, 0.3, 0.1]), m_in = Categorical([0.1, 0.7, 0.2]), - q_a = PointMass([ - 5.0 1.0 2.0; 3.0 8.0 4.0; 6.0 7.0 9.0;;; - 2.0 4.0 6.0; 8.0 10.0 12.0; 14.0 16.0 18.0;;; - 1.0 3.0 5.0; 7.0 9.0 11.0; 13.0 15.0 17.0 - ]) + q_a = PointMass( + normalize!( + [ + 5.0 1.0 2.0; 3.0 8.0 4.0; 6.0 7.0 9.0;;; + 2.0 4.0 6.0; 8.0 10.0 12.0; 14.0 16.0 18.0;;; + 1.0 3.0 5.0; 7.0 9.0 11.0; 13.0 15.0 17.0 + ], + 1 + ) + ) ), output = Categorical([0.2163742690058479, 0.4210526315789473, 0.3625730994152047]) ) @@ -392,11 +402,14 @@ end m_in = Categorical([0.2, 0.5, 0.3]), m_T2 = Categorical([0.1, 0.6, 0.3]), q_a = PointMass( - [ - 8.0 8.0 10.0; 3.0 2.0 3.0; 2.0 2.0 10.0;;; 5.0 2.0 8.0; 2.0 10.0 8.0; 9.0 6.0 6.0;;; 11.0 6.0 1.0; 3.0 8.0 4.0; 3.0 7.0 9.0;;;; - 4.0 8.0 6.0; 8.0 5.0 6.0; 5.0 10.0 11.0;;; 11.0 2.0 10.0; 2.0 7.0 10.0; 7.0 9.0 11.0;;; 5.0 10.0 4.0; 3.0 7.0 9.0; 7.0 2.0 8.0;;;; - 10.0 6.0 4.0; 10.0 11.0 5.0; 6.0 4.0 3.0;;; 9.0 9.0 2.0; 6.0 6.0 1.0; 5.0 4.0 5.0;;; 11.0 5.0 7.0; 6.0 3.0 3.0; 1.0 8.0 8.0 - ] + normalize!( + [ + 8.0 8.0 10.0; 3.0 2.0 3.0; 2.0 2.0 10.0;;; 5.0 2.0 8.0; 2.0 10.0 8.0; 9.0 6.0 6.0;;; 11.0 6.0 1.0; 3.0 8.0 4.0; 3.0 7.0 9.0;;;; + 4.0 8.0 6.0; 8.0 5.0 6.0; 5.0 10.0 11.0;;; 11.0 2.0 10.0; 2.0 7.0 10.0; 7.0 9.0 11.0;;; 5.0 10.0 4.0; 3.0 7.0 9.0; 7.0 2.0 8.0;;;; + 10.0 6.0 4.0; 10.0 11.0 5.0; 6.0 4.0 3.0;;; 9.0 9.0 2.0; 6.0 6.0 1.0; 5.0 4.0 5.0;;; 11.0 5.0 7.0; 6.0 3.0 3.0; 1.0 8.0 8.0 + ], + 1 + ) ) ), output = Categorical([0.348472379384837, 0.3413229091657237, 0.31020471144943934]) @@ -407,11 +420,14 @@ end m_in = Categorical([0.1, 0.7, 0.2]), m_T2 = Categorical([0.4, 0.3, 0.3]), q_a = PointMass( - [ - 19.0 7.0 3.0; 3.0 17.0 2.0; 9.0 9.0 12.0;;; 13.0 8.0 2.0; 6.0 16.0 7.0; 7.0 3.0 16.0;;; 15.0 5.0 1.0; 9.0 12.0 5.0; 9.0 10.0 14.0;;;; - 19.0 6.0 3.0; 4.0 17.0 9.0; 1.0 8.0 18.0;;; 11.0 1.0 8.0; 9.0 13.0 4.0; 7.0 1.0 16.0;;; 20.0 8.0 1.0; 9.0 18.0 3.0; 1.0 5.0 14.0;;;; - 20.0 9.0 8.0; 4.0 18.0 5.0; 4.0 6.0 11.0;;; 12.0 5.0 4.0; 8.0 16.0 9.0; 5.0 4.0 18.0;;; 17.0 1.0 10.0; 8.0 13.0 5.0; 5.0 1.0 13.0 - ] + normalize!( + [ + 19.0 7.0 3.0; 3.0 17.0 2.0; 9.0 9.0 12.0;;; 13.0 8.0 2.0; 6.0 16.0 7.0; 7.0 3.0 16.0;;; 15.0 5.0 1.0; 9.0 12.0 5.0; 9.0 10.0 14.0;;;; + 19.0 6.0 3.0; 4.0 17.0 9.0; 1.0 8.0 18.0;;; 11.0 1.0 8.0; 9.0 13.0 4.0; 7.0 1.0 16.0;;; 20.0 8.0 1.0; 9.0 18.0 3.0; 1.0 5.0 14.0;;;; + 20.0 9.0 8.0; 4.0 18.0 5.0; 4.0 6.0 11.0;;; 12.0 5.0 4.0; 8.0 16.0 9.0; 5.0 4.0 18.0;;; 17.0 1.0 10.0; 8.0 13.0 5.0; 5.0 1.0 13.0 + ], + 1 + ) ) ), output = Categorical([0.38575359948949073, 0.30857894946755476, 0.30566745104295456]) @@ -464,11 +480,14 @@ end m_in = Categorical([0.2, 0.5, 0.3]), m_T1 = Categorical([0.1, 0.6, 0.3]), q_a = PointMass( - [ - 8.0 8.0 10.0; 3.0 2.0 3.0; 2.0 2.0 10.0;;; 5.0 2.0 8.0; 2.0 10.0 8.0; 9.0 6.0 6.0;;; 11.0 6.0 1.0; 3.0 8.0 4.0; 3.0 7.0 9.0;;;; - 4.0 8.0 6.0; 8.0 5.0 6.0; 5.0 10.0 11.0;;; 11.0 2.0 10.0; 2.0 7.0 10.0; 7.0 9.0 11.0;;; 5.0 10.0 4.0; 3.0 7.0 9.0; 7.0 2.0 8.0;;;; - 10.0 6.0 4.0; 10.0 11.0 5.0; 6.0 4.0 3.0;;; 9.0 9.0 2.0; 6.0 6.0 1.0; 5.0 4.0 5.0;;; 11.0 5.0 7.0; 6.0 3.0 3.0; 1.0 8.0 8.0 - ] + normalize!( + [ + 8.0 8.0 10.0; 3.0 2.0 3.0; 2.0 2.0 10.0;;; 5.0 2.0 8.0; 2.0 10.0 8.0; 9.0 6.0 6.0;;; 11.0 6.0 1.0; 3.0 8.0 4.0; 3.0 7.0 9.0;;;; + 4.0 8.0 6.0; 8.0 5.0 6.0; 5.0 10.0 11.0;;; 11.0 2.0 10.0; 2.0 7.0 10.0; 7.0 9.0 11.0;;; 5.0 10.0 4.0; 3.0 7.0 9.0; 7.0 2.0 8.0;;;; + 10.0 6.0 4.0; 10.0 11.0 5.0; 6.0 4.0 3.0;;; 9.0 9.0 2.0; 6.0 6.0 1.0; 5.0 4.0 5.0;;; 11.0 5.0 7.0; 6.0 3.0 3.0; 1.0 8.0 8.0 + ], + 1 + ) ) ), output = Categorical([0.3300037591966059, 0.3779066645185543, 0.29208957628483967]) @@ -479,11 +498,14 @@ end m_in = Categorical([0.1, 0.7, 0.2]), m_T1 = Categorical([0.4, 0.3, 0.3]), q_a = PointMass( - [ - 19.0 7.0 3.0; 3.0 17.0 2.0; 9.0 9.0 12.0;;; 13.0 8.0 2.0; 6.0 16.0 7.0; 7.0 3.0 16.0;;; 15.0 5.0 1.0; 9.0 12.0 5.0; 9.0 10.0 14.0;;;; - 19.0 6.0 3.0; 4.0 17.0 9.0; 1.0 8.0 18.0;;; 11.0 1.0 8.0; 9.0 13.0 4.0; 7.0 1.0 16.0;;; 20.0 8.0 1.0; 9.0 18.0 3.0; 1.0 5.0 14.0;;;; - 20.0 9.0 8.0; 4.0 18.0 5.0; 4.0 6.0 11.0;;; 12.0 5.0 4.0; 8.0 16.0 9.0; 5.0 4.0 18.0;;; 17.0 1.0 10.0; 8.0 13.0 5.0; 5.0 1.0 13.0 - ] + normalize!( + [ + 19.0 7.0 3.0; 3.0 17.0 2.0; 9.0 9.0 12.0;;; 13.0 8.0 2.0; 6.0 16.0 7.0; 7.0 3.0 16.0;;; 15.0 5.0 1.0; 9.0 12.0 5.0; 9.0 10.0 14.0;;;; + 19.0 6.0 3.0; 4.0 17.0 9.0; 1.0 8.0 18.0;;; 11.0 1.0 8.0; 9.0 13.0 4.0; 7.0 1.0 16.0;;; 20.0 8.0 1.0; 9.0 18.0 3.0; 1.0 5.0 14.0;;;; + 20.0 9.0 8.0; 4.0 18.0 5.0; 4.0 6.0 11.0;;; 12.0 5.0 4.0; 8.0 16.0 9.0; 5.0 4.0 18.0;;; 17.0 1.0 10.0; 8.0 13.0 5.0; 5.0 1.0 13.0 + ], + 1 + ) ) ), output = Categorical([0.33546815449294726, 0.3245059133236415, 0.3400259321834113]) @@ -539,14 +561,17 @@ end m_T2 = Categorical([0.1, 0.6, 0.3]), m_T3 = Categorical([0.2, 0.3, 0.5]), q_a = PointMass( - [ - 18.0 6.0 5.0; 7.0 16.0 5.0; 2.0 2.0 18.0;;; 14.0 9.0 1.0; 2.0 13.0 7.0; 1.0 8.0 14.0;;; 18.0 2.0 2.0; 9.0 11.0 4.0; 6.0 1.0 14.0;;;; 14.0 3.0 10.0; 10.0 17.0 9.0; 5.0 7.0 13.0;;; 13.0 10.0 5.0; 9.0 18.0 4.0; 1.0 4.0 12.0;;; 18.0 7.0 9.0; 2.0 20.0 7.0; 5.0 8.0 11.0;;;; 18.0 3.0 8.0; 9.0 20.0 9.0; 8.0 5.0 19.0;;; 15.0 9.0 2.0; 1.0 14.0 1.0; 4.0 2.0 12.0;;; 18.0 5.0 10.0; 1.0 18.0 2.0; 9.0 2.0 12.0;;;;; - 16.0 2.0 1.0; 5.0 18.0 8.0; 9.0 6.0 14.0;;; 17.0 6.0 6.0; 1.0 18.0 7.0; 4.0 5.0 16.0;;; 19.0 8.0 9.0; 4.0 11.0 9.0; 9.0 10.0 18.0;;;; 14.0 7.0 8.0; 1.0 18.0 2.0; 4.0 3.0 12.0;;; 15.0 5.0 4.0; 9.0 20.0 2.0; 7.0 10.0 15.0;;; 13.0 5.0 6.0; 6.0 15.0 4.0; 1.0 10.0 16.0;;;; 12.0 2.0 1.0; 3.0 18.0 7.0; 8.0 7.0 15.0;;; 19.0 3.0 4.0; 7.0 14.0 8.0; 1.0 6.0 14.0;;; 16.0 8.0 8.0; 5.0 16.0 9.0; 1.0 10.0 11.0;;;;; - 19.0 1.0 6.0; 3.0 20.0 9.0; 5.0 9.0 12.0;;; 13.0 6.0 1.0; 9.0 12.0 7.0; 3.0 1.0 20.0;;; 19.0 10.0 8.0; 6.0 16.0 6.0; 9.0 10.0 16.0;;;; 13.0 1.0 4.0; 6.0 18.0 8.0; 5.0 9.0 20.0;;; 18.0 8.0 5.0; 8.0 11.0 4.0; 7.0 9.0 13.0;;; 15.0 8.0 10.0; 10.0 14.0 10.0; 10.0 2.0 16.0;;;; 14.0 8.0 10.0; 5.0 18.0 9.0; 9.0 4.0 15.0;;; 15.0 9.0 1.0; 7.0 12.0 8.0; 5.0 8.0 12.0;;; 20.0 3.0 1.0; 9.0 14.0 1.0; 2.0 9.0 13.0 - ] + normalize!( + [ + 18.0 6.0 5.0; 7.0 16.0 5.0; 2.0 2.0 18.0;;; 14.0 9.0 1.0; 2.0 13.0 7.0; 1.0 8.0 14.0;;; 18.0 2.0 2.0; 9.0 11.0 4.0; 6.0 1.0 14.0;;;; 14.0 3.0 10.0; 10.0 17.0 9.0; 5.0 7.0 13.0;;; 13.0 10.0 5.0; 9.0 18.0 4.0; 1.0 4.0 12.0;;; 18.0 7.0 9.0; 2.0 20.0 7.0; 5.0 8.0 11.0;;;; 18.0 3.0 8.0; 9.0 20.0 9.0; 8.0 5.0 19.0;;; 15.0 9.0 2.0; 1.0 14.0 1.0; 4.0 2.0 12.0;;; 18.0 5.0 10.0; 1.0 18.0 2.0; 9.0 2.0 12.0;;;;; + 16.0 2.0 1.0; 5.0 18.0 8.0; 9.0 6.0 14.0;;; 17.0 6.0 6.0; 1.0 18.0 7.0; 4.0 5.0 16.0;;; 19.0 8.0 9.0; 4.0 11.0 9.0; 9.0 10.0 18.0;;;; 14.0 7.0 8.0; 1.0 18.0 2.0; 4.0 3.0 12.0;;; 15.0 5.0 4.0; 9.0 20.0 2.0; 7.0 10.0 15.0;;; 13.0 5.0 6.0; 6.0 15.0 4.0; 1.0 10.0 16.0;;;; 12.0 2.0 1.0; 3.0 18.0 7.0; 8.0 7.0 15.0;;; 19.0 3.0 4.0; 7.0 14.0 8.0; 1.0 6.0 14.0;;; 16.0 8.0 8.0; 5.0 16.0 9.0; 1.0 10.0 11.0;;;;; + 19.0 1.0 6.0; 3.0 20.0 9.0; 5.0 9.0 12.0;;; 13.0 6.0 1.0; 9.0 12.0 7.0; 3.0 1.0 20.0;;; 19.0 10.0 8.0; 6.0 16.0 6.0; 9.0 10.0 16.0;;;; 13.0 1.0 4.0; 6.0 18.0 8.0; 5.0 9.0 20.0;;; 18.0 8.0 5.0; 8.0 11.0 4.0; 7.0 9.0 13.0;;; 15.0 8.0 10.0; 10.0 14.0 10.0; 10.0 2.0 16.0;;;; 14.0 8.0 10.0; 5.0 18.0 9.0; 9.0 4.0 15.0;;; 15.0 9.0 1.0; 7.0 12.0 8.0; 5.0 8.0 12.0;;; 20.0 3.0 1.0; 9.0 14.0 1.0; 2.0 9.0 13.0 + ], + 1 + ) ) ), - output = Categorical([0.3404836595269985, 0.3200594005236625, 0.339456939949339]) + output = Categorical([0.34048365920174845, 0.3200594007157279, 0.33945694008252375]) ), ( input = ( @@ -555,11 +580,14 @@ end m_T2 = Categorical([0.4, 0.3, 0.3]), m_T3 = Categorical([0.5, 0.2, 0.3]), q_a = PointMass( - [ - 15.0 4.0 9.0; 10.0 13.0 8.0; 2.0 8.0 15.0;;; 20.0 1.0 9.0; 5.0 15.0 2.0; 5.0 1.0 16.0;;; 19.0 9.0 9.0; 4.0 11.0 7.0; 2.0 10.0 16.0;;;; 19.0 3.0 3.0; 9.0 17.0 2.0; 9.0 10.0 20.0;;; 20.0 1.0 8.0; 6.0 16.0 6.0; 7.0 7.0 20.0;;; 11.0 8.0 10.0; 10.0 17.0 1.0; 7.0 3.0 14.0;;;; 12.0 10.0 6.0; 7.0 20.0 8.0; 7.0 2.0 20.0;;; 18.0 8.0 4.0; 8.0 20.0 7.0; 9.0 2.0 14.0;;; 13.0 10.0 3.0; 4.0 11.0 5.0; 2.0 5.0 17.0;;;;; - 11.0 3.0 4.0; 4.0 11.0 5.0; 8.0 5.0 17.0;;; 11.0 8.0 7.0; 10.0 11.0 3.0; 9.0 9.0 13.0;;; 15.0 8.0 5.0; 2.0 18.0 6.0; 2.0 8.0 15.0;;;; 16.0 8.0 3.0; 2.0 15.0 4.0; 7.0 4.0 13.0;;; 11.0 4.0 2.0; 5.0 18.0 9.0; 7.0 1.0 16.0;;; 16.0 3.0 8.0; 3.0 15.0 4.0; 7.0 8.0 13.0;;;; 16.0 3.0 6.0; 10.0 12.0 2.0; 2.0 9.0 17.0;;; 19.0 7.0 5.0; 7.0 13.0 4.0; 8.0 7.0 14.0;;; 13.0 6.0 9.0; 7.0 18.0 6.0; 7.0 2.0 11.0;;;;; - 11.0 1.0 7.0; 9.0 11.0 4.0; 8.0 1.0 11.0;;; 12.0 9.0 3.0; 4.0 15.0 1.0; 7.0 10.0 16.0;;; 11.0 8.0 8.0; 4.0 17.0 2.0; 7.0 2.0 17.0;;;; 13.0 2.0 9.0; 10.0 18.0 4.0; 10.0 2.0 17.0;;; 20.0 9.0 5.0; 10.0 15.0 3.0; 6.0 5.0 15.0;;; 18.0 5.0 3.0; 10.0 15.0 3.0; 7.0 5.0 20.0;;;; 16.0 7.0 6.0; 2.0 20.0 2.0; 2.0 8.0 11.0;;; 14.0 2.0 2.0; 3.0 13.0 2.0; 10.0 9.0 11.0;;; 11.0 6.0 1.0; 10.0 12.0 4.0; 4.0 5.0 13.0 - ] + normalize!( + [ + 15.0 4.0 9.0; 10.0 13.0 8.0; 2.0 8.0 15.0;;; 20.0 1.0 9.0; 5.0 15.0 2.0; 5.0 1.0 16.0;;; 19.0 9.0 9.0; 4.0 11.0 7.0; 2.0 10.0 16.0;;;; 19.0 3.0 3.0; 9.0 17.0 2.0; 9.0 10.0 20.0;;; 20.0 1.0 8.0; 6.0 16.0 6.0; 7.0 7.0 20.0;;; 11.0 8.0 10.0; 10.0 17.0 1.0; 7.0 3.0 14.0;;;; 12.0 10.0 6.0; 7.0 20.0 8.0; 7.0 2.0 20.0;;; 18.0 8.0 4.0; 8.0 20.0 7.0; 9.0 2.0 14.0;;; 13.0 10.0 3.0; 4.0 11.0 5.0; 2.0 5.0 17.0;;;;; + 11.0 3.0 4.0; 4.0 11.0 5.0; 8.0 5.0 17.0;;; 11.0 8.0 7.0; 10.0 11.0 3.0; 9.0 9.0 13.0;;; 15.0 8.0 5.0; 2.0 18.0 6.0; 2.0 8.0 15.0;;;; 16.0 8.0 3.0; 2.0 15.0 4.0; 7.0 4.0 13.0;;; 11.0 4.0 2.0; 5.0 18.0 9.0; 7.0 1.0 16.0;;; 16.0 3.0 8.0; 3.0 15.0 4.0; 7.0 8.0 13.0;;;; 16.0 3.0 6.0; 10.0 12.0 2.0; 2.0 9.0 17.0;;; 19.0 7.0 5.0; 7.0 13.0 4.0; 8.0 7.0 14.0;;; 13.0 6.0 9.0; 7.0 18.0 6.0; 7.0 2.0 11.0;;;;; + 11.0 1.0 7.0; 9.0 11.0 4.0; 8.0 1.0 11.0;;; 12.0 9.0 3.0; 4.0 15.0 1.0; 7.0 10.0 16.0;;; 11.0 8.0 8.0; 4.0 17.0 2.0; 7.0 2.0 17.0;;;; 13.0 2.0 9.0; 10.0 18.0 4.0; 10.0 2.0 17.0;;; 20.0 9.0 5.0; 10.0 15.0 3.0; 6.0 5.0 15.0;;; 18.0 5.0 3.0; 10.0 15.0 3.0; 7.0 5.0 20.0;;;; 16.0 7.0 6.0; 2.0 20.0 2.0; 2.0 8.0 11.0;;; 14.0 2.0 2.0; 3.0 13.0 2.0; 10.0 9.0 11.0;;; 11.0 6.0 1.0; 10.0 12.0 4.0; 4.0 5.0 13.0 + ], + 1 + ) ) ), output = Categorical([0.3191507064239812, 0.3234732446374377, 0.35737604893858116]) @@ -615,11 +643,14 @@ end m_T1 = Categorical([0.1, 0.6, 0.3]), m_T3 = Categorical([0.2, 0.3, 0.5]), q_a = PointMass( - [ - 18.0 6.0 5.0; 7.0 16.0 5.0; 2.0 2.0 18.0;;; 14.0 9.0 1.0; 2.0 13.0 7.0; 1.0 8.0 14.0;;; 18.0 2.0 2.0; 9.0 11.0 4.0; 6.0 1.0 14.0;;;; 14.0 3.0 10.0; 10.0 17.0 9.0; 5.0 7.0 13.0;;; 13.0 10.0 5.0; 9.0 18.0 4.0; 1.0 4.0 12.0;;; 18.0 7.0 9.0; 2.0 20.0 7.0; 5.0 8.0 11.0;;;; 18.0 3.0 8.0; 9.0 20.0 9.0; 8.0 5.0 19.0;;; 15.0 9.0 2.0; 1.0 14.0 1.0; 4.0 2.0 12.0;;; 18.0 5.0 10.0; 1.0 18.0 2.0; 9.0 2.0 12.0;;;;; - 16.0 2.0 1.0; 5.0 18.0 8.0; 9.0 6.0 14.0;;; 17.0 6.0 6.0; 1.0 18.0 7.0; 4.0 5.0 16.0;;; 19.0 8.0 9.0; 4.0 11.0 9.0; 9.0 10.0 18.0;;;; 14.0 7.0 8.0; 1.0 18.0 2.0; 4.0 3.0 12.0;;; 15.0 5.0 4.0; 9.0 20.0 2.0; 7.0 10.0 15.0;;; 13.0 5.0 6.0; 6.0 15.0 4.0; 1.0 10.0 16.0;;;; 12.0 2.0 1.0; 3.0 18.0 7.0; 8.0 7.0 15.0;;; 19.0 3.0 4.0; 7.0 14.0 8.0; 1.0 6.0 14.0;;; 16.0 8.0 8.0; 5.0 16.0 9.0; 1.0 10.0 11.0;;;;; - 19.0 1.0 6.0; 3.0 20.0 9.0; 5.0 9.0 12.0;;; 13.0 6.0 1.0; 9.0 12.0 7.0; 3.0 1.0 20.0;;; 19.0 10.0 8.0; 6.0 16.0 6.0; 9.0 10.0 16.0;;;; 13.0 1.0 4.0; 6.0 18.0 8.0; 5.0 9.0 20.0;;; 18.0 8.0 5.0; 8.0 11.0 4.0; 7.0 9.0 13.0;;; 15.0 8.0 10.0; 10.0 14.0 10.0; 10.0 2.0 16.0;;;; 14.0 8.0 10.0; 5.0 18.0 9.0; 9.0 4.0 15.0;;; 15.0 9.0 1.0; 7.0 12.0 8.0; 5.0 8.0 12.0;;; 20.0 3.0 1.0; 9.0 14.0 1.0; 2.0 9.0 13.0 - ] + normalize!( + [ + 18.0 6.0 5.0; 7.0 16.0 5.0; 2.0 2.0 18.0;;; 14.0 9.0 1.0; 2.0 13.0 7.0; 1.0 8.0 14.0;;; 18.0 2.0 2.0; 9.0 11.0 4.0; 6.0 1.0 14.0;;;; 14.0 3.0 10.0; 10.0 17.0 9.0; 5.0 7.0 13.0;;; 13.0 10.0 5.0; 9.0 18.0 4.0; 1.0 4.0 12.0;;; 18.0 7.0 9.0; 2.0 20.0 7.0; 5.0 8.0 11.0;;;; 18.0 3.0 8.0; 9.0 20.0 9.0; 8.0 5.0 19.0;;; 15.0 9.0 2.0; 1.0 14.0 1.0; 4.0 2.0 12.0;;; 18.0 5.0 10.0; 1.0 18.0 2.0; 9.0 2.0 12.0;;;;; + 16.0 2.0 1.0; 5.0 18.0 8.0; 9.0 6.0 14.0;;; 17.0 6.0 6.0; 1.0 18.0 7.0; 4.0 5.0 16.0;;; 19.0 8.0 9.0; 4.0 11.0 9.0; 9.0 10.0 18.0;;;; 14.0 7.0 8.0; 1.0 18.0 2.0; 4.0 3.0 12.0;;; 15.0 5.0 4.0; 9.0 20.0 2.0; 7.0 10.0 15.0;;; 13.0 5.0 6.0; 6.0 15.0 4.0; 1.0 10.0 16.0;;;; 12.0 2.0 1.0; 3.0 18.0 7.0; 8.0 7.0 15.0;;; 19.0 3.0 4.0; 7.0 14.0 8.0; 1.0 6.0 14.0;;; 16.0 8.0 8.0; 5.0 16.0 9.0; 1.0 10.0 11.0;;;;; + 19.0 1.0 6.0; 3.0 20.0 9.0; 5.0 9.0 12.0;;; 13.0 6.0 1.0; 9.0 12.0 7.0; 3.0 1.0 20.0;;; 19.0 10.0 8.0; 6.0 16.0 6.0; 9.0 10.0 16.0;;;; 13.0 1.0 4.0; 6.0 18.0 8.0; 5.0 9.0 20.0;;; 18.0 8.0 5.0; 8.0 11.0 4.0; 7.0 9.0 13.0;;; 15.0 8.0 10.0; 10.0 14.0 10.0; 10.0 2.0 16.0;;;; 14.0 8.0 10.0; 5.0 18.0 9.0; 9.0 4.0 15.0;;; 15.0 9.0 1.0; 7.0 12.0 8.0; 5.0 8.0 12.0;;; 20.0 3.0 1.0; 9.0 14.0 1.0; 2.0 9.0 13.0 + ], + 1 + ) ) ), output = Categorical([0.3314510207078866, 0.3506351887208107, 0.31791379057130265]) @@ -631,11 +662,14 @@ end m_T1 = Categorical([0.4, 0.3, 0.3]), m_T3 = Categorical([0.5, 0.2, 0.3]), q_a = PointMass( - [ - 15.0 4.0 9.0; 10.0 13.0 8.0; 2.0 8.0 15.0;;; 20.0 1.0 9.0; 5.0 15.0 2.0; 5.0 1.0 16.0;;; 19.0 9.0 9.0; 4.0 11.0 7.0; 2.0 10.0 16.0;;;; 19.0 3.0 3.0; 9.0 17.0 2.0; 9.0 10.0 20.0;;; 20.0 1.0 8.0; 6.0 16.0 6.0; 7.0 7.0 20.0;;; 11.0 8.0 10.0; 10.0 17.0 1.0; 7.0 3.0 14.0;;;; 12.0 10.0 6.0; 7.0 20.0 8.0; 7.0 2.0 20.0;;; 18.0 8.0 4.0; 8.0 20.0 7.0; 9.0 2.0 14.0;;; 13.0 10.0 3.0; 4.0 11.0 5.0; 2.0 5.0 17.0;;;;; - 11.0 3.0 4.0; 4.0 11.0 5.0; 8.0 5.0 17.0;;; 11.0 8.0 7.0; 10.0 11.0 3.0; 9.0 9.0 13.0;;; 15.0 8.0 5.0; 2.0 18.0 6.0; 2.0 8.0 15.0;;;; 16.0 8.0 3.0; 2.0 15.0 4.0; 7.0 4.0 13.0;;; 11.0 4.0 2.0; 5.0 18.0 9.0; 7.0 1.0 16.0;;; 16.0 3.0 8.0; 3.0 15.0 4.0; 7.0 8.0 13.0;;;; 16.0 3.0 6.0; 10.0 12.0 2.0; 2.0 9.0 17.0;;; 19.0 7.0 5.0; 7.0 13.0 4.0; 8.0 7.0 14.0;;; 13.0 6.0 9.0; 7.0 18.0 6.0; 7.0 2.0 11.0;;;;; - 11.0 1.0 7.0; 9.0 11.0 4.0; 8.0 1.0 11.0;;; 12.0 9.0 3.0; 4.0 15.0 1.0; 7.0 10.0 16.0;;; 11.0 8.0 8.0; 4.0 17.0 2.0; 7.0 2.0 17.0;;;; 13.0 2.0 9.0; 10.0 18.0 4.0; 10.0 2.0 17.0;;; 20.0 9.0 5.0; 10.0 15.0 3.0; 6.0 5.0 15.0;;; 18.0 5.0 3.0; 10.0 15.0 3.0; 7.0 5.0 20.0;;;; 16.0 7.0 6.0; 2.0 20.0 2.0; 2.0 8.0 11.0;;; 14.0 2.0 2.0; 3.0 13.0 2.0; 10.0 9.0 11.0;;; 11.0 6.0 1.0; 10.0 12.0 4.0; 4.0 5.0 13.0 - ] + normalize!( + [ + 15.0 4.0 9.0; 10.0 13.0 8.0; 2.0 8.0 15.0;;; 20.0 1.0 9.0; 5.0 15.0 2.0; 5.0 1.0 16.0;;; 19.0 9.0 9.0; 4.0 11.0 7.0; 2.0 10.0 16.0;;;; 19.0 3.0 3.0; 9.0 17.0 2.0; 9.0 10.0 20.0;;; 20.0 1.0 8.0; 6.0 16.0 6.0; 7.0 7.0 20.0;;; 11.0 8.0 10.0; 10.0 17.0 1.0; 7.0 3.0 14.0;;;; 12.0 10.0 6.0; 7.0 20.0 8.0; 7.0 2.0 20.0;;; 18.0 8.0 4.0; 8.0 20.0 7.0; 9.0 2.0 14.0;;; 13.0 10.0 3.0; 4.0 11.0 5.0; 2.0 5.0 17.0;;;;; + 11.0 3.0 4.0; 4.0 11.0 5.0; 8.0 5.0 17.0;;; 11.0 8.0 7.0; 10.0 11.0 3.0; 9.0 9.0 13.0;;; 15.0 8.0 5.0; 2.0 18.0 6.0; 2.0 8.0 15.0;;;; 16.0 8.0 3.0; 2.0 15.0 4.0; 7.0 4.0 13.0;;; 11.0 4.0 2.0; 5.0 18.0 9.0; 7.0 1.0 16.0;;; 16.0 3.0 8.0; 3.0 15.0 4.0; 7.0 8.0 13.0;;;; 16.0 3.0 6.0; 10.0 12.0 2.0; 2.0 9.0 17.0;;; 19.0 7.0 5.0; 7.0 13.0 4.0; 8.0 7.0 14.0;;; 13.0 6.0 9.0; 7.0 18.0 6.0; 7.0 2.0 11.0;;;;; + 11.0 1.0 7.0; 9.0 11.0 4.0; 8.0 1.0 11.0;;; 12.0 9.0 3.0; 4.0 15.0 1.0; 7.0 10.0 16.0;;; 11.0 8.0 8.0; 4.0 17.0 2.0; 7.0 2.0 17.0;;;; 13.0 2.0 9.0; 10.0 18.0 4.0; 10.0 2.0 17.0;;; 20.0 9.0 5.0; 10.0 15.0 3.0; 6.0 5.0 15.0;;; 18.0 5.0 3.0; 10.0 15.0 3.0; 7.0 5.0 20.0;;;; 16.0 7.0 6.0; 2.0 20.0 2.0; 2.0 8.0 11.0;;; 14.0 2.0 2.0; 3.0 13.0 2.0; 10.0 9.0 11.0;;; 11.0 6.0 1.0; 10.0 12.0 4.0; 4.0 5.0 13.0 + ], + 1 + ) ) ), output = Categorical([0.31465933370261934, 0.3276568805887473, 0.3576837857086334]) @@ -660,7 +694,7 @@ end ] ) ), - output = Categorical([0.3319038296364184, 0.3354501488301648, 0.33264602153341694]) + output = Categorical([0.3327791834414443, 0.33479029996811277, 0.332430516590443]) ), ( input = ( @@ -676,7 +710,7 @@ end ] ) ), - output = Categorical([0.3415243015031942, 0.32861514758186977, 0.32986055091493616]) + output = Categorical([0.3405707102171213, 0.32860320985809627, 0.3308260799247824]) ) ] end @@ -692,13 +726,13 @@ end m_T2 = Categorical([0.2, 0.3, 0.5]), q_a = PointMass( [ - 18.0 6.0 5.0; 7.0 16.0 5.0; 2.0 2.0 18.0;;; 14.0 9.0 1.0; 2.0 13.0 7.0; 1.0 8.0 14.0;;; 18.0 2.0 2.0; 9.0 11.0 4.0; 6.0 1.0 14.0;;;; 14.0 3.0 10.0; 10.0 17.0 9.0; 5.0 7.0 13.0;;; 13.0 10.0 5.0; 9.0 18.0 4.0; 1.0 4.0 12.0;;; 18.0 7.0 9.0; 2.0 20.0 7.0; 5.0 8.0 11.0;;;; 18.0 3.0 8.0; 9.0 20.0 9.0; 8.0 5.0 19.0;;; 15.0 9.0 2.0; 1.0 14.0 1.0; 4.0 2.0 12.0;;; 18.0 5.0 10.0; 1.0 18.0 2.0; 9.0 2.0 12.0;;;;; - 16.0 2.0 1.0; 5.0 18.0 8.0; 9.0 6.0 14.0;;; 17.0 6.0 6.0; 1.0 18.0 7.0; 4.0 5.0 16.0;;; 19.0 8.0 9.0; 4.0 11.0 9.0; 9.0 10.0 18.0;;;; 14.0 7.0 8.0; 1.0 18.0 2.0; 4.0 3.0 12.0;;; 15.0 5.0 4.0; 9.0 20.0 2.0; 7.0 10.0 15.0;;; 13.0 5.0 6.0; 6.0 15.0 4.0; 1.0 10.0 16.0;;;; 12.0 2.0 1.0; 3.0 18.0 7.0; 8.0 7.0 15.0;;; 19.0 3.0 4.0; 7.0 14.0 8.0; 1.0 6.0 14.0;;; 16.0 8.0 8.0; 5.0 16.0 9.0; 1.0 10.0 11.0;;;;; - 19.0 1.0 6.0; 3.0 20.0 9.0; 5.0 9.0 12.0;;; 13.0 6.0 1.0; 9.0 12.0 7.0; 3.0 1.0 20.0;;; 19.0 10.0 8.0; 6.0 16.0 6.0; 9.0 10.0 16.0;;;; 13.0 1.0 4.0; 6.0 18.0 8.0; 5.0 9.0 20.0;;; 18.0 8.0 5.0; 8.0 11.0 4.0; 7.0 9.0 13.0;;; 15.0 8.0 10.0; 10.0 14.0 10.0; 10.0 2.0 16.0;;;; 14.0 8.0 10.0; 5.0 18.0 9.0; 9.0 4.0 15.0;;; 15.0 9.0 1.0; 7.0 12.0 8.0; 5.0 8.0 12.0;;; 20.0 3.0 1.0; 9.0 14.0 1.0; 2.0 9.0 13.0 + 0.6666666666666666 0.25 0.17857142857142858; 0.25925925925925924 0.6666666666666666 0.17857142857142858; 0.07407407407407407 0.08333333333333333 0.6428571428571429;;; 0.8235294117647058 0.3 0.045454545454545456; 0.11764705882352941 0.43333333333333335 0.3181818181818182; 0.058823529411764705 0.26666666666666666 0.6363636363636364;;; 0.5454545454545454 0.14285714285714285 0.1; 0.2727272727272727 0.7857142857142857 0.2; 0.18181818181818182 0.07142857142857142 0.7;;;; 0.4827586206896552 0.1111111111111111 0.3125; 0.3448275862068966 0.6296296296296297 0.28125; 0.1724137931034483 0.25925925925925924 0.40625;;; 0.5652173913043478 0.3125 0.23809523809523808; 0.391304347826087 0.5625 0.19047619047619047; 0.043478260869565216 0.125 0.5714285714285714;;; 0.72 0.2 0.3333333333333333; 0.08 0.5714285714285714 0.25925925925925924; 0.2 0.22857142857142856 0.4074074074074074;;;; 0.5142857142857142 0.10714285714285714 0.2222222222222222; 0.2571428571428571 0.7142857142857143 0.25; 0.22857142857142856 0.17857142857142858 0.5277777777777778;;; 0.75 0.36 0.13333333333333333; 0.05 0.56 0.06666666666666667; 0.2 0.08 0.8;;; 0.6428571428571429 0.2 0.4166666666666667; 0.03571428571428571 0.72 0.08333333333333333; 0.32142857142857145 0.08 0.5;;;;; + 0.5333333333333333 0.07692307692307693 0.043478260869565216; 0.16666666666666666 0.6923076923076923 0.34782608695652173; 0.3 0.23076923076923078 0.6086956521739131;;; 0.7727272727272727 0.20689655172413793 0.20689655172413793; 0.045454545454545456 0.6206896551724138 0.2413793103448276; 0.18181818181818182 0.1724137931034483 0.5517241379310345;;; 0.59375 0.27586206896551724 0.25; 0.125 0.3793103448275862 0.25; 0.28125 0.3448275862068966 0.5;;;; 0.7368421052631579 0.25 0.36363636363636365; 0.05263157894736842 0.6428571428571429 0.09090909090909091; 0.21052631578947367 0.10714285714285714 0.5454545454545454;;; 0.4838709677419355 0.14285714285714285 0.19047619047619047; 0.2903225806451613 0.5714285714285714 0.09523809523809523; 0.22580645161290322 0.2857142857142857 0.7142857142857143;;; 0.65 0.16666666666666666 0.23076923076923078; 0.3 0.5 0.15384615384615385; 0.05 0.3333333333333333 0.6153846153846154;;;; 0.5217391304347826 0.07407407407407407 0.043478260869565216; 0.13043478260869565 0.6666666666666666 0.30434782608695654; 0.34782608695652173 0.25925925925925924 0.6521739130434783;;; 0.7037037037037037 0.13043478260869565 0.15384615384615385; 0.25925925925925924 0.6086956521739131 0.3076923076923077; 0.037037037037037035 0.2608695652173913 0.5384615384615384;;; 0.7272727272727273 0.23529411764705882 0.2857142857142857; 0.22727272727272727 0.47058823529411764 0.32142857142857145; 0.045454545454545456 0.29411764705882354 0.39285714285714285;;;;; + 0.7037037037037037 0.03333333333333333 0.2222222222222222; 0.1111111111111111 0.6666666666666666 0.3333333333333333; 0.18518518518518517 0.3 0.4444444444444444;;; 0.52 0.3157894736842105 0.03571428571428571; 0.36 0.631578947368421 0.25; 0.12 0.05263157894736842 0.7142857142857143;;; 0.5588235294117647 0.2777777777777778 0.26666666666666666; 0.17647058823529413 0.4444444444444444 0.2; 0.2647058823529412 0.2777777777777778 0.5333333333333333;;;; 0.5416666666666666 0.03571428571428571 0.125; 0.25 0.6428571428571429 0.25; 0.20833333333333334 0.32142857142857145 0.625;;; 0.5454545454545454 0.2857142857142857 0.22727272727272727; 0.24242424242424243 0.39285714285714285 0.18181818181818182; 0.21212121212121213 0.32142857142857145 0.5909090909090909;;; 0.42857142857142855 0.3333333333333333 0.2777777777777778; 0.2857142857142857 0.5833333333333334 0.2777777777777778; 0.2857142857142857 0.08333333333333333 0.4444444444444444;;;; 0.5 0.26666666666666666 0.29411764705882354; 0.17857142857142858 0.6 0.2647058823529412; 0.32142857142857145 0.13333333333333333 0.4411764705882353;;; 0.5555555555555556 0.3103448275862069 0.047619047619047616; 0.25925925925925924 0.41379310344827586 0.38095238095238093; 0.18518518518518517 0.27586206896551724 0.5714285714285714;;; 0.6451612903225806 0.11538461538461539 0.06666666666666667; 0.2903225806451613 0.5384615384615384 0.06666666666666667; 0.06451612903225806 0.34615384615384615 0.8666666666666667 ] ) ), - output = Categorical([0.3151127561452472, 0.3483388803486411, 0.33654836350611167]) + output = Categorical([0.33284552594585337, 0.3347077171525566, 0.33244675690159]) ), ( input = ( @@ -708,13 +742,13 @@ end m_T2 = Categorical([0.5, 0.2, 0.3]), q_a = PointMass( [ - 15.0 4.0 9.0; 10.0 13.0 8.0; 2.0 8.0 15.0;;; 20.0 1.0 9.0; 5.0 15.0 2.0; 5.0 1.0 16.0;;; 19.0 9.0 9.0; 4.0 11.0 7.0; 2.0 10.0 16.0;;;; 19.0 3.0 3.0; 9.0 17.0 2.0; 9.0 10.0 20.0;;; 20.0 1.0 8.0; 6.0 16.0 6.0; 7.0 7.0 20.0;;; 11.0 8.0 10.0; 10.0 17.0 1.0; 7.0 3.0 14.0;;;; 12.0 10.0 6.0; 7.0 20.0 8.0; 7.0 2.0 20.0;;; 18.0 8.0 4.0; 8.0 20.0 7.0; 9.0 2.0 14.0;;; 13.0 10.0 3.0; 4.0 11.0 5.0; 2.0 5.0 17.0;;;;; - 11.0 3.0 4.0; 4.0 11.0 5.0; 8.0 5.0 17.0;;; 11.0 8.0 7.0; 10.0 11.0 3.0; 9.0 9.0 13.0;;; 15.0 8.0 5.0; 2.0 18.0 6.0; 2.0 8.0 15.0;;;; 16.0 8.0 3.0; 2.0 15.0 4.0; 7.0 4.0 13.0;;; 11.0 4.0 2.0; 5.0 18.0 9.0; 7.0 1.0 16.0;;; 16.0 3.0 8.0; 3.0 15.0 4.0; 7.0 8.0 13.0;;;; 16.0 3.0 6.0; 10.0 12.0 2.0; 2.0 9.0 17.0;;; 19.0 7.0 5.0; 7.0 13.0 4.0; 8.0 7.0 14.0;;; 13.0 6.0 9.0; 7.0 18.0 6.0; 7.0 2.0 11.0;;;;; - 11.0 1.0 7.0; 9.0 11.0 4.0; 8.0 1.0 11.0;;; 12.0 9.0 3.0; 4.0 15.0 1.0; 7.0 10.0 16.0;;; 11.0 8.0 8.0; 4.0 17.0 2.0; 7.0 2.0 17.0;;;; 13.0 2.0 9.0; 10.0 18.0 4.0; 10.0 2.0 17.0;;; 20.0 9.0 5.0; 10.0 15.0 3.0; 6.0 5.0 15.0;;; 18.0 5.0 3.0; 10.0 15.0 3.0; 7.0 5.0 20.0;;;; 16.0 7.0 6.0; 2.0 20.0 2.0; 2.0 8.0 11.0;;; 14.0 2.0 2.0; 3.0 13.0 2.0; 10.0 9.0 11.0;;; 11.0 6.0 1.0; 10.0 12.0 4.0; 4.0 5.0 13.0 + 0.5555555555555556 0.16 0.28125; 0.37037037037037035 0.52 0.25; 0.07407407407407407 0.32 0.46875;;; 0.6666666666666666 0.058823529411764705 0.3333333333333333; 0.16666666666666666 0.8823529411764706 0.07407407407407407; 0.16666666666666666 0.058823529411764705 0.5925925925925926;;; 0.76 0.3 0.28125; 0.16 0.36666666666666664 0.21875; 0.08 0.3333333333333333 0.5;;;; 0.5135135135135135 0.1 0.12; 0.24324324324324326 0.5666666666666667 0.08; 0.24324324324324326 0.3333333333333333 0.8;;; 0.6060606060606061 0.041666666666666664 0.23529411764705882; 0.18181818181818182 0.6666666666666666 0.17647058823529413; 0.21212121212121213 0.2916666666666667 0.5882352941176471;;; 0.39285714285714285 0.2857142857142857 0.4; 0.35714285714285715 0.6071428571428571 0.04; 0.25 0.10714285714285714 0.56;;;; 0.46153846153846156 0.3125 0.17647058823529413; 0.2692307692307692 0.625 0.23529411764705882; 0.2692307692307692 0.0625 0.5882352941176471;;; 0.5142857142857142 0.26666666666666666 0.16; 0.22857142857142856 0.6666666666666666 0.28; 0.2571428571428571 0.06666666666666667 0.56;;; 0.6842105263157895 0.38461538461538464 0.12; 0.21052631578947367 0.4230769230769231 0.2; 0.10526315789473684 0.19230769230769232 0.68;;;;; + 0.4782608695652174 0.15789473684210525 0.15384615384615385; 0.17391304347826086 0.5789473684210527 0.19230769230769232; 0.34782608695652173 0.2631578947368421 0.6538461538461539;;; 0.36666666666666664 0.2857142857142857 0.30434782608695654; 0.3333333333333333 0.39285714285714285 0.13043478260869565; 0.3 0.32142857142857145 0.5652173913043478;;; 0.7894736842105263 0.23529411764705882 0.19230769230769232; 0.10526315789473684 0.5294117647058824 0.23076923076923078; 0.10526315789473684 0.23529411764705882 0.5769230769230769;;;; 0.64 0.2962962962962963 0.15; 0.08 0.5555555555555556 0.2; 0.28 0.14814814814814814 0.65;;; 0.4782608695652174 0.17391304347826086 0.07407407407407407; 0.21739130434782608 0.782608695652174 0.3333333333333333; 0.30434782608695654 0.043478260869565216 0.5925925925925926;;; 0.6153846153846154 0.11538461538461539 0.32; 0.11538461538461539 0.5769230769230769 0.16; 0.2692307692307692 0.3076923076923077 0.52;;;; 0.5714285714285714 0.125 0.24; 0.35714285714285715 0.5 0.08; 0.07142857142857142 0.375 0.68;;; 0.5588235294117647 0.25925925925925924 0.21739130434782608; 0.20588235294117646 0.48148148148148145 0.17391304347826086; 0.23529411764705882 0.25925925925925924 0.6086956521739131;;; 0.48148148148148145 0.23076923076923078 0.34615384615384615; 0.25925925925925924 0.6923076923076923 0.23076923076923078; 0.25925925925925924 0.07692307692307693 0.4230769230769231;;;;; + 0.39285714285714285 0.07692307692307693 0.3181818181818182; 0.32142857142857145 0.8461538461538461 0.18181818181818182; 0.2857142857142857 0.07692307692307693 0.5;;; 0.5217391304347826 0.2647058823529412 0.15; 0.17391304347826086 0.4411764705882353 0.05; 0.30434782608695654 0.29411764705882354 0.8;;; 0.5 0.2962962962962963 0.2962962962962963; 0.18181818181818182 0.6296296296296297 0.07407407407407407; 0.3181818181818182 0.07407407407407407 0.6296296296296297;;;; 0.3939393939393939 0.09090909090909091 0.3; 0.30303030303030304 0.8181818181818182 0.13333333333333333; 0.30303030303030304 0.09090909090909091 0.5666666666666667;;; 0.5555555555555556 0.3103448275862069 0.21739130434782608; 0.2777777777777778 0.5172413793103449 0.13043478260869565; 0.16666666666666666 0.1724137931034483 0.6521739130434783;;; 0.5142857142857142 0.2 0.11538461538461539; 0.2857142857142857 0.6 0.11538461538461539; 0.2 0.2 0.7692307692307693;;;; 0.8 0.2 0.3157894736842105; 0.1 0.5714285714285714 0.10526315789473684; 0.1 0.22857142857142856 0.5789473684210527;;; 0.5185185185185185 0.08333333333333333 0.13333333333333333; 0.1111111111111111 0.5416666666666666 0.13333333333333333; 0.37037037037037035 0.375 0.7333333333333333;;; 0.44 0.2608695652173913 0.05555555555555555; 0.4 0.5217391304347826 0.2222222222222222; 0.16 0.21739130434782608 0.7222222222222222 ] ) ), - output = Categorical([0.357112685209254, 0.32496913179100606, 0.3179181829997401]) + output = Categorical([0.33997124, 0.3288227, 0.33120605]) ) ] end @@ -722,7 +756,7 @@ end @testitem "rules:DiscreteTransition:T:Additional T-interface tests for structured VMP" begin using ReactiveMP, BayesBase, Random, ExponentialFamily, Distributions - import ReactiveMP: @test_rules + import ReactiveMP: @test_rules, normalize! # Test T1 interface with 3 interfaces (BP with DirichletCollection q_a) @testset "Structured VMP: T1 with 3 interfaces (DirichletCollection q_a)" begin @test_rules [check_type_promotion = true, extra_float_types = [Float64, Float32]] DiscreteTransition(:T1, Marginalisation) [ @@ -762,11 +796,16 @@ end input = ( q_out = PointMass([0.0, 1.0, 0.0]), m_in = Categorical([0.2, 0.5, 0.3]), - q_a = PointMass([ - 1.0 2.0 3.0; 4.0 5.0 6.0; 7.0 8.0 9.0;;; - 10.0 11.0 12.0; 13.0 14.0 15.0; 16.0 17.0 18.0;;; - 19.0 20.0 21.0; 22.0 23.0 24.0; 25.0 26.0 27.0 - ]) + q_a = PointMass( + normalize!( + [ + 1.0 2.0 3.0; 4.0 5.0 6.0; 7.0 8.0 9.0;;; + 10.0 11.0 12.0; 13.0 14.0 15.0; 16.0 17.0 18.0;;; + 19.0 20.0 21.0; 22.0 23.0 24.0; 25.0 26.0 27.0 + ], + 1 + ) + ) ), output = Categorical([0.12056737588652484, 0.3333333333333333, 0.546099290780142]) ), @@ -774,11 +813,16 @@ end input = ( q_out = PointMass([0.0, 1.0, 0.0]), m_in = Categorical([0.1, 0.7, 0.2]), - q_a = PointMass([ - 5.0 1.0 2.0; 3.0 8.0 4.0; 6.0 7.0 9.0;;; - 2.0 4.0 6.0; 8.0 10.0 12.0; 14.0 16.0 18.0;;; - 1.0 3.0 5.0; 7.0 9.0 11.0; 13.0 15.0 17.0 - ]) + q_a = PointMass( + normalize!( + [ + 5.0 1.0 2.0; 3.0 8.0 4.0; 6.0 7.0 9.0;;; + 2.0 4.0 6.0; 8.0 10.0 12.0; 14.0 16.0 18.0;;; + 1.0 3.0 5.0; 7.0 9.0 11.0; 13.0 15.0 17.0 + ], + 1 + ) + ) ), output = Categorical([0.2567049808429118, 0.39080459770114945, 0.35249042145593873]) ) @@ -830,11 +874,14 @@ end m_in = Categorical([0.2, 0.5, 0.3]), m_T2 = Categorical([0.1, 0.6, 0.3]), q_a = PointMass( - [ - 8.0 8.0 10.0; 3.0 2.0 3.0; 2.0 2.0 10.0;;; 5.0 2.0 8.0; 2.0 10.0 8.0; 9.0 6.0 6.0;;; 11.0 6.0 1.0; 3.0 8.0 4.0; 3.0 7.0 9.0;;;; - 4.0 8.0 6.0; 8.0 5.0 6.0; 5.0 10.0 11.0;;; 11.0 2.0 10.0; 2.0 7.0 10.0; 7.0 9.0 11.0;;; 5.0 10.0 4.0; 3.0 7.0 9.0; 7.0 2.0 8.0;;;; - 10.0 6.0 4.0; 10.0 11.0 5.0; 6.0 4.0 3.0;;; 9.0 9.0 2.0; 6.0 6.0 1.0; 5.0 4.0 5.0;;; 11.0 5.0 7.0; 6.0 3.0 3.0; 1.0 8.0 8.0 - ] + normalize!( + [ + 8.0 8.0 10.0; 3.0 2.0 3.0; 2.0 2.0 10.0;;; 5.0 2.0 8.0; 2.0 10.0 8.0; 9.0 6.0 6.0;;; 11.0 6.0 1.0; 3.0 8.0 4.0; 3.0 7.0 9.0;;;; + 4.0 8.0 6.0; 8.0 5.0 6.0; 5.0 10.0 11.0;;; 11.0 2.0 10.0; 2.0 7.0 10.0; 7.0 9.0 11.0;;; 5.0 10.0 4.0; 3.0 7.0 9.0; 7.0 2.0 8.0;;;; + 10.0 6.0 4.0; 10.0 11.0 5.0; 6.0 4.0 3.0;;; 9.0 9.0 2.0; 6.0 6.0 1.0; 5.0 4.0 5.0;;; 11.0 5.0 7.0; 6.0 3.0 3.0; 1.0 8.0 8.0 + ], + 1 + ) ) ), output = Categorical([0.35081081081081084, 0.3389189189189189, 0.3102702702702703]) @@ -845,11 +892,14 @@ end m_in = Categorical([0.1, 0.7, 0.2]), m_T2 = Categorical([0.4, 0.3, 0.3]), q_a = PointMass( - [ - 19.0 7.0 3.0; 3.0 17.0 2.0; 9.0 9.0 12.0;;; 13.0 8.0 2.0; 6.0 16.0 7.0; 7.0 3.0 16.0;;; 15.0 5.0 1.0; 9.0 12.0 5.0; 9.0 10.0 14.0;;;; - 19.0 6.0 3.0; 4.0 17.0 9.0; 1.0 8.0 18.0;;; 11.0 1.0 8.0; 9.0 13.0 4.0; 7.0 1.0 16.0;;; 20.0 8.0 1.0; 9.0 18.0 3.0; 1.0 5.0 14.0;;;; - 20.0 9.0 8.0; 4.0 18.0 5.0; 4.0 6.0 11.0;;; 12.0 5.0 4.0; 8.0 16.0 9.0; 5.0 4.0 18.0;;; 17.0 1.0 10.0; 8.0 13.0 5.0; 5.0 1.0 13.0 - ] + normalize!( + [ + 19.0 7.0 3.0; 3.0 17.0 2.0; 9.0 9.0 12.0;;; 13.0 8.0 2.0; 6.0 16.0 7.0; 7.0 3.0 16.0;;; 15.0 5.0 1.0; 9.0 12.0 5.0; 9.0 10.0 14.0;;;; + 19.0 6.0 3.0; 4.0 17.0 9.0; 1.0 8.0 18.0;;; 11.0 1.0 8.0; 9.0 13.0 4.0; 7.0 1.0 16.0;;; 20.0 8.0 1.0; 9.0 18.0 3.0; 1.0 5.0 14.0;;;; + 20.0 9.0 8.0; 4.0 18.0 5.0; 4.0 6.0 11.0;;; 12.0 5.0 4.0; 8.0 16.0 9.0; 5.0 4.0 18.0;;; 17.0 1.0 10.0; 8.0 13.0 5.0; 5.0 1.0 13.0 + ], + 1 + ) ) ), output = Categorical([0.35682119205298013, 0.33536423841059604, 0.30781456953642383]) @@ -902,11 +952,14 @@ end m_in = Categorical([0.2, 0.5, 0.3]), m_T1 = Categorical([0.1, 0.6, 0.3]), q_a = PointMass( - [ - 8.0 8.0 10.0; 3.0 2.0 3.0; 2.0 2.0 10.0;;; 5.0 2.0 8.0; 2.0 10.0 8.0; 9.0 6.0 6.0;;; 11.0 6.0 1.0; 3.0 8.0 4.0; 3.0 7.0 9.0;;;; - 4.0 8.0 6.0; 8.0 5.0 6.0; 5.0 10.0 11.0;;; 11.0 2.0 10.0; 2.0 7.0 10.0; 7.0 9.0 11.0;;; 5.0 10.0 4.0; 3.0 7.0 9.0; 7.0 2.0 8.0;;;; - 10.0 6.0 4.0; 10.0 11.0 5.0; 6.0 4.0 3.0;;; 9.0 9.0 2.0; 6.0 6.0 1.0; 5.0 4.0 5.0;;; 11.0 5.0 7.0; 6.0 3.0 3.0; 1.0 8.0 8.0 - ] + normalize!( + [ + 8.0 8.0 10.0; 3.0 2.0 3.0; 2.0 2.0 10.0;;; 5.0 2.0 8.0; 2.0 10.0 8.0; 9.0 6.0 6.0;;; 11.0 6.0 1.0; 3.0 8.0 4.0; 3.0 7.0 9.0;;;; + 4.0 8.0 6.0; 8.0 5.0 6.0; 5.0 10.0 11.0;;; 11.0 2.0 10.0; 2.0 7.0 10.0; 7.0 9.0 11.0;;; 5.0 10.0 4.0; 3.0 7.0 9.0; 7.0 2.0 8.0;;;; + 10.0 6.0 4.0; 10.0 11.0 5.0; 6.0 4.0 3.0;;; 9.0 9.0 2.0; 6.0 6.0 1.0; 5.0 4.0 5.0;;; 11.0 5.0 7.0; 6.0 3.0 3.0; 1.0 8.0 8.0 + ], + 1 + ) ) ), output = Categorical([0.3681015452538632, 0.3736203090507726, 0.2582781456953642]) @@ -917,11 +970,14 @@ end m_in = Categorical([0.1, 0.7, 0.2]), m_T1 = Categorical([0.4, 0.3, 0.3]), q_a = PointMass( - [ - 19.0 7.0 3.0; 3.0 17.0 2.0; 9.0 9.0 12.0;;; 13.0 8.0 2.0; 6.0 16.0 7.0; 7.0 3.0 16.0;;; 15.0 5.0 1.0; 9.0 12.0 5.0; 9.0 10.0 14.0;;;; - 19.0 6.0 3.0; 4.0 17.0 9.0; 1.0 8.0 18.0;;; 11.0 1.0 8.0; 9.0 13.0 4.0; 7.0 1.0 16.0;;; 20.0 8.0 1.0; 9.0 18.0 3.0; 1.0 5.0 14.0;;;; - 20.0 9.0 8.0; 4.0 18.0 5.0; 4.0 6.0 11.0;;; 12.0 5.0 4.0; 8.0 16.0 9.0; 5.0 4.0 18.0;;; 17.0 1.0 10.0; 8.0 13.0 5.0; 5.0 1.0 13.0 - ] + normalize!( + [ + 19.0 7.0 3.0; 3.0 17.0 2.0; 9.0 9.0 12.0;;; 13.0 8.0 2.0; 6.0 16.0 7.0; 7.0 3.0 16.0;;; 15.0 5.0 1.0; 9.0 12.0 5.0; 9.0 10.0 14.0;;;; + 19.0 6.0 3.0; 4.0 17.0 9.0; 1.0 8.0 18.0;;; 11.0 1.0 8.0; 9.0 13.0 4.0; 7.0 1.0 16.0;;; 20.0 8.0 1.0; 9.0 18.0 3.0; 1.0 5.0 14.0;;;; + 20.0 9.0 8.0; 4.0 18.0 5.0; 4.0 6.0 11.0;;; 12.0 5.0 4.0; 8.0 16.0 9.0; 5.0 4.0 18.0;;; 17.0 1.0 10.0; 8.0 13.0 5.0; 5.0 1.0 13.0 + ], + 1 + ) ) ), output = Categorical([0.31640931693274016, 0.343103899502748, 0.3404867835645119]) @@ -977,11 +1033,14 @@ end m_T2 = Categorical([0.1, 0.6, 0.3]), m_T3 = Categorical([0.2, 0.3, 0.5]), q_a = PointMass( - [ - 18.0 6.0 5.0; 7.0 16.0 5.0; 2.0 2.0 18.0;;; 14.0 9.0 1.0; 2.0 13.0 7.0; 1.0 8.0 14.0;;; 18.0 2.0 2.0; 9.0 11.0 4.0; 6.0 1.0 14.0;;;; 14.0 3.0 10.0; 10.0 17.0 9.0; 5.0 7.0 13.0;;; 13.0 10.0 5.0; 9.0 18.0 4.0; 1.0 4.0 12.0;;; 18.0 7.0 9.0; 2.0 20.0 7.0; 5.0 8.0 11.0;;;; 18.0 3.0 8.0; 9.0 20.0 9.0; 8.0 5.0 19.0;;; 15.0 9.0 2.0; 1.0 14.0 1.0; 4.0 2.0 12.0;;; 18.0 5.0 10.0; 1.0 18.0 2.0; 9.0 2.0 12.0;;;;; - 16.0 2.0 1.0; 5.0 18.0 8.0; 9.0 6.0 14.0;;; 17.0 6.0 6.0; 1.0 18.0 7.0; 4.0 5.0 16.0;;; 19.0 8.0 9.0; 4.0 11.0 9.0; 9.0 10.0 18.0;;;; 14.0 7.0 8.0; 1.0 18.0 2.0; 4.0 3.0 12.0;;; 15.0 5.0 4.0; 9.0 20.0 2.0; 7.0 10.0 15.0;;; 13.0 5.0 6.0; 6.0 15.0 4.0; 1.0 10.0 16.0;;;; 12.0 2.0 1.0; 3.0 18.0 7.0; 8.0 7.0 15.0;;; 19.0 3.0 4.0; 7.0 14.0 8.0; 1.0 6.0 14.0;;; 16.0 8.0 8.0; 5.0 16.0 9.0; 1.0 10.0 11.0;;;;; - 19.0 1.0 6.0; 3.0 20.0 9.0; 5.0 9.0 12.0;;; 13.0 6.0 1.0; 9.0 12.0 7.0; 3.0 1.0 20.0;;; 19.0 10.0 8.0; 6.0 16.0 6.0; 9.0 10.0 16.0;;;; 13.0 1.0 4.0; 6.0 18.0 8.0; 5.0 9.0 20.0;;; 18.0 8.0 5.0; 8.0 11.0 4.0; 7.0 9.0 13.0;;; 15.0 8.0 10.0; 10.0 14.0 10.0; 10.0 2.0 16.0;;;; 14.0 8.0 10.0; 5.0 18.0 9.0; 9.0 4.0 15.0;;; 15.0 9.0 1.0; 7.0 12.0 8.0; 5.0 8.0 12.0;;; 20.0 3.0 1.0; 9.0 14.0 1.0; 2.0 9.0 13.0 - ] + normalize!( + [ + 18.0 6.0 5.0; 7.0 16.0 5.0; 2.0 2.0 18.0;;; 14.0 9.0 1.0; 2.0 13.0 7.0; 1.0 8.0 14.0;;; 18.0 2.0 2.0; 9.0 11.0 4.0; 6.0 1.0 14.0;;;; 14.0 3.0 10.0; 10.0 17.0 9.0; 5.0 7.0 13.0;;; 13.0 10.0 5.0; 9.0 18.0 4.0; 1.0 4.0 12.0;;; 18.0 7.0 9.0; 2.0 20.0 7.0; 5.0 8.0 11.0;;;; 18.0 3.0 8.0; 9.0 20.0 9.0; 8.0 5.0 19.0;;; 15.0 9.0 2.0; 1.0 14.0 1.0; 4.0 2.0 12.0;;; 18.0 5.0 10.0; 1.0 18.0 2.0; 9.0 2.0 12.0;;;;; + 16.0 2.0 1.0; 5.0 18.0 8.0; 9.0 6.0 14.0;;; 17.0 6.0 6.0; 1.0 18.0 7.0; 4.0 5.0 16.0;;; 19.0 8.0 9.0; 4.0 11.0 9.0; 9.0 10.0 18.0;;;; 14.0 7.0 8.0; 1.0 18.0 2.0; 4.0 3.0 12.0;;; 15.0 5.0 4.0; 9.0 20.0 2.0; 7.0 10.0 15.0;;; 13.0 5.0 6.0; 6.0 15.0 4.0; 1.0 10.0 16.0;;;; 12.0 2.0 1.0; 3.0 18.0 7.0; 8.0 7.0 15.0;;; 19.0 3.0 4.0; 7.0 14.0 8.0; 1.0 6.0 14.0;;; 16.0 8.0 8.0; 5.0 16.0 9.0; 1.0 10.0 11.0;;;;; + 19.0 1.0 6.0; 3.0 20.0 9.0; 5.0 9.0 12.0;;; 13.0 6.0 1.0; 9.0 12.0 7.0; 3.0 1.0 20.0;;; 19.0 10.0 8.0; 6.0 16.0 6.0; 9.0 10.0 16.0;;;; 13.0 1.0 4.0; 6.0 18.0 8.0; 5.0 9.0 20.0;;; 18.0 8.0 5.0; 8.0 11.0 4.0; 7.0 9.0 13.0;;; 15.0 8.0 10.0; 10.0 14.0 10.0; 10.0 2.0 16.0;;;; 14.0 8.0 10.0; 5.0 18.0 9.0; 9.0 4.0 15.0;;; 15.0 9.0 1.0; 7.0 12.0 8.0; 5.0 8.0 12.0;;; 20.0 3.0 1.0; 9.0 14.0 1.0; 2.0 9.0 13.0 + ], + 1 + ) ) ), output = Categorical([0.3682781916337425, 0.3036462888681123, 0.3280755194981452]) @@ -993,11 +1052,14 @@ end m_T2 = Categorical([0.4, 0.3, 0.3]), m_T3 = Categorical([0.5, 0.2, 0.3]), q_a = PointMass( - [ - 15.0 4.0 9.0; 10.0 13.0 8.0; 2.0 8.0 15.0;;; 20.0 1.0 9.0; 5.0 15.0 2.0; 5.0 1.0 16.0;;; 19.0 9.0 9.0; 4.0 11.0 7.0; 2.0 10.0 16.0;;;; 19.0 3.0 3.0; 9.0 17.0 2.0; 9.0 10.0 20.0;;; 20.0 1.0 8.0; 6.0 16.0 6.0; 7.0 7.0 20.0;;; 11.0 8.0 10.0; 10.0 17.0 1.0; 7.0 3.0 14.0;;;; 12.0 10.0 6.0; 7.0 20.0 8.0; 7.0 2.0 20.0;;; 18.0 8.0 4.0; 8.0 20.0 7.0; 9.0 2.0 14.0;;; 13.0 10.0 3.0; 4.0 11.0 5.0; 2.0 5.0 17.0;;;;; - 11.0 3.0 4.0; 4.0 11.0 5.0; 8.0 5.0 17.0;;; 11.0 8.0 7.0; 10.0 11.0 3.0; 9.0 9.0 13.0;;; 15.0 8.0 5.0; 2.0 18.0 6.0; 2.0 8.0 15.0;;;; 16.0 8.0 3.0; 2.0 15.0 4.0; 7.0 4.0 13.0;;; 11.0 4.0 2.0; 5.0 18.0 9.0; 7.0 1.0 16.0;;; 16.0 3.0 8.0; 3.0 15.0 4.0; 7.0 8.0 13.0;;;; 16.0 3.0 6.0; 10.0 12.0 2.0; 2.0 9.0 17.0;;; 19.0 7.0 5.0; 7.0 13.0 4.0; 8.0 7.0 14.0;;; 13.0 6.0 9.0; 7.0 18.0 6.0; 7.0 2.0 11.0;;;;; - 11.0 1.0 7.0; 9.0 11.0 4.0; 8.0 1.0 11.0;;; 12.0 9.0 3.0; 4.0 15.0 1.0; 7.0 10.0 16.0;;; 11.0 8.0 8.0; 4.0 17.0 2.0; 7.0 2.0 17.0;;;; 13.0 2.0 9.0; 10.0 18.0 4.0; 10.0 2.0 17.0;;; 20.0 9.0 5.0; 10.0 15.0 3.0; 6.0 5.0 15.0;;; 18.0 5.0 3.0; 10.0 15.0 3.0; 7.0 5.0 20.0;;;; 16.0 7.0 6.0; 2.0 20.0 2.0; 2.0 8.0 11.0;;; 14.0 2.0 2.0; 3.0 13.0 2.0; 10.0 9.0 11.0;;; 11.0 6.0 1.0; 10.0 12.0 4.0; 4.0 5.0 13.0 - ] + normalize!( + [ + 15.0 4.0 9.0; 10.0 13.0 8.0; 2.0 8.0 15.0;;; 20.0 1.0 9.0; 5.0 15.0 2.0; 5.0 1.0 16.0;;; 19.0 9.0 9.0; 4.0 11.0 7.0; 2.0 10.0 16.0;;;; 19.0 3.0 3.0; 9.0 17.0 2.0; 9.0 10.0 20.0;;; 20.0 1.0 8.0; 6.0 16.0 6.0; 7.0 7.0 20.0;;; 11.0 8.0 10.0; 10.0 17.0 1.0; 7.0 3.0 14.0;;;; 12.0 10.0 6.0; 7.0 20.0 8.0; 7.0 2.0 20.0;;; 18.0 8.0 4.0; 8.0 20.0 7.0; 9.0 2.0 14.0;;; 13.0 10.0 3.0; 4.0 11.0 5.0; 2.0 5.0 17.0;;;;; + 11.0 3.0 4.0; 4.0 11.0 5.0; 8.0 5.0 17.0;;; 11.0 8.0 7.0; 10.0 11.0 3.0; 9.0 9.0 13.0;;; 15.0 8.0 5.0; 2.0 18.0 6.0; 2.0 8.0 15.0;;;; 16.0 8.0 3.0; 2.0 15.0 4.0; 7.0 4.0 13.0;;; 11.0 4.0 2.0; 5.0 18.0 9.0; 7.0 1.0 16.0;;; 16.0 3.0 8.0; 3.0 15.0 4.0; 7.0 8.0 13.0;;;; 16.0 3.0 6.0; 10.0 12.0 2.0; 2.0 9.0 17.0;;; 19.0 7.0 5.0; 7.0 13.0 4.0; 8.0 7.0 14.0;;; 13.0 6.0 9.0; 7.0 18.0 6.0; 7.0 2.0 11.0;;;;; + 11.0 1.0 7.0; 9.0 11.0 4.0; 8.0 1.0 11.0;;; 12.0 9.0 3.0; 4.0 15.0 1.0; 7.0 10.0 16.0;;; 11.0 8.0 8.0; 4.0 17.0 2.0; 7.0 2.0 17.0;;;; 13.0 2.0 9.0; 10.0 18.0 4.0; 10.0 2.0 17.0;;; 20.0 9.0 5.0; 10.0 15.0 3.0; 6.0 5.0 15.0;;; 18.0 5.0 3.0; 10.0 15.0 3.0; 7.0 5.0 20.0;;;; 16.0 7.0 6.0; 2.0 20.0 2.0; 2.0 8.0 11.0;;; 14.0 2.0 2.0; 3.0 13.0 2.0; 10.0 9.0 11.0;;; 11.0 6.0 1.0; 10.0 12.0 4.0; 4.0 5.0 13.0 + ], + 1 + ) ) ), output = Categorical([0.3456193186527781, 0.3381040532699693, 0.3162766280772525]) @@ -1053,11 +1115,14 @@ end m_T1 = Categorical([0.1, 0.6, 0.3]), m_T3 = Categorical([0.2, 0.3, 0.5]), q_a = PointMass( - [ - 18.0 6.0 5.0; 7.0 16.0 5.0; 2.0 2.0 18.0;;; 14.0 9.0 1.0; 2.0 13.0 7.0; 1.0 8.0 14.0;;; 18.0 2.0 2.0; 9.0 11.0 4.0; 6.0 1.0 14.0;;;; 14.0 3.0 10.0; 10.0 17.0 9.0; 5.0 7.0 13.0;;; 13.0 10.0 5.0; 9.0 18.0 4.0; 1.0 4.0 12.0;;; 18.0 7.0 9.0; 2.0 20.0 7.0; 5.0 8.0 11.0;;;; 18.0 3.0 8.0; 9.0 20.0 9.0; 8.0 5.0 19.0;;; 15.0 9.0 2.0; 1.0 14.0 1.0; 4.0 2.0 12.0;;; 18.0 5.0 10.0; 1.0 18.0 2.0; 9.0 2.0 12.0;;;;; - 16.0 2.0 1.0; 5.0 18.0 8.0; 9.0 6.0 14.0;;; 17.0 6.0 6.0; 1.0 18.0 7.0; 4.0 5.0 16.0;;; 19.0 8.0 9.0; 4.0 11.0 9.0; 9.0 10.0 18.0;;;; 14.0 7.0 8.0; 1.0 18.0 2.0; 4.0 3.0 12.0;;; 15.0 5.0 4.0; 9.0 20.0 2.0; 7.0 10.0 15.0;;; 13.0 5.0 6.0; 6.0 15.0 4.0; 1.0 10.0 16.0;;;; 12.0 2.0 1.0; 3.0 18.0 7.0; 8.0 7.0 15.0;;; 19.0 3.0 4.0; 7.0 14.0 8.0; 1.0 6.0 14.0;;; 16.0 8.0 8.0; 5.0 16.0 9.0; 1.0 10.0 11.0;;;;; - 19.0 1.0 6.0; 3.0 20.0 9.0; 5.0 9.0 12.0;;; 13.0 6.0 1.0; 9.0 12.0 7.0; 3.0 1.0 20.0;;; 19.0 10.0 8.0; 6.0 16.0 6.0; 9.0 10.0 16.0;;;; 13.0 1.0 4.0; 6.0 18.0 8.0; 5.0 9.0 20.0;;; 18.0 8.0 5.0; 8.0 11.0 4.0; 7.0 9.0 13.0;;; 15.0 8.0 10.0; 10.0 14.0 10.0; 10.0 2.0 16.0;;;; 14.0 8.0 10.0; 5.0 18.0 9.0; 9.0 4.0 15.0;;; 15.0 9.0 1.0; 7.0 12.0 8.0; 5.0 8.0 12.0;;; 20.0 3.0 1.0; 9.0 14.0 1.0; 2.0 9.0 13.0 - ] + normalize!( + [ + 18.0 6.0 5.0; 7.0 16.0 5.0; 2.0 2.0 18.0;;; 14.0 9.0 1.0; 2.0 13.0 7.0; 1.0 8.0 14.0;;; 18.0 2.0 2.0; 9.0 11.0 4.0; 6.0 1.0 14.0;;;; 14.0 3.0 10.0; 10.0 17.0 9.0; 5.0 7.0 13.0;;; 13.0 10.0 5.0; 9.0 18.0 4.0; 1.0 4.0 12.0;;; 18.0 7.0 9.0; 2.0 20.0 7.0; 5.0 8.0 11.0;;;; 18.0 3.0 8.0; 9.0 20.0 9.0; 8.0 5.0 19.0;;; 15.0 9.0 2.0; 1.0 14.0 1.0; 4.0 2.0 12.0;;; 18.0 5.0 10.0; 1.0 18.0 2.0; 9.0 2.0 12.0;;;;; + 16.0 2.0 1.0; 5.0 18.0 8.0; 9.0 6.0 14.0;;; 17.0 6.0 6.0; 1.0 18.0 7.0; 4.0 5.0 16.0;;; 19.0 8.0 9.0; 4.0 11.0 9.0; 9.0 10.0 18.0;;;; 14.0 7.0 8.0; 1.0 18.0 2.0; 4.0 3.0 12.0;;; 15.0 5.0 4.0; 9.0 20.0 2.0; 7.0 10.0 15.0;;; 13.0 5.0 6.0; 6.0 15.0 4.0; 1.0 10.0 16.0;;;; 12.0 2.0 1.0; 3.0 18.0 7.0; 8.0 7.0 15.0;;; 19.0 3.0 4.0; 7.0 14.0 8.0; 1.0 6.0 14.0;;; 16.0 8.0 8.0; 5.0 16.0 9.0; 1.0 10.0 11.0;;;;; + 19.0 1.0 6.0; 3.0 20.0 9.0; 5.0 9.0 12.0;;; 13.0 6.0 1.0; 9.0 12.0 7.0; 3.0 1.0 20.0;;; 19.0 10.0 8.0; 6.0 16.0 6.0; 9.0 10.0 16.0;;;; 13.0 1.0 4.0; 6.0 18.0 8.0; 5.0 9.0 20.0;;; 18.0 8.0 5.0; 8.0 11.0 4.0; 7.0 9.0 13.0;;; 15.0 8.0 10.0; 10.0 14.0 10.0; 10.0 2.0 16.0;;;; 14.0 8.0 10.0; 5.0 18.0 9.0; 9.0 4.0 15.0;;; 15.0 9.0 1.0; 7.0 12.0 8.0; 5.0 8.0 12.0;;; 20.0 3.0 1.0; 9.0 14.0 1.0; 2.0 9.0 13.0 + ], + 1 + ) ) ), output = Categorical([0.3307771289225438, 0.3462427003786177, 0.32298017069883844]) @@ -1069,11 +1134,14 @@ end m_T1 = Categorical([0.4, 0.3, 0.3]), m_T3 = Categorical([0.5, 0.2, 0.3]), q_a = PointMass( - [ - 15.0 4.0 9.0; 10.0 13.0 8.0; 2.0 8.0 15.0;;; 20.0 1.0 9.0; 5.0 15.0 2.0; 5.0 1.0 16.0;;; 19.0 9.0 9.0; 4.0 11.0 7.0; 2.0 10.0 16.0;;;; 19.0 3.0 3.0; 9.0 17.0 2.0; 9.0 10.0 20.0;;; 20.0 1.0 8.0; 6.0 16.0 6.0; 7.0 7.0 20.0;;; 11.0 8.0 10.0; 10.0 17.0 1.0; 7.0 3.0 14.0;;;; 12.0 10.0 6.0; 7.0 20.0 8.0; 7.0 2.0 20.0;;; 18.0 8.0 4.0; 8.0 20.0 7.0; 9.0 2.0 14.0;;; 13.0 10.0 3.0; 4.0 11.0 5.0; 2.0 5.0 17.0;;;;; - 11.0 3.0 4.0; 4.0 11.0 5.0; 8.0 5.0 17.0;;; 11.0 8.0 7.0; 10.0 11.0 3.0; 9.0 9.0 13.0;;; 15.0 8.0 5.0; 2.0 18.0 6.0; 2.0 8.0 15.0;;;; 16.0 8.0 3.0; 2.0 15.0 4.0; 7.0 4.0 13.0;;; 11.0 4.0 2.0; 5.0 18.0 9.0; 7.0 1.0 16.0;;; 16.0 3.0 8.0; 3.0 15.0 4.0; 7.0 8.0 13.0;;;; 16.0 3.0 6.0; 10.0 12.0 2.0; 2.0 9.0 17.0;;; 19.0 7.0 5.0; 7.0 13.0 4.0; 8.0 7.0 14.0;;; 13.0 6.0 9.0; 7.0 18.0 6.0; 7.0 2.0 11.0;;;;; - 11.0 1.0 7.0; 9.0 11.0 4.0; 8.0 1.0 11.0;;; 12.0 9.0 3.0; 4.0 15.0 1.0; 7.0 10.0 16.0;;; 11.0 8.0 8.0; 4.0 17.0 2.0; 7.0 2.0 17.0;;;; 13.0 2.0 9.0; 10.0 18.0 4.0; 10.0 2.0 17.0;;; 20.0 9.0 5.0; 10.0 15.0 3.0; 6.0 5.0 15.0;;; 18.0 5.0 3.0; 10.0 15.0 3.0; 7.0 5.0 20.0;;;; 16.0 7.0 6.0; 2.0 20.0 2.0; 2.0 8.0 11.0;;; 14.0 2.0 2.0; 3.0 13.0 2.0; 10.0 9.0 11.0;;; 11.0 6.0 1.0; 10.0 12.0 4.0; 4.0 5.0 13.0 - ] + normalize!( + [ + 15.0 4.0 9.0; 10.0 13.0 8.0; 2.0 8.0 15.0;;; 20.0 1.0 9.0; 5.0 15.0 2.0; 5.0 1.0 16.0;;; 19.0 9.0 9.0; 4.0 11.0 7.0; 2.0 10.0 16.0;;;; 19.0 3.0 3.0; 9.0 17.0 2.0; 9.0 10.0 20.0;;; 20.0 1.0 8.0; 6.0 16.0 6.0; 7.0 7.0 20.0;;; 11.0 8.0 10.0; 10.0 17.0 1.0; 7.0 3.0 14.0;;;; 12.0 10.0 6.0; 7.0 20.0 8.0; 7.0 2.0 20.0;;; 18.0 8.0 4.0; 8.0 20.0 7.0; 9.0 2.0 14.0;;; 13.0 10.0 3.0; 4.0 11.0 5.0; 2.0 5.0 17.0;;;;; + 11.0 3.0 4.0; 4.0 11.0 5.0; 8.0 5.0 17.0;;; 11.0 8.0 7.0; 10.0 11.0 3.0; 9.0 9.0 13.0;;; 15.0 8.0 5.0; 2.0 18.0 6.0; 2.0 8.0 15.0;;;; 16.0 8.0 3.0; 2.0 15.0 4.0; 7.0 4.0 13.0;;; 11.0 4.0 2.0; 5.0 18.0 9.0; 7.0 1.0 16.0;;; 16.0 3.0 8.0; 3.0 15.0 4.0; 7.0 8.0 13.0;;;; 16.0 3.0 6.0; 10.0 12.0 2.0; 2.0 9.0 17.0;;; 19.0 7.0 5.0; 7.0 13.0 4.0; 8.0 7.0 14.0;;; 13.0 6.0 9.0; 7.0 18.0 6.0; 7.0 2.0 11.0;;;;; + 11.0 1.0 7.0; 9.0 11.0 4.0; 8.0 1.0 11.0;;; 12.0 9.0 3.0; 4.0 15.0 1.0; 7.0 10.0 16.0;;; 11.0 8.0 8.0; 4.0 17.0 2.0; 7.0 2.0 17.0;;;; 13.0 2.0 9.0; 10.0 18.0 4.0; 10.0 2.0 17.0;;; 20.0 9.0 5.0; 10.0 15.0 3.0; 6.0 5.0 15.0;;; 18.0 5.0 3.0; 10.0 15.0 3.0; 7.0 5.0 20.0;;;; 16.0 7.0 6.0; 2.0 20.0 2.0; 2.0 8.0 11.0;;; 14.0 2.0 2.0; 3.0 13.0 2.0; 10.0 9.0 11.0;;; 11.0 6.0 1.0; 10.0 12.0 4.0; 4.0 5.0 13.0 + ], + 1 + ) ) ), output = Categorical([0.29589809194588856, 0.3530580581942895, 0.35104384985982195])