-
Notifications
You must be signed in to change notification settings - Fork 151
Expand file tree
/
Copy pathproto.rs
More file actions
90 lines (76 loc) · 3.23 KB
/
proto.rs
File metadata and controls
90 lines (76 loc) · 3.23 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright the Vortex contributors
//! Protobuf serialization for [`FixedShapeTensorMetadata`].
use prost::Message;
use vortex_error::VortexExpect;
use vortex_error::VortexResult;
use vortex_error::vortex_err;
use crate::fixed_shape::FixedShapeTensorMetadata;
/// Protobuf representation of [`FixedShapeTensorMetadata`].
///
/// Protobuf does not distinguish between an absent repeated field and an empty one (both will
/// deserialize as an empty `Vec`). This is fine because the semantic meaning is unambiguous:
///
/// - `logical_shape` empty: 0-dimensional (scalar) tensor.
/// - `dim_names` empty: no dimension names (`None`).
/// - `permutation` empty: no permutation, i.e., identity layout (`None`).
#[derive(Clone, PartialEq, Message)]
struct FixedShapeTensorMetadataProto {
/// The size of each logical dimension. Empty for a 0-dimensional scalar tensor.
#[prost(uint32, repeated, tag = "1")]
logical_shape: Vec<u32>,
/// Optional human-readable names for each logical dimension. When present, must have the
/// same length as `logical_shape`. Empty means no names are set.
#[prost(string, repeated, tag = "2")]
dim_names: Vec<String>,
/// Optional dimension permutation mapping logical to physical indices. When present, must
/// be a permutation of `[0, 1, ..., N-1]`. Empty means identity (row-major) layout.
#[prost(uint32, repeated, tag = "3")]
permutation: Vec<u32>,
}
/// Serializes [`FixedShapeTensorMetadata`] to protobuf bytes.
pub(crate) fn serialize(metadata: &FixedShapeTensorMetadata) -> Vec<u8> {
let logical_shape = metadata
.logical_shape()
.iter()
.map(|&d| u32::try_from(d).vortex_expect("dimension size exceeds u32"))
.collect();
let dim_names = metadata.dim_names().map(|n| n.to_vec()).unwrap_or_default();
let permutation = metadata
.permutation()
.map(|p| {
p.iter()
.map(|&i| u32::try_from(i).vortex_expect("permutation index exceeds u32"))
.collect()
})
.unwrap_or_default();
let proto = FixedShapeTensorMetadataProto {
logical_shape,
dim_names,
permutation,
};
proto.encode_to_vec()
}
/// Deserializes [`FixedShapeTensorMetadata`] from protobuf bytes.
///
/// For 0-dimensional tensors, all three repeated fields are empty, which correctly produces a
/// metadata with an empty shape and no names or permutation.
pub(crate) fn deserialize(bytes: &[u8]) -> VortexResult<FixedShapeTensorMetadata> {
let proto = FixedShapeTensorMetadataProto::decode(bytes).map_err(|e| vortex_err!("{e}"))?;
let logical_shape = proto
.logical_shape
.into_iter()
.map(|d| d as usize)
.collect();
let mut m = FixedShapeTensorMetadata::new(logical_shape);
// Note that this is fine for 0 dimensions since if we do not have any dimensions, we cannot
// have any names or permutations.
if !proto.dim_names.is_empty() {
m = m.with_dim_names(proto.dim_names)?;
}
if !proto.permutation.is_empty() {
let permutation = proto.permutation.into_iter().map(|i| i as usize).collect();
m = m.with_permutation(permutation)?;
}
Ok(m)
}