-
Notifications
You must be signed in to change notification settings - Fork 8
Expand file tree
/
Copy pathclient.rs
More file actions
163 lines (141 loc) · 4.39 KB
/
Copy pathclient.rs
File metadata and controls
163 lines (141 loc) · 4.39 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
use super::image::ImageGenerationModel;
use super::{completion::CompletionModel, embedding::EmbeddingModel};
use aws_config::{BehaviorVersion, Region};
use rig::client::Nothing;
use rig::prelude::*;
use std::sync::Arc;
use tokio::sync::OnceCell;
pub const DEFAULT_AWS_REGION: &str = "us-east-1";
#[derive(Clone)]
pub struct ClientBuilder<'a> {
region: &'a str,
}
impl<'a> ClientBuilder<'a> {
#[deprecated(
since = "0.2.6",
note = "Use `Client::from_env` or `Client::with_profile_name(\"aws_profile\")` instead"
)]
pub fn new() -> Self {
Self {
region: DEFAULT_AWS_REGION,
}
}
/// Make sure to verify model and region [compatibility]
///
/// [compatibility]: https://docs.aws.amazon.com/bedrock/latest/userguide/models-regions.html
pub fn region(mut self, region: &'a str) -> Self {
self.region = region;
self
}
/// Make sure you have permissions to access [Amazon Bedrock foundation model]
///
/// [ Amazon Bedrock foundation model]: <https://docs.aws.amazon.com/bedrock/latest/userguide/model-access-modify.html>
pub async fn build(self) -> Client {
let sdk_config = aws_config::defaults(BehaviorVersion::latest())
.region(Region::new(String::from(self.region)))
.load()
.await;
let client = aws_sdk_bedrockruntime::Client::new(&sdk_config);
Client {
profile_name: None,
aws_client: Arc::new(OnceCell::from(client)),
}
}
}
impl Default for ClientBuilder<'_> {
fn default() -> Self {
#[allow(deprecated)]
Self::new()
}
}
#[derive(Clone, Debug)]
pub struct Client {
profile_name: Option<String>,
pub(crate) aws_client: Arc<OnceCell<aws_sdk_bedrockruntime::Client>>,
}
impl From<aws_sdk_bedrockruntime::Client> for Client {
fn from(aws_client: aws_sdk_bedrockruntime::Client) -> Self {
Client {
profile_name: None,
aws_client: Arc::new(OnceCell::from(aws_client)),
}
}
}
impl Client {
fn new() -> Self {
Self {
profile_name: None,
aws_client: Arc::new(OnceCell::new()),
}
}
/// Create an AWS Bedrock client using AWS profile name
pub fn with_profile_name(profile_name: &str) -> Self {
Self {
profile_name: Some(profile_name.into()),
aws_client: Arc::new(OnceCell::new()),
}
}
pub async fn get_inner(&self) -> &aws_sdk_bedrockruntime::Client {
self.aws_client
.get_or_init(|| async {
let config = if let Some(profile_name) = &self.profile_name {
aws_config::defaults(BehaviorVersion::latest())
.profile_name(profile_name)
.load()
.await
} else {
aws_config::load_from_env().await
};
aws_sdk_bedrockruntime::Client::new(&config)
})
.await
}
}
impl ProviderClient for Client {
type Input = Nothing;
fn from_env() -> Self
where
Self: Sized,
{
Client::new()
}
fn from_val(_: Nothing) -> Self
where
Self: Sized,
{
panic!(
"Please use `Client::from_env` or `Client::with_profile_name(\"aws_profile\")` instead"
);
}
}
impl CompletionClient for Client {
type CompletionModel = CompletionModel;
fn completion_model(&self, model: impl Into<String>) -> Self::CompletionModel {
CompletionModel::new(self.clone(), model)
}
}
impl EmbeddingsClient for Client {
type EmbeddingModel = EmbeddingModel;
fn embedding_model(&self, model: impl Into<String>) -> Self::EmbeddingModel {
EmbeddingModel::new(self.clone(), model, None)
}
fn embedding_model_with_ndims(
&self,
model: impl Into<String>,
ndims: usize,
) -> Self::EmbeddingModel {
EmbeddingModel::new(self.clone(), model, Some(ndims))
}
}
impl ImageGenerationClient for Client {
type ImageGenerationModel = ImageGenerationModel;
fn image_generation_model(&self, model: impl Into<String>) -> Self::ImageGenerationModel {
ImageGenerationModel::new(self.clone(), model)
}
}
impl VerifyClient for Client {
async fn verify(&self) -> Result<(), VerifyError> {
// No API endpoint to verify the API key
Ok(())
}
}