4747
4848using namespace std ::literals::string_view_literals;
4949
50+ enum ExitCodes : int {
51+ SUCCESS = 0 ,
52+
53+ MAIN_TEST_FAILED = 1 ,
54+
55+ CLIENT_PIPE_READ_ADDR_FAILED = 2 ,
56+ CLIENT_FAILED = 3 ,
57+ CLIENT_TEST_RESULT_MISMATCH = 4 ,
58+
59+ SERVER_BIND_ERROR = 5 ,
60+ SERVER_STDOUT_REDIRECT_FAILED = 6 ,
61+
62+ WAITPID_ANY_STATUS = -1 ,
63+ };
64+
5065enum Role : uint8_t { MAIN, CLIENT, SERVER, UNDETERMINED, LAST };
5166
5267constexpr std::array<std::string_view, Role::LAST> ROLE_STRING{ " MAIN" sv, " CLIENT" sv, " SERVER" sv, " UNDETERMINED" sv };
@@ -118,29 +133,38 @@ struct TLSCreds {
118133 std::string certBytes;
119134 std::string keyBytes;
120135 std::string caBytes;
136+ std::string password;
121137};
122138
123- TLSCreds makeCreds (const ChainLength chainLen, const mkcert::ESide side) {
139+ TLSCreds makeCreds (const ChainLength chainLen, const mkcert::ESide side, StringRef password = {} ) {
124140 if (chainLen == 0 || chainLen == NO_TLS) {
125- return TLSCreds{ chainLen == NO_TLS, " " , " " , " " };
141+ return TLSCreds{ chainLen == NO_TLS, " " , " " , " " , " " };
126142 }
127143 auto arena = Arena ();
128144 auto ret = TLSCreds{};
129- auto specs = mkcert::makeCertChainSpec (arena, std::labs (chainLen), side);
130- if (chainLen < 0 ) {
131- specs[0 ].offsetNotBefore = -60l * 60 * 24 * 365 ;
132- specs[0 ].offsetNotAfter = -10l ; // cert that expired 10 seconds ago
133- }
134- auto chain = mkcert::makeCertChain (arena, specs, {} /* create root CA cert from spec*/ );
135- if (chain.size () == 1 ) {
136- ret.certBytes = concatCertChain (arena, chain).toString ();
145+ if (!password.empty ()) {
146+ ret.password = password.toString ();
147+ auto certAndKeyPem = mkcert::makePasswCert (arena, password);
148+ ret.certBytes = certAndKeyPem.certPem .toString ();
149+ ret.keyBytes = certAndKeyPem.privateKeyPem .toString ();
150+ ret.caBytes = ret.certBytes ;
137151 } else {
138- auto nonRootChain = chain;
139- nonRootChain.pop_back ();
140- ret.certBytes = concatCertChain (arena, nonRootChain).toString ();
152+ auto specs = mkcert::makeCertChainSpec (arena, std::labs (chainLen), side);
153+ if (chainLen < 0 ) {
154+ specs[0 ].offsetNotBefore = -60l * 60 * 24 * 365 ;
155+ specs[0 ].offsetNotAfter = -10l ; // cert that expired 10 seconds ago
156+ }
157+ auto chain = mkcert::makeCertChain (arena, specs, {} /* create root CA cert from spec*/ );
158+ if (chain.size () == 1 ) {
159+ ret.certBytes = concatCertChain (arena, chain).toString ();
160+ } else {
161+ auto nonRootChain = chain;
162+ nonRootChain.pop_back ();
163+ ret.certBytes = concatCertChain (arena, nonRootChain).toString ();
164+ }
165+ ret.caBytes = chain.back ().certPem .toString ();
166+ ret.keyBytes = chain.front ().privateKeyPem .toString ();
141167 }
142- ret.caBytes = chain.back ().certPem .toString ();
143- ret.keyBytes = chain.front ().privateKeyPem .toString ();
144168 return ret;
145169}
146170
@@ -255,6 +279,7 @@ int runHost(TLSCreds creds, int addrPipe, int completionPipe, Result expect) {
255279 tlsConfig.setCertificateBytes (creds.certBytes );
256280 tlsConfig.setCABytes (creds.caBytes );
257281 tlsConfig.setKeyBytes (creds.keyBytes );
282+ tlsConfig.setPassword (creds.password );
258283 }
259284 g_network = newNet2 (tlsConfig);
260285 openTraceFile ({}, 10 << 20 , 10 << 20 , " ." , IsServer ? " authz_tls_unittest_server" : " authz_tls_unittest_client" );
@@ -264,7 +289,12 @@ int runHost(TLSCreds creds, int addrPipe, int completionPipe, Result expect) {
264289 auto addr = NetworkAddress::parse (noTls ? " 127.0.0.1:0" : " 127.0.0.1:0:tls" );
265290 auto endpoint = Endpoint ();
266291 auto receiver = SessionProbeReceiver ();
267- auto listenFuture = transport.bind (addr, addr);
292+ try {
293+ transport.bind (addr, addr);
294+ } catch (const Error& err) {
295+ log (" CAUGHT Error in bind: code={} what={}" , err.code (), err.what ());
296+ return SERVER_BIND_ERROR;
297+ }
268298 transport.addEndpoint (endpoint, &receiver, TaskPriority::ReadSocket);
269299 auto thread = std::thread ([]() {
270300 g_network->run ();
@@ -275,13 +305,13 @@ int runHost(TLSCreds creds, int addrPipe, int completionPipe, Result expect) {
275305 g_network->stop ();
276306 thread.join ();
277307 });
278- return 0 ;
308+ return SUCCESS ;
279309 } else {
280310 auto dest = Endpoint ();
281311 auto & serverAddr = dest.addresses .address ;
282312 if (sizeof (serverAddr) != ::read (addrPipe, &serverAddr, sizeof (serverAddr))) {
283313 log (" Failed to read server addr from pipe: {}" , strerror (errno));
284- return 1 ;
314+ return CLIENT_PIPE_READ_ADDR_FAILED ;
285315 }
286316 if (noTls)
287317 serverAddr.flags &= ~NetworkAddress::FLAG_TLS;
@@ -290,14 +320,14 @@ int runHost(TLSCreds creds, int addrPipe, int completionPipe, Result expect) {
290320 auto & token = dest.token ;
291321 if (sizeof (token) != ::read (addrPipe, &token, sizeof (token))) {
292322 log (" Failed to read server endpoint token from pipe: {}" , strerror (errno));
293- return 2 ;
323+ return CLIENT_FAILED ;
294324 }
295325 log (" Server address is {}{}" , serverAddr.toString (), noTls ? " (TLS suffix removed)" : " " );
296326 log (" Server endpoint token is {}" , token.toString ());
297327 auto sessionProbeReq = SessionProbeRequest{};
298328 transport.sendUnreliable (SerializeSource (sessionProbeReq), dest, true /* openConnection*/ );
299329 log (" Request is sent" );
300- auto rc = 0 ;
330+ auto rc = SUCCESS ;
301331 auto result = Result::ERROR;
302332 {
303333 auto timeout = delay (expect == Result::TIMEOUT ? 0.5 : 5 );
@@ -308,12 +338,12 @@ int runHost(TLSCreds creds, int addrPipe, int completionPipe, Result expect) {
308338 auto done = true ;
309339 if (sizeof (done) != ::write (completionPipe, &done, sizeof (done))) {
310340 log (" Failed to signal server to terminate: {}" , strerror (errno));
311- rc = 4 ;
341+ rc = CLIENT_FAILED ;
312342 }
313- if (rc == 0 ) {
343+ if (rc == SUCCESS ) {
314344 if (expect != result) {
315345 log (" Test failed: expected {}, got {}" , expect, result);
316- rc = 5 ;
346+ rc = CLIENT_TEST_RESULT_MISMATCH ;
317347 } else {
318348 log (" Response OK: got {} as expected" , result);
319349 }
@@ -374,7 +404,7 @@ std::pair<bool, std::string> waitPidStatusInterpreter(const char* procName, cons
374404 return { false , message };
375405}
376406
377- bool waitPid (pid_t subProcPid, const char * procName) {
407+ bool waitPid (pid_t subProcPid, const char * procName, int expectStatus = WAITPID_ANY_STATUS ) {
378408 auto status = int {};
379409 auto pid = ::waitpid (subProcPid, &status, 0 );
380410
@@ -386,33 +416,56 @@ bool waitPid(pid_t subProcPid, const char* procName) {
386416 auto [ok, message] = waitPidStatusInterpreter (procName, status);
387417 log (" {}" , message);
388418
389- return ok;
419+ return ok || (expectStatus != WAITPID_ANY_STATUS && WEXITSTATUS (status) == expectStatus) ;
390420 }
391421}
392422
393- int runTlsTest (ChainLength serverChainLen, ChainLength clientChainLen) {
394- log (" ==== BEGIN TESTCASE ====" );
395- auto const expect = getExpectedResult (serverChainLen, clientChainLen);
396- using namespace std ::literals::string_literals;
397- log (" Cert chain length: server={} client={}" , serverChainLen, clientChainLen);
398- auto arena = Arena ();
399- auto serverCreds = makeCreds (serverChainLen, mkcert::ESide::Server);
400- auto clientCreds = makeCreds (clientChainLen, mkcert::ESide::Client);
401- // make server and client trust each other
402- std::swap (serverCreds.caBytes , clientCreds.caBytes );
423+ int runTlsTest (ChainLength serverChainLen, ChainLength clientChainLen, std::string_view passwordTestCase = " " ) {
424+ auto expect = Result::TRUSTED;
425+ TLSCreds serverCreds;
426+ TLSCreds clientCreds;
427+ int expectStatusServer = WAITPID_ANY_STATUS;
428+ int expectStatusClient = WAITPID_ANY_STATUS;
429+
430+ if (passwordTestCase.empty ()) {
431+ log (" ==== BEGIN TESTCASE ====" );
432+ expect = getExpectedResult (serverChainLen, clientChainLen);
433+ log (" Cert chain length: server={} client={}" , serverChainLen, clientChainLen);
434+ serverCreds = makeCreds (serverChainLen, mkcert::ESide::Server);
435+ clientCreds = makeCreds (clientChainLen, mkcert::ESide::Client);
436+ // make server and client trust each other
437+ std::swap (serverCreds.caBytes , clientCreds.caBytes );
438+ } else {
439+ const auto password = " abc123" _sr;
440+ serverCreds = makeCreds (serverChainLen, mkcert::ESide::Server, password);
441+ clientCreds = serverCreds;
442+
443+ if (passwordTestCase == " client" ) {
444+ log (" ==== BEGIN CLIENT BAD PASSWORD TESTCASE ====" );
445+ expect = Result::TIMEOUT;
446+ clientCreds.password = " bad" ;
447+ } else if (passwordTestCase == " server" ) {
448+ log (" ==== BEGIN SERVER BAD PASSWORD TESTCASE ====" );
449+ serverCreds.password = " bad" ;
450+ expectStatusServer = SERVER_BIND_ERROR;
451+ expectStatusClient = CLIENT_PIPE_READ_ADDR_FAILED;
452+ } else {
453+ log (" ==== BEGIN PASSWORD PROTECTED TESTCASE ====" );
454+ }
455+ }
403456 auto clientPid = pid_t {};
404457 auto serverPid = pid_t {};
405458 int addrPipe[2 ], completionPipe[2 ], serverStdoutPipe[2 ], clientStdoutPipe[2 ];
406459 if (::pipe (addrPipe) || ::pipe (completionPipe) || ::pipe (serverStdoutPipe) || ::pipe (clientStdoutPipe)) {
407460 log (" Pipe open failed: {}" , strerror (errno));
408- return 1 ;
461+ return MAIN_TEST_FAILED ;
409462 }
410463 auto ok = true ;
411464 {
412465 serverPid = fork ();
413466 if (serverPid == -1 ) {
414467 log (" fork() for server subprocess failed: {}" , strerror (errno));
415- return 1 ;
468+ return MAIN_TEST_FAILED ;
416469 } else if (serverPid == 0 ) {
417470 role = Role::SERVER;
418471 // server subprocess
@@ -429,24 +482,25 @@ int runTlsTest(ChainLength serverChainLen, ChainLength clientChainLen) {
429482 if (-1 == ::dup2 (serverStdoutPipe[1 ], STDOUT_FILENO)) {
430483 log (" Failed to redirect server stdout to pipe: {}" , strerror (errno));
431484 ::close (serverStdoutPipe[1 ]);
432- return 1 ;
485+ return SERVER_STDOUT_REDIRECT_FAILED ;
433486 }
434487 _exit (runHost<true >(std::move (serverCreds), addrPipe[1 ], completionPipe[0 ], expect));
435488 }
436- auto serverProcCleanup = ScopeExit ([&ok, serverPid]() {
437- if (!waitPid (serverPid, " Server" ))
489+ auto serverProcCleanup = ScopeExit ([&ok, serverPid, expectStatusServer ]() {
490+ if (!waitPid (serverPid, " Server" , expectStatusServer ))
438491 ok = false ;
439492 });
493+ ::close (addrPipe[1 ]);
494+ ::close (completionPipe[0 ]);
495+ ::close (serverStdoutPipe[1 ]);
496+
440497 clientPid = fork ();
441498 if (clientPid == -1 ) {
442499 log (" fork() for client subprocess failed: {}" , strerror (errno));
443- return 1 ;
500+ return MAIN_TEST_FAILED ;
444501 } else if (clientPid == 0 ) {
445502 role = Role::CLIENT;
446- ::close (addrPipe[1 ]);
447- ::close (completionPipe[0 ]);
448503 ::close (serverStdoutPipe[0 ]);
449- ::close (serverStdoutPipe[1 ]);
450504 ::close (clientStdoutPipe[0 ]);
451505 auto pipeCleanup = ScopeExit ([&addrPipe, &completionPipe]() {
452506 ::close (addrPipe[0 ]);
@@ -455,21 +509,18 @@ int runTlsTest(ChainLength serverChainLen, ChainLength clientChainLen) {
455509 if (-1 == ::dup2 (clientStdoutPipe[1 ], STDOUT_FILENO)) {
456510 log (" Failed to redirect client stdout to pipe: {}" , strerror (errno));
457511 ::close (clientStdoutPipe[1 ]);
458- return 1 ;
512+ return CLIENT_FAILED ;
459513 }
460514 _exit (runHost<false >(std::move (clientCreds), addrPipe[0 ], completionPipe[1 ], expect));
461515 }
462- auto clientProcCleanup = ScopeExit ([&ok, clientPid]() {
463- if (!waitPid (clientPid, " Client" ))
516+ auto clientProcCleanup = ScopeExit ([&ok, clientPid, expectStatusClient ]() {
517+ if (!waitPid (clientPid, " Client" , expectStatusClient ))
464518 ok = false ;
465519 });
466520 }
467521 // main process
468522 ::close (addrPipe[0 ]);
469- ::close (addrPipe[1 ]);
470- ::close (completionPipe[0 ]);
471523 ::close (completionPipe[1 ]);
472- ::close (serverStdoutPipe[1 ]);
473524 ::close (clientStdoutPipe[1 ]);
474525 auto pipeCleanup = ScopeExit ([&]() {
475526 ::close (serverStdoutPipe[0 ]);
@@ -484,7 +535,7 @@ int runTlsTest(ChainLength serverChainLen, ChainLength clientChainLen) {
484535 logRaw (fmt::runtime (serverStdout));
485536 log (" /// End Server STDOUT ///" );
486537 log (fmt::runtime (ok ? " OK" : " FAILED" ));
487- return !ok ;
538+ return ok ? SUCCESS : MAIN_TEST_FAILED ;
488539}
489540
490541int main (int argc, char ** argv) {
@@ -514,12 +565,29 @@ int main(int argc, char** argv) {
514565 if (runTlsTest (serverChainLen, clientChainLen))
515566 failed.push_back ({ serverChainLen, clientChainLen });
516567 }
568+
569+ constexpr auto singleChainPair = std::pair (ChainLength (1 ), ChainLength (1 ));
570+ inputs.insert (inputs.end (), 3 , singleChainPair);
571+
572+ std::vector<std::string_view> failedPasswordTests;
573+ for (const auto & testCase : std::array{ " no_bad_password" , " client" , " server" }) {
574+ if (runTlsTest (singleChainPair.first , singleChainPair.second , testCase)) {
575+ failed.push_back (singleChainPair);
576+ failedPasswordTests.push_back (testCase);
577+ }
578+ }
579+
517580 if (!failed.empty ()) {
581+ if (!failedPasswordTests.empty ()) {
582+ for (const auto & test : failedPasswordTests) {
583+ log (" {}, failed" , test);
584+ }
585+ }
518586 log (" Test Failed: {}/{} cases: {}" , failed.size (), inputs.size (), failed);
519- return 1 ;
587+ return MAIN_TEST_FAILED ;
520588 } else {
521589 log (" Test OK: {}/{} cases passed" , inputs.size (), inputs.size ());
522- return 0 ;
590+ return SUCCESS ;
523591 }
524592}
525593#else // _WIN32
0 commit comments