1 /*
2 * Copyright (C) 2019 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <source/HostToGuestComms.h>
18
19 #include <https/SafeCallbackable.h>
20 #include <https/Support.h>
21
22 #include <android-base/logging.h>
23
HostToGuestComms(std::shared_ptr<RunLoop> runLoop,bool isServer,int fd,ReceiveCb onReceive)24 HostToGuestComms::HostToGuestComms(
25 std::shared_ptr<RunLoop> runLoop,
26 bool isServer,
27 int fd,
28 ReceiveCb onReceive)
29 : mRunLoop(runLoop),
30 mIsServer(isServer),
31 mOnReceive(onReceive),
32 mServerSock(-1),
33 mSock(-1),
34 mInBufferLen(0),
35 mSendPending(false),
36 mConnected(false) {
37 makeFdNonblocking(fd);
38 if (mIsServer) {
39 mServerSock = fd;
40 } else {
41 mSock = fd;
42 }
43 }
44
HostToGuestComms(std::shared_ptr<RunLoop> runLoop,bool isServer,uint32_t cid,uint16_t port,ReceiveCb onReceive)45 HostToGuestComms::HostToGuestComms(
46 std::shared_ptr<RunLoop> runLoop,
47 bool isServer,
48 uint32_t cid,
49 uint16_t port,
50 ReceiveCb onReceive)
51 : mRunLoop(runLoop),
52 mIsServer(isServer),
53 mOnReceive(onReceive),
54 mServerSock(-1),
55 mSock(-1),
56 mInBufferLen(0),
57 mSendPending(false),
58 mConnected(false) {
59 int s = socket(AF_VSOCK, SOCK_STREAM, 0);
60 CHECK_GE(s, 0);
61
62 LOG(INFO) << "HostToGuestComms created socket " << s;
63
64 makeFdNonblocking(s);
65
66 sockaddr_vm addr;
67 memset(&addr, 0, sizeof(addr));
68 addr.svm_family = AF_VSOCK;
69 addr.svm_port = port;
70 addr.svm_cid = cid;
71
72 int res;
73 if (mIsServer) {
74 LOG(INFO)
75 << "Binding to cid "
76 << (addr.svm_cid == VMADDR_CID_ANY)
77 ? "VMADDR_CID_ANY" : std::to_string(addr.svm_cid);
78
79 res = bind(s, reinterpret_cast<const sockaddr *>(&addr), sizeof(addr));
80
81 if (res) {
82 LOG(ERROR)
83 << (mIsServer ? "bind" : "connect")
84 << " FAILED w/ errno "
85 << errno
86 << " ("
87 << strerror(errno)
88 << ")";
89 }
90
91 CHECK(!res);
92
93 res = listen(s, 4);
94 CHECK(!res);
95
96 mServerSock = s;
97 } else {
98 mSock = s;
99 mConnectToAddr = addr;
100 }
101 }
102
~HostToGuestComms()103 HostToGuestComms::~HostToGuestComms() {
104 if (mSock >= 0) {
105 mRunLoop->cancelSocket(mSock);
106
107 close(mSock);
108 mSock = -1;
109 }
110
111 if (mServerSock >= 0) {
112 mRunLoop->cancelSocket(mServerSock);
113
114 close(mServerSock);
115 mServerSock = -1;
116 }
117 }
118
start()119 void HostToGuestComms::start() {
120 if (mIsServer) {
121 mRunLoop->postSocketRecv(
122 mServerSock,
123 makeSafeCallback(this, &HostToGuestComms::onServerConnection));
124 } else {
125 mRunLoop->postWithDelay(
126 std::chrono::milliseconds(5000),
127 makeSafeCallback(
128 this,
129 &HostToGuestComms::onAttemptToConnect,
130 mConnectToAddr));
131 }
132 }
133
send(const void * data,size_t size,bool addFraming)134 void HostToGuestComms::send(const void *data, size_t size, bool addFraming) {
135 if (!size) {
136 return;
137 }
138
139 std::lock_guard autoLock(mLock);
140
141 size_t offset = mOutBuffer.size();
142
143 if (addFraming) {
144 uint32_t packetLen = size;
145 size_t totalSize = sizeof(packetLen) + size;
146
147 mOutBuffer.resize(offset + totalSize);
148 memcpy(mOutBuffer.data() + offset, &packetLen, sizeof(packetLen));
149 memcpy(mOutBuffer.data() + offset + sizeof(packetLen), data, size);
150 } else {
151 mOutBuffer.resize(offset + size);
152 memcpy(mOutBuffer.data() + offset, data, size);
153 }
154
155 if (mSock >= 0 && (mIsServer || mConnected) && !mSendPending) {
156 mSendPending = true;
157 mRunLoop->postSocketSend(
158 mSock,
159 makeSafeCallback(this, &HostToGuestComms::onSocketSend));
160 }
161 }
162
onServerConnection()163 void HostToGuestComms::onServerConnection() {
164 int s = accept(mServerSock, nullptr, nullptr);
165
166 if (s >= 0) {
167 if (mSock >= 0) {
168 LOG(INFO) << "Rejecting client, we already have one.";
169
170 // We already have a client.
171 close(s);
172 s = -1;
173 } else {
174 LOG(INFO) << "Accepted client socket " << s << ".";
175
176 makeFdNonblocking(s);
177
178 mSock = s;
179 mRunLoop->postSocketRecv(
180 mSock,
181 makeSafeCallback(this, &HostToGuestComms::onSocketReceive));
182
183 std::lock_guard autoLock(mLock);
184 if (!mOutBuffer.empty()) {
185 CHECK(!mSendPending);
186
187 mSendPending = true;
188 mRunLoop->postSocketSend(
189 mSock,
190 makeSafeCallback(
191 this, &HostToGuestComms::onSocketSend));
192 }
193 }
194 }
195
196 mRunLoop->postSocketRecv(
197 mServerSock,
198 makeSafeCallback(this, &HostToGuestComms::onServerConnection));
199 }
200
onSocketReceive()201 void HostToGuestComms::onSocketReceive() {
202 ssize_t n;
203 for (;;) {
204 static constexpr size_t kChunkSize = 65536;
205
206 mInBuffer.resize(mInBufferLen + kChunkSize);
207
208 do {
209 n = recv(mSock, mInBuffer.data() + mInBufferLen, kChunkSize, 0);
210 } while (n < 0 && errno == EINTR);
211
212 if (n <= 0) {
213 break;
214 }
215
216 mInBufferLen += static_cast<size_t>(n);
217 }
218
219 int savedErrno = errno;
220
221 drainInBuffer();
222
223 if ((n < 0 && savedErrno != EAGAIN && savedErrno != EWOULDBLOCK)
224 || n == 0) {
225 LOG(ERROR) << "Client is gone.";
226
227 // Client is gone.
228 mRunLoop->cancelSocket(mSock);
229
230 mSendPending = false;
231
232 close(mSock);
233 mSock = -1;
234 return;
235 }
236
237 mRunLoop->postSocketRecv(
238 mSock,
239 makeSafeCallback(this, &HostToGuestComms::onSocketReceive));
240 }
241
drainInBuffer()242 void HostToGuestComms::drainInBuffer() {
243 for (;;) {
244 uint32_t packetLen;
245
246 if (mInBufferLen < sizeof(packetLen)) {
247 return;
248 }
249
250 memcpy(&packetLen, mInBuffer.data(), sizeof(packetLen));
251
252 size_t totalLen = sizeof(packetLen) + packetLen;
253
254 if (mInBufferLen < totalLen) {
255 return;
256 }
257
258 if (mOnReceive) {
259 // LOG(INFO) << "Dispatching packet of size " << packetLen;
260
261 mOnReceive(mInBuffer.data() + sizeof(packetLen), packetLen);
262 }
263
264 mInBuffer.erase(mInBuffer.begin(), mInBuffer.begin() + totalLen);
265 mInBufferLen -= totalLen;
266 }
267 }
268
onSocketSend()269 void HostToGuestComms::onSocketSend() {
270 std::lock_guard autoLock(mLock);
271
272 CHECK(mSendPending);
273 mSendPending = false;
274
275 if (mSock < 0) {
276 return;
277 }
278
279 ssize_t n;
280 while (!mOutBuffer.empty()) {
281 do {
282 n = ::send(mSock, mOutBuffer.data(), mOutBuffer.size(), 0);
283 } while (n < 0 && errno == EINTR);
284
285 if (n <= 0) {
286 break;
287 }
288
289 mOutBuffer.erase(mOutBuffer.begin(), mOutBuffer.begin() + n);
290 }
291
292 if ((n < 0 && errno != EAGAIN && errno != EWOULDBLOCK) || n == 0) {
293 LOG(ERROR) << "Client is gone.";
294
295 // Client is gone.
296 mRunLoop->cancelSocket(mSock);
297
298 close(mSock);
299 mSock = -1;
300 return;
301 }
302
303 if (!mOutBuffer.empty()) {
304 mSendPending = true;
305 mRunLoop->postSocketSend(
306 mSock,
307 makeSafeCallback(this, &HostToGuestComms::onSocketSend));
308 }
309 }
310
onAttemptToConnect(const sockaddr_vm & addr)311 void HostToGuestComms::onAttemptToConnect(const sockaddr_vm &addr) {
312 LOG(VERBOSE) << "Attempting to connect to cid " << addr.svm_cid;
313
314 int res;
315 do {
316 res = connect(
317 mSock, reinterpret_cast<const sockaddr *>(&addr), sizeof(addr));
318 } while (res < 0 && errno == EINTR);
319
320 if (res < 0) {
321 if (errno == EINPROGRESS) {
322 LOG(VERBOSE) << "EINPROGRESS, waiting to check the connection.";
323
324 mRunLoop->postSocketSend(
325 mSock,
326 makeSafeCallback(
327 this, &HostToGuestComms::onCheckConnection, addr));
328
329 return;
330 }
331
332 LOG(INFO)
333 << "Our attempt to connect to the guest FAILED w/ error "
334 << errno
335 << " ("
336 << strerror(errno)
337 << "), will try again shortly.";
338
339 mRunLoop->postWithDelay(
340 std::chrono::milliseconds(5000),
341 makeSafeCallback(
342 this, &HostToGuestComms::onAttemptToConnect, addr));
343
344 return;
345 }
346
347 onConnected();
348 }
349
onCheckConnection(const sockaddr_vm & addr)350 void HostToGuestComms::onCheckConnection(const sockaddr_vm &addr) {
351 int err;
352
353 int res;
354 do {
355 socklen_t errSize = sizeof(err);
356
357 res = getsockopt(mSock, SOL_SOCKET, SO_ERROR, &err, &errSize);
358 } while (res < 0 && errno == EINTR);
359
360 CHECK(!res);
361
362 if (!err) {
363 onConnected();
364 } else {
365 LOG(VERBOSE)
366 << "Connection failed w/ error "
367 << err
368 << " ("
369 << strerror(err)
370 << "), will try again shortly.";
371
372 // Is there a better way of cancelling the (failed) connection that
373 // somehow is still in progress on the socket and restarting it?
374 mRunLoop->cancelSocket(mSock);
375
376 close(mSock);
377 mSock = socket(AF_VSOCK, SOCK_STREAM, 0);
378 CHECK_GE(mSock, 0);
379
380 makeFdNonblocking(mSock);
381
382 mRunLoop->postWithDelay(
383 std::chrono::milliseconds(5000),
384 makeSafeCallback(
385 this, &HostToGuestComms::onAttemptToConnect, addr));
386 }
387 }
388
onConnected()389 void HostToGuestComms::onConnected() {
390 LOG(INFO) << "Connected to guest.";
391
392 std::lock_guard autoLock(mLock);
393
394 mConnected = true;
395 CHECK(!mSendPending);
396
397 if (!mOutBuffer.empty()) {
398 mSendPending = true;
399 mRunLoop->postSocketSend(
400 mSock,
401 makeSafeCallback(this, &HostToGuestComms::onSocketSend));
402 }
403
404 mRunLoop->postSocketRecv(
405 mSock,
406 makeSafeCallback(this, &HostToGuestComms::onSocketReceive));
407 }
408
409