@@ -51,10 +51,10 @@ std::future<bool> VsockConnection::ConnectAsync(
5151}
5252
5353void VsockConnection::Disconnect () {
54- // We need to serialize all accesses to the SharedFD.
55- std::lock_guard<std::recursive_mutex> read_lock (read_mutex_);
56- std::lock_guard<std::recursive_mutex> write_lock (write_mutex_) ;
57-
54+ std::lock_guard<std::recursive_mutex> state_lock (state_mutex_);
55+ if (!fd_-> IsOpen ()) {
56+ return ;
57+ }
5858 LOG (INFO ) << " Disconnecting with fd status:" << fd_->StrError ();
5959 fd_->Shutdown (SHUT_RDWR );
6060 if (disconnect_callback_) {
@@ -64,37 +64,39 @@ void VsockConnection::Disconnect() {
6464}
6565
6666void VsockConnection::SetDisconnectCallback (std::function<void ()> callback) {
67+ std::lock_guard<std::recursive_mutex> state_lock (state_mutex_);
6768 disconnect_callback_ = callback;
6869}
6970
70- // This method created due to a race condition in IsConnected().
71- // TODO(b/345285391): remove this method once a fix found
72- bool VsockConnection::IsConnected_Unguarded () { return fd_->IsOpen (); }
73-
7471bool VsockConnection::IsConnected () {
75- // We need to serialize all accesses to the SharedFD.
76- std::lock_guard<std::recursive_mutex> read_lock (read_mutex_);
77- std::lock_guard<std::recursive_mutex> write_lock (write_mutex_);
78-
72+ std::lock_guard<std::recursive_mutex> state_lock (state_mutex_);
7973 return fd_->IsOpen ();
8074}
8175
8276bool VsockConnection::DataAvailable () {
77+ SharedFD local_fd;
78+ {
79+ std::lock_guard<std::recursive_mutex> state_lock (state_mutex_);
80+ local_fd = fd_;
81+ }
82+ if (!local_fd->IsOpen ()) {
83+ return false ;
84+ }
8385 SharedFDSet read_set;
84-
85- // We need to serialize all accesses to the SharedFD.
86- std::lock_guard<std::recursive_mutex> read_lock (read_mutex_);
87- std::lock_guard<std::recursive_mutex> write_lock (write_mutex_);
88-
89- read_set.Set (fd_);
86+ read_set.Set (local_fd);
9087 struct timeval timeout = {0 , 0 };
9188 return Select (&read_set, nullptr , nullptr , &timeout) > 0 ;
9289}
9390
9491int32_t VsockConnection::Read () {
9592 std::lock_guard<std::recursive_mutex> lock (read_mutex_);
93+ SharedFD local_fd;
94+ {
95+ std::lock_guard<std::recursive_mutex> state_lock (state_mutex_);
96+ local_fd = fd_;
97+ }
9698 int32_t result;
97- if (ReadExactBinary (fd_ , &result) != sizeof (result)) {
99+ if (ReadExactBinary (local_fd , &result) != sizeof (result)) {
98100 Disconnect ();
99101 return 0 ;
100102 }
@@ -103,16 +105,26 @@ int32_t VsockConnection::Read() {
103105
104106bool VsockConnection::Read (std::vector<char >& data) {
105107 std::lock_guard<std::recursive_mutex> lock (read_mutex_);
106- return ReadExact (fd_, &data) == data.size ();
108+ SharedFD local_fd;
109+ {
110+ std::lock_guard<std::recursive_mutex> state_lock (state_mutex_);
111+ local_fd = fd_;
112+ }
113+ return ReadExact (local_fd, &data) == data.size ();
107114}
108115
109116std::vector<char > VsockConnection::Read (size_t size) {
110117 if (size == 0 ) {
111118 return {};
112119 }
113120 std::lock_guard<std::recursive_mutex> lock (read_mutex_);
121+ SharedFD local_fd;
122+ {
123+ std::lock_guard<std::recursive_mutex> state_lock (state_mutex_);
124+ local_fd = fd_;
125+ }
114126 std::vector<char > result (size);
115- if (ReadExact (fd_ , &result) != size) {
127+ if (ReadExact (local_fd , &result) != size) {
116128 Disconnect ();
117129 return {};
118130 }
@@ -167,7 +179,12 @@ std::future<Json::Value> VsockConnection::ReadJsonMessageAsync() {
167179
168180bool VsockConnection::Write (int32_t data) {
169181 std::lock_guard<std::recursive_mutex> lock (write_mutex_);
170- if (WriteAllBinary (fd_, &data) != sizeof (data)) {
182+ SharedFD local_fd;
183+ {
184+ std::lock_guard<std::recursive_mutex> state_lock (state_mutex_);
185+ local_fd = fd_;
186+ }
187+ if (WriteAllBinary (local_fd, &data) != sizeof (data)) {
171188 Disconnect ();
172189 return false ;
173190 }
@@ -176,7 +193,12 @@ bool VsockConnection::Write(int32_t data) {
176193
177194bool VsockConnection::Write (const char * data, unsigned int size) {
178195 std::lock_guard<std::recursive_mutex> lock (write_mutex_);
179- if (WriteAll (fd_, data, size) != size) {
196+ SharedFD local_fd;
197+ {
198+ std::lock_guard<std::recursive_mutex> state_lock (state_mutex_);
199+ local_fd = fd_;
200+ }
201+ if (WriteAll (local_fd, data, size) != size) {
180202 Disconnect ();
181203 return false ;
182204 }
@@ -216,6 +238,7 @@ bool VsockConnection::WriteStrides(const char* data, unsigned int size,
216238
217239bool VsockClientConnection::Connect (unsigned int port, unsigned int cid,
218240 std::optional<int > vhost_user) {
241+ std::lock_guard<std::recursive_mutex> state_lock (state_mutex_);
219242 fd_ =
220243 SharedFD::VsockClient (cid, port, SOCK_STREAM , vhost_user ? true : false );
221244 if (!fd_->IsOpen ()) {
@@ -227,6 +250,7 @@ bool VsockClientConnection::Connect(unsigned int port, unsigned int cid,
227250VsockServerConnection::~VsockServerConnection () { ServerShutdown (); }
228251
229252void VsockServerConnection::ServerShutdown () {
253+ std::lock_guard<std::recursive_mutex> state_lock (state_mutex_);
230254 if (server_fd_->IsOpen ()) {
231255 LOG (INFO ) << __FUNCTION__
232256 << " : server fd status:" << server_fd_->StrError ();
@@ -237,12 +261,17 @@ void VsockServerConnection::ServerShutdown() {
237261
238262bool VsockServerConnection::Connect (unsigned int port, unsigned int cid,
239263 std::optional<int > vhost_user_vsock_cid) {
240- if (!server_fd_->IsOpen ()) {
241- server_fd_ = cuttlefish::SharedFD::VsockServer (port, SOCK_STREAM ,
242- vhost_user_vsock_cid, cid);
264+ {
265+ std::lock_guard<std::recursive_mutex> state_lock (state_mutex_);
266+ if (!server_fd_->IsOpen ()) {
267+ server_fd_ = cuttlefish::SharedFD::VsockServer (port, SOCK_STREAM ,
268+ vhost_user_vsock_cid, cid);
269+ }
243270 }
244271 if (server_fd_->IsOpen ()) {
245- fd_ = SharedFD::Accept (*server_fd_);
272+ auto accepted_fd = SharedFD::Accept (*server_fd_);
273+ std::lock_guard<std::recursive_mutex> state_lock (state_mutex_);
274+ fd_ = accepted_fd;
246275 return fd_->IsOpen ();
247276 } else {
248277 return false ;
0 commit comments