-
Notifications
You must be signed in to change notification settings - Fork 19
Expand file tree
/
Copy pathmodel_instance_patch.rs
More file actions
280 lines (267 loc) · 11.3 KB
/
model_instance_patch.rs
File metadata and controls
280 lines (267 loc) · 11.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
use std::{collections::BTreeMap, sync::Arc};
use arrow::{
array::{Array as _, ArrayRef, BinaryBuilder, StructArray, UInt32Builder},
datatypes::DataType,
};
use arrow_schema::Field;
use cid::Cid;
use datafusion::{
common::{
cast::{as_binary_array, as_uint32_array},
exec_datafusion_err, Result,
},
logical_expr::{
function::PartitionEvaluatorArgs, PartitionEvaluator, Signature, TypeSignature, Volatility,
WindowUDF, WindowUDFImpl,
},
};
use json_patch::PatchOperation;
use tracing::warn;
use super::{bytes_value_at, u32_value_at, EventDataContainer};
/// Applies a Ceramic data event to a document state returning the new document state.
#[derive(Debug)]
pub struct ModelInstancePatch {
signature: Signature,
}
impl ModelInstancePatch {
pub fn new_udwf() -> WindowUDF {
WindowUDF::new_from_impl(Self::new())
}
fn new() -> Self {
Self {
signature: Signature::new(
TypeSignature::Exact(vec![
// Event CID
DataType::Binary,
// Previous CID
DataType::Binary,
// Previous State
DataType::Binary,
// Previous event height
DataType::UInt32,
// State/Patch
DataType::Binary,
]),
Volatility::Immutable,
),
}
}
}
impl WindowUDFImpl for ModelInstancePatch {
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn name(&self) -> &str {
"model_instance_patch"
}
fn signature(&self) -> &datafusion::logical_expr::Signature {
&self.signature
}
fn partition_evaluator(
&self,
_partition_evaluator_args: PartitionEvaluatorArgs,
) -> Result<Box<dyn PartitionEvaluator>> {
Ok(Box::new(CeramicPatchEvaluator))
}
fn field(
&self,
field_args: datafusion::logical_expr::function::WindowUDFFieldArgs,
) -> Result<arrow_schema::Field> {
Ok(Field::new_struct(
field_args.name(),
vec![
Field::new("model_version", DataType::Binary, true),
Field::new("data", DataType::Binary, true),
Field::new("patch", DataType::Binary, true),
Field::new("event_height", DataType::UInt32, true),
],
true,
))
}
}
type MIDDataContainerPatch = EventDataContainer<Vec<PatchOperation>>;
type MIDDataContainerState = EventDataContainer<serde_json::Value>;
/// EventDataContainer with only the metadata.
#[derive(Debug, serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "camelCase")]
struct MetadataContainer {
metadata: BTreeMap<String, serde_json::Value>,
}
#[derive(Debug)]
struct CeramicPatchEvaluator;
impl CeramicPatchEvaluator {
fn apply_patch(patch: &[u8], previous_state: &[u8]) -> Result<(Vec<u8>, Option<Cid>)> {
let patch: MIDDataContainerPatch = serde_json::from_slice(patch)
.map_err(|err| exec_datafusion_err!("Error parsing patch: {err}"))?;
let mut state: MIDDataContainerState = serde_json::from_slice(previous_state)
.map_err(|err| exec_datafusion_err!("Error parsing previous state: {err}"))?;
// If the state is null use an empty object in order to apply the patch to a valid object.
if serde_json::Value::Null == state.content {
state.content = serde_json::Value::Object(serde_json::Map::default());
}
state.metadata.extend(patch.metadata);
json_patch::patch(&mut state.content, &patch.content)
.map_err(|err| exec_datafusion_err!("Error applying JSON patch: {err}"))?;
let model_version = state
.metadata
.get("modelVersion")
.and_then(|mv| {
mv.as_str().map(|mv| -> Result<_> {
mv.parse().map_err(|err| {
exec_datafusion_err!("modelVersion must be a valid CID: {err}")
})
})
})
.transpose()?;
Ok((
serde_json::to_vec(&state)
.map_err(|err| exec_datafusion_err!("Error JSON encoding: {err}"))?,
model_version,
))
}
fn parse_model_version(data: &[u8]) -> Result<Option<Cid>> {
if data.is_empty() {
return Ok(None);
}
let patch: MetadataContainer = serde_json::from_slice(data)
.map_err(|err| exec_datafusion_err!("Error parsing model version from data: {err}"))?;
patch
.metadata
.get("modelVersion")
.and_then(|mv| {
mv.as_str().map(|mv| -> Result<_> {
mv.parse().map_err(|err| {
exec_datafusion_err!("modelVersion must be a valid CID: {err}")
})
})
})
.transpose()
}
}
impl PartitionEvaluator for CeramicPatchEvaluator {
// Compute the new state of each document for a batch of events.
// Produces num_rows new document states, i.e. one for each input event.
//
// Assumption made by the function:
// * Window partitions are by stream_cid
// * Rows are ordered by the index column
//
// With these assumptions the code assumes it has all events for a stream and only events from
// a single stream.
// Additionally index sort order means that any event's previous event comes earlier in the
// data set and so a single pass algorithm can be implemented.
//
// Input data must have the following columns:
// * event_cid - unique id of the event
// * previous_cid - id of the previous event, nullable implies an init event.
// * previous_state - state of the previous event, nullable implies that the previous event
// exists in the current dataset.
// * patch - json patch to apply
fn evaluate_all(&mut self, values: &[ArrayRef], num_rows: usize) -> Result<ArrayRef> {
let event_cids = as_binary_array(&values[0])?;
let previous_cids = as_binary_array(&values[1])?;
let previous_states = as_binary_array(&values[2])?;
let previous_heights = as_uint32_array(&values[3])?;
let patches = as_binary_array(&values[4])?;
let mut new_states = BinaryBuilder::new();
let mut model_versions = BinaryBuilder::new();
let mut new_heights = UInt32Builder::new();
// We need to keep the patch around for validation.
let mut resolved_patches = BinaryBuilder::new();
for i in 0..num_rows {
if previous_cids.is_valid(i) {
if let Some((previous_state, previous_height)) = if previous_states.is_valid(i) {
// We know the previous state already
Some((previous_states.value(i), previous_heights.value(i)))
} else {
// Iterator backwards till we find the previous state among the new states.
let previous_cid = previous_cids.value(i);
let mut j = i;
loop {
if j == 0 {
break None;
}
j -= 1;
if event_cids.value(j) == previous_cid {
break Some((
bytes_value_at(&new_states, j),
u32_value_at(&new_heights, j),
));
}
}
} {
new_heights.append_value(previous_height + 1);
if patches.is_null(i) {
// We have a time event, new state is just the previous state
model_versions.append_option(
Self::parse_model_version(previous_state)?.map(|mv| mv.to_bytes()),
);
// Time events do not change data, no patch
resolved_patches.append_null();
// Allow clippy warning as previous_state is a reference back into new_states.
// So we need to copy the data to a new location before we can copy it back
// into the new_states.
#[allow(clippy::unnecessary_to_owned)]
new_states.append_value(previous_state.to_owned());
} else {
let patch = patches.value(i);
resolved_patches.append_value(patch);
match Self::apply_patch(patch, previous_state) {
Ok((data, model_version)) => {
new_states.append_value(data);
model_versions.append_option(model_version.map(|mv| mv.to_bytes()));
}
Err(err) => {
warn!(%err, event_cid=?Cid::read_bytes(event_cids.value(i)), "failed to apply patch to model instance event");
tracing::debug!(%previous_height, %num_rows, patch=?String::from_utf8_lossy(patches.value(i)), previous_state=?String::from_utf8_lossy(previous_state), "failed to apply patch to model instance event");
new_states.append_null();
model_versions.append_null();
}
};
}
} else {
// Unreachable when data is well formed.
// Appending null means well formed documents can continue to be aggregated.
new_states.append_null();
model_versions.append_null();
resolved_patches.append_null();
new_heights.append_null();
}
} else {
//Init event, patch value is the initial state
new_heights.append_value(0);
if patches.is_null(i) {
// We have an init event without data
new_states.append_null();
model_versions.append_null();
resolved_patches.append_null();
} else {
let data = patches.value(i);
new_states.append_value(data);
model_versions
.append_option(Self::parse_model_version(data)?.map(|mv| mv.to_bytes()));
// An init event's initial data is not a patch.
resolved_patches.append_null();
}
}
}
Ok(Arc::new(StructArray::from(vec![
(
Arc::new(Field::new("model_version", DataType::Binary, true)),
Arc::new(model_versions.finish()) as ArrayRef,
),
(
Arc::new(Field::new("data", DataType::Binary, true)),
Arc::new(new_states.finish()) as ArrayRef,
),
(
Arc::new(Field::new("patch", DataType::Binary, true)),
Arc::new(resolved_patches.finish()) as ArrayRef,
),
(
Arc::new(Field::new("event_height", DataType::UInt32, true)),
Arc::new(new_heights.finish()) as ArrayRef,
),
])))
}
}