From 2cc1c1b05b85360008ff6a4795a23867c4eec978 Mon Sep 17 00:00:00 2001 From: Alexandre Rodrigues Date: Wed, 29 Apr 2026 06:45:33 +0200 Subject: [PATCH 1/8] Update groovy and spock-core so tests can run in java 21 --- build.gradle | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/build.gradle b/build.gradle index a5365f89..44f08bcc 100644 --- a/build.gradle +++ b/build.gradle @@ -90,7 +90,8 @@ testing { useJUnitJupiter() dependencies { implementation "org.slf4j:slf4j-api:2.0.17" - implementation 'org.spockframework:spock-core:2.3-groovy-3.0' + implementation 'org.apache.groovy:groovy:4.0.30' + implementation 'org.spockframework:spock-core:2.3-groovy-4.0' implementation "org.mockito:mockito-core:5.16.1" implementation "org.assertj:assertj-core:3.27.3" implementation "ru.vyarus:spock-junit5:1.2.0" From 0a8961e61a1463c6529b060fa4a8bd70e0cea8a3 Mon Sep 17 00:00:00 2001 From: Alexandre Rodrigues Date: Wed, 29 Apr 2026 19:09:02 +0200 Subject: [PATCH 2/8] Add support for mlkem768x25519-sha256 Fixes #1017 --- .../java/net/schmizz/sshj/DefaultConfig.java | 2 + .../schmizz/sshj/transport/KeyExchanger.java | 21 +- .../sshj/transport/kex/AbstractDHG.java | 32 +- .../sshj/transport/kex/Curve25519DH.java | 21 +- .../kex/KexHostKeyCertificateVerifier.java | 75 +++++ .../sshj/transport/kex/KeyExchange.java | 16 + .../schmizz/sshj/transport/kex/MLKEM768.java | 108 +++++++ .../transport/kex/MLKEM768X25519SHA256.java | 225 ++++++++++++++ .../sshj/transport/kex/KeyExchangeTest.java | 2 + .../sshj/transport/kex/MLKEM768Test.java | 76 +++++ .../kex/MLKEM768X25519SHA256Test.java | 51 ++++ .../MLKEM768X25519SHA256WireFormatTest.java | 279 ++++++++++++++++++ 12 files changed, 866 insertions(+), 42 deletions(-) create mode 100644 src/main/java/net/schmizz/sshj/transport/kex/KexHostKeyCertificateVerifier.java create mode 100644 src/main/java/net/schmizz/sshj/transport/kex/MLKEM768.java create mode 100644 src/main/java/net/schmizz/sshj/transport/kex/MLKEM768X25519SHA256.java create mode 100644 src/test/java/net/schmizz/sshj/transport/kex/MLKEM768Test.java create mode 100644 src/test/java/net/schmizz/sshj/transport/kex/MLKEM768X25519SHA256Test.java create mode 100644 src/test/java/net/schmizz/sshj/transport/kex/MLKEM768X25519SHA256WireFormatTest.java diff --git a/src/main/java/net/schmizz/sshj/DefaultConfig.java b/src/main/java/net/schmizz/sshj/DefaultConfig.java index 87b80837..8c977473 100644 --- a/src/main/java/net/schmizz/sshj/DefaultConfig.java +++ b/src/main/java/net/schmizz/sshj/DefaultConfig.java @@ -35,6 +35,7 @@ import net.schmizz.sshj.transport.kex.DHGexSHA1; import net.schmizz.sshj.transport.kex.DHGexSHA256; import net.schmizz.sshj.transport.kex.ECDHNistP; +import net.schmizz.sshj.transport.kex.MLKEM768X25519SHA256; import net.schmizz.sshj.transport.random.JCERandom; import net.schmizz.sshj.transport.random.SingletonRandomFactory; import net.schmizz.sshj.userauth.keyprovider.OpenSSHKeyFile; @@ -105,6 +106,7 @@ public void setLoggerFactory(LoggerFactory loggerFactory) { protected void initKeyExchangeFactories() { setKeyExchangeFactories( + new MLKEM768X25519SHA256.Factory(), new Curve25519SHA256.Factory(), new Curve25519SHA256.FactoryLibSsh(), new DHGexSHA256.Factory(), diff --git a/src/main/java/net/schmizz/sshj/transport/KeyExchanger.java b/src/main/java/net/schmizz/sshj/transport/KeyExchanger.java index abaf72e1..68a26214 100644 --- a/src/main/java/net/schmizz/sshj/transport/KeyExchanger.java +++ b/src/main/java/net/schmizz/sshj/transport/KeyExchanger.java @@ -27,7 +27,6 @@ import net.schmizz.sshj.transport.verification.HostKeyVerifier; import org.slf4j.Logger; -import java.math.BigInteger; import java.security.GeneralSecurityException; import java.security.PublicKey; import java.util.*; @@ -309,9 +308,11 @@ private void gotStrictKexInfo(Proposal serverProposal) throws TransportException * * @return the resized key */ - private static byte[] resizedKey(byte[] E, int blockSize, Digest hash, BigInteger K, byte[] H) { + private static byte[] resizedKey(byte[] E, int blockSize, Digest hash, KeyExchange kex, byte[] H) { while (blockSize > E.length) { - Buffer.PlainBuffer buffer = new Buffer.PlainBuffer().putMPInt(K).putRawBytes(H).putRawBytes(E); + Buffer.PlainBuffer buffer = new Buffer.PlainBuffer(); + kex.putSharedSecret(buffer); + buffer.putRawBytes(H).putRawBytes(E); hash.update(buffer.array(), 0, buffer.available()); byte[] foo = hash.digest(); byte[] bar = new byte[E.length + foo.length]; @@ -333,9 +334,9 @@ private void gotNewKeys() { // session id is 'H' from the first key exchange and does not change thereafter sessionID = H; - final Buffer.PlainBuffer hashInput = new Buffer.PlainBuffer() - .putMPInt(kex.getK()) - .putRawBytes(H) + final Buffer.PlainBuffer hashInput = new Buffer.PlainBuffer(); + kex.putSharedSecret(hashInput); + hashInput.putRawBytes(H) .putByte((byte) 0) // .putRawBytes(sessionID); final int pos = hashInput.available() - sessionID.length - 1; // Position of @@ -367,13 +368,13 @@ private void gotNewKeys() { final Cipher cipher_C2S = Factory.Named.Util.create(transport.getConfig().getCipherFactories(), negotiatedAlgs.getClient2ServerCipherAlgorithm()); cipher_C2S.init(Cipher.Mode.Encrypt, - resizedKey(encryptionKey_C2S, cipher_C2S.getBlockSize(), hash, kex.getK(), kex.getH()), + resizedKey(encryptionKey_C2S, cipher_C2S.getBlockSize(), hash, kex, kex.getH()), initialIV_C2S); final Cipher cipher_S2C = Factory.Named.Util.create(transport.getConfig().getCipherFactories(), negotiatedAlgs.getServer2ClientCipherAlgorithm()); cipher_S2C.init(Cipher.Mode.Decrypt, - resizedKey(encryptionKey_S2C, cipher_S2C.getBlockSize(), hash, kex.getK(), kex.getH()), + resizedKey(encryptionKey_S2C, cipher_S2C.getBlockSize(), hash, kex, kex.getH()), initialIV_S2C); /* @@ -386,14 +387,14 @@ private void gotNewKeys() { if(cipher_C2S.getAuthenticationTagSize() == 0) { mac_C2S = Factory.Named.Util.create(transport.getConfig().getMACFactories(), negotiatedAlgs .getClient2ServerMACAlgorithm()); - mac_C2S.init(resizedKey(integrityKey_C2S, mac_C2S.getBlockSize(), hash, kex.getK(), kex.getH())); + mac_C2S.init(resizedKey(integrityKey_C2S, mac_C2S.getBlockSize(), hash, kex, kex.getH())); } MAC mac_S2C = null; if(cipher_S2C.getAuthenticationTagSize() == 0) { mac_S2C = Factory.Named.Util.create(transport.getConfig().getMACFactories(), negotiatedAlgs.getServer2ClientMACAlgorithm()); - mac_S2C.init(resizedKey(integrityKey_S2C, mac_S2C.getBlockSize(), hash, kex.getK(), kex.getH())); + mac_S2C.init(resizedKey(integrityKey_S2C, mac_S2C.getBlockSize(), hash, kex, kex.getH())); } final Compression compression_S2C = diff --git a/src/main/java/net/schmizz/sshj/transport/kex/AbstractDHG.java b/src/main/java/net/schmizz/sshj/transport/kex/AbstractDHG.java index e330ac40..034463df 100644 --- a/src/main/java/net/schmizz/sshj/transport/kex/AbstractDHG.java +++ b/src/main/java/net/schmizz/sshj/transport/kex/AbstractDHG.java @@ -91,41 +91,11 @@ public boolean next(Message msg, SSHPacket packet) throw new TransportException(DisconnectReason.KEY_EXCHANGE_FAILED, "KeyExchange signature verification failed"); - verifyCertificate(K_S); + KexHostKeyCertificateVerifier.verify(trans, hostKey, K_S); return true; } - private void verifyCertificate(byte[] K_S) throws TransportException { - if (hostKey instanceof Certificate && trans.getConfig().isVerifyHostKeyCertificates()) { - final Certificate hostKey = (Certificate) this.hostKey; - String signatureType, caKeyType; - try { - signatureType = new Buffer.PlainBuffer(hostKey.getSignature()).readString(); - } catch (Buffer.BufferException e) { - signatureType = null; - } - try { - caKeyType = new Buffer.PlainBuffer(hostKey.getSignatureKey()).readString(); - } catch (Buffer.BufferException e) { - caKeyType = null; - } - log.debug("Verifying signature of the key with type {} (signature type {}, CA key type {})", - hostKey.getType(), signatureType, caKeyType); - - try { - final String certError = KeyType.CertUtils.verifyHostCertificate(K_S, hostKey, trans.getRemoteHost()); - if (certError != null) { - throw new TransportException(DisconnectReason.KEY_EXCHANGE_FAILED, - "KeyExchange certificate check failed: " + certError); - } - } catch (Buffer.BufferException | SSHRuntimeException e) { - throw new TransportException(DisconnectReason.KEY_EXCHANGE_FAILED, - "KeyExchange certificate check failed", e); - } - } - } - protected abstract void initDH(DHBase dh) throws GeneralSecurityException; diff --git a/src/main/java/net/schmizz/sshj/transport/kex/Curve25519DH.java b/src/main/java/net/schmizz/sshj/transport/kex/Curve25519DH.java index d156cb32..730b6976 100644 --- a/src/main/java/net/schmizz/sshj/transport/kex/Curve25519DH.java +++ b/src/main/java/net/schmizz/sshj/transport/kex/Curve25519DH.java @@ -34,7 +34,8 @@ public class Curve25519DH extends DHBase { private static final String ALGORITHM = "X25519"; - private static final int KEY_LENGTH = 32; + /** Length in bytes of an X25519 public key and of the X25519 shared secret. */ + public static final int KEY_LENGTH = 32; private int encodedKeyLength; @@ -43,6 +44,11 @@ public class Curve25519DH extends DHBase { // Algorithm Identifier is set on Key Agreement Initialization private byte[] algorithmId = new byte[KEY_LENGTH]; + // Raw shared secret bytes captured in computeK; preserved alongside the BigInteger so + // callers that need the fixed-length byte form (e.g. PQ/T hybrid key exchanges) can + // obtain it without having to deal with mpint sign-bit padding. + private byte[] sharedSecretBytes; + public Curve25519DH() { super(ALGORITHM, ALGORITHM); } @@ -61,10 +67,23 @@ void computeK(final byte[] peerPublicKey) throws GeneralSecurityException { agreement.doPhase(generatedPeerPublicKey, true); final byte[] sharedSecretKey = agreement.generateSecret(); + sharedSecretBytes = sharedSecretKey; final BigInteger sharedSecretNumber = new BigInteger(BigInteger.ONE.signum(), sharedSecretKey); setK(sharedSecretNumber); } + /** + * Returns the raw bytes of the most recently computed X25519 shared secret, i.e. the + * unmodified output of the underlying {@link javax.crypto.KeyAgreement}. For X25519 + * the length is always {@value #KEY_LENGTH} bytes. + * + * @return the shared secret bytes, or {@code null} if {@link #computeK(byte[])} has + * not been invoked yet. + */ + public byte[] getSharedSecretBytes() { + return sharedSecretBytes; + } + /** * Initialize Key Agreement with generated Public and Private Key Pair * diff --git a/src/main/java/net/schmizz/sshj/transport/kex/KexHostKeyCertificateVerifier.java b/src/main/java/net/schmizz/sshj/transport/kex/KexHostKeyCertificateVerifier.java new file mode 100644 index 00000000..5d35f1fd --- /dev/null +++ b/src/main/java/net/schmizz/sshj/transport/kex/KexHostKeyCertificateVerifier.java @@ -0,0 +1,75 @@ +/* + * Copyright (C)2009 - SSHJ Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package net.schmizz.sshj.transport.kex; + +import com.hierynomus.sshj.userauth.certificate.Certificate; +import net.schmizz.sshj.common.Buffer; +import net.schmizz.sshj.common.DisconnectReason; +import net.schmizz.sshj.common.KeyType; +import net.schmizz.sshj.common.SSHRuntimeException; +import net.schmizz.sshj.transport.Transport; +import net.schmizz.sshj.transport.TransportException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.security.PublicKey; + +/** + * Shared helper for key-exchange implementations that need to validate an OpenSSH + * host certificate after the host-key signature has been verified. + */ +final class KexHostKeyCertificateVerifier { + + private static final Logger log = LoggerFactory.getLogger(KexHostKeyCertificateVerifier.class); + + private KexHostKeyCertificateVerifier() { + } + + /** + * If {@code hostKey} is an OpenSSH certificate and host-certificate verification is + * enabled in the {@link net.schmizz.sshj.Config}, validate it (signature, principals, + * validity window) using {@link KeyType.CertUtils#verifyHostCertificate}. No-op otherwise. + */ + static void verify(Transport trans, PublicKey publicKey, byte[] K_S) throws TransportException { + if (publicKey instanceof Certificate && trans.getConfig().isVerifyHostKeyCertificates()) { + final Certificate hostKey = (Certificate) publicKey; + String signatureType, caKeyType; + try { + signatureType = new Buffer.PlainBuffer(hostKey.getSignature()).readString(); + } catch (Buffer.BufferException e) { + signatureType = null; + } + try { + caKeyType = new Buffer.PlainBuffer(hostKey.getSignatureKey()).readString(); + } catch (Buffer.BufferException e) { + caKeyType = null; + } + log.debug("Verifying signature of the key with type {} (signature type {}, CA key type {})", + hostKey.getType(), signatureType, caKeyType); + + try { + final String certError = KeyType.CertUtils.verifyHostCertificate(K_S, hostKey, trans.getRemoteHost()); + if (certError != null) { + throw new TransportException(DisconnectReason.KEY_EXCHANGE_FAILED, + "KeyExchange certificate check failed: " + certError); + } + } catch (Buffer.BufferException | SSHRuntimeException e) { + throw new TransportException(DisconnectReason.KEY_EXCHANGE_FAILED, + "KeyExchange certificate check failed", e); + } + } + } +} diff --git a/src/main/java/net/schmizz/sshj/transport/kex/KeyExchange.java b/src/main/java/net/schmizz/sshj/transport/kex/KeyExchange.java index a235934e..8c66d685 100644 --- a/src/main/java/net/schmizz/sshj/transport/kex/KeyExchange.java +++ b/src/main/java/net/schmizz/sshj/transport/kex/KeyExchange.java @@ -15,6 +15,7 @@ */ package net.schmizz.sshj.transport.kex; +import net.schmizz.sshj.common.Buffer; import net.schmizz.sshj.common.Message; import net.schmizz.sshj.common.SSHPacket; import net.schmizz.sshj.transport.Transport; @@ -73,4 +74,19 @@ void init(Transport trans, String V_S, String V_C, byte[] I_S, byte[] I_C) boolean next(Message msg, SSHPacket buffer) throws GeneralSecurityException, TransportException; + /** + * Encode the shared secret K and append it to the given buffer when computing + * the exchange hash and deriving session keys. + *

+ * Most key exchange methods encode K as an SSH {@code mpint} (RFC 4253). PQ/T + * hybrid methods such as {@code mlkem768x25519-sha256} encode K as an SSH + * {@code string} (a fixed-length byte array) per the IETF draft + * {@code draft-kampanakis-curdle-ssh-pq-ke}. Implementations that use a + * non-default encoding override this method. + *

+ */ + default void putSharedSecret(Buffer.PlainBuffer buffer) { + buffer.putMPInt(getK()); + } + } diff --git a/src/main/java/net/schmizz/sshj/transport/kex/MLKEM768.java b/src/main/java/net/schmizz/sshj/transport/kex/MLKEM768.java new file mode 100644 index 00000000..0dad3628 --- /dev/null +++ b/src/main/java/net/schmizz/sshj/transport/kex/MLKEM768.java @@ -0,0 +1,108 @@ +/* + * Copyright (C)2009 - SSHJ Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package net.schmizz.sshj.transport.kex; + +import org.bouncycastle.crypto.AsymmetricCipherKeyPair; +import org.bouncycastle.crypto.SecretWithEncapsulation; +import org.bouncycastle.pqc.crypto.mlkem.MLKEMExtractor; +import org.bouncycastle.pqc.crypto.mlkem.MLKEMGenerator; +import org.bouncycastle.pqc.crypto.mlkem.MLKEMKeyGenerationParameters; +import org.bouncycastle.pqc.crypto.mlkem.MLKEMKeyPairGenerator; +import org.bouncycastle.pqc.crypto.mlkem.MLKEMParameters; +import org.bouncycastle.pqc.crypto.mlkem.MLKEMPrivateKeyParameters; +import org.bouncycastle.pqc.crypto.mlkem.MLKEMPublicKeyParameters; + +import java.security.GeneralSecurityException; +import java.security.SecureRandom; + +/** + * Helper around the Bouncy Castle lightweight implementation of ML-KEM-768 + * (FIPS 203). Provides client-side key generation and decapsulation, as well + * as server-side encapsulation (used by the unit tests). + * + *

For the parameter set used here, the byte sizes are:

+ *
    + *
  • Public key: {@value #PUBLIC_KEY_LENGTH} bytes
  • + *
  • Ciphertext: {@value #CIPHERTEXT_LENGTH} bytes
  • + *
  • Shared secret: {@value #SHARED_SECRET_LENGTH} bytes
  • + *
+ */ +public final class MLKEM768 { + + /** Length in bytes of an ML-KEM-768 public key. */ + public static final int PUBLIC_KEY_LENGTH = 1184; + + /** Length in bytes of an ML-KEM-768 ciphertext. */ + public static final int CIPHERTEXT_LENGTH = 1088; + + /** Length in bytes of the shared secret produced by ML-KEM-768. */ + public static final int SHARED_SECRET_LENGTH = 32; + + private MLKEMPublicKeyParameters publicKey; + private MLKEMPrivateKeyParameters privateKey; + + /** + * Generate an ephemeral ML-KEM-768 key pair using the provided source of randomness. + * + * @param random source of randomness + * @return the encoded public key (length {@value #PUBLIC_KEY_LENGTH}) + */ + public byte[] generateKeyPair(final SecureRandom random) { + final MLKEMKeyPairGenerator generator = new MLKEMKeyPairGenerator(); + generator.init(new MLKEMKeyGenerationParameters(random, MLKEMParameters.ml_kem_768)); + final AsymmetricCipherKeyPair keyPair = generator.generateKeyPair(); + publicKey = (MLKEMPublicKeyParameters) keyPair.getPublic(); + privateKey = (MLKEMPrivateKeyParameters) keyPair.getPrivate(); + return publicKey.getEncoded(); + } + + /** + * Decapsulate a ciphertext received from the peer using the previously generated private key. + * + * @param ciphertext peer ciphertext (must be exactly {@value #CIPHERTEXT_LENGTH} bytes) + * @return the shared secret (length {@value #SHARED_SECRET_LENGTH}) + * @throws GeneralSecurityException if the ciphertext has an invalid length or no key pair has been generated + */ + public byte[] decapsulate(final byte[] ciphertext) throws GeneralSecurityException { + if (privateKey == null) { + throw new GeneralSecurityException("ML-KEM-768 key pair has not been generated"); + } + if (ciphertext == null || ciphertext.length != CIPHERTEXT_LENGTH) { + throw new GeneralSecurityException( + "ML-KEM-768 ciphertext length must be " + CIPHERTEXT_LENGTH + " bytes"); + } + return new MLKEMExtractor(privateKey).extractSecret(ciphertext); + } + + /** + * Server-side encapsulation against a peer public key. Used by the test suite to simulate + * a server response without requiring an external SSH server. + * + * @param peerPublicKey peer public key (must be exactly {@value #PUBLIC_KEY_LENGTH} bytes) + * @param random source of randomness + * @return the encapsulation result containing the ciphertext and the shared secret + * @throws GeneralSecurityException if the peer public key has an invalid length + */ + public static SecretWithEncapsulation encapsulate(final byte[] peerPublicKey, final SecureRandom random) + throws GeneralSecurityException { + if (peerPublicKey == null || peerPublicKey.length != PUBLIC_KEY_LENGTH) { + throw new GeneralSecurityException( + "ML-KEM-768 public key length must be " + PUBLIC_KEY_LENGTH + " bytes"); + } + final MLKEMPublicKeyParameters peer = new MLKEMPublicKeyParameters(MLKEMParameters.ml_kem_768, peerPublicKey); + return new MLKEMGenerator(random).generateEncapsulated(peer); + } +} diff --git a/src/main/java/net/schmizz/sshj/transport/kex/MLKEM768X25519SHA256.java b/src/main/java/net/schmizz/sshj/transport/kex/MLKEM768X25519SHA256.java new file mode 100644 index 00000000..efc5dd1d --- /dev/null +++ b/src/main/java/net/schmizz/sshj/transport/kex/MLKEM768X25519SHA256.java @@ -0,0 +1,225 @@ +/* + * Copyright (C)2009 - SSHJ Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package net.schmizz.sshj.transport.kex; + +import com.hierynomus.sshj.userauth.certificate.Certificate; +import net.schmizz.sshj.common.Buffer; +import net.schmizz.sshj.common.DisconnectReason; +import net.schmizz.sshj.common.Message; +import net.schmizz.sshj.common.SSHPacket; +import net.schmizz.sshj.signature.Signature; +import net.schmizz.sshj.transport.Transport; +import net.schmizz.sshj.transport.TransportException; +import net.schmizz.sshj.transport.digest.SHA256; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.math.BigInteger; +import java.security.GeneralSecurityException; +import java.security.SecureRandom; + +/** + * Post-Quantum Traditional (PQ/T) hybrid SSH key exchange combining + * {@code curve25519-sha256} with {@code ML-KEM-768}, as defined in the IETF + * draft draft-kampanakis-curdle-ssh-pq-ke and implemented by + * OpenSSH 9.9+ under the algorithm name {@code mlkem768x25519-sha256}. + * + *

Wire protocol (the message numbers 30/31 are reused from RFC 4253):

+ *
+ * client -> server: SSH_MSG_KEX_HYBRID_INIT (30)
+ *   string  C_INIT = C_PK2 || C_PK1
+ * server -> client: SSH_MSG_KEX_HYBRID_REPLY (31)
+ *   string  K_S, server's public host key
+ *   string  S_REPLY = S_CT2 || S_PK1
+ *   string  signature on the exchange hash
+ * 
+ * + *

Where {@code C_PK1} / {@code S_PK1} are 32-byte X25519 public keys and + * {@code C_PK2} / {@code S_CT2} are the ML-KEM-768 client public key + * ({@value MLKEM768#PUBLIC_KEY_LENGTH} bytes) and server ciphertext + * ({@value MLKEM768#CIPHERTEXT_LENGTH} bytes) respectively.

+ * + *

The shared secret K is computed as {@code K = SHA-256(K_PQ || K_CL)} and + * is encoded as an SSH {@code string} (not {@code mpint}) when fed into both + * the exchange hash H and the session key derivation.

+ */ +public class MLKEM768X25519SHA256 extends KeyExchangeBase { + + private static final String NAME = "mlkem768x25519-sha256"; + + /** Named factory for the {@code mlkem768x25519-sha256} key exchange. */ + public static class Factory implements net.schmizz.sshj.common.Factory.Named { + @Override + public KeyExchange create() { + return new MLKEM768X25519SHA256(); + } + + @Override + public String getName() { + return NAME; + } + } + + private final Logger log = LoggerFactory.getLogger(getClass()); + + private final MLKEM768 mlkem = new MLKEM768(); + private final Curve25519DH x25519 = new Curve25519DH(); + + private byte[] cInit; + + private byte[] kEncoded; + + public MLKEM768X25519SHA256() { + super(new SHA256()); + } + + @Override + public void init(final Transport trans, final String V_S, final String V_C, final byte[] I_S, final byte[] I_C) + throws GeneralSecurityException, TransportException { + super.init(trans, V_S, V_C, I_S, I_C); + digest.init(); + + // Generate X25519 ephemeral key pair (C_PK1). + x25519.init(null, trans.getConfig().getRandomFactory()); + + // Generate ML-KEM-768 ephemeral key pair (C_PK2). + final byte[] mlkemPublicKey = mlkem.generateKeyPair(new SecureRandom()); + + // C_INIT is the concatenation C_PK2 || C_PK1. + final byte[] x25519PublicKey = x25519.getE(); + cInit = new byte[MLKEM768.PUBLIC_KEY_LENGTH + Curve25519DH.KEY_LENGTH]; + System.arraycopy(mlkemPublicKey, 0, cInit, 0, MLKEM768.PUBLIC_KEY_LENGTH); + System.arraycopy(x25519PublicKey, 0, cInit, MLKEM768.PUBLIC_KEY_LENGTH, Curve25519DH.KEY_LENGTH); + + log.debug("Sending SSH_MSG_KEX_HYBRID_INIT"); + trans.write(new SSHPacket(Message.KEXDH_INIT).putBytes(cInit)); + } + + @Override + public boolean next(final Message msg, final SSHPacket packet) + throws GeneralSecurityException, TransportException { + if (msg != Message.KEXDH_31) { + throw new TransportException(DisconnectReason.KEY_EXCHANGE_FAILED, + "Unexpected packet: " + msg); + } + + log.debug("Received SSH_MSG_KEX_HYBRID_REPLY"); + final byte[] K_S; + final byte[] sReply; + final byte[] sig; + try { + K_S = packet.readBytes(); + sReply = packet.readBytes(); + sig = packet.readBytes(); + hostKey = new Buffer.PlainBuffer(K_S).readPublicKey(); + } catch (Buffer.BufferException be) { + throw new TransportException(be); + } + + // S_REPLY = S_CT2 || S_PK1 + final int expectedLength = MLKEM768.CIPHERTEXT_LENGTH + Curve25519DH.KEY_LENGTH; + if (sReply.length != expectedLength) { + throw new TransportException(DisconnectReason.KEY_EXCHANGE_FAILED, + "S_REPLY length must be " + expectedLength + " bytes but was " + sReply.length); + } + final byte[] sCt2 = new byte[MLKEM768.CIPHERTEXT_LENGTH]; + final byte[] sPk1 = new byte[Curve25519DH.KEY_LENGTH]; + System.arraycopy(sReply, 0, sCt2, 0, MLKEM768.CIPHERTEXT_LENGTH); + System.arraycopy(sReply, MLKEM768.CIPHERTEXT_LENGTH, sPk1, 0, Curve25519DH.KEY_LENGTH); + + // K_PQ: decapsulate ML-KEM-768 ciphertext. + final byte[] kPq = mlkem.decapsulate(sCt2); + + // K_CL: X25519 shared secret in raw byte form (NOT mpint), as required by the draft. + x25519.computeK(sPk1); + final byte[] kCl = x25519.getSharedSecretBytes(); + + // Per RFC 8731, an all-zero output indicates a low-order point and MUST be rejected. + if (isAllZero(kCl)) { + throw new TransportException(DisconnectReason.KEY_EXCHANGE_FAILED, + "X25519 key agreement produced an all-zero shared secret"); + } + + // K = HASH(K_PQ || K_CL), encoded as a string in H and key derivation. + digest.init(); + digest.update(kPq, 0, kPq.length); + digest.update(kCl, 0, kCl.length); + kEncoded = digest.digest(); + + // Compute exchange hash H over: V_C, V_S, I_C, I_S, K_S, C_INIT, S_REPLY, K (as string). + final Buffer.PlainBuffer hashBuffer = initializedBuffer() + .putString(K_S) + .putString(cInit) + .putString(sReply) + .putString(kEncoded); + + digest.init(); + digest.update(hashBuffer.array(), hashBuffer.rpos(), hashBuffer.available()); + H = digest.digest(); + + // Verify the host key signature on H. + final Signature signature = trans.getHostKeyAlgorithm().newSignature(); + if (hostKey instanceof Certificate) { + signature.initVerify(((Certificate) hostKey).getKey()); + } else { + signature.initVerify(hostKey); + } + signature.update(H, 0, H.length); + if (!signature.verify(sig)) { + throw new TransportException(DisconnectReason.KEY_EXCHANGE_FAILED, + "KeyExchange signature verification failed"); + } + + KexHostKeyCertificateVerifier.verify(trans, hostKey, K_S); + + return true; + } + + /** + * For PQ/T hybrid key exchanges, K is the SHA-256 output of the concatenation of + * the two shared secrets and is encoded as an SSH {@code string} (length-prefixed + * byte array) per draft-kampanakis-curdle-ssh-pq-ke section 2.5, instead of + * the traditional {@code mpint} encoding used by RFC 4253 / RFC 5656 / + * RFC 8731 key exchanges. + */ + @Override + public void putSharedSecret(final Buffer.PlainBuffer buffer) { + buffer.putString(kEncoded); + } + + /** + * Unsupported for the hybrid PQ key exchange. K is a fixed-length byte string + * (the SHA-256 of {@code K_PQ || K_CL}) and is encoded on the wire as an SSH + * {@code string}, not as an {@code mpint}. Callers that legitimately need the + * shared secret bytes for inclusion in the exchange hash or key derivation + * MUST use {@link #putSharedSecret(Buffer.PlainBuffer)}. + * + * @throws UnsupportedOperationException always + */ + @Override + public BigInteger getK() { + throw new UnsupportedOperationException( + "K is a fixed-length string for hybrid KEX; use putSharedSecret(...)"); + } + + private static boolean isAllZero(final byte[] data) { + int acc = 0; + for (final byte b : data) { + acc |= b & 0xff; + } + return acc == 0; + } +} diff --git a/src/test/java/com/hierynomus/sshj/transport/kex/KeyExchangeTest.java b/src/test/java/com/hierynomus/sshj/transport/kex/KeyExchangeTest.java index d0b61011..2940e247 100644 --- a/src/test/java/com/hierynomus/sshj/transport/kex/KeyExchangeTest.java +++ b/src/test/java/com/hierynomus/sshj/transport/kex/KeyExchangeTest.java @@ -23,6 +23,7 @@ import net.schmizz.sshj.transport.kex.DHGexSHA1; import net.schmizz.sshj.transport.kex.DHGexSHA256; import net.schmizz.sshj.transport.kex.ECDHNistP; +import net.schmizz.sshj.transport.kex.MLKEM768X25519SHA256; import net.schmizz.sshj.transport.random.JCERandom; import net.schmizz.sshj.transport.random.SingletonRandomFactory; import org.apache.sshd.common.kex.BuiltinDHFactories; @@ -58,6 +59,7 @@ public static Collection getParameters() { { DHGServer.newFactory(BuiltinDHFactories.dhg16_512), DHGroups.Group16SHA512() }, { DHGServer.newFactory(BuiltinDHFactories.dhg17_512), DHGroups.Group17SHA512() }, { DHGServer.newFactory(BuiltinDHFactories.dhg18_512), DHGroups.Group18SHA512() }, + { DHGServer.newFactory(BuiltinDHFactories.mlkem768x25519), new MLKEM768X25519SHA256.Factory() }, }); } diff --git a/src/test/java/net/schmizz/sshj/transport/kex/MLKEM768Test.java b/src/test/java/net/schmizz/sshj/transport/kex/MLKEM768Test.java new file mode 100644 index 00000000..17fd4a52 --- /dev/null +++ b/src/test/java/net/schmizz/sshj/transport/kex/MLKEM768Test.java @@ -0,0 +1,76 @@ +/* + * Copyright (C)2009 - SSHJ Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package net.schmizz.sshj.transport.kex; + +import org.bouncycastle.crypto.SecretWithEncapsulation; +import org.junit.jupiter.api.Test; + +import java.security.GeneralSecurityException; +import java.security.SecureRandom; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class MLKEM768Test { + + @Test + public void generateKeyPairProducesCorrectlySizedPublicKey() { + final MLKEM768 mlkem = new MLKEM768(); + final byte[] publicKey = mlkem.generateKeyPair(new SecureRandom()); + + assertNotNull(publicKey); + assertEquals(MLKEM768.PUBLIC_KEY_LENGTH, publicKey.length); + } + + @Test + public void encapsulateAndDecapsulateProduceMatchingSecret() throws GeneralSecurityException { + final SecureRandom random = new SecureRandom(); + final MLKEM768 mlkem = new MLKEM768(); + + final byte[] publicKey = mlkem.generateKeyPair(random); + final SecretWithEncapsulation server = MLKEM768.encapsulate(publicKey, random); + + assertEquals(MLKEM768.CIPHERTEXT_LENGTH, server.getEncapsulation().length); + assertEquals(MLKEM768.SHARED_SECRET_LENGTH, server.getSecret().length); + + final byte[] clientSecret = mlkem.decapsulate(server.getEncapsulation()); + + assertEquals(MLKEM768.SHARED_SECRET_LENGTH, clientSecret.length); + assertArrayEquals(server.getSecret(), clientSecret); + } + + @Test + public void decapsulateRejectsCiphertextOfWrongLength() { + final MLKEM768 mlkem = new MLKEM768(); + mlkem.generateKeyPair(new SecureRandom()); + + assertThrows(GeneralSecurityException.class, () -> mlkem.decapsulate(new byte[10])); + } + + @Test + public void decapsulateBeforeKeyGenFails() { + assertThrows(GeneralSecurityException.class, + () -> new MLKEM768().decapsulate(new byte[MLKEM768.CIPHERTEXT_LENGTH])); + } + + @Test + public void encapsulateRejectsPublicKeyOfWrongLength() { + assertThrows(GeneralSecurityException.class, + () -> MLKEM768.encapsulate(new byte[10], new SecureRandom())); + } +} diff --git a/src/test/java/net/schmizz/sshj/transport/kex/MLKEM768X25519SHA256Test.java b/src/test/java/net/schmizz/sshj/transport/kex/MLKEM768X25519SHA256Test.java new file mode 100644 index 00000000..bcb36e89 --- /dev/null +++ b/src/test/java/net/schmizz/sshj/transport/kex/MLKEM768X25519SHA256Test.java @@ -0,0 +1,51 @@ +/* + * Copyright (C)2009 - SSHJ Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package net.schmizz.sshj.transport.kex; + +import net.schmizz.sshj.DefaultConfig; +import net.schmizz.sshj.common.Factory; +import org.junit.jupiter.api.Test; + +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertNotSame; + +public class MLKEM768X25519SHA256Test { + + @Test + public void factoryHasIanaName() { + assertEquals("mlkem768x25519-sha256", new MLKEM768X25519SHA256.Factory().getName()); + } + + @Test + public void factoryProducesFreshInstances() { + final MLKEM768X25519SHA256.Factory factory = new MLKEM768X25519SHA256.Factory(); + final KeyExchange first = factory.create(); + final KeyExchange second = factory.create(); + + assertInstanceOf(MLKEM768X25519SHA256.class, first); + assertInstanceOf(MLKEM768X25519SHA256.class, second); + assertNotSame(first, second); + } + + @Test + public void registeredFirstInDefaultConfig() { + final List> factories = new DefaultConfig().getKeyExchangeFactories(); + assertEquals("mlkem768x25519-sha256", factories.get(0).getName()); + } +} diff --git a/src/test/java/net/schmizz/sshj/transport/kex/MLKEM768X25519SHA256WireFormatTest.java b/src/test/java/net/schmizz/sshj/transport/kex/MLKEM768X25519SHA256WireFormatTest.java new file mode 100644 index 00000000..a7824feb --- /dev/null +++ b/src/test/java/net/schmizz/sshj/transport/kex/MLKEM768X25519SHA256WireFormatTest.java @@ -0,0 +1,279 @@ +/* + * Copyright (C)2009 - SSHJ Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package net.schmizz.sshj.transport.kex; + +import com.hierynomus.sshj.key.KeyAlgorithm; +import net.schmizz.sshj.Config; +import net.schmizz.sshj.common.Buffer; +import net.schmizz.sshj.common.KeyType; +import net.schmizz.sshj.common.Message; +import net.schmizz.sshj.common.SSHPacket; +import net.schmizz.sshj.signature.Signature; +import net.schmizz.sshj.transport.Transport; +import net.schmizz.sshj.transport.digest.SHA256; +import net.schmizz.sshj.transport.random.JCERandom; +import org.bouncycastle.crypto.SecretWithEncapsulation; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; + +import java.math.BigInteger; +import java.nio.ByteBuffer; +import java.security.KeyPair; +import java.security.KeyPairGenerator; +import java.security.SecureRandom; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +/** + * Targeted wire-format assertions for {@link MLKEM768X25519SHA256} that don't depend on + * a peer SSH implementation. These verify the byte layout mandated by + * {@code draft-kampanakis-curdle-ssh-pq-ke-05}, in particular invariants that an + * interop test against a peer with the same bug would fail to catch. + */ +public class MLKEM768X25519SHA256WireFormatTest { + + /** + * {@code C_INIT = C_PK2 || C_PK1} where {@code C_PK2} is the 1184-byte ML-KEM-768 public + * key and {@code C_PK1} is the 32-byte X25519 public key, in that order. Asserts the + * exact concatenation length and that the leading 1184 bytes round-trip as a valid + * ML-KEM-768 public key by encapsulating against them. + */ + @Test + public void cInitIsMlkemPublicKeyConcatenatedWithX25519PublicKey() throws Exception { + final SSHPacket initPacket = runInitAndCapturePacket(); + + // First byte: SSH_MSG_KEX_HYBRID_INIT = 30. + assertEquals(Message.KEXDH_INIT, initPacket.readMessageID()); + // Then: C_INIT as an SSH 'string' (uint32 length || bytes). + final byte[] cInit = initPacket.readBytes(); + + assertEquals(MLKEM768.PUBLIC_KEY_LENGTH + Curve25519DH.KEY_LENGTH, cInit.length, + "C_INIT must be exactly PUBLIC_KEY_LENGTH + KEY_LENGTH bytes"); + + // Demonstrate the leading slice is a valid ML-KEM-768 public key by encapsulating + // against it; this would fail if the order were reversed (X25519 key first). + final byte[] mlkemPk = new byte[MLKEM768.PUBLIC_KEY_LENGTH]; + System.arraycopy(cInit, 0, mlkemPk, 0, MLKEM768.PUBLIC_KEY_LENGTH); + final SecretWithEncapsulation enc = MLKEM768.encapsulate(mlkemPk, new SecureRandom()); + assertEquals(MLKEM768.CIPHERTEXT_LENGTH, enc.getEncapsulation().length); + } + + /** + * Per draft-kampanakis section 2.4: {@code K = HASH(K_PQ || K_CL)}, with the PQ secret + * first. Reversing the order changes K and silently breaks interop. We control both + * halves by playing the server. + */ + @Test + public void kIsSha256OfMlkemSecretConcatenatedWithX25519Secret() throws Exception { + final ServerExchange exchange = runFullExchange(); + + // Recompute K with the documented order. + final SHA256 hash = new SHA256(); + hash.init(); + hash.update(exchange.kPq, 0, exchange.kPq.length); + hash.update(exchange.kCl, 0, exchange.kCl.length); + final byte[] expectedK = hash.digest(); + + // And with the wrong order, to make sure the assertion below would fail if the + // implementation accidentally reversed the inputs. + hash.init(); + hash.update(exchange.kCl, 0, exchange.kCl.length); + hash.update(exchange.kPq, 0, exchange.kPq.length); + final byte[] reversedK = hash.digest(); + + assertArrayEquals(expectedK, exchange.kEncoded, + "K must be SHA-256(K_PQ || K_CL) in that exact order"); + assertNotEquals(new BigInteger(1, reversedK), new BigInteger(1, exchange.kEncoded), + "test setup sanity: reversed-order K must differ from documented-order K"); + } + + /** + * Per draft-kampanakis section 2.5: K is encoded as an SSH {@code string} (length-prefixed + * fixed byte array) — NOT as an {@code mpint}. The discriminator: when K's high bit is + * set, {@code mpint} would prepend a 0x00 sign byte, expanding the length to 33; the + * draft mandates exactly 32 bytes with no padding. + * + *

We force the high-bit case by retrying until SHA-256 returns a value whose first + * byte has the high bit set; with random K_PQ/K_CL inputs each attempt has ≈50% + * probability so we converge in a few iterations.

+ */ + @Test + public void putSharedSecretWritesStringNotMpintEvenWhenHighBitSet() throws Exception { + ServerExchange exchange = null; + for (int attempt = 0; attempt < 32; attempt++) { + final ServerExchange candidate = runFullExchange(); + if ((candidate.kEncoded[0] & 0x80) != 0) { + exchange = candidate; + break; + } + } + assertTrue(exchange != null, + "could not produce a K with the high bit set in 32 attempts (extremely unlikely)"); + + // putSharedSecret() must emit: 4-byte big-endian length == 32, then the 32 K bytes. + // mpint encoding of the same value would emit length == 33 with a leading 0x00. + final Buffer.PlainBuffer buf = new Buffer.PlainBuffer(); + exchange.kex.putSharedSecret(buf); + final byte[] wire = buf.getCompactData(); + + assertEquals(4 + 32, wire.length, + "K must be a 32-byte SSH string (4-byte length + 32 bytes), not an mpint"); + final int length = ByteBuffer.wrap(wire, 0, 4).getInt(); + assertEquals(32, length, "string length prefix must be 32"); + final byte[] payload = new byte[32]; + System.arraycopy(wire, 4, payload, 0, 32); + assertArrayEquals(exchange.kEncoded, payload, + "string payload must be exactly the K bytes with no mpint sign-byte padding"); + + // Cross-check against what an mpint would have produced. + final Buffer.PlainBuffer mpintBuf = new Buffer.PlainBuffer(); + mpintBuf.putMPInt(new BigInteger(1, exchange.kEncoded)); + final int mpintLength = ByteBuffer.wrap(mpintBuf.getCompactData(), 0, 4).getInt(); + assertEquals(33, mpintLength, + "test setup sanity: mpint encoding of a high-bit-set 32-byte value must be 33 bytes"); + } + + /** + * For every other KEX in sshj K is a number and callers reasonably assume + * {@code new Buffer.PlainBuffer().putMPInt(kex.getK())} reproduces the exact bytes + * that went into the exchange hash H. For the hybrid PQ KEX that assumption is wrong: + * K is a fixed-length string and is encoded via {@link KeyExchange#putSharedSecret}. + * To prevent silent misuse, {@link MLKEM768X25519SHA256#getK()} must fail loudly. + */ + @Test + public void getKThrowsUnsupportedOperation() throws Exception { + final ServerExchange exchange = runFullExchange(); + final UnsupportedOperationException ex = assertThrows(UnsupportedOperationException.class, + () -> exchange.kex.getK()); + assertTrue(ex.getMessage() != null && ex.getMessage().contains("putSharedSecret"), + "error message should steer callers toward putSharedSecret(...) but was: " + ex.getMessage()); + } + + /** + * Drives {@link MLKEM768X25519SHA256#init} with mocked transport collaborators and + * returns the {@link SSHPacket} the implementation wrote to the wire. + */ + private SSHPacket runInitAndCapturePacket() throws Exception { + final Transport trans = mock(Transport.class); + final Config config = mock(Config.class); + when(trans.getConfig()).thenReturn(config); + when(config.getRandomFactory()).thenReturn(new JCERandom.Factory()); + + final MLKEM768X25519SHA256 kex = new MLKEM768X25519SHA256(); + kex.init(trans, "SSH-2.0-server", "SSH-2.0-client", new byte[]{1}, new byte[]{2}); + + final ArgumentCaptor packetCaptor = ArgumentCaptor.forClass(SSHPacket.class); + verify(trans).write(packetCaptor.capture()); + return packetCaptor.getValue(); + } + + /** + * Drives a complete {@code init} → server reply → {@code next} round trip, with this + * test acting as the server. Returns the per-side secrets ({@code K_PQ}, {@code K_CL}) + * that the client must combine, plus the final {@code K} computed by the client. + */ + private ServerExchange runFullExchange() throws Exception { + // --- Set up a real Ed25519 host key for the signature step --- + final KeyPairGenerator hostKpg = KeyPairGenerator.getInstance("Ed25519"); + final KeyPair hostKeyPair = hostKpg.generateKeyPair(); + final Buffer.PlainBuffer ksBuf = new Buffer.PlainBuffer(); + KeyType.ED25519.putPubKeyIntoBuffer(hostKeyPair.getPublic(), ksBuf); + final byte[] kS = ksBuf.getCompactData(); + + // --- Mock transport (signature.verify is stubbed to true; we don't need to actually sign) --- + final Transport trans = mock(Transport.class); + final Config config = mock(Config.class); + when(trans.getConfig()).thenReturn(config); + when(config.getRandomFactory()).thenReturn(new JCERandom.Factory()); + final KeyAlgorithm hostKeyAlg = mock(KeyAlgorithm.class); + final Signature signature = mock(Signature.class); + when(trans.getHostKeyAlgorithm()).thenReturn(hostKeyAlg); + when(hostKeyAlg.newSignature()).thenReturn(signature); + when(signature.verify(any(byte[].class))).thenReturn(true); + + // --- Drive init() and capture the packet the client emitted --- + final MLKEM768X25519SHA256 kex = new MLKEM768X25519SHA256(); + kex.init(trans, "SSH-2.0-server", "SSH-2.0-client", new byte[]{1}, new byte[]{2}); + final ArgumentCaptor packetCaptor = ArgumentCaptor.forClass(SSHPacket.class); + verify(trans).write(packetCaptor.capture()); + final SSHPacket initPacket = packetCaptor.getValue(); + initPacket.readMessageID(); + final byte[] cInit = initPacket.readBytes(); + + // --- Server side: split C_INIT, encapsulate against C_PK2, agree against C_PK1 --- + final byte[] cPk2 = new byte[MLKEM768.PUBLIC_KEY_LENGTH]; + final byte[] cPk1 = new byte[Curve25519DH.KEY_LENGTH]; + System.arraycopy(cInit, 0, cPk2, 0, MLKEM768.PUBLIC_KEY_LENGTH); + System.arraycopy(cInit, MLKEM768.PUBLIC_KEY_LENGTH, cPk1, 0, Curve25519DH.KEY_LENGTH); + + final SecretWithEncapsulation enc = MLKEM768.encapsulate(cPk2, new SecureRandom()); + final byte[] kPq = enc.getSecret(); + final byte[] sCt2 = enc.getEncapsulation(); + + final Curve25519DH serverDh = new Curve25519DH(); + serverDh.init(null, new JCERandom.Factory()); + serverDh.computeK(cPk1); + final byte[] kCl = serverDh.getSharedSecretBytes(); + final byte[] sPk1 = serverDh.getE(); + + final byte[] sReply = new byte[MLKEM768.CIPHERTEXT_LENGTH + Curve25519DH.KEY_LENGTH]; + System.arraycopy(sCt2, 0, sReply, 0, MLKEM768.CIPHERTEXT_LENGTH); + System.arraycopy(sPk1, 0, sReply, MLKEM768.CIPHERTEXT_LENGTH, Curve25519DH.KEY_LENGTH); + + // --- Build the SSH_MSG_KEX_HYBRID_REPLY and feed it to next() --- + final SSHPacket reply = new SSHPacket(Message.KEXDH_31) + .putBytes(kS) + .putBytes(sReply) + .putBytes(new byte[]{0x00}); // signature payload; verify() is stubbed + reply.readMessageID(); // advance past the message id, as the dispatcher would + kex.next(Message.KEXDH_31, reply); + + // K is not retrievable as a BigInteger for the hybrid KEX (getK() throws); + // extract the on-wire bytes via putSharedSecret(...), then strip the SSH string length prefix. + final Buffer.PlainBuffer sharedSecretBuffer = new Buffer.PlainBuffer(); + kex.putSharedSecret(sharedSecretBuffer); + final byte[] kEncoded; + try { + kEncoded = sharedSecretBuffer.readBytes(); + } catch (final Buffer.BufferException e) { + throw new AssertionError("Failed to read K written by putSharedSecret", e); + } + + return new ServerExchange(kex, kPq, kCl, kEncoded); + } + + private static final class ServerExchange { + final MLKEM768X25519SHA256 kex; + final byte[] kPq; + final byte[] kCl; + final byte[] kEncoded; + + ServerExchange(final MLKEM768X25519SHA256 kex, final byte[] kPq, final byte[] kCl, final byte[] kEncoded) { + this.kex = kex; + this.kPq = kPq; + this.kCl = kCl; + this.kEncoded = kEncoded; + } + } +} From d30af3d28653ac7f67f37c577ce6edc1ba0bee2c Mon Sep 17 00:00:00 2001 From: Alexandre Rodrigues Date: Mon, 4 May 2026 15:52:39 +0200 Subject: [PATCH 3/8] Add MLKEMHybridKexIntegrationTest.java --- .../com/hierynomus/sshj/SshdContainer.java | 15 ++- .../kex/MLKEMHybridKexIntegrationTest.java | 98 +++++++++++++++++++ 2 files changed, 112 insertions(+), 1 deletion(-) create mode 100644 src/itest/java/com/hierynomus/sshj/transport/kex/MLKEMHybridKexIntegrationTest.java diff --git a/src/itest/java/com/hierynomus/sshj/SshdContainer.java b/src/itest/java/com/hierynomus/sshj/SshdContainer.java index 91b531e6..5b01ea43 100644 --- a/src/itest/java/com/hierynomus/sshj/SshdContainer.java +++ b/src/itest/java/com/hierynomus/sshj/SshdContainer.java @@ -103,9 +103,12 @@ public static SshdConfigBuilder defaultBuilder() { } public static class Builder implements Consumer { + private static final String DEFAULT_BASE_IMAGE = "alpine:3.19.0"; + private List hostKeys = new ArrayList<>(); private List certificates = new ArrayList<>(); private @NotNull SshdConfigBuilder sshdConfig = SshdConfigBuilder.defaultBuilder(); + private @NotNull String baseImage = DEFAULT_BASE_IMAGE; public static Builder defaultBuilder() { Builder b = new Builder(); @@ -119,6 +122,16 @@ public static Builder defaultBuilder() { return this; } + /** + * Override the base image used to build the sshd container. Useful for tests that need + * a specific OpenSSH version (for example, OpenSSH ≥10 for the + * {@code mlkem768x25519-sha256} key exchange). + */ + public @NotNull Builder withBaseImage(@NotNull String baseImage) { + this.baseImage = baseImage; + return this; + } + public @NotNull Builder withAllKeys() { this.addHostKey("test-container/ssh_host_ecdsa_key"); this.addHostKey("test-container/ssh_host_ed25519_key"); @@ -148,7 +161,7 @@ public static Builder defaultBuilder() { @Override public void accept(@NotNull DockerfileBuilder builder) { - builder.from("alpine:3.19.0"); + builder.from(baseImage); builder.run("apk add --no-cache openssh"); builder.expose(22); builder.copy("entrypoint.sh", "/entrypoint.sh"); diff --git a/src/itest/java/com/hierynomus/sshj/transport/kex/MLKEMHybridKexIntegrationTest.java b/src/itest/java/com/hierynomus/sshj/transport/kex/MLKEMHybridKexIntegrationTest.java new file mode 100644 index 00000000..bff441b3 --- /dev/null +++ b/src/itest/java/com/hierynomus/sshj/transport/kex/MLKEMHybridKexIntegrationTest.java @@ -0,0 +1,98 @@ +/* + * Copyright (C)2009 - SSHJ Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.hierynomus.sshj.transport.kex; + +import com.hierynomus.sshj.SshdContainer; +import com.hierynomus.sshj.SshdContainer.SshdConfigBuilder; +import net.schmizz.sshj.Config; +import net.schmizz.sshj.DefaultConfig; +import net.schmizz.sshj.SSHClient; +import net.schmizz.sshj.transport.kex.MLKEM768X25519SHA256; +import net.schmizz.sshj.transport.verification.PromiscuousVerifier; +import org.junit.jupiter.api.Test; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; + +import java.util.Collections; +import java.util.concurrent.atomic.AtomicReference; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +/** + * Verifies interop with a real OpenSSH server (10.x) for the post-quantum hybrid key + * exchange {@code mlkem768x25519-sha256}. + * + *

The container is built on Alpine 3.22, whose {@code openssh} package is + * 10.0p1 — the first OpenSSH release that ships {@code mlkem768x25519-sha256} (it is + * the default KEX in 10.x). The {@link SshdConfigBuilder} {@code KexAlgorithms} line is + * replaced with one containing only {@code mlkem768x25519-sha256} to ensure negotiation + * cannot fall through to a classical KEX.

+ */ +@Testcontainers +public class MLKEMHybridKexIntegrationTest { + + private static final String OPENSSH_10_BASE_IMAGE = "alpine:3.22"; + private static final String HYBRID_KEX_NAME = "mlkem768x25519-sha256"; + + /** + * sshd_config without a {@code KexAlgorithms} line. Required because in + * {@code sshd_config} the first occurrence of an option wins, so we cannot simply + * append our hybrid-only line on top of {@link SshdConfigBuilder#DEFAULT_SSHD_CONFIG} + * (which already declares a classical-only {@code KexAlgorithms}). We then add the + * hybrid line via {@link SshdConfigBuilder#with(String, String)}. + */ + private static final String SSHD_CONFIG_NO_KEX = "" + + "PermitRootLogin yes\n" + + "AuthorizedKeysFile .ssh/authorized_keys\n" + + "Subsystem sftp /usr/lib/ssh/sftp-server\n" + + "macs hmac-sha2-256-etm@openssh.com,hmac-sha2-512-etm@openssh.com,hmac-sha2-256,hmac-sha2-512\n" + + "TrustedUserCAKeys /etc/ssh/trusted_ca_keys\n" + + "Ciphers aes128-ctr,aes192-ctr,aes256-ctr,aes128-gcm@openssh.com,aes256-gcm@openssh.com,chacha20-poly1305@openssh.com\n" + + "LogLevel DEBUG2\n"; + + @Container + private static final SshdContainer sshd = SshdContainer.Builder.defaultBuilder() + .withBaseImage(OPENSSH_10_BASE_IMAGE) + .withSshdConfig(new SshdConfigBuilder(SSHD_CONFIG_NO_KEX) + .with("KexAlgorithms", HYBRID_KEX_NAME)) + .withAllKeys() + .build(); + + @Test + public void shouldNegotiateMlkem768X25519Sha256WithOpenSsh10() throws Throwable { + final Config config = new DefaultConfig(); + // Force sshj to offer ONLY the hybrid KEX so the assertion below cannot pass by + // falling back to a classical one. + config.setKeyExchangeFactories(Collections.singletonList(new MLKEM768X25519SHA256.Factory())); + + final AtomicReference negotiatedKex = new AtomicReference<>(); + try (SSHClient client = new SSHClient(config)) { + client.addHostKeyVerifier(new PromiscuousVerifier()); + client.addAlgorithmsVerifier(algorithms -> { + negotiatedKex.set(algorithms.getKeyExchangeAlgorithm()); + return true; + }); + client.connect("127.0.0.1", sshd.getFirstMappedPort()); + + client.authPublickey("sshj", "src/itest/resources/keyfiles/id_rsa_opensshv1"); + assertTrue(client.isAuthenticated(), "public-key auth should succeed over the hybrid KEX"); + } + + assertEquals(HYBRID_KEX_NAME, negotiatedKex.get(), + "transport must have negotiated mlkem768x25519-sha256 with the OpenSSH 10 server"); + } +} From 1e1c174f512f1dda9c374208e6e01d0e33fbff5a Mon Sep 17 00:00:00 2001 From: Alexandre Rodrigues Date: Mon, 4 May 2026 09:12:48 +0200 Subject: [PATCH 4/8] Use SecurityUtils.getKeyPairGenerator --- .../schmizz/sshj/transport/kex/MLKEM768.java | 27 +++++++++++-------- .../transport/kex/MLKEM768X25519SHA256.java | 5 ++-- .../sshj/transport/kex/MLKEM768Test.java | 10 +++---- 3 files changed, 23 insertions(+), 19 deletions(-) diff --git a/src/main/java/net/schmizz/sshj/transport/kex/MLKEM768.java b/src/main/java/net/schmizz/sshj/transport/kex/MLKEM768.java index 0dad3628..22fb0c4a 100644 --- a/src/main/java/net/schmizz/sshj/transport/kex/MLKEM768.java +++ b/src/main/java/net/schmizz/sshj/transport/kex/MLKEM768.java @@ -15,17 +15,19 @@ */ package net.schmizz.sshj.transport.kex; -import org.bouncycastle.crypto.AsymmetricCipherKeyPair; +import net.schmizz.sshj.common.SecurityUtils; import org.bouncycastle.crypto.SecretWithEncapsulation; import org.bouncycastle.pqc.crypto.mlkem.MLKEMExtractor; import org.bouncycastle.pqc.crypto.mlkem.MLKEMGenerator; -import org.bouncycastle.pqc.crypto.mlkem.MLKEMKeyGenerationParameters; -import org.bouncycastle.pqc.crypto.mlkem.MLKEMKeyPairGenerator; import org.bouncycastle.pqc.crypto.mlkem.MLKEMParameters; import org.bouncycastle.pqc.crypto.mlkem.MLKEMPrivateKeyParameters; import org.bouncycastle.pqc.crypto.mlkem.MLKEMPublicKeyParameters; +import org.bouncycastle.pqc.crypto.util.PrivateKeyFactory; +import org.bouncycastle.pqc.crypto.util.PublicKeyFactory; +import java.io.IOException; import java.security.GeneralSecurityException; +import java.security.KeyPair; import java.security.SecureRandom; /** @@ -55,17 +57,20 @@ public final class MLKEM768 { private MLKEMPrivateKeyParameters privateKey; /** - * Generate an ephemeral ML-KEM-768 key pair using the provided source of randomness. + * Generate an ephemeral ML-KEM-768 key pair via the JCA, using the same path + * as the rest of sshj (see {@link SecurityUtils#getKeyPairGenerator(String)}). * - * @param random source of randomness * @return the encoded public key (length {@value #PUBLIC_KEY_LENGTH}) + * @throws GeneralSecurityException if no JCA provider supports ML-KEM-768 or key conversion fails */ - public byte[] generateKeyPair(final SecureRandom random) { - final MLKEMKeyPairGenerator generator = new MLKEMKeyPairGenerator(); - generator.init(new MLKEMKeyGenerationParameters(random, MLKEMParameters.ml_kem_768)); - final AsymmetricCipherKeyPair keyPair = generator.generateKeyPair(); - publicKey = (MLKEMPublicKeyParameters) keyPair.getPublic(); - privateKey = (MLKEMPrivateKeyParameters) keyPair.getPrivate(); + public byte[] generateKeyPair() throws GeneralSecurityException { + final KeyPair keyPair = SecurityUtils.getKeyPairGenerator("ML-KEM-768").generateKeyPair(); + try { + publicKey = (MLKEMPublicKeyParameters) PublicKeyFactory.createKey(keyPair.getPublic().getEncoded()); + privateKey = (MLKEMPrivateKeyParameters) PrivateKeyFactory.createKey(keyPair.getPrivate().getEncoded()); + } catch (IOException e) { + throw new GeneralSecurityException("Failed to convert ML-KEM-768 JCA key pair to lightweight parameters", e); + } return publicKey.getEncoded(); } diff --git a/src/main/java/net/schmizz/sshj/transport/kex/MLKEM768X25519SHA256.java b/src/main/java/net/schmizz/sshj/transport/kex/MLKEM768X25519SHA256.java index efc5dd1d..5b072ab7 100644 --- a/src/main/java/net/schmizz/sshj/transport/kex/MLKEM768X25519SHA256.java +++ b/src/main/java/net/schmizz/sshj/transport/kex/MLKEM768X25519SHA256.java @@ -29,7 +29,6 @@ import java.math.BigInteger; import java.security.GeneralSecurityException; -import java.security.SecureRandom; /** * Post-Quantum Traditional (PQ/T) hybrid SSH key exchange combining @@ -95,8 +94,8 @@ public void init(final Transport trans, final String V_S, final String V_C, fina // Generate X25519 ephemeral key pair (C_PK1). x25519.init(null, trans.getConfig().getRandomFactory()); - // Generate ML-KEM-768 ephemeral key pair (C_PK2). - final byte[] mlkemPublicKey = mlkem.generateKeyPair(new SecureRandom()); + // Generate ML-KEM-768 ephemeral key pair (C_PK2) via JCA. + final byte[] mlkemPublicKey = mlkem.generateKeyPair(); // C_INIT is the concatenation C_PK2 || C_PK1. final byte[] x25519PublicKey = x25519.getE(); diff --git a/src/test/java/net/schmizz/sshj/transport/kex/MLKEM768Test.java b/src/test/java/net/schmizz/sshj/transport/kex/MLKEM768Test.java index 17fd4a52..34e2c931 100644 --- a/src/test/java/net/schmizz/sshj/transport/kex/MLKEM768Test.java +++ b/src/test/java/net/schmizz/sshj/transport/kex/MLKEM768Test.java @@ -29,9 +29,9 @@ public class MLKEM768Test { @Test - public void generateKeyPairProducesCorrectlySizedPublicKey() { + public void generateKeyPairProducesCorrectlySizedPublicKey() throws GeneralSecurityException { final MLKEM768 mlkem = new MLKEM768(); - final byte[] publicKey = mlkem.generateKeyPair(new SecureRandom()); + final byte[] publicKey = mlkem.generateKeyPair(); assertNotNull(publicKey); assertEquals(MLKEM768.PUBLIC_KEY_LENGTH, publicKey.length); @@ -42,7 +42,7 @@ public void encapsulateAndDecapsulateProduceMatchingSecret() throws GeneralSecur final SecureRandom random = new SecureRandom(); final MLKEM768 mlkem = new MLKEM768(); - final byte[] publicKey = mlkem.generateKeyPair(random); + final byte[] publicKey = mlkem.generateKeyPair(); final SecretWithEncapsulation server = MLKEM768.encapsulate(publicKey, random); assertEquals(MLKEM768.CIPHERTEXT_LENGTH, server.getEncapsulation().length); @@ -55,9 +55,9 @@ public void encapsulateAndDecapsulateProduceMatchingSecret() throws GeneralSecur } @Test - public void decapsulateRejectsCiphertextOfWrongLength() { + public void decapsulateRejectsCiphertextOfWrongLength() throws GeneralSecurityException { final MLKEM768 mlkem = new MLKEM768(); - mlkem.generateKeyPair(new SecureRandom()); + mlkem.generateKeyPair(); assertThrows(GeneralSecurityException.class, () -> mlkem.decapsulate(new byte[10])); } From 7e753215778d869896b0aa8b2d65b0a400aa5ac1 Mon Sep 17 00:00:00 2001 From: Alexandre Rodrigues Date: Fri, 8 May 2026 15:48:03 +0200 Subject: [PATCH 5/8] Remove direct dependency in bouncy castle. --- .../java/net/schmizz/sshj/DefaultConfig.java | 12 +- .../java/net/schmizz/sshj/common/JcaKEM.java | 163 ++++++++++++++++++ .../schmizz/sshj/common/SecurityUtils.java | 55 ++++++ .../java/net/schmizz/sshj/common/SshjKEM.java | 72 ++++++++ .../schmizz/sshj/transport/kex/MLKEM768.java | 107 ++++++++---- .../transport/kex/MLKEM768X25519SHA256.java | 25 +++ .../sshj/transport/kex/MLKEM768Test.java | 16 +- .../MLKEM768X25519SHA256WireFormatTest.java | 13 +- 8 files changed, 410 insertions(+), 53 deletions(-) create mode 100644 src/main/java/net/schmizz/sshj/common/JcaKEM.java create mode 100644 src/main/java/net/schmizz/sshj/common/SshjKEM.java diff --git a/src/main/java/net/schmizz/sshj/DefaultConfig.java b/src/main/java/net/schmizz/sshj/DefaultConfig.java index 8c977473..06c89ff3 100644 --- a/src/main/java/net/schmizz/sshj/DefaultConfig.java +++ b/src/main/java/net/schmizz/sshj/DefaultConfig.java @@ -35,6 +35,7 @@ import net.schmizz.sshj.transport.kex.DHGexSHA1; import net.schmizz.sshj.transport.kex.DHGexSHA256; import net.schmizz.sshj.transport.kex.ECDHNistP; +import net.schmizz.sshj.transport.kex.KeyExchange; import net.schmizz.sshj.transport.kex.MLKEM768X25519SHA256; import net.schmizz.sshj.transport.random.JCERandom; import net.schmizz.sshj.transport.random.SingletonRandomFactory; @@ -44,6 +45,7 @@ import org.slf4j.Logger; import java.util.Arrays; +import java.util.ArrayList; import java.util.List; import java.util.LinkedList; import java.util.ListIterator; @@ -105,8 +107,11 @@ public void setLoggerFactory(LoggerFactory loggerFactory) { } protected void initKeyExchangeFactories() { - setKeyExchangeFactories( - new MLKEM768X25519SHA256.Factory(), + final List> factories = new ArrayList<>(); + if (MLKEM768X25519SHA256.isSupported()) { + factories.add(new MLKEM768X25519SHA256.Factory()); + } + factories.addAll(Arrays.>asList( new Curve25519SHA256.Factory(), new Curve25519SHA256.FactoryLibSsh(), new DHGexSHA256.Factory(), @@ -130,7 +135,8 @@ protected void initKeyExchangeFactories() { ExtendedDHGroups.Group16SHA512AtSSH(), ExtendedDHGroups.Group18SHA512AtSSH(), new ExtInfoClientFactory() - ); + )); + setKeyExchangeFactories(factories); } protected void initKeyAlgorithms() { diff --git a/src/main/java/net/schmizz/sshj/common/JcaKEM.java b/src/main/java/net/schmizz/sshj/common/JcaKEM.java new file mode 100644 index 00000000..9f4f0a1c --- /dev/null +++ b/src/main/java/net/schmizz/sshj/common/JcaKEM.java @@ -0,0 +1,163 @@ +/* + * Copyright (C)2009 - SSHJ Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package net.schmizz.sshj.common; + +import javax.crypto.SecretKey; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.security.GeneralSecurityException; +import java.security.NoSuchAlgorithmException; +import java.security.NoSuchProviderException; +import java.security.PrivateKey; +import java.security.PublicKey; + +/** + * Implementation of {@link SshjKEM} backed by the JDK 21+ {@code javax.crypto.KEM} API, + * accessed reflectively so that this class compiles on Java 8 source level. + * + *

On Java versions older than 21 the {@code javax.crypto.KEM} class is absent and the + * static initializer leaves {@link #API_AVAILABLE} {@code false}. Callers should query + * {@link #isApiAvailable()} (or call through {@link SecurityUtils#getKEM(String)}, which + * throws {@link NoSuchAlgorithmException} when the API is missing) before using this class.

+ */ +final class JcaKEM implements SshjKEM { + + private static final boolean API_AVAILABLE; + private static final Method GET_INSTANCE; + private static final Method GET_INSTANCE_PROVIDER; + private static final Method NEW_ENCAPSULATOR; + private static final Method NEW_DECAPSULATOR; + private static final Method ENCAPSULATE; + private static final Method ENCAPSULATION; + private static final Method KEY; + private static final Method DECAPSULATE; + + static { + Method gi = null; + Method gip = null; + Method ne = null; + Method nd = null; + Method e = null; + Method en = null; + Method k = null; + Method d = null; + boolean available = false; + try { + Class kemClass = Class.forName("javax.crypto.KEM"); + gi = kemClass.getMethod("getInstance", String.class); + gip = kemClass.getMethod("getInstance", String.class, String.class); + ne = kemClass.getMethod("newEncapsulator", PublicKey.class); + nd = kemClass.getMethod("newDecapsulator", PrivateKey.class); + Class encapsulatorClass = Class.forName("javax.crypto.KEM$Encapsulator"); + e = encapsulatorClass.getMethod("encapsulate"); + Class encapsulatedClass = Class.forName("javax.crypto.KEM$Encapsulated"); + en = encapsulatedClass.getMethod("encapsulation"); + k = encapsulatedClass.getMethod("key"); + Class decapsulatorClass = Class.forName("javax.crypto.KEM$Decapsulator"); + d = decapsulatorClass.getMethod("decapsulate", byte[].class); + available = true; + } catch (Throwable t) { + // Java < 21: javax.crypto.KEM not present. API_AVAILABLE stays false. + } + API_AVAILABLE = available; + GET_INSTANCE = gi; + GET_INSTANCE_PROVIDER = gip; + NEW_ENCAPSULATOR = ne; + NEW_DECAPSULATOR = nd; + ENCAPSULATE = e; + ENCAPSULATION = en; + KEY = k; + DECAPSULATE = d; + } + + static boolean isApiAvailable() { + return API_AVAILABLE; + } + + static JcaKEM create(String algorithm, String provider) + throws NoSuchAlgorithmException, NoSuchProviderException { + if (!API_AVAILABLE) { + throw new NoSuchAlgorithmException("javax.crypto.KEM is not available; Java 21 or later is required"); + } + try { + Object kem = (provider == null) + ? GET_INSTANCE.invoke(null, algorithm) + : GET_INSTANCE_PROVIDER.invoke(null, algorithm, provider); + return new JcaKEM(kem); + } catch (InvocationTargetException ite) { + Throwable cause = ite.getCause(); + if (cause instanceof NoSuchAlgorithmException) { + throw (NoSuchAlgorithmException) cause; + } + if (cause instanceof NoSuchProviderException) { + throw (NoSuchProviderException) cause; + } + NoSuchAlgorithmException nae = new NoSuchAlgorithmException( + "Failed to obtain KEM instance for algorithm " + algorithm); + nae.initCause(cause); + throw nae; + } catch (IllegalAccessException iae) { + NoSuchAlgorithmException nae = new NoSuchAlgorithmException("Failed to access javax.crypto.KEM"); + nae.initCause(iae); + throw nae; + } + } + + private final Object kem; + + private JcaKEM(Object kem) { + this.kem = kem; + } + + @Override + public Encapsulated encapsulate(PublicKey peerPublicKey) throws GeneralSecurityException { + try { + Object encapsulator = NEW_ENCAPSULATOR.invoke(kem, peerPublicKey); + Object result = ENCAPSULATE.invoke(encapsulator); + byte[] ciphertext = (byte[]) ENCAPSULATION.invoke(result); + SecretKey sharedSecret = (SecretKey) KEY.invoke(result); + return new Encapsulated(ciphertext, sharedSecret.getEncoded()); + } catch (InvocationTargetException ite) { + throw rethrow(ite); + } catch (IllegalAccessException iae) { + throw new GeneralSecurityException(iae); + } + } + + @Override + public byte[] decapsulate(PrivateKey ourPrivateKey, byte[] ciphertext) throws GeneralSecurityException { + try { + Object decapsulator = NEW_DECAPSULATOR.invoke(kem, ourPrivateKey); + SecretKey sharedSecret = (SecretKey) DECAPSULATE.invoke(decapsulator, (Object) ciphertext); + return sharedSecret.getEncoded(); + } catch (InvocationTargetException ite) { + throw rethrow(ite); + } catch (IllegalAccessException iae) { + throw new GeneralSecurityException(iae); + } + } + + private static GeneralSecurityException rethrow(InvocationTargetException ite) { + Throwable cause = ite.getCause(); + if (cause instanceof GeneralSecurityException) { + return (GeneralSecurityException) cause; + } + if (cause instanceof RuntimeException) { + throw (RuntimeException) cause; + } + return new GeneralSecurityException(cause); + } +} diff --git a/src/main/java/net/schmizz/sshj/common/SecurityUtils.java b/src/main/java/net/schmizz/sshj/common/SecurityUtils.java index eaab755f..0aa0357a 100644 --- a/src/main/java/net/schmizz/sshj/common/SecurityUtils.java +++ b/src/main/java/net/schmizz/sshj/common/SecurityUtils.java @@ -188,6 +188,61 @@ public static synchronized KeyPairGenerator getKeyPairGenerator(String algorithm return KeyPairGenerator.getInstance(algorithm, getSecurityProvider()); } + /** + * Creates a new instance of {@link SshjKEM} for the given algorithm. This wraps + * the JDK 21+ {@code javax.crypto.KEM} API, accessed reflectively so that this + * library still compiles at Java 8 source level. + * + * @param algorithm KEM algorithm name (Bouncy Castle 1.80 registers ML-KEM under {@code "ML-KEM"}; + * the per-parameter-set name {@code "ML-KEM-768"} is selected via the keys passed + * to {@link SshjKEM#encapsulate(java.security.PublicKey)} / + * {@link SshjKEM#decapsulate(java.security.PrivateKey, byte[])}) + * @return new instance + * @throws NoSuchAlgorithmException if no provider supplies the algorithm, or if the runtime + * is older than Java 21 (in which case the underlying API is absent) + * @throws NoSuchProviderException + */ + public static synchronized SshjKEM getKEM(String algorithm) + throws NoSuchAlgorithmException, NoSuchProviderException { + register(); + return JcaKEM.create(algorithm, getSecurityProvider()); + } + + /** + * Tests whether a JCA service of the given type and algorithm is available with the + * currently configured security provider chain (registering Bouncy Castle on demand, + * if enabled, before probing). + * + *

Special-cased for {@code type == "KEM"}: in addition to looking up the JCA service + * descriptor we also verify that the underlying {@code javax.crypto.KEM} API class is + * present, since on Java versions older than 21 the API itself is absent regardless + * of any provider's claims.

+ * + * @param type JCA service type (e.g. {@code "KeyPairGenerator"}, {@code "KeyFactory"}, + * {@code "KEM"}, {@code "Signature"}, ...) + * @param algorithm JCA algorithm name as registered by the provider + * @return {@code true} if a provider on the current chain offers the service + */ + public static synchronized boolean isAlgorithmAvailable(String type, String algorithm) { + register(); + if ("KEM".equals(type) && !JcaKEM.isApiAvailable()) { + return false; + } + Provider[] providers; + if (getSecurityProvider() == null) { + providers = Security.getProviders(); + } else { + Provider single = Security.getProvider(getSecurityProvider()); + providers = (single == null) ? new Provider[0] : new Provider[]{single}; + } + for (Provider p : providers) { + if (p.getService(type, algorithm) != null) { + return true; + } + } + return false; + } + /** * Create a new instance of {@link Mac} with the given algorithm. * diff --git a/src/main/java/net/schmizz/sshj/common/SshjKEM.java b/src/main/java/net/schmizz/sshj/common/SshjKEM.java new file mode 100644 index 00000000..d41f8661 --- /dev/null +++ b/src/main/java/net/schmizz/sshj/common/SshjKEM.java @@ -0,0 +1,72 @@ +/* + * Copyright (C)2009 - SSHJ Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package net.schmizz.sshj.common; + +import java.security.GeneralSecurityException; +import java.security.PrivateKey; +import java.security.PublicKey; + +/** + * sshj-internal abstraction over the JDK 21+ {@code javax.crypto.KEM} API. + * + *

Obtained via {@link SecurityUtils#getKEM(String)}. Hides the reflective lookup + * needed to compile against Java 8 source level while still using the modern KEM + * API at runtime, and translates the four nested {@code KEM} classes into two + * straightforward methods.

+ */ +public interface SshjKEM { + + /** + * Server-side encapsulation against a peer's public key. + * + * @param peerPublicKey peer public key + * @return the produced ciphertext and the raw shared secret bytes + * @throws GeneralSecurityException if encapsulation fails + */ + Encapsulated encapsulate(PublicKey peerPublicKey) throws GeneralSecurityException; + + /** + * Client-side decapsulation of a ciphertext using the local private key. + * + * @param ourPrivateKey local private key + * @param ciphertext peer ciphertext + * @return the raw shared secret bytes + * @throws GeneralSecurityException if decapsulation fails + */ + byte[] decapsulate(PrivateKey ourPrivateKey, byte[] ciphertext) throws GeneralSecurityException; + + /** + * Result of {@link SshjKEM#encapsulate(PublicKey)}: the ciphertext to send to the peer + * and the shared secret bytes for both sides to derive keys from. + */ + final class Encapsulated { + private final byte[] ciphertext; + private final byte[] sharedSecret; + + public Encapsulated(byte[] ciphertext, byte[] sharedSecret) { + this.ciphertext = ciphertext; + this.sharedSecret = sharedSecret; + } + + public byte[] getCiphertext() { + return ciphertext; + } + + public byte[] getSharedSecret() { + return sharedSecret; + } + } +} diff --git a/src/main/java/net/schmizz/sshj/transport/kex/MLKEM768.java b/src/main/java/net/schmizz/sshj/transport/kex/MLKEM768.java index 22fb0c4a..791ae8d5 100644 --- a/src/main/java/net/schmizz/sshj/transport/kex/MLKEM768.java +++ b/src/main/java/net/schmizz/sshj/transport/kex/MLKEM768.java @@ -16,26 +16,27 @@ package net.schmizz.sshj.transport.kex; import net.schmizz.sshj.common.SecurityUtils; -import org.bouncycastle.crypto.SecretWithEncapsulation; -import org.bouncycastle.pqc.crypto.mlkem.MLKEMExtractor; -import org.bouncycastle.pqc.crypto.mlkem.MLKEMGenerator; -import org.bouncycastle.pqc.crypto.mlkem.MLKEMParameters; -import org.bouncycastle.pqc.crypto.mlkem.MLKEMPrivateKeyParameters; -import org.bouncycastle.pqc.crypto.mlkem.MLKEMPublicKeyParameters; -import org.bouncycastle.pqc.crypto.util.PrivateKeyFactory; -import org.bouncycastle.pqc.crypto.util.PublicKeyFactory; +import net.schmizz.sshj.common.SshjKEM; -import java.io.IOException; import java.security.GeneralSecurityException; +import java.security.KeyFactory; import java.security.KeyPair; -import java.security.SecureRandom; +import java.security.PublicKey; +import java.security.spec.X509EncodedKeySpec; /** - * Helper around the Bouncy Castle lightweight implementation of ML-KEM-768 - * (FIPS 203). Provides client-side key generation and decapsulation, as well - * as server-side encapsulation (used by the unit tests). + * Helper around the JCA implementation of ML-KEM-768 (FIPS 203). Provides + * client-side key generation and decapsulation, as well as server-side + * encapsulation (used by the unit tests). * - *

For the parameter set used here, the byte sizes are:

+ *

All cryptographic operations route through {@link SecurityUtils}: key generation + * via {@link SecurityUtils#getKeyPairGenerator(String)}, encapsulation/decapsulation via + * {@link SecurityUtils#getKEM(String)} (the JDK 21+ {@code javax.crypto.KEM} API), + * and public-key reconstruction from the SSH wire format via + * {@link SecurityUtils#getKeyFactory(String)}. No dependency on Bouncy Castle classes + * or any other specific provider remains here.

+ * + *

For this parameter set, the byte sizes are:

*
    *
  • Public key: {@value #PUBLIC_KEY_LENGTH} bytes
  • *
  • Ciphertext: {@value #CIPHERTEXT_LENGTH} bytes
  • @@ -53,25 +54,59 @@ public final class MLKEM768 { /** Length in bytes of the shared secret produced by ML-KEM-768. */ public static final int SHARED_SECRET_LENGTH = 32; - private MLKEMPublicKeyParameters publicKey; - private MLKEMPrivateKeyParameters privateKey; + /** + * Algorithm name to pass to {@link SecurityUtils#getKeyPairGenerator(String)} and + * {@link SecurityUtils#getKeyFactory(String)}. The JCA selects the parameter set from + * this name. + */ + static final String KEY_ALGORITHM = "ML-KEM-768"; + + /** + * Algorithm family name to pass to {@link SecurityUtils#getKEM(String)}. The JCA + * {@code javax.crypto.KEM} provider only registers under the family name; the + * parameter set is inferred from the {@link java.security.PublicKey} or + * {@link java.security.PrivateKey} passed to {@code newEncapsulator} / + * {@code newDecapsulator}. + */ + static final String KEM_ALGORITHM = "ML-KEM"; /** - * Generate an ephemeral ML-KEM-768 key pair via the JCA, using the same path - * as the rest of sshj (see {@link SecurityUtils#getKeyPairGenerator(String)}). + * Constant DER prefix for an X.509 {@code SubjectPublicKeyInfo} wrapping a 1184-byte + * ML-KEM-768 public key. {@code AlgorithmIdentifier} OID is + * {@code 2.16.840.1.101.3.4.4.2}; {@code BIT STRING} length is 1185 (raw key + the + * leading "0 unused bits" byte). + */ + private static final byte[] SPKI_PREFIX = new byte[] { + (byte) 0x30, (byte) 0x82, (byte) 0x04, (byte) 0xb2, + (byte) 0x30, (byte) 0x0b, (byte) 0x06, (byte) 0x09, + (byte) 0x60, (byte) 0x86, (byte) 0x48, (byte) 0x01, + (byte) 0x65, (byte) 0x03, (byte) 0x04, (byte) 0x04, + (byte) 0x02, (byte) 0x03, (byte) 0x82, (byte) 0x04, + (byte) 0xa1, (byte) 0x00, + }; + + private KeyPair keyPair; + + /** + * Generate an ephemeral ML-KEM-768 key pair via the JCA. * - * @return the encoded public key (length {@value #PUBLIC_KEY_LENGTH}) - * @throws GeneralSecurityException if no JCA provider supports ML-KEM-768 or key conversion fails + * @return the encoded public key (length {@value #PUBLIC_KEY_LENGTH}) in the raw + * wire format expected by the SSH hybrid KEX (the trailing portion of the + * SPKI encoding) + * @throws GeneralSecurityException if no JCA provider supports ML-KEM-768 or the + * encoded public key is malformed */ public byte[] generateKeyPair() throws GeneralSecurityException { - final KeyPair keyPair = SecurityUtils.getKeyPairGenerator("ML-KEM-768").generateKeyPair(); - try { - publicKey = (MLKEMPublicKeyParameters) PublicKeyFactory.createKey(keyPair.getPublic().getEncoded()); - privateKey = (MLKEMPrivateKeyParameters) PrivateKeyFactory.createKey(keyPair.getPrivate().getEncoded()); - } catch (IOException e) { - throw new GeneralSecurityException("Failed to convert ML-KEM-768 JCA key pair to lightweight parameters", e); + keyPair = SecurityUtils.getKeyPairGenerator(KEY_ALGORITHM).generateKeyPair(); + final byte[] spki = keyPair.getPublic().getEncoded(); + if (spki.length != SPKI_PREFIX.length + PUBLIC_KEY_LENGTH) { + throw new GeneralSecurityException( + "Unexpected ML-KEM-768 SPKI length " + spki.length + + " (expected " + (SPKI_PREFIX.length + PUBLIC_KEY_LENGTH) + ")"); } - return publicKey.getEncoded(); + final byte[] raw = new byte[PUBLIC_KEY_LENGTH]; + System.arraycopy(spki, SPKI_PREFIX.length, raw, 0, PUBLIC_KEY_LENGTH); + return raw; } /** @@ -82,14 +117,14 @@ public byte[] generateKeyPair() throws GeneralSecurityException { * @throws GeneralSecurityException if the ciphertext has an invalid length or no key pair has been generated */ public byte[] decapsulate(final byte[] ciphertext) throws GeneralSecurityException { - if (privateKey == null) { + if (keyPair == null) { throw new GeneralSecurityException("ML-KEM-768 key pair has not been generated"); } if (ciphertext == null || ciphertext.length != CIPHERTEXT_LENGTH) { throw new GeneralSecurityException( "ML-KEM-768 ciphertext length must be " + CIPHERTEXT_LENGTH + " bytes"); } - return new MLKEMExtractor(privateKey).extractSecret(ciphertext); + return SecurityUtils.getKEM(KEM_ALGORITHM).decapsulate(keyPair.getPrivate(), ciphertext); } /** @@ -97,17 +132,21 @@ public byte[] decapsulate(final byte[] ciphertext) throws GeneralSecurityExcepti * a server response without requiring an external SSH server. * * @param peerPublicKey peer public key (must be exactly {@value #PUBLIC_KEY_LENGTH} bytes) - * @param random source of randomness * @return the encapsulation result containing the ciphertext and the shared secret - * @throws GeneralSecurityException if the peer public key has an invalid length + * @throws GeneralSecurityException if the peer public key has an invalid length or + * no JCA provider supports ML-KEM-768 */ - public static SecretWithEncapsulation encapsulate(final byte[] peerPublicKey, final SecureRandom random) + public static SshjKEM.Encapsulated encapsulate(final byte[] peerPublicKey) throws GeneralSecurityException { if (peerPublicKey == null || peerPublicKey.length != PUBLIC_KEY_LENGTH) { throw new GeneralSecurityException( "ML-KEM-768 public key length must be " + PUBLIC_KEY_LENGTH + " bytes"); } - final MLKEMPublicKeyParameters peer = new MLKEMPublicKeyParameters(MLKEMParameters.ml_kem_768, peerPublicKey); - return new MLKEMGenerator(random).generateEncapsulated(peer); + final byte[] spki = new byte[SPKI_PREFIX.length + PUBLIC_KEY_LENGTH]; + System.arraycopy(SPKI_PREFIX, 0, spki, 0, SPKI_PREFIX.length); + System.arraycopy(peerPublicKey, 0, spki, SPKI_PREFIX.length, PUBLIC_KEY_LENGTH); + final KeyFactory kf = SecurityUtils.getKeyFactory(KEY_ALGORITHM); + final PublicKey reconstructed = kf.generatePublic(new X509EncodedKeySpec(spki)); + return SecurityUtils.getKEM(KEM_ALGORITHM).encapsulate(reconstructed); } } diff --git a/src/main/java/net/schmizz/sshj/transport/kex/MLKEM768X25519SHA256.java b/src/main/java/net/schmizz/sshj/transport/kex/MLKEM768X25519SHA256.java index 5b072ab7..e6802ec5 100644 --- a/src/main/java/net/schmizz/sshj/transport/kex/MLKEM768X25519SHA256.java +++ b/src/main/java/net/schmizz/sshj/transport/kex/MLKEM768X25519SHA256.java @@ -20,6 +20,7 @@ import net.schmizz.sshj.common.DisconnectReason; import net.schmizz.sshj.common.Message; import net.schmizz.sshj.common.SSHPacket; +import net.schmizz.sshj.common.SecurityUtils; import net.schmizz.sshj.signature.Signature; import net.schmizz.sshj.transport.Transport; import net.schmizz.sshj.transport.TransportException; @@ -59,10 +60,34 @@ public class MLKEM768X25519SHA256 extends KeyExchangeBase { private static final String NAME = "mlkem768x25519-sha256"; + /** + * Whether this hybrid key exchange can be used at runtime. Requires the + * {@code javax.crypto.KEM} API (Java 21+) and a JCA provider that supplies + * both an {@code ML-KEM-768} {@link java.security.KeyPairGenerator} / + * {@link java.security.KeyFactory} and an {@code ML-KEM} KEM service. On older + * runtimes (or when no such provider is registered, e.g. when Bouncy Castle is + * not on the classpath and the JDK does not yet ship an ML-KEM provider of its + * own) callers should refrain from advertising the algorithm. + * + *

    The result is cached after the first call.

    + * + * @return {@code true} iff a working ML-KEM-768 implementation is reachable through the JCA + */ + public static boolean isSupported() { + return SecurityUtils.isAlgorithmAvailable("KeyPairGenerator", MLKEM768.KEY_ALGORITHM) + && SecurityUtils.isAlgorithmAvailable("KeyFactory", MLKEM768.KEY_ALGORITHM) + && SecurityUtils.isAlgorithmAvailable("KEM", MLKEM768.KEM_ALGORITHM); + } + /** Named factory for the {@code mlkem768x25519-sha256} key exchange. */ public static class Factory implements net.schmizz.sshj.common.Factory.Named { @Override public KeyExchange create() { + if (!isSupported()) { + throw new IllegalStateException( + "mlkem768x25519-sha256 is not supported on this runtime: requires Java 21+ " + + "and a JCA provider for ML-KEM-768"); + } return new MLKEM768X25519SHA256(); } diff --git a/src/test/java/net/schmizz/sshj/transport/kex/MLKEM768Test.java b/src/test/java/net/schmizz/sshj/transport/kex/MLKEM768Test.java index 34e2c931..ab181929 100644 --- a/src/test/java/net/schmizz/sshj/transport/kex/MLKEM768Test.java +++ b/src/test/java/net/schmizz/sshj/transport/kex/MLKEM768Test.java @@ -15,11 +15,10 @@ */ package net.schmizz.sshj.transport.kex; -import org.bouncycastle.crypto.SecretWithEncapsulation; +import net.schmizz.sshj.common.SshjKEM; import org.junit.jupiter.api.Test; import java.security.GeneralSecurityException; -import java.security.SecureRandom; import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -39,19 +38,18 @@ public void generateKeyPairProducesCorrectlySizedPublicKey() throws GeneralSecur @Test public void encapsulateAndDecapsulateProduceMatchingSecret() throws GeneralSecurityException { - final SecureRandom random = new SecureRandom(); final MLKEM768 mlkem = new MLKEM768(); final byte[] publicKey = mlkem.generateKeyPair(); - final SecretWithEncapsulation server = MLKEM768.encapsulate(publicKey, random); + final SshjKEM.Encapsulated server = MLKEM768.encapsulate(publicKey); - assertEquals(MLKEM768.CIPHERTEXT_LENGTH, server.getEncapsulation().length); - assertEquals(MLKEM768.SHARED_SECRET_LENGTH, server.getSecret().length); + assertEquals(MLKEM768.CIPHERTEXT_LENGTH, server.getCiphertext().length); + assertEquals(MLKEM768.SHARED_SECRET_LENGTH, server.getSharedSecret().length); - final byte[] clientSecret = mlkem.decapsulate(server.getEncapsulation()); + final byte[] clientSecret = mlkem.decapsulate(server.getCiphertext()); assertEquals(MLKEM768.SHARED_SECRET_LENGTH, clientSecret.length); - assertArrayEquals(server.getSecret(), clientSecret); + assertArrayEquals(server.getSharedSecret(), clientSecret); } @Test @@ -71,6 +69,6 @@ public void decapsulateBeforeKeyGenFails() { @Test public void encapsulateRejectsPublicKeyOfWrongLength() { assertThrows(GeneralSecurityException.class, - () -> MLKEM768.encapsulate(new byte[10], new SecureRandom())); + () -> MLKEM768.encapsulate(new byte[10])); } } diff --git a/src/test/java/net/schmizz/sshj/transport/kex/MLKEM768X25519SHA256WireFormatTest.java b/src/test/java/net/schmizz/sshj/transport/kex/MLKEM768X25519SHA256WireFormatTest.java index a7824feb..42dfcc4d 100644 --- a/src/test/java/net/schmizz/sshj/transport/kex/MLKEM768X25519SHA256WireFormatTest.java +++ b/src/test/java/net/schmizz/sshj/transport/kex/MLKEM768X25519SHA256WireFormatTest.java @@ -21,11 +21,11 @@ import net.schmizz.sshj.common.KeyType; import net.schmizz.sshj.common.Message; import net.schmizz.sshj.common.SSHPacket; +import net.schmizz.sshj.common.SshjKEM; import net.schmizz.sshj.signature.Signature; import net.schmizz.sshj.transport.Transport; import net.schmizz.sshj.transport.digest.SHA256; import net.schmizz.sshj.transport.random.JCERandom; -import org.bouncycastle.crypto.SecretWithEncapsulation; import org.junit.jupiter.api.Test; import org.mockito.ArgumentCaptor; @@ -33,7 +33,6 @@ import java.nio.ByteBuffer; import java.security.KeyPair; import java.security.KeyPairGenerator; -import java.security.SecureRandom; import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -75,8 +74,8 @@ public void cInitIsMlkemPublicKeyConcatenatedWithX25519PublicKey() throws Except // against it; this would fail if the order were reversed (X25519 key first). final byte[] mlkemPk = new byte[MLKEM768.PUBLIC_KEY_LENGTH]; System.arraycopy(cInit, 0, mlkemPk, 0, MLKEM768.PUBLIC_KEY_LENGTH); - final SecretWithEncapsulation enc = MLKEM768.encapsulate(mlkemPk, new SecureRandom()); - assertEquals(MLKEM768.CIPHERTEXT_LENGTH, enc.getEncapsulation().length); + final SshjKEM.Encapsulated enc = MLKEM768.encapsulate(mlkemPk); + assertEquals(MLKEM768.CIPHERTEXT_LENGTH, enc.getCiphertext().length); } /** @@ -227,9 +226,9 @@ private ServerExchange runFullExchange() throws Exception { System.arraycopy(cInit, 0, cPk2, 0, MLKEM768.PUBLIC_KEY_LENGTH); System.arraycopy(cInit, MLKEM768.PUBLIC_KEY_LENGTH, cPk1, 0, Curve25519DH.KEY_LENGTH); - final SecretWithEncapsulation enc = MLKEM768.encapsulate(cPk2, new SecureRandom()); - final byte[] kPq = enc.getSecret(); - final byte[] sCt2 = enc.getEncapsulation(); + final SshjKEM.Encapsulated enc = MLKEM768.encapsulate(cPk2); + final byte[] kPq = enc.getSharedSecret(); + final byte[] sCt2 = enc.getCiphertext(); final Curve25519DH serverDh = new Curve25519DH(); serverDh.init(null, new JCERandom.Factory()); From 49103990461a4b628fff5abce1bf4a557aeda9b3 Mon Sep 17 00:00:00 2001 From: Alexandre Rodrigues Date: Fri, 8 May 2026 18:31:47 +0200 Subject: [PATCH 6/8] Add a fallback to Bouncy castle using reflection --- .../schmizz/sshj/common/BouncyCastleKEM.java | 172 ++++++++++++++++++ .../schmizz/sshj/common/SecurityUtils.java | 68 ++++++- .../transport/kex/MLKEM768X25519SHA256.java | 27 +-- .../sshj/common/BouncyCastleKEMTest.java | 78 ++++++++ 4 files changed, 325 insertions(+), 20 deletions(-) create mode 100644 src/main/java/net/schmizz/sshj/common/BouncyCastleKEM.java create mode 100644 src/test/java/net/schmizz/sshj/common/BouncyCastleKEMTest.java diff --git a/src/main/java/net/schmizz/sshj/common/BouncyCastleKEM.java b/src/main/java/net/schmizz/sshj/common/BouncyCastleKEM.java new file mode 100644 index 00000000..6df9c59c --- /dev/null +++ b/src/main/java/net/schmizz/sshj/common/BouncyCastleKEM.java @@ -0,0 +1,172 @@ +/* + * Copyright (C)2009 - SSHJ Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package net.schmizz.sshj.common; + +import java.lang.reflect.Constructor; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.security.GeneralSecurityException; +import java.security.NoSuchAlgorithmException; +import java.security.PrivateKey; +import java.security.PublicKey; +import java.security.SecureRandom; + +/** + * Implementation of {@link SshjKEM} backed by Bouncy Castle's lightweight ML-KEM API + * ({@code org.bouncycastle.pqc.crypto.mlkem}), accessed entirely via reflection so + * that this class compiles, loads and verifies even when Bouncy Castle is absent + * from the runtime classpath (e.g. shaded out by a downstream consumer). + * + *

    This is a fallback used by {@link SecurityUtils#getKEM(String)} when the JDK + * 21+ {@code javax.crypto.KEM} API is not available (i.e. on Java 8–20). + * Callers should query {@link #isAvailable()} first.

    + */ +final class BouncyCastleKEM implements SshjKEM { + + /** BC ML-KEM family name (parameter set inferred from the encoded key). */ + private static final String ML_KEM = "ML-KEM"; + + private static final boolean AVAILABLE; + private static final Constructor GENERATOR_CTOR; + private static final Method GENERATE_ENCAPSULATED; + private static final Method GET_ENCAPSULATION; + private static final Method GET_SECRET; + private static final Method DESTROY; + private static final Constructor EXTRACTOR_CTOR; + private static final Class MLKEM_PRIVATE_KEY_PARAMETERS; + private static final Method EXTRACT_SECRET; + private static final Method PUBLIC_KEY_FACTORY_CREATE; + private static final Method PRIVATE_KEY_FACTORY_CREATE; + + static { + boolean available = false; + Constructor generatorCtor = null; + Method generateEncapsulated = null; + Method getEncapsulation = null; + Method getSecret = null; + Method destroy = null; + Constructor extractorCtor = null; + Class mlkemPrivateKeyParameters = null; + Method extractSecret = null; + Method publicKeyFactoryCreate = null; + Method privateKeyFactoryCreate = null; + try { + Class generator = Class.forName("org.bouncycastle.pqc.crypto.mlkem.MLKEMGenerator"); + Class extractor = Class.forName("org.bouncycastle.pqc.crypto.mlkem.MLKEMExtractor"); + mlkemPrivateKeyParameters = Class.forName("org.bouncycastle.pqc.crypto.mlkem.MLKEMPrivateKeyParameters"); + Class asymmetricKeyParameter = Class.forName("org.bouncycastle.crypto.params.AsymmetricKeyParameter"); + Class secretWithEncapsulation = Class.forName("org.bouncycastle.crypto.SecretWithEncapsulation"); + Class publicKeyFactory = Class.forName("org.bouncycastle.pqc.crypto.util.PublicKeyFactory"); + Class privateKeyFactory = Class.forName("org.bouncycastle.pqc.crypto.util.PrivateKeyFactory"); + + generatorCtor = generator.getConstructor(SecureRandom.class); + generateEncapsulated = generator.getMethod("generateEncapsulated", asymmetricKeyParameter); + getEncapsulation = secretWithEncapsulation.getMethod("getEncapsulation"); + getSecret = secretWithEncapsulation.getMethod("getSecret"); + destroy = secretWithEncapsulation.getMethod("destroy"); + extractorCtor = extractor.getConstructor(mlkemPrivateKeyParameters); + extractSecret = extractor.getMethod("extractSecret", byte[].class); + publicKeyFactoryCreate = publicKeyFactory.getMethod("createKey", byte[].class); + privateKeyFactoryCreate = privateKeyFactory.getMethod("createKey", byte[].class); + + available = true; + } catch (Throwable t) { + // Bouncy Castle PQC absent or incompatible: fallback unavailable. + } + AVAILABLE = available; + GENERATOR_CTOR = generatorCtor; + GENERATE_ENCAPSULATED = generateEncapsulated; + GET_ENCAPSULATION = getEncapsulation; + GET_SECRET = getSecret; + DESTROY = destroy; + EXTRACTOR_CTOR = extractorCtor; + MLKEM_PRIVATE_KEY_PARAMETERS = mlkemPrivateKeyParameters; + EXTRACT_SECRET = extractSecret; + PUBLIC_KEY_FACTORY_CREATE = publicKeyFactoryCreate; + PRIVATE_KEY_FACTORY_CREATE = privateKeyFactoryCreate; + } + + static boolean isAvailable() { + return AVAILABLE; + } + + static BouncyCastleKEM create(String algorithm) throws NoSuchAlgorithmException { + if (!AVAILABLE) { + throw new NoSuchAlgorithmException( + "Bouncy Castle PQC is not available; cannot fall back from javax.crypto.KEM"); + } + if (!ML_KEM.equals(algorithm)) { + throw new NoSuchAlgorithmException( + "Bouncy Castle KEM fallback only supports " + ML_KEM + ", requested " + algorithm); + } + return new BouncyCastleKEM(); + } + + private BouncyCastleKEM() { + } + + @Override + public Encapsulated encapsulate(PublicKey peerPublicKey) throws GeneralSecurityException { + try { + Object params = PUBLIC_KEY_FACTORY_CREATE.invoke(null, (Object) peerPublicKey.getEncoded()); + Object generator = GENERATOR_CTOR.newInstance(new SecureRandom()); + Object result = GENERATE_ENCAPSULATED.invoke(generator, params); + try { + byte[] ciphertext = (byte[]) GET_ENCAPSULATION.invoke(result); + byte[] sharedSecret = (byte[]) GET_SECRET.invoke(result); + return new Encapsulated(ciphertext, sharedSecret); + } finally { + try { + DESTROY.invoke(result); + } catch (Throwable ignore) { + // best-effort wipe + } + } + } catch (InvocationTargetException ite) { + throw rethrow(ite, "Failed to encapsulate via Bouncy Castle"); + } catch (ReflectiveOperationException roe) { + throw new GeneralSecurityException("Failed to invoke Bouncy Castle ML-KEM API", roe); + } + } + + @Override + public byte[] decapsulate(PrivateKey ourPrivateKey, byte[] ciphertext) throws GeneralSecurityException { + try { + Object params = PRIVATE_KEY_FACTORY_CREATE.invoke(null, (Object) ourPrivateKey.getEncoded()); + if (!MLKEM_PRIVATE_KEY_PARAMETERS.isInstance(params)) { + throw new GeneralSecurityException( + "Expected ML-KEM private key but got " + params.getClass().getName()); + } + Object extractor = EXTRACTOR_CTOR.newInstance(params); + return (byte[]) EXTRACT_SECRET.invoke(extractor, (Object) ciphertext); + } catch (InvocationTargetException ite) { + throw rethrow(ite, "Failed to decapsulate via Bouncy Castle"); + } catch (ReflectiveOperationException roe) { + throw new GeneralSecurityException("Failed to invoke Bouncy Castle ML-KEM API", roe); + } + } + + private static GeneralSecurityException rethrow(InvocationTargetException ite, String message) { + Throwable cause = ite.getCause(); + if (cause instanceof GeneralSecurityException) { + return (GeneralSecurityException) cause; + } + if (cause instanceof RuntimeException) { + throw (RuntimeException) cause; + } + return new GeneralSecurityException(message, cause); + } +} diff --git a/src/main/java/net/schmizz/sshj/common/SecurityUtils.java b/src/main/java/net/schmizz/sshj/common/SecurityUtils.java index 0aa0357a..d250cefb 100644 --- a/src/main/java/net/schmizz/sshj/common/SecurityUtils.java +++ b/src/main/java/net/schmizz/sshj/common/SecurityUtils.java @@ -189,23 +189,52 @@ public static synchronized KeyPairGenerator getKeyPairGenerator(String algorithm } /** - * Creates a new instance of {@link SshjKEM} for the given algorithm. This wraps - * the JDK 21+ {@code javax.crypto.KEM} API, accessed reflectively so that this - * library still compiles at Java 8 source level. + * Creates a new instance of {@link SshjKEM} for the given algorithm. + * + *

    Two backends are tried, in order:

    + *
      + *
    1. The JDK 21+ {@code javax.crypto.KEM} API (accessed reflectively so this library + * still compiles at Java 8 source level), dispatched through the configured JCA + * provider chain.
    2. + *
    3. If the JCA path is unusable—either because the {@code javax.crypto.KEM} + * class is absent, or because no registered provider offers the requested KEM + * service—a Bouncy Castle lightweight-API fallback + * ({@code org.bouncycastle.pqc.crypto.mlkem}) is used when those classes are on + * the classpath. (BC 1.80 ships the lightweight ML-KEM API on every JDK but + * only registers the JCA {@code KEM} service on JDK 21+; the fallback covers + * older JDKs where BC's KeyPairGenerator/KeyFactory are registered yet + * its JCA KEM service is not.)
    4. + *
    * * @param algorithm KEM algorithm name (Bouncy Castle 1.80 registers ML-KEM under {@code "ML-KEM"}; * the per-parameter-set name {@code "ML-KEM-768"} is selected via the keys passed * to {@link SshjKEM#encapsulate(java.security.PublicKey)} / * {@link SshjKEM#decapsulate(java.security.PrivateKey, byte[])}) * @return new instance - * @throws NoSuchAlgorithmException if no provider supplies the algorithm, or if the runtime - * is older than Java 21 (in which case the underlying API is absent) + * @throws NoSuchAlgorithmException if neither backend can supply the algorithm * @throws NoSuchProviderException */ public static synchronized SshjKEM getKEM(String algorithm) throws NoSuchAlgorithmException, NoSuchProviderException { register(); - return JcaKEM.create(algorithm, getSecurityProvider()); + if (JcaKEM.isApiAvailable()) { + try { + return JcaKEM.create(algorithm, getSecurityProvider()); + } catch (NoSuchAlgorithmException jcaFailure) { + if (!BouncyCastleKEM.isAvailable()) { + throw jcaFailure; + } + // Fall through to BC fallback: JCA KEM API present but no provider offers + // the requested algorithm. Common on JDK 17/20 with BC 1.80, where BC + // registers ML-KEM as KeyPairGenerator/KeyFactory but not as a KEM service. + } + } + if (BouncyCastleKEM.isAvailable()) { + return BouncyCastleKEM.create(algorithm); + } + throw new NoSuchAlgorithmException( + "No KEM implementation available for " + algorithm + + " (requires Java 21+ for javax.crypto.KEM, or Bouncy Castle PQC on the classpath)"); } /** @@ -223,11 +252,34 @@ public static synchronized SshjKEM getKEM(String algorithm) * @param algorithm JCA algorithm name as registered by the provider * @return {@code true} if a provider on the current chain offers the service */ + /** + * Tests whether a JCA service of the given type and algorithm is available with the + * currently configured security provider chain (registering Bouncy Castle on demand, + * if enabled, before probing). + * + *

    Special-cased for {@code type == "KEM"}: the JCA {@code javax.crypto.KEM} class + * was introduced in Java 21, so on older runtimes a JCA provider's claim to + * support a "KEM" service is moot. We therefore additionally check that either + * the {@code javax.crypto.KEM} API class is present and a provider offers + * the service, or the Bouncy Castle PQC fallback is available.

    + * + * @param type JCA service type (e.g. {@code "KeyPairGenerator"}, {@code "KeyFactory"}, + * {@code "KEM"}, {@code "Signature"}, ...) + * @param algorithm JCA algorithm name as registered by the provider + * @return {@code true} if a provider on the current chain offers the service + */ public static synchronized boolean isAlgorithmAvailable(String type, String algorithm) { register(); - if ("KEM".equals(type) && !JcaKEM.isApiAvailable()) { - return false; + if ("KEM".equals(type)) { + if (JcaKEM.isApiAvailable() && hasProviderService(type, algorithm)) { + return true; + } + return BouncyCastleKEM.isAvailable(); } + return hasProviderService(type, algorithm); + } + + private static boolean hasProviderService(String type, String algorithm) { Provider[] providers; if (getSecurityProvider() == null) { providers = Security.getProviders(); diff --git a/src/main/java/net/schmizz/sshj/transport/kex/MLKEM768X25519SHA256.java b/src/main/java/net/schmizz/sshj/transport/kex/MLKEM768X25519SHA256.java index e6802ec5..dad55cf7 100644 --- a/src/main/java/net/schmizz/sshj/transport/kex/MLKEM768X25519SHA256.java +++ b/src/main/java/net/schmizz/sshj/transport/kex/MLKEM768X25519SHA256.java @@ -61,17 +61,19 @@ public class MLKEM768X25519SHA256 extends KeyExchangeBase { private static final String NAME = "mlkem768x25519-sha256"; /** - * Whether this hybrid key exchange can be used at runtime. Requires the - * {@code javax.crypto.KEM} API (Java 21+) and a JCA provider that supplies - * both an {@code ML-KEM-768} {@link java.security.KeyPairGenerator} / - * {@link java.security.KeyFactory} and an {@code ML-KEM} KEM service. On older - * runtimes (or when no such provider is registered, e.g. when Bouncy Castle is - * not on the classpath and the JDK does not yet ship an ML-KEM provider of its - * own) callers should refrain from advertising the algorithm. + * Whether this hybrid key exchange can be used at runtime. Requires a JCA provider + * that supplies an {@code ML-KEM-768} {@link java.security.KeyPairGenerator} and + * {@link java.security.KeyFactory}, plus one of: + *
      + *
    • the JDK 21+ {@code javax.crypto.KEM} API together with a provider that + * registers an {@code ML-KEM} KEM service, or
    • + *
    • the Bouncy Castle PQC lightweight API + * ({@code org.bouncycastle.pqc.crypto.mlkem}) on the classpath, which works + * on any JDK.
    • + *
    + * When neither is reachable callers should refrain from advertising the algorithm. * - *

    The result is cached after the first call.

    - * - * @return {@code true} iff a working ML-KEM-768 implementation is reachable through the JCA + * @return {@code true} iff a working ML-KEM-768 implementation is reachable */ public static boolean isSupported() { return SecurityUtils.isAlgorithmAvailable("KeyPairGenerator", MLKEM768.KEY_ALGORITHM) @@ -85,8 +87,9 @@ public static class Factory implements net.schmizz.sshj.common.Factory.Named Date: Sat, 9 May 2026 08:54:08 +0200 Subject: [PATCH 7/8] Remove duplicated comment and re-use SecureRandom --- .../net/schmizz/sshj/common/BouncyCastleKEM.java | 9 ++++++++- .../net/schmizz/sshj/common/SecurityUtils.java | 15 --------------- 2 files changed, 8 insertions(+), 16 deletions(-) diff --git a/src/main/java/net/schmizz/sshj/common/BouncyCastleKEM.java b/src/main/java/net/schmizz/sshj/common/BouncyCastleKEM.java index 6df9c59c..346cc137 100644 --- a/src/main/java/net/schmizz/sshj/common/BouncyCastleKEM.java +++ b/src/main/java/net/schmizz/sshj/common/BouncyCastleKEM.java @@ -39,6 +39,13 @@ final class BouncyCastleKEM implements SshjKEM { /** BC ML-KEM family name (parameter set inferred from the encoded key). */ private static final String ML_KEM = "ML-KEM"; + /** + * Shared {@link SecureRandom}. {@code SecureRandom} is documented thread-safe, and + * lazily seeded by the JDK on first use, so a single instance avoids paying the + * (potentially blocking) seed cost on every encapsulation. + */ + private static final SecureRandom SECURE_RANDOM = new SecureRandom(); + private static final boolean AVAILABLE; private static final Constructor GENERATOR_CTOR; private static final Method GENERATE_ENCAPSULATED; @@ -122,7 +129,7 @@ private BouncyCastleKEM() { public Encapsulated encapsulate(PublicKey peerPublicKey) throws GeneralSecurityException { try { Object params = PUBLIC_KEY_FACTORY_CREATE.invoke(null, (Object) peerPublicKey.getEncoded()); - Object generator = GENERATOR_CTOR.newInstance(new SecureRandom()); + Object generator = GENERATOR_CTOR.newInstance(SECURE_RANDOM); Object result = GENERATE_ENCAPSULATED.invoke(generator, params); try { byte[] ciphertext = (byte[]) GET_ENCAPSULATION.invoke(result); diff --git a/src/main/java/net/schmizz/sshj/common/SecurityUtils.java b/src/main/java/net/schmizz/sshj/common/SecurityUtils.java index d250cefb..af0e7165 100644 --- a/src/main/java/net/schmizz/sshj/common/SecurityUtils.java +++ b/src/main/java/net/schmizz/sshj/common/SecurityUtils.java @@ -237,21 +237,6 @@ public static synchronized SshjKEM getKEM(String algorithm) + " (requires Java 21+ for javax.crypto.KEM, or Bouncy Castle PQC on the classpath)"); } - /** - * Tests whether a JCA service of the given type and algorithm is available with the - * currently configured security provider chain (registering Bouncy Castle on demand, - * if enabled, before probing). - * - *

    Special-cased for {@code type == "KEM"}: in addition to looking up the JCA service - * descriptor we also verify that the underlying {@code javax.crypto.KEM} API class is - * present, since on Java versions older than 21 the API itself is absent regardless - * of any provider's claims.

    - * - * @param type JCA service type (e.g. {@code "KeyPairGenerator"}, {@code "KeyFactory"}, - * {@code "KEM"}, {@code "Signature"}, ...) - * @param algorithm JCA algorithm name as registered by the provider - * @return {@code true} if a provider on the current chain offers the service - */ /** * Tests whether a JCA service of the given type and algorithm is available with the * currently configured security provider chain (registering Bouncy Castle on demand, From 72bd51f2bdc1c5609b81a1e4ea4e01ae54c7479d Mon Sep 17 00:00:00 2001 From: Alexandre Rodrigues Date: Mon, 11 May 2026 07:31:12 +0200 Subject: [PATCH 8/8] Add JcaKEMTest --- .../net/schmizz/sshj/common/JcaKEMTest.java | 171 ++++++++++++++++++ 1 file changed, 171 insertions(+) create mode 100644 src/test/java/net/schmizz/sshj/common/JcaKEMTest.java diff --git a/src/test/java/net/schmizz/sshj/common/JcaKEMTest.java b/src/test/java/net/schmizz/sshj/common/JcaKEMTest.java new file mode 100644 index 00000000..d9f33494 --- /dev/null +++ b/src/test/java/net/schmizz/sshj/common/JcaKEMTest.java @@ -0,0 +1,171 @@ +/* + * Copyright (C)2009 - SSHJ Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package net.schmizz.sshj.common; + +import org.bouncycastle.jce.provider.BouncyCastleProvider; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; + +import java.security.GeneralSecurityException; +import java.security.KeyPair; +import java.security.KeyPairGenerator; +import java.security.NoSuchAlgorithmException; +import java.security.NoSuchProviderException; +import java.security.Security; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +/** + * Exercises {@link JcaKEM} directly. {@code JcaKEM} accesses the JDK 21+ + * {@code javax.crypto.KEM} API via reflection, so we want to verify that: + * + *
      + *
    • {@link JcaKEM#isApiAvailable()} agrees with the actual presence of the API + * class on the runtime,
    • + *
    • a successful {@code create(...)} produces an instance that can round-trip + * encap/decap to a matching shared secret,
    • + *
    • the reflective exception-unwrapping paths translate + * {@code InvocationTargetException} into the expected + * {@link NoSuchAlgorithmException} / {@link NoSuchProviderException} / + * {@link GeneralSecurityException} types instead of leaking the reflective + * wrapper.
    • + *
    + * + *

    All tests are gated on {@link JcaKEM#isApiAvailable()} so they skip cleanly + * on Java < 21 (the existing {@code BouncyCastleKEMTest} covers the + * fallback path on older runtimes).

    + */ +public class JcaKEMTest { + + @BeforeAll + public static void registerProvider() { + if (Security.getProvider(BouncyCastleProvider.PROVIDER_NAME) == null) { + Security.addProvider(new BouncyCastleProvider()); + } + } + + @Test + public void apiAvailabilityMatchesClassPresence() { + boolean classPresent; + try { + Class.forName("javax.crypto.KEM"); + classPresent = true; + } catch (ClassNotFoundException e) { + classPresent = false; + } + assertEquals(classPresent, JcaKEM.isApiAvailable(), + "JcaKEM.isApiAvailable() must reflect actual javax.crypto.KEM presence"); + } + + @Test + public void roundTripProducesMatchingSecretWithDefaultProvider() throws Exception { + assumeTrue(JcaKEM.isApiAvailable(), "javax.crypto.KEM not available on this JRE"); + assumeTrue(providerOffersService("KEM", "ML-KEM"), + "No JCA provider registers the ML-KEM KEM service on this JRE"); + + KeyPair kp = generateMlKem768KeyPair(); + SshjKEM kem = JcaKEM.create("ML-KEM", null); + + SshjKEM.Encapsulated encapsulated = kem.encapsulate(kp.getPublic()); + assertNotNull(encapsulated); + assertEquals(1088, encapsulated.getCiphertext().length, "ML-KEM-768 ciphertext length"); + assertEquals(32, encapsulated.getSharedSecret().length, "ML-KEM-768 shared secret length"); + + byte[] decapsulated = kem.decapsulate(kp.getPrivate(), encapsulated.getCiphertext()); + assertArrayEquals(encapsulated.getSharedSecret(), decapsulated, + "decapsulated secret must equal encapsulated secret"); + } + + @Test + public void roundTripProducesMatchingSecretWithExplicitProvider() throws Exception { + assumeTrue(JcaKEM.isApiAvailable(), "javax.crypto.KEM not available on this JRE"); + assumeTrue(Security.getProvider(BouncyCastleProvider.PROVIDER_NAME) + .getService("KEM", "ML-KEM") != null, + "BouncyCastle does not register ML-KEM KEM service on this JRE"); + + KeyPair kp = generateMlKem768KeyPair(); + SshjKEM kem = JcaKEM.create("ML-KEM", BouncyCastleProvider.PROVIDER_NAME); + + SshjKEM.Encapsulated encapsulated = kem.encapsulate(kp.getPublic()); + byte[] decapsulated = kem.decapsulate(kp.getPrivate(), encapsulated.getCiphertext()); + assertArrayEquals(encapsulated.getSharedSecret(), decapsulated); + } + + @Test + public void createUnwrapsNoSuchAlgorithmException() { + assumeTrue(JcaKEM.isApiAvailable(), "javax.crypto.KEM not available on this JRE"); + + NoSuchAlgorithmException ex = assertThrows(NoSuchAlgorithmException.class, + () -> JcaKEM.create("BOGUS-KEM-ALGORITHM", null)); + // Must be a direct NSAE, not a reflective wrapper like + // InvocationTargetException or some generic GeneralSecurityException. + assertEquals(NoSuchAlgorithmException.class, ex.getClass(), + "exception type must be exactly NoSuchAlgorithmException"); + } + + @Test + public void createUnwrapsNoSuchProviderException() { + assumeTrue(JcaKEM.isApiAvailable(), "javax.crypto.KEM not available on this JRE"); + + NoSuchProviderException ex = assertThrows(NoSuchProviderException.class, + () -> JcaKEM.create("ML-KEM", "ThisProviderDoesNotExist")); + assertEquals(NoSuchProviderException.class, ex.getClass(), + "exception type must be exactly NoSuchProviderException"); + } + + @Test + public void decapsulateRejectsWrongLengthCiphertext() throws Exception { + assumeTrue(JcaKEM.isApiAvailable(), "javax.crypto.KEM not available on this JRE"); + assumeTrue(providerOffersService("KEM", "ML-KEM"), + "No JCA provider registers the ML-KEM KEM service on this JRE"); + + KeyPair kp = generateMlKem768KeyPair(); + SshjKEM kem = JcaKEM.create("ML-KEM", null); + + byte[] tooShort = new byte[10]; + GeneralSecurityException ex = assertThrows(GeneralSecurityException.class, + () -> kem.decapsulate(kp.getPrivate(), tooShort)); + // The reflective layer must translate any thrown checked exception into + // a GeneralSecurityException (or subclass), never let an + // InvocationTargetException or RuntimeException leak. + assertNotNull(ex.getMessage() != null ? ex.getMessage() : ex.getCause(), + "exception should carry a message or cause"); + } + + private static KeyPair generateMlKem768KeyPair() throws Exception { + KeyPairGenerator kpg; + try { + kpg = KeyPairGenerator.getInstance("ML-KEM-768"); + } catch (NoSuchAlgorithmException firstTry) { + // SunJCE on JDK 21 doesn't register ML-KEM-768; explicitly fall back to BC. + kpg = KeyPairGenerator.getInstance("ML-KEM-768", BouncyCastleProvider.PROVIDER_NAME); + } + return kpg.generateKeyPair(); + } + + private static boolean providerOffersService(String type, String algorithm) { + for (java.security.Provider p : Security.getProviders()) { + if (p.getService(type, algorithm) != null) { + return true; + } + } + return false; + } +}