1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <source/HostToGuestComms.h>
18 
19 #include <https/SafeCallbackable.h>
20 #include <https/Support.h>
21 
22 #include <android-base/logging.h>
23 
HostToGuestComms(std::shared_ptr<RunLoop> runLoop,bool isServer,int fd,ReceiveCb onReceive)24 HostToGuestComms::HostToGuestComms(
25         std::shared_ptr<RunLoop> runLoop,
26         bool isServer,
27         int fd,
28         ReceiveCb onReceive)
29     : mRunLoop(runLoop),
30       mIsServer(isServer),
31       mOnReceive(onReceive),
32       mServerSock(-1),
33       mSock(-1),
34       mInBufferLen(0),
35       mSendPending(false),
36       mConnected(false) {
37     makeFdNonblocking(fd);
38     if (mIsServer) {
39         mServerSock = fd;
40     } else {
41         mSock = fd;
42     }
43 }
44 
HostToGuestComms(std::shared_ptr<RunLoop> runLoop,bool isServer,uint32_t cid,uint16_t port,ReceiveCb onReceive)45 HostToGuestComms::HostToGuestComms(
46         std::shared_ptr<RunLoop> runLoop,
47         bool isServer,
48         uint32_t cid,
49         uint16_t port,
50         ReceiveCb onReceive)
51     : mRunLoop(runLoop),
52       mIsServer(isServer),
53       mOnReceive(onReceive),
54       mServerSock(-1),
55       mSock(-1),
56       mInBufferLen(0),
57       mSendPending(false),
58       mConnected(false) {
59     int s = socket(AF_VSOCK, SOCK_STREAM, 0);
60     CHECK_GE(s, 0);
61 
62     LOG(INFO) << "HostToGuestComms created socket " << s;
63 
64     makeFdNonblocking(s);
65 
66     sockaddr_vm addr;
67     memset(&addr, 0, sizeof(addr));
68     addr.svm_family = AF_VSOCK;
69     addr.svm_port = port;
70     addr.svm_cid = cid;
71 
72     int res;
73     if (mIsServer) {
74         LOG(INFO)
75             << "Binding to cid "
76             << (addr.svm_cid == VMADDR_CID_ANY)
77                     ? "VMADDR_CID_ANY" : std::to_string(addr.svm_cid);
78 
79         res = bind(s, reinterpret_cast<const sockaddr *>(&addr), sizeof(addr));
80 
81         if (res) {
82             LOG(ERROR)
83                 << (mIsServer ? "bind" : "connect")
84                 << " FAILED w/ errno "
85                 << errno
86                 << " ("
87                 << strerror(errno)
88                 << ")";
89         }
90 
91         CHECK(!res);
92 
93         res = listen(s, 4);
94         CHECK(!res);
95 
96         mServerSock = s;
97     } else {
98         mSock = s;
99         mConnectToAddr = addr;
100     }
101 }
102 
~HostToGuestComms()103 HostToGuestComms::~HostToGuestComms() {
104     if (mSock >= 0) {
105         mRunLoop->cancelSocket(mSock);
106 
107         close(mSock);
108         mSock = -1;
109     }
110 
111     if (mServerSock >= 0) {
112         mRunLoop->cancelSocket(mServerSock);
113 
114         close(mServerSock);
115         mServerSock = -1;
116     }
117 }
118 
start()119 void HostToGuestComms::start() {
120     if (mIsServer) {
121         mRunLoop->postSocketRecv(
122                 mServerSock,
123                 makeSafeCallback(this, &HostToGuestComms::onServerConnection));
124     } else {
125         mRunLoop->postWithDelay(
126                 std::chrono::milliseconds(5000),
127                 makeSafeCallback(
128                     this,
129                     &HostToGuestComms::onAttemptToConnect,
130                     mConnectToAddr));
131     }
132 }
133 
send(const void * data,size_t size,bool addFraming)134 void HostToGuestComms::send(const void *data, size_t size, bool addFraming) {
135     if (!size) {
136         return;
137     }
138 
139     std::lock_guard autoLock(mLock);
140 
141     size_t offset = mOutBuffer.size();
142 
143     if (addFraming) {
144         uint32_t packetLen = size;
145         size_t totalSize = sizeof(packetLen) + size;
146 
147         mOutBuffer.resize(offset + totalSize);
148         memcpy(mOutBuffer.data() + offset, &packetLen, sizeof(packetLen));
149         memcpy(mOutBuffer.data() + offset + sizeof(packetLen), data, size);
150     } else {
151         mOutBuffer.resize(offset + size);
152         memcpy(mOutBuffer.data() + offset, data, size);
153     }
154 
155     if (mSock >= 0 && (mIsServer || mConnected) && !mSendPending) {
156         mSendPending = true;
157         mRunLoop->postSocketSend(
158                 mSock,
159                 makeSafeCallback(this, &HostToGuestComms::onSocketSend));
160     }
161 }
162 
onServerConnection()163 void HostToGuestComms::onServerConnection() {
164     int s = accept(mServerSock, nullptr, nullptr);
165 
166     if (s >= 0) {
167         if (mSock >= 0) {
168             LOG(INFO) << "Rejecting client, we already have one.";
169 
170             // We already have a client.
171             close(s);
172             s = -1;
173         } else {
174             LOG(INFO) << "Accepted client socket " << s << ".";
175 
176             makeFdNonblocking(s);
177 
178             mSock = s;
179             mRunLoop->postSocketRecv(
180                     mSock,
181                     makeSafeCallback(this, &HostToGuestComms::onSocketReceive));
182 
183             std::lock_guard autoLock(mLock);
184             if (!mOutBuffer.empty()) {
185                 CHECK(!mSendPending);
186 
187                 mSendPending = true;
188                 mRunLoop->postSocketSend(
189                         mSock,
190                         makeSafeCallback(
191                             this, &HostToGuestComms::onSocketSend));
192             }
193         }
194     }
195 
196     mRunLoop->postSocketRecv(
197             mServerSock,
198             makeSafeCallback(this, &HostToGuestComms::onServerConnection));
199 }
200 
onSocketReceive()201 void HostToGuestComms::onSocketReceive() {
202     ssize_t n;
203     for (;;) {
204         static constexpr size_t kChunkSize = 65536;
205 
206         mInBuffer.resize(mInBufferLen + kChunkSize);
207 
208         do {
209             n = recv(mSock, mInBuffer.data() + mInBufferLen, kChunkSize, 0);
210         } while (n < 0 && errno == EINTR);
211 
212         if (n <= 0) {
213             break;
214         }
215 
216         mInBufferLen += static_cast<size_t>(n);
217     }
218 
219     int savedErrno = errno;
220 
221     drainInBuffer();
222 
223     if ((n < 0 && savedErrno != EAGAIN && savedErrno != EWOULDBLOCK)
224             || n == 0) {
225         LOG(ERROR) << "Client is gone.";
226 
227         // Client is gone.
228         mRunLoop->cancelSocket(mSock);
229 
230         mSendPending = false;
231 
232         close(mSock);
233         mSock = -1;
234         return;
235     }
236 
237     mRunLoop->postSocketRecv(
238             mSock,
239             makeSafeCallback(this, &HostToGuestComms::onSocketReceive));
240 }
241 
drainInBuffer()242 void HostToGuestComms::drainInBuffer() {
243     for (;;) {
244         uint32_t packetLen;
245 
246         if (mInBufferLen < sizeof(packetLen)) {
247             return;
248         }
249 
250         memcpy(&packetLen, mInBuffer.data(), sizeof(packetLen));
251 
252         size_t totalLen = sizeof(packetLen) + packetLen;
253 
254         if (mInBufferLen < totalLen) {
255             return;
256         }
257 
258         if (mOnReceive) {
259             // LOG(INFO) << "Dispatching packet of size " << packetLen;
260 
261             mOnReceive(mInBuffer.data() + sizeof(packetLen), packetLen);
262         }
263 
264         mInBuffer.erase(mInBuffer.begin(), mInBuffer.begin() + totalLen);
265         mInBufferLen -= totalLen;
266     }
267 }
268 
onSocketSend()269 void HostToGuestComms::onSocketSend() {
270     std::lock_guard autoLock(mLock);
271 
272     CHECK(mSendPending);
273     mSendPending = false;
274 
275     if (mSock < 0) {
276         return;
277     }
278 
279     ssize_t n;
280     while (!mOutBuffer.empty()) {
281         do {
282             n = ::send(mSock, mOutBuffer.data(), mOutBuffer.size(), 0);
283         } while (n < 0 && errno == EINTR);
284 
285         if (n <= 0) {
286             break;
287         }
288 
289         mOutBuffer.erase(mOutBuffer.begin(), mOutBuffer.begin() + n);
290     }
291 
292     if ((n < 0 && errno != EAGAIN && errno != EWOULDBLOCK) || n == 0) {
293         LOG(ERROR) << "Client is gone.";
294 
295         // Client is gone.
296         mRunLoop->cancelSocket(mSock);
297 
298         close(mSock);
299         mSock = -1;
300         return;
301     }
302 
303     if (!mOutBuffer.empty()) {
304         mSendPending = true;
305         mRunLoop->postSocketSend(
306                 mSock,
307                 makeSafeCallback(this, &HostToGuestComms::onSocketSend));
308     }
309 }
310 
onAttemptToConnect(const sockaddr_vm & addr)311 void HostToGuestComms::onAttemptToConnect(const sockaddr_vm &addr) {
312     LOG(VERBOSE) << "Attempting to connect to cid " << addr.svm_cid;
313 
314     int res;
315     do {
316         res = connect(
317             mSock, reinterpret_cast<const sockaddr *>(&addr), sizeof(addr));
318     } while (res < 0 && errno == EINTR);
319 
320     if (res < 0) {
321         if (errno == EINPROGRESS) {
322             LOG(VERBOSE) << "EINPROGRESS, waiting to check the connection.";
323 
324             mRunLoop->postSocketSend(
325                     mSock,
326                     makeSafeCallback(
327                         this, &HostToGuestComms::onCheckConnection, addr));
328 
329             return;
330         }
331 
332         LOG(INFO)
333             << "Our attempt to connect to the guest FAILED w/ error "
334             << errno
335             << " ("
336             << strerror(errno)
337             << "), will try again shortly.";
338 
339         mRunLoop->postWithDelay(
340                 std::chrono::milliseconds(5000),
341                 makeSafeCallback(
342                     this, &HostToGuestComms::onAttemptToConnect, addr));
343 
344         return;
345     }
346 
347     onConnected();
348 }
349 
onCheckConnection(const sockaddr_vm & addr)350 void HostToGuestComms::onCheckConnection(const sockaddr_vm &addr) {
351     int err;
352 
353     int res;
354     do {
355         socklen_t errSize = sizeof(err);
356 
357         res = getsockopt(mSock, SOL_SOCKET, SO_ERROR, &err, &errSize);
358     } while (res < 0 && errno == EINTR);
359 
360     CHECK(!res);
361 
362     if (!err) {
363         onConnected();
364     } else {
365         LOG(VERBOSE)
366             << "Connection failed w/ error "
367             << err
368             << " ("
369             << strerror(err)
370             << "), will try again shortly.";
371 
372         // Is there a better way of cancelling the (failed) connection that
373         // somehow is still in progress on the socket and restarting it?
374         mRunLoop->cancelSocket(mSock);
375 
376         close(mSock);
377         mSock = socket(AF_VSOCK, SOCK_STREAM, 0);
378         CHECK_GE(mSock, 0);
379 
380         makeFdNonblocking(mSock);
381 
382         mRunLoop->postWithDelay(
383                 std::chrono::milliseconds(5000),
384                 makeSafeCallback(
385                     this, &HostToGuestComms::onAttemptToConnect, addr));
386     }
387 }
388 
onConnected()389 void HostToGuestComms::onConnected() {
390     LOG(INFO) << "Connected to guest.";
391 
392     std::lock_guard autoLock(mLock);
393 
394     mConnected = true;
395     CHECK(!mSendPending);
396 
397     if (!mOutBuffer.empty()) {
398         mSendPending = true;
399         mRunLoop->postSocketSend(
400                 mSock,
401                 makeSafeCallback(this, &HostToGuestComms::onSocketSend));
402     }
403 
404     mRunLoop->postSocketRecv(
405             mSock,
406             makeSafeCallback(this, &HostToGuestComms::onSocketReceive));
407 }
408 
409