Skip to content

Commit eaec7bb

Browse files
committed
feat(auth): add support for custom env, fs, and http implementations
1 parent 7483755 commit eaec7bb

24 files changed

+2021
-650
lines changed

src/auth/src/access_boundary.rs

Lines changed: 67 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ use crate::credentials::{
1818
AccessToken, AccessTokenCredentialsProvider, CacheableResource, CredentialsProvider, dynamic,
1919
};
2020
use crate::errors::CredentialsError;
21+
use crate::io::{HttpRequest, SharedHttpClientProvider};
2122
use crate::mds::client::Client as MDSClient;
2223
use crate::{Result, errors};
2324
use google_cloud_gax::Result as GaxResult;
@@ -28,7 +29,6 @@ use google_cloud_gax::retry_loop_internal::retry_loop;
2829
use google_cloud_gax::retry_policy::{Aip194Strict, RetryPolicy, RetryPolicyExt};
2930
use google_cloud_gax::retry_throttler::{AdaptiveThrottler, RetryThrottlerArg};
3031
use http::{Extensions, HeaderMap, HeaderValue};
31-
use reqwest::Client;
3232
use std::clone::Clone;
3333
use std::error::Error;
3434
use std::fmt::Debug;
@@ -235,11 +235,16 @@ impl<T> CredentialsWithAccessBoundary<T>
235235
where
236236
T: dynamic::AccessTokenCredentialsProvider + 'static,
237237
{
238-
pub(crate) fn new(credentials: T, access_boundary_url: Option<String>) -> Self {
238+
pub(crate) fn new(
239+
credentials: T,
240+
access_boundary_url: Option<String>,
241+
http: SharedHttpClientProvider,
242+
) -> Self {
239243
let credentials = Arc::new(credentials);
240244
let provider = IAMAccessBoundaryProvider {
241245
credentials: credentials.clone(),
242246
url: access_boundary_url,
247+
http,
243248
};
244249
let access_boundary = Arc::new(AccessBoundary::new(provider));
245250
Self {
@@ -253,13 +258,15 @@ where
253258
credentials: T,
254259
mds_client: MDSClient,
255260
iam_endpoint_override: Option<String>,
261+
http: SharedHttpClientProvider,
256262
) -> Self {
257263
let credentials = Arc::new(credentials);
258264
let provider = MDSAccessBoundaryProvider {
259265
credentials: credentials.clone(),
260266
mds_client,
261267
iam_endpoint_override,
262268
url: OnceLock::new(),
269+
http,
263270
};
264271
let access_boundary = Arc::new(AccessBoundary::new(provider));
265272
Self {
@@ -403,6 +410,7 @@ where
403410
{
404411
credentials: Arc<T>,
405412
url: Option<String>,
413+
http: SharedHttpClientProvider,
406414
}
407415

408416
#[async_trait::async_trait]
@@ -413,7 +421,11 @@ where
413421
async fn fetch_access_boundary(&self) -> Result<Option<String>> {
414422
match self.url.as_ref() {
415423
Some(url) => {
416-
let client = AccessBoundaryClient::new(self.credentials.clone(), url.clone());
424+
let client = AccessBoundaryClient::new(
425+
self.credentials.clone(),
426+
url.clone(),
427+
SharedHttpClientProvider::clone(&self.http),
428+
);
417429
client.fetch().await
418430
}
419431
None => Ok(None), // No URL means no access boundary
@@ -431,6 +443,7 @@ where
431443
mds_client: MDSClient,
432444
iam_endpoint_override: Option<String>,
433445
url: OnceLock<String>,
446+
http: SharedHttpClientProvider,
434447
}
435448

436449
#[async_trait::async_trait]
@@ -449,7 +462,11 @@ where
449462
}
450463

451464
let url = self.url.get().unwrap().to_string();
452-
let client = AccessBoundaryClient::new(self.credentials.clone(), url);
465+
let client = AccessBoundaryClient::new(
466+
self.credentials.clone(),
467+
url,
468+
SharedHttpClientProvider::clone(&self.http),
469+
);
453470
client.fetch().await
454471
}
455472
}
@@ -472,10 +489,11 @@ struct AccessBoundaryClient<T> {
472489
url: String,
473490
retry_policy: Arc<dyn RetryPolicy>,
474491
backoff_policy: Arc<dyn BackoffPolicy>,
492+
http: SharedHttpClientProvider,
475493
}
476494

477495
impl<T> AccessBoundaryClient<T> {
478-
fn new(credentials: Arc<T>, url: String) -> Self {
496+
fn new(credentials: Arc<T>, url: String, http: SharedHttpClientProvider) -> Self {
479497
let retry_policy = Aip194Strict.with_time_limit(Duration::from_secs(60));
480498
let backoff_policy = ExponentialBackoff::default();
481499

@@ -484,6 +502,7 @@ impl<T> AccessBoundaryClient<T> {
484502
url,
485503
retry_policy: Arc::new(retry_policy),
486504
backoff_policy: Arc::new(backoff_policy),
505+
http,
487506
}
488507
}
489508
}
@@ -509,19 +528,19 @@ where
509528
}
510529

511530
async fn fetch_with_retry(self) -> GaxResult<AllowedLocationsResponse> {
512-
let client = Client::new();
513531
let sleep = async |d| tokio::time::sleep(d).await;
514532

515533
let retry_throttler: RetryThrottlerArg = AdaptiveThrottler::default().into();
516534
let creds = self.credentials;
517535
let url = self.url;
536+
let http = self.http;
518537
let inner = async move |d| {
519538
let headers = creds
520539
.headers(Extensions::new())
521540
.await
522541
.map_err(GaxError::authentication)?;
523542

524-
let attempt = self::fetch_access_boundary_call(&client, &url, headers);
543+
let attempt = self::fetch_access_boundary_call(&http, &url, headers);
525544
match d {
526545
Some(timeout) => match tokio::time::timeout(timeout, attempt).await {
527546
Ok(r) => r,
@@ -544,7 +563,7 @@ where
544563
}
545564

546565
async fn fetch_access_boundary_call(
547-
client: &Client,
566+
http: &SharedHttpClientProvider,
548567
url: &str,
549568
headers: CacheableResource<HeaderMap>,
550569
) -> GaxResult<AllowedLocationsResponse> {
@@ -555,24 +574,19 @@ async fn fetch_access_boundary_call(
555574
}
556575
};
557576

558-
let resp = client
559-
.get(url)
560-
.headers(headers)
561-
.send()
562-
.await
563-
.map_err(GaxError::io)?;
564-
565-
let status = resp.status();
566-
if !status.is_success() {
567-
let err_headers = resp.headers().clone();
568-
let err_payload = resp
569-
.bytes()
570-
.await
571-
.map_err(|e| GaxError::transport(err_headers.clone(), e))?;
572-
return Err(GaxError::http(status.as_u16(), err_headers, err_payload));
577+
let request = HttpRequest::get(url).headers_from_map(&headers);
578+
579+
let response = http.execute(request).await.map_err(GaxError::io)?;
580+
581+
if !response.is_success() {
582+
return Err(GaxError::http(
583+
response.status.as_u16(),
584+
response.headers,
585+
response.body.into(),
586+
));
573587
}
574588

575-
resp.json().await.map_err(GaxError::io)
589+
response.json().map_err(GaxError::io)
576590
}
577591

578592
async fn refresh_task<T>(provider: T, tx_header: watch::Sender<(Option<BoundaryValue>, EntityTag)>)
@@ -750,7 +764,11 @@ pub(crate) mod tests {
750764
});
751765
let url = server.url("/allowedLocations").to_string();
752766

753-
let creds = CredentialsWithAccessBoundary::new(mock, Some(url));
767+
let creds = CredentialsWithAccessBoundary::new(
768+
mock,
769+
Some(url),
770+
SharedHttpClientProvider::default(),
771+
);
754772

755773
// wait for the background task to fetch the access boundary.
756774
creds.wait_for_boundary().await;
@@ -804,9 +822,18 @@ pub(crate) mod tests {
804822
})
805823
});
806824
let endpoint = server.url("").to_string().trim_end_matches('/').to_string();
807-
let mds_client = MDSClient::new(Some(endpoint.clone()));
825+
let mds_client = MDSClient::new(
826+
Some(endpoint.clone()),
827+
SharedEnvProvider::default(),
828+
SharedHttpClientProvider::default(),
829+
);
808830

809-
let creds = CredentialsWithAccessBoundary::new_for_mds(mock, mds_client, Some(endpoint));
831+
let creds = CredentialsWithAccessBoundary::new_for_mds(
832+
mock,
833+
mds_client,
834+
Some(endpoint),
835+
SharedHttpClientProvider::default(),
836+
);
810837

811838
// wait for the background task to fetch the access boundary.
812839
creds.wait_for_boundary().await;
@@ -849,7 +876,8 @@ pub(crate) mod tests {
849876
})
850877
});
851878
let url = server.url("/allowedLocations").to_string();
852-
let client = AccessBoundaryClient::new(Arc::new(mock), url);
879+
let client =
880+
AccessBoundaryClient::new(Arc::new(mock), url, SharedHttpClientProvider::default());
853881
let val = client.fetch().await?;
854882
assert!(val.is_none(), "{val:?}");
855883

@@ -879,7 +907,8 @@ pub(crate) mod tests {
879907
});
880908

881909
let url = server.url("/allowedLocations").to_string();
882-
let mut client = AccessBoundaryClient::new(Arc::new(mock), url);
910+
let mut client =
911+
AccessBoundaryClient::new(Arc::new(mock), url, SharedHttpClientProvider::default());
883912
client.retry_policy = Arc::new(Aip194Strict.with_attempt_limit(3));
884913
client.backoff_policy = Arc::new(test_backoff_policy());
885914

@@ -899,7 +928,11 @@ pub(crate) mod tests {
899928
))
900929
});
901930

902-
let client = AccessBoundaryClient::new(Arc::new(mock), "http://localhost".to_string());
931+
let client = AccessBoundaryClient::new(
932+
Arc::new(mock),
933+
"http://localhost".to_string(),
934+
SharedHttpClientProvider::default(),
935+
);
903936
let err = client.fetch().await.unwrap_err();
904937
assert!(!err.is_transient(), "{err:?}");
905938
}
@@ -918,7 +951,8 @@ pub(crate) mod tests {
918951
data: headers,
919952
})
920953
});
921-
let creds = CredentialsWithAccessBoundary::new(mock, None);
954+
let creds =
955+
CredentialsWithAccessBoundary::new(mock, None, SharedHttpClientProvider::default());
922956

923957
let cached_headers = creds.headers(Extensions::new()).await?;
924958
let token = get_token_from_headers(cached_headers.clone());
@@ -1240,7 +1274,8 @@ pub(crate) mod tests {
12401274
});
12411275

12421276
let url = server.url("/allowedLocations").to_string();
1243-
let mut client = AccessBoundaryClient::new(Arc::new(mock), url);
1277+
let mut client =
1278+
AccessBoundaryClient::new(Arc::new(mock), url, SharedHttpClientProvider::default());
12441279
client.backoff_policy = Arc::new(test_backoff_policy());
12451280
let val = client.fetch().await?;
12461281
assert_eq!(val.as_deref(), Some("0x123"));

0 commit comments

Comments
 (0)