1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <webrtc/AdbWebSocketHandler.h>
18 
19 #include "Utils.h"
20 
21 #include <https/BaseConnection.h>
22 #include <https/Support.h>
23 
24 #include <android-base/logging.h>
25 
26 #include <unistd.h>
27 
28 using namespace android;
29 
30 struct AdbWebSocketHandler::AdbConnection : public BaseConnection {
31     explicit AdbConnection(
32             AdbWebSocketHandler *parent,
33             std::shared_ptr<RunLoop> runLoop,
34             int sock);
35 
36     void send(const void *_data, size_t size);
37 
38 protected:
39     ssize_t processClientRequest(const void *data, size_t size) override;
40     void onDisconnect(int err) override;
41 
42 private:
43     AdbWebSocketHandler *mParent;
44 };
45 
46 ////////////////////////////////////////////////////////////////////////////////
47 
AdbConnection(AdbWebSocketHandler * parent,std::shared_ptr<RunLoop> runLoop,int sock)48 AdbWebSocketHandler::AdbConnection::AdbConnection(
49         AdbWebSocketHandler *parent,
50         std::shared_ptr<RunLoop> runLoop,
51         int sock)
52     : BaseConnection(runLoop, sock),
53       mParent(parent) {
54 }
55 
56 // Thanks for calling it a crc32, adb documentation!
computeNotACrc32(const void * _data,size_t size)57 static uint32_t computeNotACrc32(const void *_data, size_t size) {
58     auto data = static_cast<const uint8_t *>(_data);
59     uint32_t sum = 0;
60     for (size_t i = 0; i < size; ++i) {
61         sum += data[i];
62     }
63 
64     return sum;
65 }
66 
verifyAdbHeader(const void * _data,size_t size,size_t * _payloadLength)67 static int verifyAdbHeader(
68         const void *_data, size_t size, size_t *_payloadLength) {
69     auto data = static_cast<const uint8_t *>(_data);
70 
71     *_payloadLength = 0;
72 
73     if (size < 24) {
74         return -EAGAIN;
75     }
76 
77     uint32_t command = U32LE_AT(data);
78     uint32_t magic = U32LE_AT(data + 20);
79 
80     if (command != (magic ^ 0xffffffff)) {
81         return -EINVAL;
82     }
83 
84     uint32_t payloadLength = U32LE_AT(data + 12);
85 
86     if (size < 24 + payloadLength) {
87         return -EAGAIN;
88     }
89 
90     auto payloadCrc = U32LE_AT(data + 16);
91     auto crc32 = computeNotACrc32(data + 24, payloadLength);
92 
93     if (payloadCrc != crc32) {
94         return -EINVAL;
95     }
96 
97     *_payloadLength = payloadLength;
98 
99     return 0;
100 }
101 
processClientRequest(const void * _data,size_t size)102 ssize_t AdbWebSocketHandler::AdbConnection::processClientRequest(
103         const void *_data, size_t size) {
104     auto data = static_cast<const uint8_t *>(_data);
105 
106     LOG(VERBOSE)
107         << "AdbConnection::processClientRequest (size = " << size << ")";
108 
109     // hexdump(data, size);
110 
111     size_t payloadLength;
112     int err = verifyAdbHeader(data, size, &payloadLength);
113 
114     if (err) {
115         return err;
116     }
117 
118     mParent->sendMessage(
119             data, payloadLength + 24, WebSocketHandler::SendMode::binary);
120 
121     return payloadLength + 24;
122 }
123 
onDisconnect(int err)124 void AdbWebSocketHandler::AdbConnection::onDisconnect(int err) {
125     LOG(INFO) << "AdbConnection::onDisconnect(err=" << err << ")";
126 
127     mParent->sendMessage(
128             nullptr /* data */,
129             0 /* size */,
130             WebSocketHandler::SendMode::closeConnection);
131 }
132 
send(const void * _data,size_t size)133 void AdbWebSocketHandler::AdbConnection::send(const void *_data, size_t size) {
134     BaseConnection::send(_data, size);
135 }
136 
137 ////////////////////////////////////////////////////////////////////////////////
138 
AdbWebSocketHandler(std::shared_ptr<RunLoop> runLoop,const std::string & adb_host_and_port)139 AdbWebSocketHandler::AdbWebSocketHandler(
140         std::shared_ptr<RunLoop> runLoop,
141         const std::string &adb_host_and_port)
142     : mRunLoop(runLoop),
143       mSocket(-1) {
144     LOG(INFO) << "Connecting to " << adb_host_and_port;
145 
146     auto err = setupSocket(adb_host_and_port);
147     CHECK(!err);
148 
149     mAdbConnection = std::make_shared<AdbConnection>(this, mRunLoop, mSocket);
150 }
151 
~AdbWebSocketHandler()152 AdbWebSocketHandler::~AdbWebSocketHandler() {
153     if (mSocket >= 0) {
154         close(mSocket);
155         mSocket = -1;
156     }
157 }
158 
run()159 void AdbWebSocketHandler::run() {
160     mAdbConnection->run();
161 }
162 
setupSocket(const std::string & adb_host_and_port)163 int AdbWebSocketHandler::setupSocket(const std::string &adb_host_and_port) {
164     auto colonPos = adb_host_and_port.find(':');
165     if (colonPos == std::string::npos) {
166         return -EINVAL;
167     }
168 
169     auto host = adb_host_and_port.substr(0, colonPos);
170 
171     const char *portString = adb_host_and_port.c_str() + colonPos + 1;
172     char *end;
173     unsigned long port = strtoul(portString, &end, 10);
174 
175     if (end == portString || *end != '\0' || port > 65535) {
176         return -EINVAL;
177     }
178 
179     int err;
180 
181     int sock = socket(PF_INET, SOCK_STREAM, 0);
182 
183     if (sock < 0) {
184         err = -errno;
185         goto bail;
186     }
187 
188     makeFdNonblocking(sock);
189 
190     sockaddr_in addr;
191     memset(addr.sin_zero, 0, sizeof(addr.sin_zero));
192     addr.sin_family = AF_INET;
193     addr.sin_addr.s_addr = inet_addr(host.c_str());
194     addr.sin_port = htons(port);
195 
196     if (connect(sock,
197                 reinterpret_cast<const sockaddr *>(&addr),
198                 sizeof(addr)) < 0
199             && errno != EINPROGRESS) {
200         err = -errno;
201         goto bail2;
202     }
203 
204     mSocket = sock;
205 
206     return 0;
207 
208 bail2:
209     close(sock);
210     sock = -1;
211 
212 bail:
213     return err;
214 }
215 
handleMessage(uint8_t headerByte,const uint8_t * msg,size_t len)216 int AdbWebSocketHandler::handleMessage(
217         uint8_t headerByte, const uint8_t *msg, size_t len) {
218     LOG(VERBOSE)
219         << "headerByte = "
220         << StringPrintf("0x%02x", (unsigned)headerByte);
221 
222     // hexdump(msg, len);
223 
224     if (!(headerByte & 0x80)) {
225         // I only want to receive whole messages here, not fragments.
226         return -EINVAL;
227     }
228 
229     auto opcode = headerByte & 0x1f;
230     switch (opcode) {
231         case 0x8:
232         {
233             // closeConnection.
234             break;
235         }
236 
237         case 0x2:
238         {
239             // binary
240 
241             size_t payloadLength;
242             int err = verifyAdbHeader(msg, len, &payloadLength);
243 
244             if (err || len != 24 + payloadLength) {
245                 LOG(ERROR) << "websocket message is not a valid adb message.";
246                 return -EINVAL;
247             }
248 
249             mAdbConnection->send(msg, len);
250             break;
251         }
252 
253         default:
254             return -EINVAL;
255     }
256 
257     return 0;
258 }
259 
260