@@ -43,7 +43,6 @@ enum Classification {
4343#[ derive( Debug ) ]
4444pub ( crate ) struct PemEncodedKey {
4545 content : Vec < u8 > ,
46- asn1 : Vec < simple_asn1:: ASN1Block > ,
4746 pem_type : PemType ,
4847 standard : Standard ,
4948}
@@ -53,22 +52,15 @@ impl PemEncodedKey {
5352 pub fn new ( input : & [ u8 ] ) -> Result < PemEncodedKey > {
5453 match pem:: parse ( input) {
5554 Ok ( content) => {
56- let asn1_content = match simple_asn1:: from_der ( content. contents ( ) ) {
57- Ok ( asn1) => asn1,
58- Err ( _) => return Err ( ErrorKind :: InvalidKeyFormat . into ( ) ) ,
59- } ;
60-
6155 match content. tag ( ) {
6256 // This handles a PKCS#1 RSA Private key
6357 "RSA PRIVATE KEY" => Ok ( PemEncodedKey {
6458 content : content. into_contents ( ) ,
65- asn1 : asn1_content,
6659 pem_type : PemType :: RsaPrivate ,
6760 standard : Standard :: Pkcs1 ,
6861 } ) ,
6962 "RSA PUBLIC KEY" => Ok ( PemEncodedKey {
7063 content : content. into_contents ( ) ,
71- asn1 : asn1_content,
7264 pem_type : PemType :: RsaPublic ,
7365 standard : Standard :: Pkcs1 ,
7466 } ) ,
@@ -79,41 +71,22 @@ impl PemEncodedKey {
7971
8072 // This handles PKCS#8 certificates and public & private keys
8173 tag @ "PRIVATE KEY" | tag @ "PUBLIC KEY" | tag @ "CERTIFICATE" => {
82- match classify_pem ( & asn1_content) {
83- Some ( c) => {
84- let is_private = tag == "PRIVATE KEY" ;
85- let pem_type = match c {
86- Classification :: Ec => {
87- if is_private {
88- PemType :: EcPrivate
89- } else {
90- PemType :: EcPublic
91- }
92- }
93- Classification :: Ed => {
94- if is_private {
95- PemType :: EdPrivate
96- } else {
97- PemType :: EdPublic
98- }
99- }
100- Classification :: Rsa => {
101- if is_private {
102- PemType :: RsaPrivate
103- } else {
104- PemType :: RsaPublic
105- }
106- }
107- } ;
108- Ok ( PemEncodedKey {
109- content : content. into_contents ( ) ,
110- asn1 : asn1_content,
111- pem_type,
112- standard : Standard :: Pkcs8 ,
113- } )
114- }
115- None => Err ( ErrorKind :: InvalidKeyFormat . into ( ) ) ,
116- }
74+ let is_private = tag == "PRIVATE KEY" ;
75+ let pem_type = match classify_der ( content. contents ( ) )
76+ . ok_or ( ErrorKind :: InvalidKeyFormat ) ?
77+ {
78+ Classification :: Ec if is_private => PemType :: EcPrivate ,
79+ Classification :: Ec => PemType :: EcPublic ,
80+ Classification :: Ed if is_private => PemType :: EdPrivate ,
81+ Classification :: Ed => PemType :: EdPublic ,
82+ Classification :: Rsa if is_private => PemType :: RsaPrivate ,
83+ Classification :: Rsa => PemType :: RsaPublic ,
84+ } ;
85+ Ok ( PemEncodedKey {
86+ content : content. into_contents ( ) ,
87+ pem_type,
88+ standard : Standard :: Pkcs8 ,
89+ } )
11790 }
11891
11992 // Unknown/unsupported type
@@ -140,7 +113,8 @@ impl PemEncodedKey {
140113 match self . standard {
141114 Standard :: Pkcs1 => Err ( ErrorKind :: InvalidKeyFormat . into ( ) ) ,
142115 Standard :: Pkcs8 => match self . pem_type {
143- PemType :: EcPublic => extract_first_bitstring ( & self . asn1 ) ,
116+ PemType :: EcPublic => extract_first_bitstring_der ( & self . content )
117+ . ok_or_else ( || ErrorKind :: InvalidKeyFormat . into ( ) ) ,
144118 _ => Err ( ErrorKind :: InvalidKeyFormat . into ( ) ) ,
145119 } ,
146120 }
@@ -162,7 +136,8 @@ impl PemEncodedKey {
162136 match self . standard {
163137 Standard :: Pkcs1 => Err ( ErrorKind :: InvalidKeyFormat . into ( ) ) ,
164138 Standard :: Pkcs8 => match self . pem_type {
165- PemType :: EdPublic => extract_first_bitstring ( & self . asn1 ) ,
139+ PemType :: EdPublic => extract_first_bitstring_der ( & self . content )
140+ . ok_or_else ( || ErrorKind :: InvalidKeyFormat . into ( ) ) ,
166141 _ => Err ( ErrorKind :: InvalidKeyFormat . into ( ) ) ,
167142 } ,
168143 }
@@ -173,68 +148,175 @@ impl PemEncodedKey {
173148 match self . standard {
174149 Standard :: Pkcs1 => Ok ( self . content . as_slice ( ) ) ,
175150 Standard :: Pkcs8 => match self . pem_type {
176- PemType :: RsaPrivate => extract_first_bitstring ( & self . asn1 ) ,
177- PemType :: RsaPublic => extract_first_bitstring ( & self . asn1 ) ,
151+ PemType :: RsaPrivate | PemType :: RsaPublic => {
152+ extract_first_bitstring_der ( & self . content )
153+ . ok_or_else ( || ErrorKind :: InvalidKeyFormat . into ( ) )
154+ }
178155 _ => Err ( ErrorKind :: InvalidKeyFormat . into ( ) ) ,
179156 } ,
180157 }
181158 }
182159}
183160
161+ const TAG_BIT_STRING : u8 = 0x03 ;
162+ const TAG_OCTET_STRING : u8 = 0x04 ;
163+ const TAG_OID : u8 = 0x06 ;
164+ const TAG_SEQUENCE : u8 = 0x30 ;
165+
184166// This really just finds and returns the first bitstring or octet string
185167// Which is the x coordinate for EC public keys
186168// And the DER contents of an RSA key
187169// Though PKCS#11 keys shouldn't have anything else.
188170// It will get confusing with certificates.
189- fn extract_first_bitstring ( asn1 : & [ simple_asn1:: ASN1Block ] ) -> Result < & [ u8 ] > {
190- for asn1_entry in asn1. iter ( ) {
191- match asn1_entry {
192- simple_asn1:: ASN1Block :: Sequence ( _, entries) => {
193- if let Ok ( result) = extract_first_bitstring ( entries) {
194- return Ok ( result) ;
171+ fn extract_first_bitstring_der ( bytes : & [ u8 ] ) -> Option < & [ u8 ] > {
172+ let mut stack = vec ! [ bytes] ;
173+
174+ while let Some ( bytes) = stack. pop ( ) {
175+ let Some ( ( tag, value, rest) ) = read_tlv ( bytes) else {
176+ continue ; // Skip invalid TLV
177+ } ;
178+
179+ if !rest. is_empty ( ) {
180+ stack. push ( rest) ;
181+ }
182+
183+ match tag {
184+ TAG_BIT_STRING => {
185+ if value. is_empty ( ) {
186+ return None ; // Missing padding length
187+ } else if value[ 0 ] != 0 {
188+ return None ; // Padding length must be zero for cryptographic keys
195189 }
190+ return Some ( & value[ 1 ..] ) ;
196191 }
197- simple_asn1:: ASN1Block :: BitString ( _, _, value) => {
198- return Ok ( value. as_ref ( ) ) ;
199- }
200- simple_asn1:: ASN1Block :: OctetString ( _, value) => {
201- return Ok ( value. as_ref ( ) ) ;
192+ TAG_OCTET_STRING => return Some ( value) ,
193+ TAG_SEQUENCE => {
194+ stack. push ( value) ;
202195 }
203- _ => ( ) ,
196+ _ => { }
204197 }
205198 }
206199
207- Err ( ErrorKind :: InvalidEcdsaKey . into ( ) )
200+ None
208201}
209202
210203/// Find whether this is EC, RSA, or Ed
211- fn classify_pem ( asn1 : & [ simple_asn1:: ASN1Block ] ) -> Option < Classification > {
212- // These should be constant but the macro requires
213- // #![feature(const_vec_new)]
214- let ec_public_key_oid = simple_asn1:: oid!( 1 , 2 , 840 , 10_045 , 2 , 1 ) ;
215- let rsa_public_key_oid = simple_asn1:: oid!( 1 , 2 , 840 , 113_549 , 1 , 1 , 1 ) ;
216- let ed25519_oid = simple_asn1:: oid!( 1 , 3 , 101 , 112 ) ;
217-
218- for asn1_entry in asn1. iter ( ) {
219- match asn1_entry {
220- simple_asn1:: ASN1Block :: Sequence ( _, entries) => {
221- if let Some ( classification) = classify_pem ( entries) {
222- return Some ( classification) ;
223- }
224- }
225- simple_asn1:: ASN1Block :: ObjectIdentifier ( _, oid) => {
226- if oid == ec_public_key_oid {
227- return Some ( Classification :: Ec ) ;
228- }
229- if oid == rsa_public_key_oid {
230- return Some ( Classification :: Rsa ) ;
231- }
232- if oid == ed25519_oid {
233- return Some ( Classification :: Ed ) ;
234- }
204+ fn classify_der ( bytes : & [ u8 ] ) -> Option < Classification > {
205+ const EC_PUBLIC_KEY_OID : & [ u8 ] = & [ 0x2A , 0x86 , 0x48 , 0xCE , 0x3D , 0x02 , 0x01 ] ; // 1.2.840.10045.2.1
206+ const RSA_PUBLIC_KEY_OID : & [ u8 ] = & [ 0x2A , 0x86 , 0x48 , 0x86 , 0xF7 , 0x0D , 0x01 , 0x01 , 0x01 ] ; // 1.2.840.113549.1.1.1
207+ const ED25519_OID : & [ u8 ] = & [ 0x2B , 0x65 , 0x70 ] ; // 1.3.101.112
208+
209+ let mut stack = vec ! [ bytes] ;
210+
211+ while let Some ( bytes) = stack. pop ( ) {
212+ let Some ( ( tag, value, rest) ) = read_tlv ( bytes) else {
213+ continue ; // Skip invalid TLV
214+ } ;
215+
216+ if !rest. is_empty ( ) {
217+ stack. push ( rest) ;
218+ }
219+
220+ if tag == TAG_OID {
221+ match value {
222+ EC_PUBLIC_KEY_OID => return Some ( Classification :: Ec ) ,
223+ RSA_PUBLIC_KEY_OID => return Some ( Classification :: Rsa ) ,
224+ ED25519_OID => return Some ( Classification :: Ed ) ,
225+ _ => { }
235226 }
236- _ => { }
227+ } else if tag == TAG_SEQUENCE {
228+ stack. push ( value) ;
237229 }
238230 }
231+
239232 None
240233}
234+
235+ /// Returns `Some((tag, value, rest))` or `None` if the TLV is invalid.
236+ fn read_tlv ( mut bytes : & [ u8 ] ) -> Option < ( u8 , & [ u8 ] , & [ u8 ] ) > {
237+ if bytes. len ( ) < 2 {
238+ return None ;
239+ }
240+
241+ let tag = bytes[ 0 ] ;
242+ let len = bytes[ 1 ] ;
243+ bytes = & bytes[ 2 ..] ;
244+
245+ let len = if len < 0x80 {
246+ len as usize
247+ } else {
248+ let len_len = ( len & 0x7f ) as usize ;
249+ if len_len == 0 {
250+ return None ; // Indefinite length
251+ } else if size_of :: < usize > ( ) < len_len {
252+ return None ; // Too long; prevents usize overflow
253+ } else if bytes. len ( ) < len_len {
254+ return None ; // Not enough bytes
255+ }
256+ let len_bytes = & bytes[ ..len_len] ;
257+ bytes = & bytes[ len_len..] ;
258+ len_bytes. iter ( ) . fold ( 0 , |acc, & x| acc * 256 + x as usize )
259+ } ;
260+
261+ if bytes. len ( ) < len {
262+ return None ; // Not enough bytes
263+ }
264+
265+ let ( value, rest) = bytes. split_at ( len) ;
266+ Some ( ( tag, value, rest) )
267+ }
268+
269+ #[ cfg( test) ]
270+ mod tests {
271+ use super :: * ;
272+
273+ #[ test]
274+ fn classify_ec_key ( ) {
275+ let pem = pem:: parse ( include_bytes ! ( "../../tests/ecdsa/public_ecdsa_key.pem" ) ) . unwrap ( ) ;
276+ assert_eq ! ( classify_der( pem. contents( ) ) , Some ( Classification :: Ec ) ) ;
277+ }
278+
279+ #[ test]
280+ fn classify_rsa_key ( ) {
281+ let pem = pem:: parse ( include_bytes ! ( "../../tests/rsa/public_rsa_key_pkcs8.pem" ) ) . unwrap ( ) ;
282+ assert_eq ! ( classify_der( pem. contents( ) ) , Some ( Classification :: Rsa ) ) ;
283+ }
284+
285+ #[ test]
286+ fn classify_ed25519_key ( ) {
287+ let pem = pem:: parse ( include_bytes ! ( "../../tests/eddsa/public_ed25519_key.pem" ) ) . unwrap ( ) ;
288+ assert_eq ! ( classify_der( pem. contents( ) ) , Some ( Classification :: Ed ) ) ;
289+ }
290+
291+ #[ test]
292+ fn ec_public_key_extraction ( ) {
293+ let key =
294+ PemEncodedKey :: new ( include_bytes ! ( "../../tests/ecdsa/public_ecdsa_key.pem" ) ) . unwrap ( ) ;
295+ let bytes = key. as_ec_public_key ( ) . unwrap ( ) ;
296+ assert_eq ! ( bytes[ 0 ] , 0x04 ) ; // uncompressed point
297+ assert_eq ! ( bytes. len( ) , 65 ) ; // 1 + 32 + 32 for P-256
298+ }
299+
300+ #[ test]
301+ fn ed_public_key_extraction ( ) {
302+ let key =
303+ PemEncodedKey :: new ( include_bytes ! ( "../../tests/eddsa/public_ed25519_key.pem" ) ) . unwrap ( ) ;
304+ let bytes = key. as_ed_public_key ( ) . unwrap ( ) ;
305+ assert_eq ! ( bytes. len( ) , 32 ) ;
306+ }
307+
308+ #[ test]
309+ fn rsa_pkcs8_key_extraction ( ) {
310+ let key =
311+ PemEncodedKey :: new ( include_bytes ! ( "../../tests/rsa/public_rsa_key_pkcs8.pem" ) ) . unwrap ( ) ;
312+ let bytes = key. as_rsa_key ( ) . unwrap ( ) ;
313+ assert_eq ! ( bytes[ 0 ] , 0x30 ) ; // SEQUENCE
314+ }
315+ #[ test]
316+ fn rsa_pkcs1_key ( ) {
317+ let key = PemEncodedKey :: new ( include_bytes ! ( "../../tests/rsa/private_rsa_key_pkcs1.pem" ) )
318+ . unwrap ( ) ;
319+ let bytes = key. as_rsa_key ( ) . unwrap ( ) ;
320+ assert_eq ! ( bytes[ 0 ] , 0x30 ) ; // SEQUENCE
321+ }
322+ }
0 commit comments