@@ -77,7 +77,6 @@ fun ecJwkThumbprintSha256(jwk: JSONObject): ByteArray {
7777 put(" y" , jwk.get(" y" ))
7878 }
7979 val md = MessageDigest .getInstance(" SHA-256" )
80- Log .d(" helenqinn" , jwkWithRequired.toString())
8180 return md.digest(jwkWithRequired.toString().toByteArray())
8281}
8382
@@ -188,15 +187,11 @@ fun jwsDeserialization(jws: String): Pair<JSONObject, JSONObject> {
188187 return Pair (header, JSONObject (payload))
189188}
190189
191- /* * ECDH-ES key agreement, A128GCM encryption, JWE Compact Serialization */
192- fun jweSerialization (recipientKeyJwk : JSONObject , plainText : String ): String {
193- val kid = recipientKeyJwk.optString(" kid" )
194- val x = recipientKeyJwk.getString(" x" )
195- val y = recipientKeyJwk.getString(" y" )
190+ fun toEcPublicKey (x : String , y : String ): PublicKey {
196191 val kf = KeyFactory .getInstance(" EC" )
197192 val parameters = AlgorithmParameters .getInstance(" EC" )
198193 parameters.init (ECGenParameterSpec (" secp256r1" ))
199- val publicKey = kf.generatePublic(
194+ return kf.generatePublic(
200195 ECPublicKeySpec (
201196 ECPoint (
202197 BigInteger (1 , x.decodeBase64UrlNoPadding()),
@@ -205,6 +200,12 @@ fun jweSerialization(recipientKeyJwk: JSONObject, plainText: String): String {
205200 parameters.getParameterSpec(ECParameterSpec ::class .java)
206201 )
207202 )
203+ }
204+
205+ /* * ECDH-ES key agreement, A128GCM encryption, JWE Compact Serialization */
206+ fun jweSerialization (recipientKeyJwk : JSONObject , plainText : String ): String {
207+ val kid = recipientKeyJwk.optString(" kid" )
208+ val publicKey = toEcPublicKey(recipientKeyJwk.getString(" x" ), recipientKeyJwk.getString(" y" ))
208209 val kpg = KeyPairGenerator .getInstance(" EC" )
209210 kpg.initialize(ECGenParameterSpec (" secp256r1" ))
210211 val kp = kpg.genKeyPair()
@@ -251,3 +252,63 @@ fun jweSerialization(recipientKeyJwk: JSONObject, plainText: String): String {
251252 val tagEncoded = tag.toBase64UrlNoPadding()
252253 return " ${headerEncoded} ..${ivEncoded} .${ctEncoded} .${tagEncoded} "
253254}
255+
256+ fun jweDecrypt (jwe : String , privateKey : PrivateKey ): String {
257+ val parts = jwe.split(" ." )
258+
259+ val headerB64 = parts[0 ]
260+ val ivB64 = parts[2 ]
261+ val ciphertextB64 = parts[3 ]
262+ val tagB64 = parts[4 ]
263+
264+ val headerJsonStr = String (headerB64.decodeBase64UrlNoPadding(), Charsets .UTF_8 )
265+ val header = JSONObject (headerJsonStr)
266+
267+ val alg = header.optString(" alg" )
268+ val enc = header.optString(" enc" )
269+ require(alg == " ECDH-ES" && enc == " A128GCM" ) { " Unsupported algorithms: alg=$alg , enc=$enc " }
270+
271+ val epk = header.getJSONObject(" epk" )
272+ require(epk.getString(" crv" ) == " P-256" ) { " Only P-256 curve is supported" }
273+
274+ val publicKey = toEcPublicKey(epk.getString(" x" ), epk.getString(" y" ))
275+
276+ val keyAgreement = KeyAgreement .getInstance(" ECDH" )
277+ keyAgreement.init (privateKey)
278+ keyAgreement.doPhase(publicKey, true )
279+ val sharedSecret = keyAgreement.generateSecret()
280+ val concatKdf = ConcatKeyDerivationFunction (" SHA-256" )
281+
282+ val algOctets = " A128GCM" .toByteArray()
283+ val keydatalen = 128
284+
285+ val apu = if (header.has(" apu" )) header.getString(" apu" ).decodeBase64UrlNoPadding() else ByteArray (0 )
286+ val apv = if (header.has(" apv" )) header.getString(" apv" ).decodeBase64UrlNoPadding() else ByteArray (0 )
287+
288+ val derivedKey = concatKdf.kdf(
289+ sharedSecret,
290+ keydatalen,
291+ intToBigEndianByteArray(algOctets.size) + algOctets,
292+ intToBigEndianByteArray(apu.size) + apu,
293+ intToBigEndianByteArray(apv.size) + apv,
294+ intToBigEndianByteArray(keydatalen),
295+ ByteArray (0 )
296+ )
297+
298+ val iv = ivB64.decodeBase64UrlNoPadding()
299+ val ciphertext = ciphertextB64.decodeBase64UrlNoPadding()
300+ val tag = tagB64.decodeBase64UrlNoPadding()
301+
302+ val cipher = Cipher .getInstance(" AES/GCM/NoPadding" )
303+ val secretKey = SecretKeySpec (derivedKey, " AES" )
304+
305+ cipher.init (Cipher .DECRYPT_MODE , secretKey, GCMParameterSpec (128 , iv))
306+
307+ cipher.updateAAD(headerB64.toByteArray())
308+
309+ // In Java/Android, the Cipher expects the ciphertext and authentication tag to be concatenated
310+ val combinedCiphertextAndTag = ciphertext + tag
311+ val plaintextBytes = cipher.doFinal(combinedCiphertextAndTag)
312+
313+ return String (plaintextBytes, Charsets .UTF_8 )
314+ }
0 commit comments