4141except ImportError : # pragma: NO COVER
4242 from collections import Mapping # type: ignore
4343import abc
44+ import base64
4445import json
4546import os
4647from typing import NamedTuple
@@ -145,9 +146,88 @@ def get_subject_token(self, context, request):
145146class _X509Supplier (SubjectTokenSupplier ):
146147 """Internal supplier for X509 workload credentials. This class is used internally and always returns an empty string as the subject token."""
147148
149+ def __init__ (self , trust_chain_path , leaf_cert_callback ):
150+ self ._trust_chain_path = trust_chain_path
151+ self ._leaf_cert_callback = leaf_cert_callback
152+
148153 @_helpers .copy_docstring (SubjectTokenSupplier )
149154 def get_subject_token (self , context , request ):
150- return ""
155+ # Import OpennSSL inline because it is an extra import only required by customers
156+ # using mTLS.
157+ from OpenSSL import crypto
158+
159+ leaf_cert = crypto .load_certificate (
160+ crypto .FILETYPE_PEM , self ._leaf_cert_callback ()
161+ )
162+ trust_chain = self ._read_trust_chain ()
163+ cert_chain = []
164+
165+ cert_chain .append (_X509Supplier ._encode_cert (leaf_cert ))
166+
167+ if trust_chain is None or len (trust_chain ) == 0 :
168+ return json .dumps (cert_chain )
169+
170+ # Append the first cert if it is not the leaf cert.
171+ first_cert = _X509Supplier ._encode_cert (trust_chain [0 ])
172+ if first_cert != cert_chain [0 ]:
173+ cert_chain .append (first_cert )
174+
175+ for i in range (1 , len (trust_chain )):
176+ encoded = _X509Supplier ._encode_cert (trust_chain [i ])
177+ # Check if the current cert is the leaf cert and raise an exception if it is.
178+ if encoded == cert_chain [0 ]:
179+ raise exceptions .RefreshError (
180+ "The leaf certificate must be at the top of the trust chain file"
181+ )
182+ else :
183+ cert_chain .append (encoded )
184+ return json .dumps (cert_chain )
185+
186+ def _read_trust_chain (self ):
187+ # Import OpennSSL inline because it is an extra import only required by customers
188+ # using mTLS.
189+ from OpenSSL import crypto
190+
191+ certificate_trust_chain = []
192+ # If no trust chain path was provided, return an empty list.
193+ if self ._trust_chain_path is None or self ._trust_chain_path == "" :
194+ return certificate_trust_chain
195+ try :
196+ # Open the trust chain file.
197+ with open (self ._trust_chain_path , "rb" ) as f :
198+ trust_chain_data = f .read ()
199+ # Split PEM data into individual certificates.
200+ cert_blocks = trust_chain_data .split (b"-----BEGIN CERTIFICATE-----" )
201+ for cert_block in cert_blocks :
202+ # Skip empty blocks.
203+ if cert_block .strip ():
204+ cert_data = b"-----BEGIN CERTIFICATE-----" + cert_block
205+ try :
206+ # Load each certificate and add it to the trust chain.
207+ cert = crypto .load_certificate (
208+ crypto .FILETYPE_PEM , cert_data
209+ )
210+ certificate_trust_chain .append (cert )
211+ except Exception as e :
212+ raise exceptions .RefreshError (
213+ "Error loading PEM certificates from the trust chain file '{}'" .format (
214+ self ._trust_chain_path
215+ )
216+ ) from e
217+ return certificate_trust_chain
218+ except FileNotFoundError :
219+ raise exceptions .RefreshError (
220+ "Trust chain file '{}' was not found." .format (self ._trust_chain_path )
221+ )
222+
223+ def _encode_cert (cert ):
224+ # Import OpennSSL inline because it is an extra import only required by customers
225+ # using mTLS.
226+ from OpenSSL import crypto
227+
228+ return base64 .b64encode (
229+ crypto .dump_certificate (crypto .FILETYPE_ASN1 , cert )
230+ ).decode ("utf-8" )
151231
152232
153233def _parse_token_data (token_content , format_type = "text" , subject_token_field_name = None ):
@@ -296,7 +376,9 @@ def __init__(
296376 self ._credential_source_headers ,
297377 )
298378 else : # self._credential_source_certificate
299- self ._subject_token_supplier = _X509Supplier ()
379+ self ._subject_token_supplier = _X509Supplier (
380+ self ._trust_chain_path , self ._get_cert_bytes
381+ )
300382
301383 @_helpers .copy_docstring (external_account .Credentials )
302384 def retrieve_subject_token (self , request ):
@@ -314,6 +396,10 @@ def _get_mtls_cert_and_key_paths(self):
314396 self ._certificate_config_location
315397 )
316398
399+ def _get_cert_bytes (self ):
400+ cert_path , _ = self ._get_mtls_cert_and_key_paths ()
401+ return _mtls_helper ._read_cert_file (cert_path )
402+
317403 def _mtls_required (self ):
318404 return self ._credential_source_certificate is not None
319405
@@ -350,6 +436,9 @@ def _validate_certificate_config(self):
350436 use_default = self ._credential_source_certificate .get (
351437 "use_default_certificate_config"
352438 )
439+ self ._trust_chain_path = self ._credential_source_certificate .get (
440+ "trust_chain_path"
441+ )
353442 if self ._certificate_config_location and use_default :
354443 raise exceptions .MalformedError (
355444 "Invalid certificate configuration, certificate_config_location cannot be specified when use_default_certificate_config = true."
0 commit comments