Skip to content

Commit 611a2a1

Browse files
Add service-name check and regress test
1 parent 834e60f commit 611a2a1

3 files changed

Lines changed: 222 additions & 38 deletions

File tree

src/internal.c

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -8346,6 +8346,7 @@ static int DoUserAuthRequest(WOLFSSH* ssh,
83468346
word32 begin;
83478347
int ret = WS_SUCCESS;
83488348
byte authNameId;
8349+
byte serviceValid = 1;
83498350
WS_UserAuthData authData;
83508351

83518352
WLOG(WS_LOG_DEBUG, "Entering DoUserAuthRequest()");
@@ -8356,37 +8357,32 @@ static int DoUserAuthRequest(WOLFSSH* ssh,
83568357
if (ret == WS_SUCCESS) {
83578358
begin = *idx;
83588359
WMEMSET(&authData, 0, sizeof(authData));
8359-
ret = GetSize(&authData.usernameSz, buf, len, &begin);
8360-
}
8361-
8362-
if (ret == WS_SUCCESS) {
8363-
authData.username = buf + begin;
8364-
begin += authData.usernameSz;
8365-
8366-
ret = GetUint32(&authData.serviceNameSz, buf, len, &begin);
8360+
ret = GetStringRef(&authData.usernameSz, &authData.username,
8361+
buf, len, &begin);
83678362
}
83688363

83698364
if (ret == WS_SUCCESS) {
8370-
ret = wolfSSH_SetUsernameRaw(ssh, authData.username, authData.usernameSz);
8365+
ret = GetStringRef(&authData.serviceNameSz, &authData.serviceName,
8366+
buf, len, &begin);
83718367
}
83728368

83738369
if (ret == WS_SUCCESS) {
8374-
if (authData.serviceNameSz > len - begin) {
8375-
ret = WS_BUFFER_E;
8370+
if (NameToId((const char*)authData.serviceName, authData.serviceNameSz)
8371+
!= ID_SERVICE_CONNECTION) {
8372+
WLOG(WS_LOG_DEBUG, "DUAR: Invalid service name");
8373+
serviceValid = 0;
8374+
ret = SendUserAuthFailure(ssh, 0);
8375+
/* Consume all remaining data */
8376+
*idx = len;
8377+
}
8378+
else {
8379+
ret = GetStringRef(&authData.authNameSz, &authData.authName,
8380+
buf, len, &begin);
83768381
}
83778382
}
83788383

8379-
if (ret == WS_SUCCESS) {
8380-
authData.serviceName = buf + begin;
8381-
begin += authData.serviceNameSz;
8382-
8383-
ret = GetSize(&authData.authNameSz, buf, len, &begin);
8384-
}
8385-
8386-
if (ret == WS_SUCCESS) {
8387-
authData.authName = buf + begin;
8388-
begin += authData.authNameSz;
8389-
authNameId = NameToId((char*)authData.authName, authData.authNameSz);
8384+
if (ret == WS_SUCCESS && serviceValid) {
8385+
authNameId = NameToId((const char*)authData.authName, authData.authNameSz);
83908386
ssh->authId = authNameId;
83918387

83928388
if (authNameId == ID_USERAUTH_PASSWORD)
@@ -8409,8 +8405,10 @@ static int DoUserAuthRequest(WOLFSSH* ssh,
84098405
#endif
84108406
else {
84118407
WLOG(WS_LOG_DEBUG,
8412-
"invalid userauth type: %s", IdToName(authNameId));
8408+
"DUAR: invalid userauth type: %s", IdToName(authNameId));
84138409
ret = SendUserAuthFailure(ssh, 0);
8410+
/* Consume all remaining data */
8411+
begin = len;
84148412
}
84158413

84168414
if (ret == WS_SUCCESS) {
@@ -17976,6 +17974,12 @@ int wolfSSH_TestChannelPutData(WOLFSSH_CHANNEL* channel, byte* data,
1797617974
return ChannelPutData(channel, data, dataSz);
1797717975
}
1797817976

17977+
int wolfSSH_TestDoUserAuthRequest(WOLFSSH* ssh, byte* buf, word32 len,
17978+
word32* idx)
17979+
{
17980+
return DoUserAuthRequest(ssh, buf, len, idx);
17981+
}
17982+
1797917983
#ifndef WOLFSSH_NO_DH_GEX_SHA256
1798017984

1798117985
int wolfSSH_TestDoKexDhGexRequest(WOLFSSH* ssh, byte* buf, word32 len,

tests/unit.c

Lines changed: 193 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1181,13 +1181,25 @@ static int test_ChannelPutData(void)
11811181
return result;
11821182
}
11831183

1184+
/* Plaintext SSH packet from IoSend (before encryption/MAC): LENGTH_SZ,
1185+
* PAD_LENGTH_SZ, then payload starting with the message ID (RFC 4253;
1186+
* wolfSSH PreparePacket/BundlePacket). Not for encrypted payloads or
1187+
* arbitrary truncated chunks. */
1188+
static int CaptureMsgId(const byte* buf, word32 len)
1189+
{
1190+
word32 off = LENGTH_SZ + PAD_LENGTH_SZ;
1191+
1192+
if (len <= off)
1193+
return -1;
1194+
return (int)buf[off];
1195+
}
1196+
11841197
/* Verify DoChannelRequest sends CHANNEL_SUCCESS for known types and
11851198
* CHANNEL_FAILURE for unrecognized ones (RFC 4254 Section 5.4).
11861199
*
11871200
* A custom IoSend callback captures the outgoing packet in plaintext
1188-
* (no cipher negotiated on a fresh session). The SSH packet layout is:
1189-
* [4-byte packet_length][1-byte padding_length][1-byte msg_id]...
1190-
* so the message ID lives at byte offset 5. */
1201+
* (no cipher negotiated on a fresh session). Message ID is read via
1202+
* CaptureMsgId() using LENGTH_SZ + PAD_LENGTH_SZ. */
11911203
static byte s_chanReqCapture[256];
11921204
static word32 s_chanReqCaptureSz = 0;
11931205

@@ -1289,19 +1301,23 @@ static int test_DoChannelRequest(void)
12891301
goto done;
12901302
}
12911303

1292-
if (s_chanReqCaptureSz <= 5) {
1293-
printf("DoChannelRequest[%s]: captured packet too short (%u)\n",
1294-
cases[i].label, s_chanReqCaptureSz);
1295-
result = -410 - i;
1296-
goto done;
1297-
}
1304+
{
1305+
int capMsgId = CaptureMsgId(s_chanReqCapture, s_chanReqCaptureSz);
12981306

1299-
if (s_chanReqCapture[5] != cases[i].expectMsgId) {
1300-
printf("DoChannelRequest[%s]: msg_id=0x%02x, expected=0x%02x\n",
1301-
cases[i].label,
1302-
s_chanReqCapture[5], cases[i].expectMsgId);
1303-
result = -420 - i;
1304-
goto done;
1307+
if (capMsgId < 0) {
1308+
printf("DoChannelRequest[%s]: captured packet too short (%u)\n",
1309+
cases[i].label, s_chanReqCaptureSz);
1310+
result = -410 - i;
1311+
goto done;
1312+
}
1313+
1314+
if (capMsgId != (int)cases[i].expectMsgId) {
1315+
printf("DoChannelRequest[%s]: msg_id=0x%02x, expected=0x%02x\n",
1316+
cases[i].label,
1317+
capMsgId, cases[i].expectMsgId);
1318+
result = -420 - i;
1319+
goto done;
1320+
}
13051321
}
13061322
}
13071323

@@ -1311,6 +1327,163 @@ static int test_DoChannelRequest(void)
13111327
return result;
13121328
}
13131329

1330+
/* Capture buffer for the service-name unit test. Separate from the channel-
1331+
* request capture so the two tests can run independently in any order. */
1332+
static byte s_authSvcCapture[256];
1333+
static word32 s_authSvcCaptureSz = 0;
1334+
1335+
static int CaptureIoSendAuthSvc(WOLFSSH* ssh, void* buf, word32 sz, void* ctx)
1336+
{
1337+
(void)ssh; (void)ctx;
1338+
s_authSvcCaptureSz = (sz < (word32)sizeof(s_authSvcCapture))
1339+
? sz : (word32)sizeof(s_authSvcCapture);
1340+
WMEMCPY(s_authSvcCapture, buf, s_authSvcCaptureSz);
1341+
return (int)sz;
1342+
}
1343+
1344+
/* Verify DoUserAuthRequest rejects non-"ssh-connection" service names per
1345+
* RFC 4252 Section 5. For each case we assert:
1346+
* 1. ret == WS_SUCCESS (connection stays open for retry)
1347+
* 2. SSH_MSG_USERAUTH_FAILURE is actually sent (see CaptureMsgId():
1348+
* LENGTH_SZ + PAD_LENGTH_SZ then msg id)
1349+
* 3. *idx == len (entire payload consumed; buffer stays aligned)
1350+
*
1351+
* For invalid-service cases the auth-method field is intentionally omitted
1352+
* from the payload. DoUserAuthRequest must short-circuit at the service-name
1353+
* check and still satisfy all three assertions — proving it never tries to
1354+
* parse the missing auth-method field. If the short-circuit were absent,
1355+
* GetSize() for authNameSz would hit end-of-buffer and return WS_BUFFER_E,
1356+
* failing assertion 1.
1357+
*
1358+
* For the valid-service case, auth method "xyz-unknown" (always unsupported
1359+
* regardless of compile-time options) is included. The function reaches
1360+
* auth-method dispatch, falls to the unknown-method else-branch, and sends
1361+
* USERAUTH_FAILURE via that normal path.
1362+
*
1363+
* A second valid-service row appends fake password-style bytes after the
1364+
* method name. That proves DoUserAuthRequest() consumes trailing
1365+
* method-specific payload (begin = len in the unknown-method branch); without
1366+
* it, DoReceive() could advance inputBuffer.idx short of the packet end and
1367+
* misalign decoding. */
1368+
static const byte s_unknownAuthTrailingFakePassword[] = {
1369+
0x00, /* "change password" FALSE */
1370+
0x00, 0x00, 0x00, 0x08,
1371+
'p', 'a', 's', 's', 'w', 'o', 'r', 'd',
1372+
};
1373+
1374+
static int test_DoUserAuthRequest_serviceName(void)
1375+
{
1376+
WOLFSSH_CTX* ctx = NULL;
1377+
WOLFSSH* ssh = NULL;
1378+
int result = 0;
1379+
struct {
1380+
const char* svcName;
1381+
word32 svcNameSz;
1382+
const char* authMethod; /* NULL = omit field (proves short-circuit) */
1383+
word32 authMethodSz;
1384+
int expectRet;
1385+
const char* label;
1386+
const byte* authTrailing; /* bytes after auth method; NULL if none */
1387+
word32 authTrailingSz;
1388+
} cases[] = {
1389+
/* valid service: auth dispatch fires, fails on unknown method */
1390+
{ "ssh-connection", 14, "xyz-unknown", 11, WS_SUCCESS,
1391+
"valid svc unknown auth", NULL, 0 },
1392+
/* same but trailing junk must be skipped so *idx reaches len */
1393+
{ "ssh-connection", 14, "xyz-unknown", 11, WS_SUCCESS,
1394+
"valid svc unknown auth trailing junk",
1395+
s_unknownAuthTrailingFakePassword,
1396+
(word32)sizeof(s_unknownAuthTrailingFakePassword) },
1397+
/* invalid service: short-circuit, auth-method field absent */
1398+
{ "ssh-agent", 9, NULL, 0, WS_SUCCESS,
1399+
"invalid ssh-agent svc", NULL, 0 },
1400+
{ "bad", 3, NULL, 0, WS_SUCCESS,
1401+
"invalid bad svc", NULL, 0 },
1402+
/* zero-length service name: NameToId("",0)==ID_UNKNOWN, must reject */
1403+
{ "", 0, NULL, 0, WS_SUCCESS,
1404+
"zero-length svc", NULL, 0 },
1405+
/* ssh-userauth: NameToId returns ID_SERVICE_USERAUTH, not
1406+
* ID_SERVICE_CONNECTION, so must also be rejected */
1407+
{ "ssh-userauth", 12, NULL, 0, WS_SUCCESS,
1408+
"invalid ssh-userauth svc", NULL, 0 },
1409+
};
1410+
int i;
1411+
1412+
ctx = wolfSSH_CTX_new(WOLFSSH_ENDPOINT_SERVER, NULL);
1413+
if (ctx == NULL) return -500;
1414+
wolfSSH_SetIOSend(ctx, CaptureIoSendAuthSvc);
1415+
1416+
ssh = wolfSSH_new(ctx);
1417+
if (ssh == NULL) { wolfSSH_CTX_free(ctx); return -501; }
1418+
1419+
for (i = 0; i < (int)(sizeof(cases)/sizeof(cases[0])); i++) {
1420+
byte buf[128];
1421+
word32 len = 0, idx = 0;
1422+
word32 snsz = cases[i].svcNameSz;
1423+
int ret;
1424+
1425+
s_authSvcCaptureSz = 0;
1426+
WMEMSET(s_authSvcCapture, 0, sizeof(s_authSvcCapture));
1427+
1428+
/* username: "user" */
1429+
buf[len++] = 0; buf[len++] = 0; buf[len++] = 0; buf[len++] = 4;
1430+
WMEMCPY(buf + len, "user", 4); len += 4;
1431+
1432+
/* service name */
1433+
buf[len++] = (byte)(snsz >> 24); buf[len++] = (byte)(snsz >> 16);
1434+
buf[len++] = (byte)(snsz >> 8); buf[len++] = (byte)snsz;
1435+
WMEMCPY(buf + len, cases[i].svcName, snsz); len += snsz;
1436+
1437+
/* auth method: omit for invalid-service cases to prove short-circuit */
1438+
if (cases[i].authMethod != NULL) {
1439+
word32 amsz = cases[i].authMethodSz;
1440+
buf[len++] = (byte)(amsz >> 24); buf[len++] = (byte)(amsz >> 16);
1441+
buf[len++] = (byte)(amsz >> 8); buf[len++] = (byte)amsz;
1442+
WMEMCPY(buf + len, cases[i].authMethod, amsz); len += amsz;
1443+
if (cases[i].authTrailingSz > 0U) {
1444+
WMEMCPY(buf + len, cases[i].authTrailing,
1445+
cases[i].authTrailingSz);
1446+
len += cases[i].authTrailingSz;
1447+
}
1448+
}
1449+
1450+
ret = wolfSSH_TestDoUserAuthRequest(ssh, buf, len, &idx);
1451+
1452+
if (ret != cases[i].expectRet) {
1453+
printf("DoUserAuthRequest_svcName[%s]: ret=%d expected=%d\n",
1454+
cases[i].label, ret, cases[i].expectRet);
1455+
result = -502 - i;
1456+
break;
1457+
}
1458+
1459+
/* MSGID_USERAUTH_FAILURE must be in the captured packet. */
1460+
{
1461+
int capMsgId = CaptureMsgId(s_authSvcCapture, s_authSvcCaptureSz);
1462+
1463+
if (capMsgId < 0 || capMsgId != MSGID_USERAUTH_FAILURE) {
1464+
printf("DoUserAuthRequest_svcName[%s]: USERAUTH_FAILURE not "
1465+
"sent (capSz=%u msg_id=0x%02x)\n", cases[i].label,
1466+
s_authSvcCaptureSz,
1467+
capMsgId >= 0 ? capMsgId : 0);
1468+
result = -520 - i;
1469+
break;
1470+
}
1471+
}
1472+
1473+
/* All cases must consume the entire payload. */
1474+
if (idx != len) {
1475+
printf("DoUserAuthRequest_svcName[%s]: idx=%u expected len=%u\n",
1476+
cases[i].label, idx, len);
1477+
result = -510 - i;
1478+
break;
1479+
}
1480+
}
1481+
1482+
wolfSSH_free(ssh);
1483+
wolfSSH_CTX_free(ctx);
1484+
return result;
1485+
}
1486+
13141487
#if !defined(WOLFSSH_NO_RSA)
13151488

13161489
/* 2048-bit RSA private key (PKCS#1 DER).
@@ -1609,6 +1782,11 @@ int wolfSSH_UnitTest(int argc, char** argv)
16091782
unitResult = test_ChannelPutData();
16101783
printf("ChannelPutData: %s\n", (unitResult == 0 ? "SUCCESS" : "FAILED"));
16111784
testResult = testResult || unitResult;
1785+
1786+
unitResult = test_DoUserAuthRequest_serviceName();
1787+
printf("DoUserAuthRequest_serviceName: %s\n",
1788+
(unitResult == 0 ? "SUCCESS" : "FAILED"));
1789+
testResult = testResult || unitResult;
16121790
#endif
16131791

16141792
#ifdef WOLFSSH_KEYGEN

wolfssh/internal.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1350,6 +1350,8 @@ enum WS_MessageIdLimits {
13501350
WOLFSSH_API int wolfSSH_TestDoKexDhReply(WOLFSSH* ssh, byte* buf,
13511351
word32 len, word32* idx);
13521352
WOLFSSH_API int wolfSSH_TestChannelPutData(WOLFSSH_CHANNEL*, byte*, word32);
1353+
WOLFSSH_API int wolfSSH_TestDoUserAuthRequest(WOLFSSH* ssh, byte* buf,
1354+
word32 len, word32* idx);
13531355
#ifndef WOLFSSH_NO_DH_GEX_SHA256
13541356
WOLFSSH_API int wolfSSH_TestDoKexDhGexRequest(WOLFSSH* ssh, byte* buf,
13551357
word32 len, word32* idx);

0 commit comments

Comments
 (0)