@@ -598,34 +598,43 @@ static void HandshakeInfoFree(HandshakeInfo* hs, void* heap)
598598#ifndef NO_WOLFSSH_SERVER
599599INLINE static int IsMessageAllowedServer(WOLFSSH *ssh, byte msg)
600600{
601+ /* Transport Layer Generic messages are always allowed. */
602+ if (MSGIDLIMIT_TRANS_GEN(msg)) {
603+ return 1;
604+ }
605+
601606 /* Has client userauth started? */
607+ /* Allows the server to receive up to KEXDH GEX Request during KEX. */
602608 if (ssh->acceptState < ACCEPT_KEYED) {
603- if (msg > MSGID_KEXDH_LIMIT ) {
609+ if (msg > MSGID_KEXDH_GEX_REQUEST ) {
604610 return 0;
605611 }
606612 }
607613 /* Is server userauth complete? */
608614 if (ssh->acceptState < ACCEPT_SERVER_USERAUTH_SENT) {
615+ /* The server should only receive the user auth request message,
616+ * it should not accept the other user auth messages, it sends
617+ * them. (>50) */
609618 /* Explicitly check for messages not allowed before user
610619 * authentication has comleted. */
611- if (msg >= MSGID_USERAUTH_LIMIT ) {
612- WLOG(WS_LOG_DEBUG, "Message ID %u not allowed by server "
613- " before user authentication is complete", msg );
620+ if (MSGIDLIMIT_POST_USERAUTH( msg) ) {
621+ WLOG(WS_LOG_DEBUG, "Message ID %u not allowed by %s %s",
622+ msg, "server", " before user authentication is complete");
614623 return 0;
615624 }
616625 /* Explicitly check for the user authentication messages that
617626 * only the server sends, it shouldn't receive them. */
618- if ((msg > MSGID_USERAUTH_RESTRICT ) &&
627+ if ((msg > MSGID_USERAUTH_REQUEST ) &&
619628 (msg != MSGID_USERAUTH_INFO_RESPONSE)) {
620- WLOG(WS_LOG_DEBUG, "Message ID %u not allowed by server "
621- " during user authentication", msg );
629+ WLOG(WS_LOG_DEBUG, "Message ID %u not allowed by %s %s",
630+ msg, "server", " during user authentication");
622631 return 0;
623632 }
624633 }
625634 else {
626- if (msg >= MSGID_USERAUTH_RESTRICT && msg < MSGID_USERAUTH_LIMIT ) {
627- WLOG(WS_LOG_DEBUG, "Message ID %u not allowed by server "
628- " after user authentication", msg );
635+ if (msg >= MSGID_USERAUTH_REQUEST && msg < MSGID_GLOBAL_REQUEST ) {
636+ WLOG(WS_LOG_DEBUG, "Message ID %u not allowed by %s %s",
637+ msg, "server", " after user authentication");
629638 return 0;
630639 }
631640 }
@@ -638,49 +647,94 @@ INLINE static int IsMessageAllowedServer(WOLFSSH *ssh, byte msg)
638647#ifndef NO_WOLFSSH_CLIENT
639648INLINE static int IsMessageAllowedClient(WOLFSSH *ssh, byte msg)
640649{
650+ /* Only the client should send these messages, never receive. */
651+ if (msg == MSGID_SERVICE_REQUEST || msg == MSGID_USERAUTH_REQUEST) {
652+ WLOG(WS_LOG_DEBUG, "Message ID %u not allowed by %s %s",
653+ msg, "client", "ever");
654+ return 0;
655+ }
656+
657+ if (msg == MSGID_SERVICE_ACCEPT) {
658+ if (ssh->connectState == CONNECT_CLIENT_USERAUTH_REQUEST_SENT) {
659+ return 1;
660+ }
661+ else {
662+ WLOG(WS_LOG_DEBUG, "Message ID %u not allowed by %s %s",
663+ msg, "client", "after starting user auth");
664+ return 0;
665+ }
666+ }
667+
668+ /* Transport Layer Generic messages are always allowed. */
669+ if (MSGIDLIMIT_TRANS_GEN(msg)) {
670+ return 1;
671+ }
672+
641673 /* Is KEX complete? */
642- if (ssh->connectState < CONNECT_KEYED && ssh->handshake ) {
643- /* If expecting a specific message, and didn't receive it, error. */
644- if (ssh->handshake->expectMsgId != MSGID_NONE) {
645- if (msg != ssh->handshake->expectMsgId ) {
646- WLOG(WS_LOG_DEBUG, "Message ID %u not the expected message %u ",
647- msg, ssh->handshake->expectMsgId );
674+ if (MSGIDLIMIT_TRANS(msg) ) {
675+ if (ssh->isKeying) {
676+ /* MSGID_KEXINIT not allowed when keying. */
677+ if (msg == MSGID_KEXINIT ) {
678+ WLOG(WS_LOG_DEBUG, "Message ID %u not allowed by %s %s ",
679+ msg, "client", "when keying" );
648680 return 0;
649681 }
650- ssh->handshake->expectMsgId = MSGID_NONE;
682+
683+ /* Error if expecting a specific message and didn't receive. */
684+ if (ssh->handshake && ssh->handshake->expectMsgId != MSGID_NONE) {
685+ if (msg != ssh->handshake->expectMsgId) {
686+ WLOG(WS_LOG_DEBUG,
687+ "Message ID %u not the expected message %u",
688+ msg, ssh->handshake->expectMsgId);
689+ return 0;
690+ }
691+ else {
692+ /* Got the expected message, clear expectation. */
693+ ssh->handshake->expectMsgId = MSGID_NONE;
694+ return 1;
695+ }
696+ }
651697 }
652- }
653- /* Has client userauth started? */
654- if (ssh->connectState < CONNECT_CLIENT_KEXDH_INIT_SENT) {
655- if (msg >= MSGID_KEXDH_LIMIT) {
698+ else {
699+ /* MSGID_KEXINIT only allowed when not keying. */
700+ if (msg == MSGID_KEXINIT) {
701+ return 1;
702+ }
703+
704+ /* All other transport KEX and ALGO messages are not allowed
705+ * when not keying. */
706+ WLOG(WS_LOG_DEBUG, "Message ID %u not allowed by %s %s",
707+ msg, "client", "when not keying");
656708 return 0;
657709 }
658710 }
711+
659712 /* Is client userauth complete? */
660713 if (ssh->connectState < CONNECT_SERVER_USERAUTH_ACCEPT_DONE) {
661- /* Explicitly check for messages not allowed before user
662- * authentication has comleted. */
663- if (msg >= MSGID_USERAUTH_LIMIT) {
664- WLOG(WS_LOG_DEBUG, "Message ID %u not allowed by client "
665- "before user authentication is complete", msg);
714+ /* The endpoints should not allow message IDs greater than or
715+ * equal to msgid 80 before user authentication is complete.
716+ * Per RFC 4252 section 6. */
717+ if (MSGIDLIMIT_POST_USERAUTH(msg)) {
718+ WLOG(WS_LOG_DEBUG, "Message ID %u not allowed by %s %s",
719+ msg, "client", "before user authentication is complete");
666720 return 0;
667721 }
668- /* Explicitly check for the user authentication message that
669- * only the client sends, it shouldn't receive it. */
670- if (msg == MSGID_USERAUTH_RESTRICT) {
671- WLOG(WS_LOG_DEBUG, "Message ID %u not allowed by client "
672- "during user authentication", msg);
673- return 0;
722+ else if (MSGIDLIMIT_AUTH(msg)) {
723+ return 1;
674724 }
675725 }
676726 else {
677- if (msg >= MSGID_USERAUTH_RESTRICT && msg < MSGID_USERAUTH_LIMIT) {
678- WLOG(WS_LOG_DEBUG, "Message ID %u not allowed by client "
679- "after user authentication", msg);
727+ if (MSGIDLIMIT_POST_USERAUTH(msg)) {
728+ return 1;
729+ }
730+ else if (MSGIDLIMIT_AUTH(msg)) {
731+ WLOG(WS_LOG_DEBUG, "Message ID %u not allowed by %s %s",
732+ msg, "client", "after user authentication");
680733 return 0;
681734 }
682735 }
683- return 1;
736+
737+ return 0;
684738}
685739#endif /* NO_WOLFSSH_CLIENT */
686740
@@ -1071,7 +1125,7 @@ WOLFSSH* SshInit(WOLFSSH* ssh, WOLFSSH_CTX* ctx)
10711125 ssh->fs = NULL;
10721126 ssh->acceptState = ACCEPT_BEGIN;
10731127 ssh->clientState = CLIENT_BEGIN;
1074- ssh->isKeying = 1 ;
1128+ ssh->isKeying = 0 ;
10751129 ssh->authId = ID_USERAUTH_PUBLICKEY;
10761130 ssh->supportedAuth[0] = ID_USERAUTH_PUBLICKEY;
10771131 ssh->supportedAuth[1] = ID_USERAUTH_PASSWORD;
@@ -4302,10 +4356,12 @@ static int DoKexInit(WOLFSSH* ssh, byte* buf, word32 len, word32* idx)
43024356 byte scratchLen[LENGTH_SZ];
43034357 word32 strSz = 0;
43044358
4305- if (!ssh->isKeying) {
4359+ ssh->kexinitRxd = 1;
4360+ if (!ssh->kexinitTxd) {
43064361 WLOG(WS_LOG_DEBUG, "Keying initiated");
43074362 ret = SendKexInit(ssh);
43084363 }
4364+ ssh->isKeying = ssh->kexinitTxd;
43094365
43104366 /* account for possible want write case from SendKexInit */
43114367 if (ret == WS_SUCCESS || ret == WS_WANT_WRITE)
@@ -5835,8 +5891,10 @@ static int DoKexDhReply(WOLFSSH* ssh, byte* buf, word32 len, word32* idx)
58355891 ret = GenerateKeys(ssh, hashId, !ssh->handshake->useEccMlKem);
58365892 }
58375893
5838- if (ret == WS_SUCCESS)
5894+ if (ret == WS_SUCCESS) {
58395895 ret = SendNewKeys(ssh);
5896+ ssh->handshake->expectMsgId = MSGID_NEWKEYS;
5897+ }
58405898
58415899 if (sigKeyBlock_ptr)
58425900 WFREE(sigKeyBlock_ptr, ssh->ctx->heap, DYNTYPE_PRIVKEY);
@@ -5917,6 +5975,8 @@ static int DoNewKeys(WOLFSSH* ssh, byte* buf, word32 len, word32* idx)
59175975 ssh->rxCount = 0;
59185976 ssh->highwaterFlag = 0;
59195977 ssh->isKeying = 0;
5978+ ssh->kexinitTxd = 0;
5979+ ssh->kexinitRxd = 0;
59205980 HandshakeInfoFree(ssh->handshake, ssh->ctx->heap);
59215981 ssh->handshake = NULL;
59225982 WLOG(WS_LOG_DEBUG, "Keying completed");
@@ -9359,7 +9419,7 @@ static int DoPacket(WOLFSSH* ssh, byte* bufferConsumed)
93599419 case MSGID_KEXINIT:
93609420 WLOG(WS_LOG_DEBUG, "Decoding MSGID_KEXINIT");
93619421 ret = DoKexInit(ssh, buf + idx, payloadSz, &payloadIdx);
9362- if (ssh->isKeying == 1 &&
9422+ if (ssh->kexinitTxd == 1 &&
93639423 ssh->connectState == CONNECT_SERVER_CHANNEL_REQUEST_DONE) {
93649424 if (ssh->handshake->kexId == ID_DH_GEX_SHA256) {
93659425#if !defined(WOLFSSH_NO_DH) && !defined(WOLFSSH_NO_DH_GEX_SHA256)
@@ -10455,7 +10515,6 @@ int SendKexInit(WOLFSSH* ssh)
1045510515 }
1045610516
1045710517 if (ret == WS_SUCCESS) {
10458- ssh->isKeying = 1;
1045910518 if (ssh->handshake == NULL) {
1046010519 ssh->handshake = HandshakeInfoNew(ssh->ctx->heap);
1046110520 if (ssh->handshake == NULL) {
@@ -10578,8 +10637,11 @@ int SendKexInit(WOLFSSH* ssh)
1057810637 ret = BundlePacket(ssh);
1057910638 }
1058010639
10581- if (ret == WS_SUCCESS)
10640+ if (ret == WS_SUCCESS) {
10641+ ssh->kexinitTxd = 1;
10642+ ssh->isKeying = ssh->kexinitRxd;
1058210643 ret = wolfSSH_SendPacket(ssh);
10644+ }
1058310645
1058410646 if (ret != WS_WANT_WRITE && ret != WS_SUCCESS)
1058510647 PurgePacket(ssh);
0 commit comments