@@ -647,38 +647,64 @@ INLINE static int IsMessageAllowedServer(WOLFSSH *ssh, byte msg)
647647#ifndef NO_WOLFSSH_CLIENT
648648INLINE static int IsMessageAllowedClient(WOLFSSH *ssh, byte msg)
649649{
650- /* Transport Layer Generic messages are always allowed. */
651- if (MSGIDLIMIT_TRANS_GEN(msg)) {
652- return 1;
653- }
654-
655- /* The client should only send the user auth request message,
656- * it should not accept it. */
657- if (msg == MSGID_USERAUTH_REQUEST) {
650+ /* Only the client should send these messages, never receive. */
651+ if (msg == MSGID_SERVICE_REQUEST || msg == MSGID_USERAUTH_REQUEST) {
658652 WLOG(WS_LOG_DEBUG, "Message ID %u not allowed by %s %s",
659653 msg, "client", "ever");
660654 return 0;
661655 }
662656
657+ /* Transport Layer Generic messages are always allowed. */
658+ if (MSGIDLIMIT_TRANS_GEN(msg)) {
659+ return 1;
660+ }
661+
663662 /* Is KEX complete? */
664- if (ssh->connectState < CONNECT_KEYED && ssh->handshake ) {
665- /* If expecting a specific message, and didn't receive it, error. */
666- if (ssh->handshake->expectMsgId != MSGID_NONE) {
667- if (msg != ssh->handshake->expectMsgId ) {
668- WLOG(WS_LOG_DEBUG, "Message ID %u not the expected message %u ",
669- msg, ssh->handshake->expectMsgId );
663+ if (MSGIDLIMIT_TRANS(msg) ) {
664+ if (ssh->isKeying) {
665+ /* MSGID_KEXINIT not allowed when keying. */
666+ if (msg == MSGID_KEXINIT ) {
667+ WLOG(WS_LOG_DEBUG, "Message ID %u not allowed by %s %s ",
668+ msg, "client", "when keying" );
670669 return 0;
671670 }
672- ssh->handshake->expectMsgId = MSGID_NONE;
673- return 1;
671+
672+ /* Error if expecting a specific message and didn't receive. */
673+ if (ssh->handshake && ssh->handshake->expectMsgId != MSGID_NONE) {
674+ if (msg != ssh->handshake->expectMsgId) {
675+ WLOG(WS_LOG_DEBUG,
676+ "Message ID %u not the expected message %u",
677+ msg, ssh->handshake->expectMsgId);
678+ return 0;
679+ }
680+ else {
681+ /* Got the expected message, clear expectation. */
682+ ssh->handshake->expectMsgId = MSGID_NONE;
683+ return 1;
684+ }
685+ }
686+ }
687+ else {
688+ /* MSGID_KEXINIT only allowed when not keying. */
689+ if (msg == MSGID_KEXINIT) {
690+ return 1;
691+ }
692+
693+ /* All other transport KEX and ALGO messages are not allowed
694+ * when not keying. */
695+ WLOG(WS_LOG_DEBUG, "Message ID %u not allowed by %s %s",
696+ msg, "client", "when not keying");
697+ return 0;
674698 }
675699 }
700+
676701 /* Has client userauth started? */
677702 if (ssh->connectState < CONNECT_CLIENT_KEXDH_INIT_SENT) {
678703 if (msg >= MSGID_KEXDH_GEX_REQUEST) {
679704 return 0;
680705 }
681706 }
707+
682708 /* Is client userauth complete? */
683709 if (ssh->connectState < CONNECT_SERVER_USERAUTH_ACCEPT_DONE) {
684710 /* The endpoints should not allow message IDs greater than or
@@ -689,13 +715,6 @@ INLINE static int IsMessageAllowedClient(WOLFSSH *ssh, byte msg)
689715 msg, "client", "before user authentication is complete");
690716 return 0;
691717 }
692- /* Explicitly check for the user authentication request message.
693- * The client only sends the message, it shouldn't receive it. */
694- if (msg == MSGID_USERAUTH_REQUEST) {
695- WLOG(WS_LOG_DEBUG, "Message ID %u not allowed by %s %s",
696- msg, "client", "during user authentication");
697- return 0;
698- }
699718 }
700719 else {
701720 if (MSGIDLIMIT_AUTH(msg)) {
@@ -704,6 +723,7 @@ INLINE static int IsMessageAllowedClient(WOLFSSH *ssh, byte msg)
704723 return 0;
705724 }
706725 }
726+
707727 return 1;
708728}
709729#endif /* NO_WOLFSSH_CLIENT */
@@ -1095,7 +1115,7 @@ WOLFSSH* SshInit(WOLFSSH* ssh, WOLFSSH_CTX* ctx)
10951115 ssh->fs = NULL;
10961116 ssh->acceptState = ACCEPT_BEGIN;
10971117 ssh->clientState = CLIENT_BEGIN;
1098- ssh->isKeying = 1 ;
1118+ ssh->isKeying = 0 ;
10991119 ssh->authId = ID_USERAUTH_PUBLICKEY;
11001120 ssh->supportedAuth[0] = ID_USERAUTH_PUBLICKEY;
11011121 ssh->supportedAuth[1] = ID_USERAUTH_PASSWORD;
@@ -4326,10 +4346,14 @@ static int DoKexInit(WOLFSSH* ssh, byte* buf, word32 len, word32* idx)
43264346 byte scratchLen[LENGTH_SZ];
43274347 word32 strSz = 0;
43284348
4329- if (!ssh->isKeying) {
4349+ ssh->kexinitRxd = 1;
4350+ if (!ssh->kexinitTxd) {
43304351 WLOG(WS_LOG_DEBUG, "Keying initiated");
43314352 ret = SendKexInit(ssh);
43324353 }
4354+ else {
4355+ ssh->handshake->expectMsgId = MSGID_KEXINIT;
4356+ }
43334357
43344358 /* account for possible want write case from SendKexInit */
43354359 if (ret == WS_SUCCESS || ret == WS_WANT_WRITE)
@@ -5941,6 +5965,8 @@ static int DoNewKeys(WOLFSSH* ssh, byte* buf, word32 len, word32* idx)
59415965 ssh->rxCount = 0;
59425966 ssh->highwaterFlag = 0;
59435967 ssh->isKeying = 0;
5968+ ssh->kexinitTxd = 0;
5969+ ssh->kexinitRxd = 0;
59445970 HandshakeInfoFree(ssh->handshake, ssh->ctx->heap);
59455971 ssh->handshake = NULL;
59465972 WLOG(WS_LOG_DEBUG, "Keying completed");
@@ -9383,7 +9409,7 @@ static int DoPacket(WOLFSSH* ssh, byte* bufferConsumed)
93839409 case MSGID_KEXINIT:
93849410 WLOG(WS_LOG_DEBUG, "Decoding MSGID_KEXINIT");
93859411 ret = DoKexInit(ssh, buf + idx, payloadSz, &payloadIdx);
9386- if (ssh->isKeying == 1 &&
9412+ if (ssh->kexinitTxd == 1 &&
93879413 ssh->connectState == CONNECT_SERVER_CHANNEL_REQUEST_DONE) {
93889414 if (ssh->handshake->kexId == ID_DH_GEX_SHA256) {
93899415#if !defined(WOLFSSH_NO_DH) && !defined(WOLFSSH_NO_DH_GEX_SHA256)
@@ -10479,7 +10505,6 @@ int SendKexInit(WOLFSSH* ssh)
1047910505 }
1048010506
1048110507 if (ret == WS_SUCCESS) {
10482- ssh->isKeying = 1;
1048310508 if (ssh->handshake == NULL) {
1048410509 ssh->handshake = HandshakeInfoNew(ssh->ctx->heap);
1048510510 if (ssh->handshake == NULL) {
@@ -10602,8 +10627,10 @@ int SendKexInit(WOLFSSH* ssh)
1060210627 ret = BundlePacket(ssh);
1060310628 }
1060410629
10605- if (ret == WS_SUCCESS)
10630+ if (ret == WS_SUCCESS) {
10631+ ssh->kexinitTxd = 1;
1060610632 ret = wolfSSH_SendPacket(ssh);
10633+ }
1060710634
1060810635 if (ret != WS_WANT_WRITE && ret != WS_SUCCESS)
1060910636 PurgePacket(ssh);
0 commit comments