Skip to content

Commit c8d01d8

Browse files
committed
Assert the VSS protocol version matches
We assert that the VSS protocol version matches on all responses returned from the server, including errors.
1 parent 283af94 commit c8d01d8

3 files changed

Lines changed: 78 additions & 9 deletions

File tree

src/client.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ const CONTENT_TYPE: &str = "content-type";
2020
const DEFAULT_TIMEOUT_SECS: u64 = 10;
2121
const MAX_RESPONSE_BODY_SIZE: usize = 1024 * 1024 * 1024; // 1GB
2222
const DEFAULT_CLIENT_CAPACITY: usize = 10;
23+
const PROTOCOL_VERSION_HEADER: &str = "vss-protocol-version";
24+
const PROTOCOL_VERSION: &str = "0";
2325

2426
/// Thin-client to access a hosted instance of Versioned Storage Service (VSS).
2527
/// The provided [`VssClient`] API is minimalistic and is congruent to the VSS server-side API.
@@ -212,6 +214,16 @@ impl<R: RetryPolicy<E = VssError>> VssClient<R> {
212214
}
213215

214216
let response = self.client.send_async(http_request).await?;
217+
// Return early in case of version mismatch, this issue must be solved first.
218+
if response.headers.get(PROTOCOL_VERSION_HEADER).map(String::as_str)
219+
!= Some(PROTOCOL_VERSION)
220+
{
221+
let mut response = response;
222+
return Err(VssError::VSSVersionMismatchError {
223+
version_served: response.headers.remove(PROTOCOL_VERSION_HEADER),
224+
version_expected: String::from(PROTOCOL_VERSION),
225+
});
226+
}
215227

216228
let status_code = response.status_code;
217229
let payload = response.into_bytes();

src/error.rs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,14 @@ pub enum VssError {
2626
/// There is an unknown error, it could be a client-side bug, unrecognized error-code, network error
2727
/// or something else.
2828
InternalError(String),
29+
30+
/// The VSS server and client speak different versions of the VSS protocol
31+
VSSVersionMismatchError {
32+
/// The VSS protocol version served
33+
version_served: Option<String>,
34+
/// The VSS protocol version expected
35+
version_expected: String,
36+
},
2937
}
3038

3139
impl VssError {
@@ -62,6 +70,21 @@ impl Display for VssError {
6270
VssError::InternalServerError(message) => {
6371
write!(f, "InternalServerError: {}", message)
6472
},
73+
VssError::VSSVersionMismatchError {
74+
version_served: Some(served),
75+
version_expected,
76+
} => {
77+
write!(
78+
f,
79+
"The VSS server and client speak different versions of the \
80+
VSS protocol, the server serves version {}, client expects \
81+
{}",
82+
served, version_expected,
83+
)
84+
},
85+
VssError::VSSVersionMismatchError { version_served: None, version_expected: _ } => {
86+
write!(f, "The server did not set the `vss-protocol-version` header")
87+
},
6588
VssError::InternalError(message) => {
6689
write!(f, "InternalError: {}", message)
6790
},

tests/tests.rs

Lines changed: 43 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ mod tests {
2121

2222
const APPLICATION_OCTET_STREAM: &str = "application/octet-stream";
2323
const CONTENT_TYPE: &str = "content-type";
24+
const PROTOCOL_VERSION_HEADER: &str = "vss-protocol-version";
25+
const PROTOCOL_VERSION: &str = "0";
2426

2527
const GET_OBJECT_ENDPOINT: &'static str = "/getObject";
2628
const PUT_OBJECT_ENDPOINT: &'static str = "/putObjects";
@@ -44,6 +46,7 @@ mod tests {
4446
.match_header(CONTENT_TYPE, APPLICATION_OCTET_STREAM)
4547
.match_body(get_request.encode_to_vec())
4648
.with_status(200)
49+
.with_header(PROTOCOL_VERSION_HEADER, PROTOCOL_VERSION)
4750
.with_body(mock_response.encode_to_vec())
4851
.create();
4952

@@ -77,6 +80,7 @@ mod tests {
7780
.match_header("headerkey", "headervalue")
7881
.match_body(get_request.encode_to_vec())
7982
.with_status(200)
83+
.with_header(PROTOCOL_VERSION_HEADER, PROTOCOL_VERSION)
8084
.with_body(mock_response.encode_to_vec())
8185
.create();
8286

@@ -119,6 +123,7 @@ mod tests {
119123
.match_header(CONTENT_TYPE, APPLICATION_OCTET_STREAM)
120124
.match_body(request.encode_to_vec())
121125
.with_status(200)
126+
.with_header(PROTOCOL_VERSION_HEADER, PROTOCOL_VERSION)
122127
.with_body(mock_response.encode_to_vec())
123128
.create();
124129

@@ -154,6 +159,7 @@ mod tests {
154159
.match_header(CONTENT_TYPE, APPLICATION_OCTET_STREAM)
155160
.match_body(request.encode_to_vec())
156161
.with_status(200)
162+
.with_header(PROTOCOL_VERSION_HEADER, PROTOCOL_VERSION)
157163
.with_body(mock_response.encode_to_vec())
158164
.create();
159165

@@ -195,6 +201,7 @@ mod tests {
195201
.match_header(CONTENT_TYPE, APPLICATION_OCTET_STREAM)
196202
.match_body(request.encode_to_vec())
197203
.with_status(200)
204+
.with_header(PROTOCOL_VERSION_HEADER, PROTOCOL_VERSION)
198205
.with_body(mock_response.encode_to_vec())
199206
.create();
200207

@@ -222,6 +229,7 @@ mod tests {
222229
};
223230
let mock_server = mockito::mock("POST", GET_OBJECT_ENDPOINT)
224231
.with_status(404)
232+
.with_header(PROTOCOL_VERSION_HEADER, PROTOCOL_VERSION)
225233
.with_body(&error_response.encode_to_vec())
226234
.create();
227235

@@ -246,6 +254,7 @@ mod tests {
246254
let mock_response = GetObjectResponse { value: None, ..Default::default() };
247255
let mock_server = mockito::mock("POST", GET_OBJECT_ENDPOINT)
248256
.with_status(200)
257+
.with_header(PROTOCOL_VERSION_HEADER, PROTOCOL_VERSION)
249258
.with_body(&mock_response.encode_to_vec())
250259
.create();
251260

@@ -270,6 +279,7 @@ mod tests {
270279
};
271280
let mock_server = mockito::mock("POST", Matcher::Any)
272281
.with_status(400)
282+
.with_header(PROTOCOL_VERSION_HEADER, PROTOCOL_VERSION)
273283
.with_body(&error_response.encode_to_vec())
274284
.create();
275285

@@ -330,6 +340,7 @@ mod tests {
330340
};
331341
let mock_server = mockito::mock("POST", Matcher::Any)
332342
.with_status(401)
343+
.with_header(PROTOCOL_VERSION_HEADER, PROTOCOL_VERSION)
333344
.with_body(&error_response.encode_to_vec())
334345
.create();
335346

@@ -412,6 +423,7 @@ mod tests {
412423
};
413424
let mock_server = mockito::mock("POST", Matcher::Any)
414425
.with_status(409)
426+
.with_header(PROTOCOL_VERSION_HEADER, PROTOCOL_VERSION)
415427
.with_body(&error_response.encode_to_vec())
416428
.create();
417429

@@ -445,6 +457,7 @@ mod tests {
445457
};
446458
let mock_server = mockito::mock("POST", Matcher::Any)
447459
.with_status(500)
460+
.with_header(PROTOCOL_VERSION_HEADER, PROTOCOL_VERSION)
448461
.with_body(&error_response.encode_to_vec())
449462
.create();
450463

@@ -502,6 +515,7 @@ mod tests {
502515
ErrorResponse { error_code: 999, message: "UnknownException".to_string() };
503516
let mut _mock_server = mockito::mock("POST", Matcher::Any)
504517
.with_status(999)
518+
.with_header(PROTOCOL_VERSION_HEADER, PROTOCOL_VERSION)
505519
.with_body(&error_response.encode_to_vec())
506520
.create();
507521

@@ -534,6 +548,7 @@ mod tests {
534548
let malformed_error_response = b"malformed";
535549
_mock_server = mockito::mock("POST", Matcher::Any)
536550
.with_status(409)
551+
.with_header(PROTOCOL_VERSION_HEADER, PROTOCOL_VERSION)
537552
.with_body(&malformed_error_response)
538553
.create();
539554

@@ -546,17 +561,36 @@ mod tests {
546561
let list_malformed_err_response = vss_client.list_key_versions(&list_request).await;
547562
assert!(matches!(list_malformed_err_response.unwrap_err(), VssError::InternalError { .. }));
548563

549-
// Requests to endpoints are no longer mocked and will result in network error.
564+
// Requests to endpoints are no longer mocked and will result in version mismatch
565+
// errors.
550566
drop(_mock_server);
551567

552-
let get_network_err = vss_client.get_object(&get_request).await;
553-
assert!(matches!(get_network_err.unwrap_err(), VssError::InternalError { .. }));
554-
555-
let put_network_err = vss_client.put_object(&put_request).await;
556-
assert!(matches!(put_network_err.unwrap_err(), VssError::InternalError { .. }));
557-
558-
let list_network_err = vss_client.list_key_versions(&list_request).await;
559-
assert!(matches!(list_network_err.unwrap_err(), VssError::InternalError { .. }));
568+
let get_version_err = vss_client.get_object(&get_request).await;
569+
assert!(matches!(
570+
get_version_err.unwrap_err(),
571+
VssError::VSSVersionMismatchError {
572+
version_served: None,
573+
version_expected: version
574+
} if version == PROTOCOL_VERSION
575+
));
576+
577+
let put_version_err = vss_client.put_object(&put_request).await;
578+
assert!(matches!(
579+
put_version_err.unwrap_err(),
580+
VssError::VSSVersionMismatchError {
581+
version_served: None,
582+
version_expected: version
583+
} if version == PROTOCOL_VERSION
584+
));
585+
586+
let list_version_err = vss_client.list_key_versions(&list_request).await;
587+
assert!(matches!(
588+
list_version_err.unwrap_err(),
589+
VssError::VSSVersionMismatchError {
590+
version_served: None,
591+
version_expected: version
592+
} if version == PROTOCOL_VERSION
593+
));
560594
}
561595

562596
fn retry_policy() -> impl RetryPolicy<E = VssError> {

0 commit comments

Comments
 (0)