Skip to content

Commit 82a2fb3

Browse files
committed
mmds: performance: use zerocopy instead of bitcode for MMDS tokens
MMDS Token is just a POD type with byte arrays inside. This means we can convert it to base64 directly without needing intermediate bitcode serialization/deserialization. This improves performance by 5-30% on aarch64 platform. Signed-off-by: Egor Lazarchuk <yegorlz@amazon.co.uk>
1 parent c07bff1 commit 82a2fb3

1 file changed

Lines changed: 9 additions & 26 deletions

File tree

src/vmm/src/mmds/token.rs

Lines changed: 9 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@ use std::ops::Add;
77

88
use aws_lc_rs::aead::{AES_256_GCM, Aad, Nonce, RandomizedNonceKey};
99
use base64::Engine;
10-
use serde::{Deserialize, Serialize};
1110
use utils::time::{ClockType, get_time_ms};
11+
use zerocopy::{FromBytes, Immutable, IntoBytes};
1212

1313
/// Length of initialization vector.
1414
pub const IV_LEN: usize = 12;
@@ -19,11 +19,6 @@ pub const PAYLOAD_LEN: usize = std::mem::size_of::<u64>();
1919
/// Length of encryption tag.
2020
pub const TAG_LEN: usize = 16;
2121

22-
/// Maximum size in bytes for token deserialization to prevent DOS attacks.
23-
/// The Token struct contains fixed-size arrays (IV_LEN + PAYLOAD_LEN + TAG_LEN = 36 bytes)
24-
/// plus bitcode serialization overhead. This limit provides a safe margin.
25-
const TOKEN_DESERIALIZATION_BYTES_LIMIT: usize = 100;
26-
2722
/// Constant to convert seconds to milliseconds.
2823
pub const MILLISECONDS_PER_SECOND: u64 = 1_000;
2924

@@ -50,8 +45,6 @@ pub enum MmdsTokenError {
5045
ExpiryExtraction,
5146
/// Invalid time to live value provided for token: {0}. Please provide a value between {MIN_TOKEN_TTL_SECONDS:} and {MAX_TOKEN_TTL_SECONDS:}.
5247
InvalidTtlValue(u32),
53-
/// Bitcode serialization failed: {0}.
54-
Serialization(#[from] bitcode::Error),
5548
/// Failed to encrypt token.
5649
TokenEncryption,
5750
}
@@ -99,7 +92,7 @@ impl TokenAuthority {
9992
// Create token structure containing the encrypted expiry value.
10093
let token = self.create_token(ttl_seconds)?;
10194
// Encode struct into base64 in order to obtain token string.
102-
let encoded_token = token.base64_encode()?;
95+
let encoded_token = token.base64_encode();
10396
// Increase the count of encrypted tokens.
10497
self.num_encrypted_tokens += 1;
10598

@@ -245,7 +238,8 @@ impl TokenAuthority {
245238
}
246239

247240
/// Structure for token information.
248-
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
241+
#[derive(Clone, Debug, FromBytes, Immutable, IntoBytes, PartialEq)]
242+
#[repr(C)]
249243
struct Token {
250244
// Nonce or Initialization Vector.
251245
iv: [u8; IV_LEN],
@@ -262,27 +256,16 @@ impl Token {
262256
}
263257

264258
/// Encode token structure into a string using base64 encoding.
265-
fn base64_encode(&self) -> Result<String, MmdsTokenError> {
266-
let token_bytes: Vec<u8> = bitcode::serialize(self)?;
267-
268-
// Encode token structure bytes into base64.
269-
Ok(base64::engine::general_purpose::STANDARD.encode(token_bytes))
259+
fn base64_encode(&self) -> String {
260+
base64::engine::general_purpose::STANDARD.encode(self.as_bytes())
270261
}
271262

272263
/// Decode token structure from base64 string.
273264
fn base64_decode(encoded_token: &str) -> Result<Self, MmdsTokenError> {
274-
let token_bytes = base64::engine::general_purpose::STANDARD
265+
let bytes = base64::engine::general_purpose::STANDARD
275266
.decode(encoded_token)
276267
.map_err(|_| MmdsTokenError::ExpiryExtraction)?;
277-
278-
// Check size limit to prevent DOS attacks
279-
if token_bytes.len() > TOKEN_DESERIALIZATION_BYTES_LIMIT {
280-
return Err(MmdsTokenError::ExpiryExtraction);
281-
}
282-
283-
let token: Token =
284-
bitcode::deserialize(&token_bytes).map_err(|_| MmdsTokenError::ExpiryExtraction)?;
285-
Ok(token)
268+
Self::read_from_bytes(&bytes).map_err(|_| MmdsTokenError::ExpiryExtraction)
286269
}
287270
}
288271

@@ -392,7 +375,7 @@ mod tests {
392375
#[test]
393376
fn test_encode_decode() {
394377
let expected_token = Token::new([0u8; IV_LEN], [0u8; PAYLOAD_LEN], [0u8; TAG_LEN]);
395-
let mut encoded_token = expected_token.base64_encode().unwrap();
378+
let mut encoded_token = expected_token.base64_encode();
396379
let actual_token = Token::base64_decode(&encoded_token).unwrap();
397380
assert_eq!(actual_token, expected_token);
398381

0 commit comments

Comments
 (0)