Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 56 additions & 27 deletions base/cvd/cuttlefish/common/libs/utils/vsock_connection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,10 @@ std::future<bool> VsockConnection::ConnectAsync(
}

void VsockConnection::Disconnect() {
// We need to serialize all accesses to the SharedFD.
std::lock_guard<std::recursive_mutex> read_lock(read_mutex_);
std::lock_guard<std::recursive_mutex> write_lock(write_mutex_);

std::lock_guard<std::recursive_mutex> state_lock(state_mutex_);
if (!fd_->IsOpen()) {
return;
}
LOG(INFO) << "Disconnecting with fd status:" << fd_->StrError();
fd_->Shutdown(SHUT_RDWR);
if (disconnect_callback_) {
Expand All @@ -64,37 +64,39 @@ void VsockConnection::Disconnect() {
}

void VsockConnection::SetDisconnectCallback(std::function<void()> callback) {
std::lock_guard<std::recursive_mutex> state_lock(state_mutex_);
disconnect_callback_ = callback;
}

// This method created due to a race condition in IsConnected().
// TODO(b/345285391): remove this method once a fix found
bool VsockConnection::IsConnected_Unguarded() { return fd_->IsOpen(); }

bool VsockConnection::IsConnected() {
// We need to serialize all accesses to the SharedFD.
std::lock_guard<std::recursive_mutex> read_lock(read_mutex_);
std::lock_guard<std::recursive_mutex> write_lock(write_mutex_);

std::lock_guard<std::recursive_mutex> state_lock(state_mutex_);
return fd_->IsOpen();
}

bool VsockConnection::DataAvailable() {
SharedFD local_fd;
{
std::lock_guard<std::recursive_mutex> state_lock(state_mutex_);
local_fd = fd_;
}
if (!local_fd->IsOpen()) {
return false;
}
SharedFDSet read_set;

// We need to serialize all accesses to the SharedFD.
std::lock_guard<std::recursive_mutex> read_lock(read_mutex_);
std::lock_guard<std::recursive_mutex> write_lock(write_mutex_);

read_set.Set(fd_);
read_set.Set(local_fd);
struct timeval timeout = {0, 0};
return Select(&read_set, nullptr, nullptr, &timeout) > 0;
}

int32_t VsockConnection::Read() {
std::lock_guard<std::recursive_mutex> lock(read_mutex_);
SharedFD local_fd;
{
std::lock_guard<std::recursive_mutex> state_lock(state_mutex_);
local_fd = fd_;
}
int32_t result;
if (ReadExactBinary(fd_, &result) != sizeof(result)) {
if (ReadExactBinary(local_fd, &result) != sizeof(result)) {
Disconnect();
return 0;
}
Expand All @@ -103,16 +105,26 @@ int32_t VsockConnection::Read() {

bool VsockConnection::Read(std::vector<char>& data) {
std::lock_guard<std::recursive_mutex> lock(read_mutex_);
return ReadExact(fd_, &data) == data.size();
SharedFD local_fd;
{
std::lock_guard<std::recursive_mutex> state_lock(state_mutex_);
local_fd = fd_;
}
return ReadExact(local_fd, &data) == data.size();
}

std::vector<char> VsockConnection::Read(size_t size) {
if (size == 0) {
return {};
}
std::lock_guard<std::recursive_mutex> lock(read_mutex_);
SharedFD local_fd;
{
std::lock_guard<std::recursive_mutex> state_lock(state_mutex_);
local_fd = fd_;
}
std::vector<char> result(size);
if (ReadExact(fd_, &result) != size) {
if (ReadExact(local_fd, &result) != size) {
Disconnect();
return {};
}
Expand Down Expand Up @@ -167,7 +179,12 @@ std::future<Json::Value> VsockConnection::ReadJsonMessageAsync() {

bool VsockConnection::Write(int32_t data) {
std::lock_guard<std::recursive_mutex> lock(write_mutex_);
if (WriteAllBinary(fd_, &data) != sizeof(data)) {
SharedFD local_fd;
{
std::lock_guard<std::recursive_mutex> state_lock(state_mutex_);
local_fd = fd_;
}
if (WriteAllBinary(local_fd, &data) != sizeof(data)) {
Disconnect();
return false;
}
Expand All @@ -176,7 +193,12 @@ bool VsockConnection::Write(int32_t data) {

bool VsockConnection::Write(const char* data, unsigned int size) {
std::lock_guard<std::recursive_mutex> lock(write_mutex_);
if (WriteAll(fd_, data, size) != size) {
SharedFD local_fd;
{
std::lock_guard<std::recursive_mutex> state_lock(state_mutex_);
local_fd = fd_;
}
if (WriteAll(local_fd, data, size) != size) {
Disconnect();
return false;
}
Expand Down Expand Up @@ -216,6 +238,7 @@ bool VsockConnection::WriteStrides(const char* data, unsigned int size,

bool VsockClientConnection::Connect(unsigned int port, unsigned int cid,
std::optional<int> vhost_user) {
std::lock_guard<std::recursive_mutex> state_lock(state_mutex_);
fd_ =
SharedFD::VsockClient(cid, port, SOCK_STREAM, vhost_user ? true : false);
if (!fd_->IsOpen()) {
Expand All @@ -227,6 +250,7 @@ bool VsockClientConnection::Connect(unsigned int port, unsigned int cid,
VsockServerConnection::~VsockServerConnection() { ServerShutdown(); }

void VsockServerConnection::ServerShutdown() {
std::lock_guard<std::recursive_mutex> state_lock(state_mutex_);
if (server_fd_->IsOpen()) {
LOG(INFO) << __FUNCTION__
<< ": server fd status:" << server_fd_->StrError();
Expand All @@ -237,12 +261,17 @@ void VsockServerConnection::ServerShutdown() {

bool VsockServerConnection::Connect(unsigned int port, unsigned int cid,
std::optional<int> vhost_user_vsock_cid) {
if (!server_fd_->IsOpen()) {
server_fd_ = cuttlefish::SharedFD::VsockServer(port, SOCK_STREAM,
vhost_user_vsock_cid, cid);
{
std::lock_guard<std::recursive_mutex> state_lock(state_mutex_);
if (!server_fd_->IsOpen()) {
server_fd_ = cuttlefish::SharedFD::VsockServer(port, SOCK_STREAM,
vhost_user_vsock_cid, cid);
}
}
if (server_fd_->IsOpen()) {
fd_ = SharedFD::Accept(*server_fd_);
auto accepted_fd = SharedFD::Accept(*server_fd_);
std::lock_guard<std::recursive_mutex> state_lock(state_mutex_);
fd_ = accepted_fd;
return fd_->IsOpen();
} else {
return false;
Expand Down
2 changes: 1 addition & 1 deletion base/cvd/cuttlefish/common/libs/utils/vsock_connection.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ class VsockConnection {
std::optional<int> vhost_user_vsock_cid);
void SetDisconnectCallback(std::function<void()> callback);

bool IsConnected_Unguarded();
bool IsConnected();
bool DataAvailable();
int32_t Read();
Expand All @@ -63,6 +62,7 @@ class VsockConnection {
unsigned int num_strides, int stride_size);

protected:
std::recursive_mutex state_mutex_;
std::recursive_mutex read_mutex_;
std::recursive_mutex write_mutex_;
std::function<void()> disconnect_callback_;
Expand Down
Loading