@@ -35,20 +35,30 @@ pub struct StateFetcher {
3535 cdn_url : String ,
3636 etag : RwLock < Option < String > > ,
3737 sdk : Option < Sdk > ,
38+ encryption_key : Option < Vec < u8 > > ,
3839}
3940
4041impl StateFetcher {
4142 /// Create a new state fetcher with the given client, client secret, and sdk identity.
42- pub fn new ( client : ClientWithMiddleware , client_secret : String , sdk : Option < Sdk > ) -> Self {
43+ pub fn new (
44+ client : ClientWithMiddleware ,
45+ client_secret : String ,
46+ sdk : Option < Sdk > ,
47+ encryption_key_hex : Option < String > ,
48+ ) -> Self {
4349 let hash = Self :: hash_client_secret ( & client_secret) ;
4450 let cdn_url = format ! ( "{}/{}" , CDN_BASE_URL , hash) ;
51+ let encryption_key = encryption_key_hex. map ( |hex_str| {
52+ hex:: decode ( & hex_str) . expect ( "encryption_key must be valid hex" )
53+ } ) ;
4554
4655 Self {
4756 client,
4857 client_secret,
4958 cdn_url,
5059 etag : RwLock :: new ( None ) ,
5160 sdk,
61+ encryption_key,
5262 }
5363 }
5464
@@ -100,23 +110,88 @@ impl StateFetcher {
100110 }
101111 }
102112
113+ // Check encryption header
114+ let encrypted_header = response
115+ . headers ( )
116+ . get ( "x-amz-meta-encrypted" )
117+ . and_then ( |v| v. to_str ( ) . ok ( ) )
118+ == Some ( "true" ) ;
119+
103120 // Parse response body
104- let bytes = response. bytes ( ) . await ?;
105- let request = SetResolverStateRequest :: decode ( bytes) . map_err ( |e| {
106- Error :: StateParse ( format ! ( "Failed to decode SetResolverStateRequest: {}" , e) )
121+ let raw_bytes = response. bytes ( ) . await ?;
122+
123+ // Try unencrypted path first (header check + protobuf fallback)
124+ let decrypted_bytes = if !encrypted_header {
125+ match SetResolverStateRequest :: decode ( raw_bytes. clone ( ) ) {
126+ Ok ( request) => {
127+ let state_pb = ResolverStatePb :: decode ( request. state ) . map_err ( |e| {
128+ Error :: StateParse ( format ! ( "Failed to decode ResolverState: {}" , e) )
129+ } ) ?;
130+ let state =
131+ ResolverState :: from_proto ( state_pb, & request. account_id , self . sdk . clone ( ) )
132+ . map_err ( |e| {
133+ Error :: StateParse ( format ! (
134+ "Failed to create ResolverState: {:?}" ,
135+ e
136+ ) )
137+ } ) ?;
138+ return Ok ( Some ( ( state, request. account_id ) ) ) ;
139+ }
140+ Err ( _) => {
141+ tracing:: warn!( "Protobuf decode failed, treating state as encrypted" ) ;
142+ Self :: decrypt ( & raw_bytes, & self . encryption_key ) ?
143+ }
144+ }
145+ } else {
146+ Self :: decrypt ( & raw_bytes, & self . encryption_key ) ?
147+ } ;
148+
149+ let request = SetResolverStateRequest :: decode ( decrypted_bytes. as_slice ( ) ) . map_err ( |e| {
150+ Error :: StateParse ( format ! (
151+ "Failed to decode decrypted SetResolverStateRequest: {}" ,
152+ e
153+ ) )
107154 } ) ?;
108155
109- // Parse the inner ResolverState
110156 let state_pb = ResolverStatePb :: decode ( request. state )
111157 . map_err ( |e| Error :: StateParse ( format ! ( "Failed to decode ResolverState: {}" , e) ) ) ?;
112158
113- // Convert to ResolverState
114159 let state = ResolverState :: from_proto ( state_pb, & request. account_id , self . sdk . clone ( ) )
115160 . map_err ( |e| Error :: StateParse ( format ! ( "Failed to create ResolverState: {:?}" , e) ) ) ?;
116161
117162 Ok ( Some ( ( state, request. account_id ) ) )
118163 }
119164
165+ /// Decrypt AES-256-GCM encrypted state (Tink NO_PREFIX format).
166+ fn decrypt ( data : & [ u8 ] , key : & Option < Vec < u8 > > ) -> Result < Vec < u8 > > {
167+ use aes_gcm:: { aead:: Aead , Aes256Gcm , KeyInit , Nonce } ;
168+
169+ let key_bytes = key. as_ref ( ) . ok_or_else ( || {
170+ Error :: StateParse (
171+ "Resolver state is encrypted but no encryption_key was provided. \
172+ Set the encryption key for this client credential."
173+ . to_string ( ) ,
174+ )
175+ } ) ?;
176+
177+ if data. len ( ) < 12 {
178+ return Err ( Error :: StateParse (
179+ "Encrypted state too short (missing nonce)" . to_string ( ) ,
180+ ) ) ;
181+ }
182+
183+ let cipher = Aes256Gcm :: new_from_slice ( key_bytes)
184+ . map_err ( |e| Error :: StateParse ( format ! ( "Invalid encryption key: {}" , e) ) ) ?;
185+ let nonce = Nonce :: from_slice ( & data[ ..12 ] ) ;
186+ cipher
187+ . decrypt ( nonce, & data[ 12 ..] )
188+ . map_err ( |_| {
189+ Error :: StateParse (
190+ "Failed to decrypt resolver state: invalid key or corrupted data" . to_string ( ) ,
191+ )
192+ } )
193+ }
194+
120195 /// Get the client secret.
121196 pub fn client_secret ( & self ) -> & str {
122197 & self . client_secret
@@ -171,6 +246,16 @@ impl Default for SharedState {
171246mod tests {
172247 use super :: * ;
173248 use crate :: test_utils:: { create_minimal_state, create_state_with_flag} ;
249+ use std:: path:: PathBuf ;
250+
251+ fn data_dir ( ) -> PathBuf {
252+ PathBuf :: from ( env ! ( "CARGO_MANIFEST_DIR" ) )
253+ . parent ( )
254+ . unwrap ( )
255+ . parent ( )
256+ . unwrap ( )
257+ . join ( "data" )
258+ }
174259
175260 #[ test]
176261 fn test_hash_client_secret ( ) {
@@ -259,4 +344,37 @@ mod tests {
259344 Some ( "custom-account-id" . to_string( ) )
260345 ) ;
261346 }
347+
348+ #[ test]
349+ fn test_decrypt_encrypted_state ( ) {
350+ let encrypted = std:: fs:: read ( data_dir ( ) . join ( "resolver_state_encrypted.pb" ) ) . unwrap ( ) ;
351+ let hex_key =
352+ std:: fs:: read_to_string ( data_dir ( ) . join ( "encryption_key_test.hex" ) ) . unwrap ( ) ;
353+ let key = Some ( hex:: decode ( hex_key. trim ( ) ) . unwrap ( ) ) ;
354+
355+ let decrypted = StateFetcher :: decrypt ( & encrypted, & key) . unwrap ( ) ;
356+ let request = SetResolverStateRequest :: decode ( decrypted. as_slice ( ) ) . unwrap ( ) ;
357+ assert_eq ! ( request. account_id, "confidence-test" ) ;
358+
359+ let state_pb = ResolverStatePb :: decode ( request. state ) . unwrap ( ) ;
360+ let state = ResolverState :: from_proto ( state_pb, & request. account_id , None ) . unwrap ( ) ;
361+ assert ! ( !state. flags. is_empty( ) ) ;
362+ }
363+
364+ #[ test]
365+ fn test_decrypt_rejects_wrong_key ( ) {
366+ use aes_gcm:: { Aes256Gcm , KeyInit , aead:: OsRng } ;
367+ let encrypted = std:: fs:: read ( data_dir ( ) . join ( "resolver_state_encrypted.pb" ) ) . unwrap ( ) ;
368+ let wrong_key = Aes256Gcm :: generate_key ( OsRng ) . to_vec ( ) ;
369+ let result = StateFetcher :: decrypt ( & encrypted, & Some ( wrong_key) ) ;
370+ assert ! ( result. is_err( ) ) ;
371+ }
372+
373+ #[ test]
374+ fn test_decrypt_rejects_missing_key ( ) {
375+ let encrypted = std:: fs:: read ( data_dir ( ) . join ( "resolver_state_encrypted.pb" ) ) . unwrap ( ) ;
376+
377+ let result = StateFetcher :: decrypt ( & encrypted, & None ) ;
378+ assert ! ( result. is_err( ) ) ;
379+ }
262380}
0 commit comments