Skip to content

Commit d03cfa9

Browse files
committed
Cleaned up so all _decrypt respond in the same manner.
1 parent 20482bc commit d03cfa9

File tree

3 files changed

+31
-26
lines changed

3 files changed

+31
-26
lines changed

src/jwkest/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from binascii import unhexlify
1616

17-
__version__ = "1.3.5"
17+
__version__ = "1.4.0"
1818

1919
logger = logging.getLogger(__name__)
2020

src/jwkest/jwe.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -387,9 +387,9 @@ def _decrypt(enc, key, ctxt, auth_data, iv, tag):
387387
try:
388388
text = gcm.decrypt(bytes_to_long(iv), ctxt, bytes_to_long(tag),
389389
auth_data)
390-
return text, True
390+
return text
391391
except DecryptionFailed:
392-
return None, False
392+
raise
393393
elif enc in ["A128CBC-HS256", "A192CBC-HS384", "A256CBC-HS512"]:
394394
return aes_cbc_hmac_decrypt(key, iv, auth_data, ctxt, tag)
395395
else:
@@ -563,12 +563,10 @@ def decrypt(self, token, key, cek=None):
563563
except AssertionError:
564564
raise NotSupportedAlgorithm(enc)
565565

566-
msg, flag = self._decrypt(enc, cek, jwe.ciphertext(),
567-
jwe.b64_protected_header(),
568-
jwe.initialization_vector(),
569-
jwe.authentication_tag())
570-
if flag is False:
571-
raise DecryptionFailed()
566+
msg = self._decrypt(enc, cek, jwe.ciphertext(),
567+
jwe.b64_protected_header(),
568+
jwe.initialization_vector(),
569+
jwe.authentication_tag())
572570

573571
if "zip" in jwe.headers and jwe.headers["zip"] == "DEF":
574572
msg = zlib.decompress(msg)
@@ -603,7 +601,8 @@ def enc_setup(self, msg, auth_data, key=None, **kwargs):
603601
# Generate an ephemeral key pair if none is given
604602
curve = NISTEllipticCurve.by_name(key.crv)
605603
if "epk" in kwargs:
606-
epk = kwargs["epk"] if isinstance(kwargs["epk"], ECKey) else ECKey(kwargs["epk"])
604+
epk = kwargs["epk"] if isinstance(kwargs["epk"], ECKey) else ECKey(
605+
kwargs["epk"])
607606
else:
608607
epk = ECKey().load_key(key=NISTEllipticCurve.by_name(key.crv))
609608

@@ -650,7 +649,8 @@ def dec_setup(self, token, key=None, **kwargs):
650649

651650
# Handle EPK / Curve
652651
if "epk" not in self.headers or "crv" not in self.headers["epk"]:
653-
raise Exception("Ephemeral Public Key Missing in ECDH-ES Computation")
652+
raise Exception(
653+
"Ephemeral Public Key Missing in ECDH-ES Computation")
654654

655655
epubkey = ECKey(**self.headers["epk"])
656656
apu = apv = ""
@@ -716,12 +716,12 @@ def decrypt(self, token=None, key=None, **kwargs):
716716
if not self.cek:
717717
raise Exception("Content Encryption Key is Not Yet Set")
718718

719-
msg, valid = super(JWE_EC, self)._decrypt(self.headers["enc"], self.cek,
720-
self.ctxt,
721-
jwe.b64part[0],
722-
self.iv, self.tag)
719+
msg = super(JWE_EC, self)._decrypt(self.headers["enc"], self.cek,
720+
self.ctxt,
721+
jwe.b64part[0],
722+
self.iv, self.tag)
723723
self.msg = msg
724-
self.msg_valid = valid
724+
self.msg_valid = True
725725
return msg
726726

727727

@@ -782,7 +782,9 @@ def encrypt(self, keys=None, cek="", iv="", **kwargs):
782782

783783
if not keys:
784784
logger.error(
785-
"Could not find any suitable encryption key for alg='{}'".format(_alg))
785+
"Could not find any suitable encryption key for alg='{"
786+
"}'".format(
787+
_alg))
786788
raise NoSuitableEncryptionKey(_alg)
787789

788790
# Determine Encryption Class by Algorithm

tests/test_4_jwe.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -193,8 +193,8 @@ def full_path(local_file):
193193
rsa = RSA.importKey(open(KEY, 'r').read())
194194
plain = b'Now is the time for all good men to come to the aid of their country.'
195195

196-
def test_cek_reuse_encryption_rsaes_rsa15():
197196

197+
def test_cek_reuse_encryption_rsaes_rsa15():
198198
_rsa = JWE_RSA(plain, alg="RSA1_5", enc="A128CBC-HS256")
199199
jwt = _rsa.encrypt(rsa)
200200
dec = JWE_RSA()
@@ -209,8 +209,8 @@ def test_cek_reuse_encryption_rsaes_rsa15():
209209

210210
assert msg == plain
211211

212-
def test_cek_reuse_encryption_rsaes_rsa_oaep():
213212

213+
def test_cek_reuse_encryption_rsaes_rsa_oaep():
214214
_rsa = JWE_RSA(plain, alg="RSA-OAEP", enc="A256GCM")
215215
jwt = _rsa.encrypt(rsa)
216216
dec = JWE_RSA()
@@ -225,6 +225,7 @@ def test_cek_reuse_encryption_rsaes_rsa_oaep():
225225

226226
assert msg == plain
227227

228+
228229
def test_rsa_encrypt_decrypt_rsa_cbc():
229230
_rsa = JWE_RSA(plain, alg="RSA1_5", enc="A128CBC-HS256")
230231
jwt = _rsa.encrypt(rsa)
@@ -239,8 +240,8 @@ def test_rsa_encrypt_decrypt_rsa_oaep_gcm():
239240
msg = JWE_RSA().decrypt(jwt, rsa)
240241

241242
assert msg == plain
242-
243-
243+
244+
244245
def test_rsa_encrypt_decrypt_rsa_oaep_256_gcm():
245246
jwt = JWE_RSA(plain[:1], alg="RSA-OAEP-256", enc="A256GCM").encrypt(rsa)
246247
msg = JWE_RSA().decrypt(jwt, rsa)
@@ -277,10 +278,11 @@ def test_rsa_with_kid():
277278
localpriv, localpub = curve.key_pair()
278279

279280
localkey = ECKey(crv=curve.name(), d=localpriv, x=localpub[0], y=localpub[1])
280-
remotekey = ECKey(crv=curve.name(), d=remotepriv, x=remotepub[0], y=remotepub[1])
281+
remotekey = ECKey(crv=curve.name(), d=remotepriv, x=remotepub[0],
282+
y=remotepub[1])
281283

282-
def test_ecdh_encrypt_decrypt_direct_key():
283284

285+
def test_ecdh_encrypt_decrypt_direct_key():
284286
jwenc = JWE_EC(plain, alg="ECDH-ES", enc="A128GCM")
285287
cek, encrypted_key, iv, params, ret_epk = jwenc.enc_setup(plain, '',
286288
key=remotekey,
@@ -333,13 +335,14 @@ def test_ecdh_encrypt_decrypt_keywrapped_key():
333335

334336

335337
def test_sym_encrypt_decrypt():
336-
encryption_key = SYMKey(use="enc", key='DukeofHazardpass', kid="some-key-id")
338+
encryption_key = SYMKey(use="enc", key='DukeofHazardpass',
339+
kid="some-key-id")
337340
jwe = JWE_SYM("some content", alg="A128KW", enc="A128CBC-HS256")
338341
_jwe = jwe.encrypt(key=encryption_key, kid="some-key-id")
339342
jwdec = JWE_SYM()
340343

341344
resp = jwdec.decrypt(_jwe, encryption_key)
342-
assert resp[0] == b'some content'
345+
assert resp == b'some content'
343346

344347

345348
def test_ecdh_no_setup_dynamic_epk():
@@ -348,4 +351,4 @@ def test_ecdh_no_setup_dynamic_epk():
348351
assert jwt
349352
ret_jwe = factory(jwt)
350353
res = ret_jwe.decrypt(jwt, [remotekey])
351-
assert res
354+
assert res

0 commit comments

Comments
 (0)