Skip to content

Commit f3941ac

Browse files
authored
XDB-478 Fix segmentation fault, properly initialize activeTlsPolicy on password errors (#210)
1 parent cadebf6 commit f3941ac

6 files changed

Lines changed: 154 additions & 62 deletions

File tree

fdbrpc/tests/AuthzTlsTest.actor.cpp

Lines changed: 121 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,21 @@
4747

4848
using namespace std::literals::string_view_literals;
4949

50+
enum ExitCodes : int {
51+
SUCCESS = 0,
52+
53+
MAIN_TEST_FAILED = 1,
54+
55+
CLIENT_PIPE_READ_ADDR_FAILED = 2,
56+
CLIENT_FAILED = 3,
57+
CLIENT_TEST_RESULT_MISMATCH = 4,
58+
59+
SERVER_BIND_ERROR = 5,
60+
SERVER_STDOUT_REDIRECT_FAILED = 6,
61+
62+
WAITPID_ANY_STATUS = -1,
63+
};
64+
5065
enum Role : uint8_t { MAIN, CLIENT, SERVER, UNDETERMINED, LAST };
5166

5267
constexpr std::array<std::string_view, Role::LAST> ROLE_STRING{ "MAIN"sv, "CLIENT"sv, "SERVER"sv, "UNDETERMINED"sv };
@@ -118,29 +133,38 @@ struct TLSCreds {
118133
std::string certBytes;
119134
std::string keyBytes;
120135
std::string caBytes;
136+
std::string password;
121137
};
122138

123-
TLSCreds makeCreds(const ChainLength chainLen, const mkcert::ESide side) {
139+
TLSCreds makeCreds(const ChainLength chainLen, const mkcert::ESide side, StringRef password = {}) {
124140
if (chainLen == 0 || chainLen == NO_TLS) {
125-
return TLSCreds{ chainLen == NO_TLS, "", "", "" };
141+
return TLSCreds{ chainLen == NO_TLS, "", "", "", "" };
126142
}
127143
auto arena = Arena();
128144
auto ret = TLSCreds{};
129-
auto specs = mkcert::makeCertChainSpec(arena, std::labs(chainLen), side);
130-
if (chainLen < 0) {
131-
specs[0].offsetNotBefore = -60l * 60 * 24 * 365;
132-
specs[0].offsetNotAfter = -10l; // cert that expired 10 seconds ago
133-
}
134-
auto chain = mkcert::makeCertChain(arena, specs, {} /* create root CA cert from spec*/);
135-
if (chain.size() == 1) {
136-
ret.certBytes = concatCertChain(arena, chain).toString();
145+
if (!password.empty()) {
146+
ret.password = password.toString();
147+
auto certAndKeyPem = mkcert::makePasswCert(arena, password);
148+
ret.certBytes = certAndKeyPem.certPem.toString();
149+
ret.keyBytes = certAndKeyPem.privateKeyPem.toString();
150+
ret.caBytes = ret.certBytes;
137151
} else {
138-
auto nonRootChain = chain;
139-
nonRootChain.pop_back();
140-
ret.certBytes = concatCertChain(arena, nonRootChain).toString();
152+
auto specs = mkcert::makeCertChainSpec(arena, std::labs(chainLen), side);
153+
if (chainLen < 0) {
154+
specs[0].offsetNotBefore = -60l * 60 * 24 * 365;
155+
specs[0].offsetNotAfter = -10l; // cert that expired 10 seconds ago
156+
}
157+
auto chain = mkcert::makeCertChain(arena, specs, {} /* create root CA cert from spec*/);
158+
if (chain.size() == 1) {
159+
ret.certBytes = concatCertChain(arena, chain).toString();
160+
} else {
161+
auto nonRootChain = chain;
162+
nonRootChain.pop_back();
163+
ret.certBytes = concatCertChain(arena, nonRootChain).toString();
164+
}
165+
ret.caBytes = chain.back().certPem.toString();
166+
ret.keyBytes = chain.front().privateKeyPem.toString();
141167
}
142-
ret.caBytes = chain.back().certPem.toString();
143-
ret.keyBytes = chain.front().privateKeyPem.toString();
144168
return ret;
145169
}
146170

@@ -255,6 +279,7 @@ int runHost(TLSCreds creds, int addrPipe, int completionPipe, Result expect) {
255279
tlsConfig.setCertificateBytes(creds.certBytes);
256280
tlsConfig.setCABytes(creds.caBytes);
257281
tlsConfig.setKeyBytes(creds.keyBytes);
282+
tlsConfig.setPassword(creds.password);
258283
}
259284
g_network = newNet2(tlsConfig);
260285
openTraceFile({}, 10 << 20, 10 << 20, ".", IsServer ? "authz_tls_unittest_server" : "authz_tls_unittest_client");
@@ -264,7 +289,12 @@ int runHost(TLSCreds creds, int addrPipe, int completionPipe, Result expect) {
264289
auto addr = NetworkAddress::parse(noTls ? "127.0.0.1:0" : "127.0.0.1:0:tls");
265290
auto endpoint = Endpoint();
266291
auto receiver = SessionProbeReceiver();
267-
auto listenFuture = transport.bind(addr, addr);
292+
try {
293+
transport.bind(addr, addr);
294+
} catch (const Error& err) {
295+
log("CAUGHT Error in bind: code={} what={}", err.code(), err.what());
296+
return SERVER_BIND_ERROR;
297+
}
268298
transport.addEndpoint(endpoint, &receiver, TaskPriority::ReadSocket);
269299
auto thread = std::thread([]() {
270300
g_network->run();
@@ -275,13 +305,13 @@ int runHost(TLSCreds creds, int addrPipe, int completionPipe, Result expect) {
275305
g_network->stop();
276306
thread.join();
277307
});
278-
return 0;
308+
return SUCCESS;
279309
} else {
280310
auto dest = Endpoint();
281311
auto& serverAddr = dest.addresses.address;
282312
if (sizeof(serverAddr) != ::read(addrPipe, &serverAddr, sizeof(serverAddr))) {
283313
log("Failed to read server addr from pipe: {}", strerror(errno));
284-
return 1;
314+
return CLIENT_PIPE_READ_ADDR_FAILED;
285315
}
286316
if (noTls)
287317
serverAddr.flags &= ~NetworkAddress::FLAG_TLS;
@@ -290,14 +320,14 @@ int runHost(TLSCreds creds, int addrPipe, int completionPipe, Result expect) {
290320
auto& token = dest.token;
291321
if (sizeof(token) != ::read(addrPipe, &token, sizeof(token))) {
292322
log("Failed to read server endpoint token from pipe: {}", strerror(errno));
293-
return 2;
323+
return CLIENT_FAILED;
294324
}
295325
log("Server address is {}{}", serverAddr.toString(), noTls ? " (TLS suffix removed)" : "");
296326
log("Server endpoint token is {}", token.toString());
297327
auto sessionProbeReq = SessionProbeRequest{};
298328
transport.sendUnreliable(SerializeSource(sessionProbeReq), dest, true /*openConnection*/);
299329
log("Request is sent");
300-
auto rc = 0;
330+
auto rc = SUCCESS;
301331
auto result = Result::ERROR;
302332
{
303333
auto timeout = delay(expect == Result::TIMEOUT ? 0.5 : 5);
@@ -308,12 +338,12 @@ int runHost(TLSCreds creds, int addrPipe, int completionPipe, Result expect) {
308338
auto done = true;
309339
if (sizeof(done) != ::write(completionPipe, &done, sizeof(done))) {
310340
log("Failed to signal server to terminate: {}", strerror(errno));
311-
rc = 4;
341+
rc = CLIENT_FAILED;
312342
}
313-
if (rc == 0) {
343+
if (rc == SUCCESS) {
314344
if (expect != result) {
315345
log("Test failed: expected {}, got {}", expect, result);
316-
rc = 5;
346+
rc = CLIENT_TEST_RESULT_MISMATCH;
317347
} else {
318348
log("Response OK: got {} as expected", result);
319349
}
@@ -374,7 +404,7 @@ std::pair<bool, std::string> waitPidStatusInterpreter(const char* procName, cons
374404
return { false, message };
375405
}
376406

377-
bool waitPid(pid_t subProcPid, const char* procName) {
407+
bool waitPid(pid_t subProcPid, const char* procName, int expectStatus = WAITPID_ANY_STATUS) {
378408
auto status = int{};
379409
auto pid = ::waitpid(subProcPid, &status, 0);
380410

@@ -386,33 +416,56 @@ bool waitPid(pid_t subProcPid, const char* procName) {
386416
auto [ok, message] = waitPidStatusInterpreter(procName, status);
387417
log("{}", message);
388418

389-
return ok;
419+
return ok || (expectStatus != WAITPID_ANY_STATUS && WEXITSTATUS(status) == expectStatus);
390420
}
391421
}
392422

393-
int runTlsTest(ChainLength serverChainLen, ChainLength clientChainLen) {
394-
log("==== BEGIN TESTCASE ====");
395-
auto const expect = getExpectedResult(serverChainLen, clientChainLen);
396-
using namespace std::literals::string_literals;
397-
log("Cert chain length: server={} client={}", serverChainLen, clientChainLen);
398-
auto arena = Arena();
399-
auto serverCreds = makeCreds(serverChainLen, mkcert::ESide::Server);
400-
auto clientCreds = makeCreds(clientChainLen, mkcert::ESide::Client);
401-
// make server and client trust each other
402-
std::swap(serverCreds.caBytes, clientCreds.caBytes);
423+
int runTlsTest(ChainLength serverChainLen, ChainLength clientChainLen, std::string_view passwordTestCase = "") {
424+
auto expect = Result::TRUSTED;
425+
TLSCreds serverCreds;
426+
TLSCreds clientCreds;
427+
int expectStatusServer = WAITPID_ANY_STATUS;
428+
int expectStatusClient = WAITPID_ANY_STATUS;
429+
430+
if (passwordTestCase.empty()) {
431+
log("==== BEGIN TESTCASE ====");
432+
expect = getExpectedResult(serverChainLen, clientChainLen);
433+
log("Cert chain length: server={} client={}", serverChainLen, clientChainLen);
434+
serverCreds = makeCreds(serverChainLen, mkcert::ESide::Server);
435+
clientCreds = makeCreds(clientChainLen, mkcert::ESide::Client);
436+
// make server and client trust each other
437+
std::swap(serverCreds.caBytes, clientCreds.caBytes);
438+
} else {
439+
const auto password = "abc123"_sr;
440+
serverCreds = makeCreds(serverChainLen, mkcert::ESide::Server, password);
441+
clientCreds = serverCreds;
442+
443+
if (passwordTestCase == "client") {
444+
log("==== BEGIN CLIENT BAD PASSWORD TESTCASE ====");
445+
expect = Result::TIMEOUT;
446+
clientCreds.password = "bad";
447+
} else if (passwordTestCase == "server") {
448+
log("==== BEGIN SERVER BAD PASSWORD TESTCASE ====");
449+
serverCreds.password = "bad";
450+
expectStatusServer = SERVER_BIND_ERROR;
451+
expectStatusClient = CLIENT_PIPE_READ_ADDR_FAILED;
452+
} else {
453+
log("==== BEGIN PASSWORD PROTECTED TESTCASE ====");
454+
}
455+
}
403456
auto clientPid = pid_t{};
404457
auto serverPid = pid_t{};
405458
int addrPipe[2], completionPipe[2], serverStdoutPipe[2], clientStdoutPipe[2];
406459
if (::pipe(addrPipe) || ::pipe(completionPipe) || ::pipe(serverStdoutPipe) || ::pipe(clientStdoutPipe)) {
407460
log("Pipe open failed: {}", strerror(errno));
408-
return 1;
461+
return MAIN_TEST_FAILED;
409462
}
410463
auto ok = true;
411464
{
412465
serverPid = fork();
413466
if (serverPid == -1) {
414467
log("fork() for server subprocess failed: {}", strerror(errno));
415-
return 1;
468+
return MAIN_TEST_FAILED;
416469
} else if (serverPid == 0) {
417470
role = Role::SERVER;
418471
// server subprocess
@@ -429,24 +482,25 @@ int runTlsTest(ChainLength serverChainLen, ChainLength clientChainLen) {
429482
if (-1 == ::dup2(serverStdoutPipe[1], STDOUT_FILENO)) {
430483
log("Failed to redirect server stdout to pipe: {}", strerror(errno));
431484
::close(serverStdoutPipe[1]);
432-
return 1;
485+
return SERVER_STDOUT_REDIRECT_FAILED;
433486
}
434487
_exit(runHost<true>(std::move(serverCreds), addrPipe[1], completionPipe[0], expect));
435488
}
436-
auto serverProcCleanup = ScopeExit([&ok, serverPid]() {
437-
if (!waitPid(serverPid, "Server"))
489+
auto serverProcCleanup = ScopeExit([&ok, serverPid, expectStatusServer]() {
490+
if (!waitPid(serverPid, "Server", expectStatusServer))
438491
ok = false;
439492
});
493+
::close(addrPipe[1]);
494+
::close(completionPipe[0]);
495+
::close(serverStdoutPipe[1]);
496+
440497
clientPid = fork();
441498
if (clientPid == -1) {
442499
log("fork() for client subprocess failed: {}", strerror(errno));
443-
return 1;
500+
return MAIN_TEST_FAILED;
444501
} else if (clientPid == 0) {
445502
role = Role::CLIENT;
446-
::close(addrPipe[1]);
447-
::close(completionPipe[0]);
448503
::close(serverStdoutPipe[0]);
449-
::close(serverStdoutPipe[1]);
450504
::close(clientStdoutPipe[0]);
451505
auto pipeCleanup = ScopeExit([&addrPipe, &completionPipe]() {
452506
::close(addrPipe[0]);
@@ -455,21 +509,18 @@ int runTlsTest(ChainLength serverChainLen, ChainLength clientChainLen) {
455509
if (-1 == ::dup2(clientStdoutPipe[1], STDOUT_FILENO)) {
456510
log("Failed to redirect client stdout to pipe: {}", strerror(errno));
457511
::close(clientStdoutPipe[1]);
458-
return 1;
512+
return CLIENT_FAILED;
459513
}
460514
_exit(runHost<false>(std::move(clientCreds), addrPipe[0], completionPipe[1], expect));
461515
}
462-
auto clientProcCleanup = ScopeExit([&ok, clientPid]() {
463-
if (!waitPid(clientPid, "Client"))
516+
auto clientProcCleanup = ScopeExit([&ok, clientPid, expectStatusClient]() {
517+
if (!waitPid(clientPid, "Client", expectStatusClient))
464518
ok = false;
465519
});
466520
}
467521
// main process
468522
::close(addrPipe[0]);
469-
::close(addrPipe[1]);
470-
::close(completionPipe[0]);
471523
::close(completionPipe[1]);
472-
::close(serverStdoutPipe[1]);
473524
::close(clientStdoutPipe[1]);
474525
auto pipeCleanup = ScopeExit([&]() {
475526
::close(serverStdoutPipe[0]);
@@ -484,7 +535,7 @@ int runTlsTest(ChainLength serverChainLen, ChainLength clientChainLen) {
484535
logRaw(fmt::runtime(serverStdout));
485536
log("/// End Server STDOUT ///");
486537
log(fmt::runtime(ok ? "OK" : "FAILED"));
487-
return !ok;
538+
return ok ? SUCCESS : MAIN_TEST_FAILED;
488539
}
489540

490541
int main(int argc, char** argv) {
@@ -514,12 +565,29 @@ int main(int argc, char** argv) {
514565
if (runTlsTest(serverChainLen, clientChainLen))
515566
failed.push_back({ serverChainLen, clientChainLen });
516567
}
568+
569+
constexpr auto singleChainPair = std::pair(ChainLength(1), ChainLength(1));
570+
inputs.insert(inputs.end(), 3, singleChainPair);
571+
572+
std::vector<std::string_view> failedPasswordTests;
573+
for (const auto& testCase : std::array{ "no_bad_password", "client", "server" }) {
574+
if (runTlsTest(singleChainPair.first, singleChainPair.second, testCase)) {
575+
failed.push_back(singleChainPair);
576+
failedPasswordTests.push_back(testCase);
577+
}
578+
}
579+
517580
if (!failed.empty()) {
581+
if (!failedPasswordTests.empty()) {
582+
for (const auto& test : failedPasswordTests) {
583+
log(" {}, failed", test);
584+
}
585+
}
518586
log("Test Failed: {}/{} cases: {}", failed.size(), inputs.size(), failed);
519-
return 1;
587+
return MAIN_TEST_FAILED;
520588
} else {
521589
log("Test OK: {}/{} cases passed", inputs.size(), inputs.size());
522-
return 0;
590+
return SUCCESS;
523591
}
524592
}
525593
#else // _WIN32

flow/MkCert.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -91,13 +91,13 @@ struct CertAndKeyNative {
9191
return ret;
9292
}
9393

94-
PemType toPem(Arena& arena) {
94+
PemType toPem(Arena& arena, StringRef password = StringRef()) {
9595
auto ret = PemType{};
9696
if (null())
9797
return ret;
9898
ASSERT(valid());
9999
ret.certPem = writeX509CertPem(arena, cert);
100-
ret.privateKeyPem = privateKey.writePem(arena);
100+
ret.privateKeyPem = privateKey.writePem(arena, password);
101101
return ret;
102102
}
103103
};
@@ -255,10 +255,10 @@ CertAndKeyNative makeCertNative(CertSpecRef spec, CertAndKeyNative issuer) {
255255
return ret;
256256
}
257257

258-
CertAndKeyRef CertAndKeyRef::make(Arena& arena, CertSpecRef spec, CertAndKeyRef issuerPem) {
258+
CertAndKeyRef CertAndKeyRef::make(Arena& arena, CertSpecRef spec, CertAndKeyRef issuerPem, StringRef password) {
259259
auto issuer = CertAndKeyNative::fromPem(issuerPem);
260260
auto newCertAndKey = makeCertNative(spec, issuer);
261-
return newCertAndKey.toPem(arena);
261+
return newCertAndKey.toPem(arena, password);
262262
}
263263

264264
CertSpecRef CertSpecRef::make(Arena& arena, CertKind kind) {
@@ -370,4 +370,9 @@ StringRef CertKind::getCommonName(StringRef prefix, Arena& arena) const {
370370
}
371371
}
372372

373+
CertAndKeyRef makePasswCert(Arena& arena, StringRef password) {
374+
auto spec = CertSpecRef::make(arena, CertKind(Server{}));
375+
return CertAndKeyRef::make(arena, spec, CertAndKeyRef{}, password);
376+
}
377+
373378
} // namespace mkcert

flow/Net2.actor.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1422,6 +1422,8 @@ void Net2::initTLS(ETLSInitState targetState) {
14221422
sslContextVar.set(ReferencedObject<boost::asio::ssl::context>::from(std::move(newContext)));
14231423
} catch (Error& e) {
14241424
TraceEvent("Net2TLSInitError").error(e);
1425+
flushTraceFileVoid();
1426+
throw;
14251427
}
14261428
backgroundCertRefresh =
14271429
reloadCertificatesOnChange(tlsConfig, onPolicyFailure, &sslContextVar, &activeTlsPolicy);

0 commit comments

Comments
 (0)