-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathStateSerialization.jl
More file actions
177 lines (159 loc) · 5.39 KB
/
StateSerialization.jl
File metadata and controls
177 lines (159 loc) · 5.39 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
function stringVariableType(varT::AbstractStateType{N}) where {N}
T = typeof(varT)
if N == Any
return string(parentmodule(T), ".", nameof(T))
elseif N isa Integer
return string(parentmodule(T), ".", nameof(T), "{", join(N, ","), "}")
else
throw(
SerializationError(
"Serializing Variable State type only supports an integer parameter, got '$(T)'.",
),
)
end
end
function parseVariableType(_typeString::AbstractString)
m = match(r"{(\d+)}", _typeString)
if !isnothing(m) #parameters in type
param = parse(Int, m[1])
typeString = _typeString[1:(m.offset - 1)]
else
param = nothing
typeString = _typeString
end
split_typeSyms = Symbol.(split(typeString, "."))
subtype = nothing
if length(split_typeSyms) == 1
@warn "Module not found in variable '$typeString'." maxlog = 1
subtype = getfield(Main, split_typeSyms[1]) # no module specified, use Main
#FIXME interm fallback for backwards compatibility in IIFTypes and RoMETypes
elseif split_typeSyms[1] in Symbol.(values(Base.loaded_modules))
m = getfield(Main, split_typeSyms[1])
subtype = getfield(m, split_typeSyms[end])
else
@warn "Module not found in Main, using Main for type '$typeString'." maxlog = 1
subtype = getfield(Main, split_typeSyms[end])
end
if isnothing(subtype)
throw(SerializationError("Unable to deserialize type $(_typeString), not found"))
end
if isnothing(param)
# no parameters, just return the type
return subtype
else
# return the type with parameters
return subtype{param}
end
end
##==============================================================================
## State Packing and unpacking
##==============================================================================
# Old PackedState struct fields
# id::Union{UUID, Nothing}
# vecval::Vector{Float64}
# dimval::Int
# vecbw::Vector{Float64}
# dimbw::Int
# BayesNetOutVertIDs::Vector{Symbol}
# dimIDs::Vector{Int}
# dims::Int
# eliminated::Bool
# BayesNetVertID::Symbol
# separator::Vector{Symbol}
# variableType::String
# initialized::Bool
# infoPerCoord::Vector{Float64}
# ismargin::Bool
# dontmargin::Bool
# solvedCount::Int
# solveKey::Symbol
# covar::Vector{Float64}
# _version::VersionNumber = _getDFGVersion()
# returns a named tuple until State serialization is fully consolidated
function packState(state::State{T}) where {T <: StateType}
castval = if 0 < length(state.val)
precast = getCoordinates.(T, state.val)
@cast castval[i, j] := precast[j][i]
castval
else
zeros(1, 0)
end
length(state.covar) > 1 && @warn(
"Packing of more than one parametric covariance is NOT supported yet, only packing first."
)
return (
label = state.label,
vecval = castval[:],
dimval = size(castval, 1),
vecbw = state.bw[:],
dimbw = size(state.bw, 1),
separator = state.separator,
statetype = stringVariableType(getStateKind(state)),
initialized = state.initialized,
observability = state.observability,
marginalized = state.marginalized,
solves = state.solves,
covar = isempty(state.covar) ? Float64[] : vec(state.covar[1]),
version = version(State),
)
end
function unpackState(obj)
T = parseVariableType(obj.statetype)
r3 = obj.dimval
c3 = r3 > 0 ? floor(Int, length(obj.vecval) / r3) : 0
M3 = reshape(obj.vecval, r3, c3)
@cast val_[j][i] := M3[i, j]
vals = Vector{getPointType(T)}(undef, length(val_))
# vals = getPoint.(T, val_)
for (i, v) in enumerate(val_)
vals[i] = getPoint(T, v)
end
r4 = obj.dimbw
c4 = r4 > 0 ? floor(Int, length(obj.vecbw) / r4) : 0
BW = reshape(obj.vecbw, r4, c4)
#
N = getDimension(T)
return State{T, getPointType(T), N}(;
label = Symbol(obj.label),
val = vals,
bw = BW,
#TODO only one covar is currently supported in packed VND
covar = isempty(obj.covar) ? SMatrix{N, N, Float64}[] : [obj.covar],
separator = Symbol.(obj.separator),
initialized = obj.initialized,
observability = obj.observability,
marginalized = obj.marginalized,
solves = obj.solves,
)
end
function unpackOldState(d)
@debug "Dispatching conversion packed variable -> variable for type $(string(d.variableType))"
# Figuring out the variableType
T = parseVariableType(d.variableType)
r3 = d.dimval
c3 = r3 > 0 ? floor(Int, length(d.vecval) / r3) : 0
M3 = reshape(d.vecval, r3, c3)
@cast val_[j][i] := M3[i, j]
vals = Vector{getPointType(T)}(undef, length(val_))
# vals = getPoint.(T, val_)
for (i, v) in enumerate(val_)
vals[i] = getPoint(T, v)
end
r4 = d.dimbw
c4 = r4 > 0 ? floor(Int, length(d.vecbw) / r4) : 0
BW = reshape(d.vecbw, r4, c4)
#
N = getDimension(T)
return State{T, getPointType(T), N}(;
label = Symbol(d.label),
val = vals,
bw = BW,
#TODO only one covar is currently supported in packed VND
covar = isempty(d.covar) ? SMatrix{N, N, Float64}[] : [d.covar],
separator = Symbol.(d.separator),
initialized = d.initialized,
observability = d.infoPerCoord,
marginalized = d.ismargin,
solves = d.solvedCount,
)
end