1 /*
2  * Copyright 2017, The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "socket.h"
18 
19 #include "message.h"
20 #include "utils.h"
21 
22 #include <errno.h>
23 #include <linux/if_packet.h>
24 #include <netinet/ip.h>
25 #include <netinet/udp.h>
26 #include <string.h>
27 #include <sys/socket.h>
28 #include <sys/types.h>
29 #include <sys/uio.h>
30 #include <unistd.h>
31 
32 // Combine the checksum of |buffer| with |size| bytes with |checksum|. This is
33 // used for checksum calculations for IP and UDP.
addChecksum(const uint8_t * buffer,size_t size,uint32_t checksum)34 static uint32_t addChecksum(const uint8_t* buffer,
35                             size_t size,
36                             uint32_t checksum) {
37     const uint16_t* data = reinterpret_cast<const uint16_t*>(buffer);
38     while (size > 1) {
39         checksum += *data++;
40         size -= 2;
41     }
42     if (size > 0) {
43         // Odd size, add the last byte
44         checksum += *reinterpret_cast<const uint8_t*>(data);
45     }
46     // msw is the most significant word, the upper 16 bits of the checksum
47     for (uint32_t msw = checksum >> 16; msw != 0; msw = checksum >> 16) {
48         checksum = (checksum & 0xFFFF) + msw;
49     }
50     return checksum;
51 }
52 
53 // Convenienct template function for checksum calculation
54 template<typename T>
addChecksum(const T & data,uint32_t checksum)55 static uint32_t addChecksum(const T& data, uint32_t checksum) {
56     return addChecksum(reinterpret_cast<const uint8_t*>(&data), sizeof(T), checksum);
57 }
58 
59 // Finalize the IP or UDP |checksum| by inverting and truncating it.
finishChecksum(uint32_t checksum)60 static uint32_t finishChecksum(uint32_t checksum) {
61     return ~checksum & 0xFFFF;
62 }
63 
Socket()64 Socket::Socket() : mSocketFd(-1) {
65 }
66 
~Socket()67 Socket::~Socket() {
68     if (mSocketFd != -1) {
69         ::close(mSocketFd);
70         mSocketFd = -1;
71     }
72 }
73 
74 
open(int domain,int type,int protocol)75 Result Socket::open(int domain, int type, int protocol) {
76     if (mSocketFd != -1) {
77         return Result::error("Socket already open");
78     }
79     mSocketFd = ::socket(domain, type, protocol);
80     if (mSocketFd == -1) {
81         return Result::error("Failed to open socket: %s", strerror(errno));
82     }
83     return Result::success();
84 }
85 
bind(const void * sockaddr,size_t sockaddrLength)86 Result Socket::bind(const void* sockaddr, size_t sockaddrLength) {
87     if (mSocketFd == -1) {
88         return Result::error("Socket not open");
89     }
90 
91     int status = ::bind(mSocketFd,
92                         reinterpret_cast<const struct sockaddr*>(sockaddr),
93                         sockaddrLength);
94     if (status != 0) {
95         return Result::error("Unable to bind raw socket: %s", strerror(errno));
96     }
97 
98     return Result::success();
99 }
100 
bindIp(in_addr_t address,uint16_t port)101 Result Socket::bindIp(in_addr_t address, uint16_t port) {
102     struct sockaddr_in sockaddr;
103     memset(&sockaddr, 0, sizeof(sockaddr));
104     sockaddr.sin_family = AF_INET;
105     sockaddr.sin_port = htons(port);
106     sockaddr.sin_addr.s_addr = address;
107 
108     return bind(&sockaddr, sizeof(sockaddr));
109 }
110 
bindRaw(unsigned int interfaceIndex)111 Result Socket::bindRaw(unsigned int interfaceIndex) {
112     struct sockaddr_ll sockaddr;
113     memset(&sockaddr, 0, sizeof(sockaddr));
114     sockaddr.sll_family = AF_PACKET;
115     sockaddr.sll_protocol = htons(ETH_P_IP);
116     sockaddr.sll_ifindex = interfaceIndex;
117 
118     return bind(&sockaddr, sizeof(sockaddr));
119 }
120 
sendOnInterface(unsigned int interfaceIndex,in_addr_t destinationAddress,uint16_t destinationPort,const Message & message)121 Result Socket::sendOnInterface(unsigned int interfaceIndex,
122                                in_addr_t destinationAddress,
123                                uint16_t destinationPort,
124                                const Message& message) {
125     if (mSocketFd == -1) {
126         return Result::error("Socket not open");
127     }
128 
129     char controlData[CMSG_SPACE(sizeof(struct in_pktinfo))] = { 0 };
130     struct sockaddr_in addr;
131     memset(&addr, 0, sizeof(addr));
132     addr.sin_family = AF_INET;
133     addr.sin_port = htons(destinationPort);
134     addr.sin_addr.s_addr = destinationAddress;
135 
136     struct msghdr header;
137     memset(&header, 0, sizeof(header));
138     struct iovec iov;
139     // The struct member is non-const since it's used for receiving but it's
140     // safe to cast away const for sending.
141     iov.iov_base = const_cast<uint8_t*>(message.data());
142     iov.iov_len = message.size();
143     header.msg_name = &addr;
144     header.msg_namelen = sizeof(addr);
145     header.msg_iov = &iov;
146     header.msg_iovlen = 1;
147     header.msg_control = &controlData;
148     header.msg_controllen = sizeof(controlData);
149 
150     struct cmsghdr* controlHeader = CMSG_FIRSTHDR(&header);
151     controlHeader->cmsg_level = IPPROTO_IP;
152     controlHeader->cmsg_type = IP_PKTINFO;
153     controlHeader->cmsg_len = CMSG_LEN(sizeof(struct in_pktinfo));
154     auto packetInfo =
155         reinterpret_cast<struct in_pktinfo*>(CMSG_DATA(controlHeader));
156     memset(packetInfo, 0, sizeof(*packetInfo));
157     packetInfo->ipi_ifindex = interfaceIndex;
158 
159     ssize_t status = ::sendmsg(mSocketFd, &header, 0);
160     if (status <= 0) {
161         return Result::error("Failed to send packet: %s", strerror(errno));
162     }
163     return Result::success();
164 }
165 
sendRawUdp(in_addr_t source,uint16_t sourcePort,in_addr_t destination,uint16_t destinationPort,unsigned int interfaceIndex,const Message & message)166 Result Socket::sendRawUdp(in_addr_t source,
167                           uint16_t sourcePort,
168                           in_addr_t destination,
169                           uint16_t destinationPort,
170                           unsigned int interfaceIndex,
171                           const Message& message) {
172     struct iphdr ip;
173     struct udphdr udp;
174 
175     ip.version = IPVERSION;
176     ip.ihl = sizeof(ip) >> 2;
177     ip.tos = 0;
178     ip.tot_len = htons(sizeof(ip) + sizeof(udp) + message.size());
179     ip.id = 0;
180     ip.frag_off = 0;
181     ip.ttl = IPDEFTTL;
182     ip.protocol = IPPROTO_UDP;
183     ip.check = 0;
184     ip.saddr = source;
185     ip.daddr = destination;
186     ip.check = finishChecksum(addChecksum(ip, 0));
187 
188     udp.source = htons(sourcePort);
189     udp.dest = htons(destinationPort);
190     udp.len = htons(sizeof(udp) + message.size());
191     udp.check = 0;
192 
193     uint32_t udpChecksum = 0;
194     udpChecksum = addChecksum(ip.saddr, udpChecksum);
195     udpChecksum = addChecksum(ip.daddr, udpChecksum);
196     udpChecksum = addChecksum(htons(IPPROTO_UDP), udpChecksum);
197     udpChecksum = addChecksum(udp.len, udpChecksum);
198     udpChecksum = addChecksum(udp, udpChecksum);
199     udpChecksum = addChecksum(message.data(), message.size(), udpChecksum);
200     udp.check = finishChecksum(udpChecksum);
201 
202     struct iovec iov[3];
203 
204     iov[0].iov_base = static_cast<void*>(&ip);
205     iov[0].iov_len = sizeof(ip);
206     iov[1].iov_base = static_cast<void*>(&udp);
207     iov[1].iov_len = sizeof(udp);
208     // sendmsg requires these to be non-const but for sending won't modify them
209     iov[2].iov_base = static_cast<void*>(const_cast<uint8_t*>(message.data()));
210     iov[2].iov_len = message.size();
211 
212     struct sockaddr_ll dest;
213     memset(&dest, 0, sizeof(dest));
214     dest.sll_family = AF_PACKET;
215     dest.sll_protocol = htons(ETH_P_IP);
216     dest.sll_ifindex = interfaceIndex;
217     dest.sll_halen = ETH_ALEN;
218     memset(dest.sll_addr, 0xFF, ETH_ALEN);
219 
220     struct msghdr header;
221     memset(&header, 0, sizeof(header));
222     header.msg_name = &dest;
223     header.msg_namelen = sizeof(dest);
224     header.msg_iov = iov;
225     header.msg_iovlen = sizeof(iov) / sizeof(iov[0]);
226 
227     ssize_t res = ::sendmsg(mSocketFd, &header, 0);
228     if (res == -1) {
229         return Result::error("Failed to send message: %s", strerror(errno));
230     }
231     return Result::success();
232 }
233 
receiveFromInterface(Message * message,unsigned int * interfaceIndex)234 Result Socket::receiveFromInterface(Message* message,
235                                     unsigned int* interfaceIndex) {
236     char controlData[CMSG_SPACE(sizeof(struct in_pktinfo))];
237     struct msghdr header;
238     memset(&header, 0, sizeof(header));
239     struct iovec iov;
240     iov.iov_base = message->data();
241     iov.iov_len = message->capacity();
242     header.msg_iov = &iov;
243     header.msg_iovlen = 1;
244     header.msg_control = &controlData;
245     header.msg_controllen = sizeof(controlData);
246 
247     ssize_t bytesRead = ::recvmsg(mSocketFd, &header, 0);
248     if (bytesRead < 0) {
249         return Result::error("Error receiving on socket: %s", strerror(errno));
250     }
251     message->setSize(static_cast<size_t>(bytesRead));
252     if (header.msg_controllen >= sizeof(struct cmsghdr)) {
253         for (struct cmsghdr* ctrl = CMSG_FIRSTHDR(&header);
254              ctrl;
255              ctrl = CMSG_NXTHDR(&header, ctrl)) {
256             if (ctrl->cmsg_level == SOL_IP &&
257                 ctrl->cmsg_type == IP_PKTINFO) {
258                 auto packetInfo =
259                     reinterpret_cast<struct in_pktinfo*>(CMSG_DATA(ctrl));
260                 *interfaceIndex = packetInfo->ipi_ifindex;
261             }
262         }
263     }
264     return Result::success();
265 }
266 
receiveRawUdp(uint16_t expectedPort,Message * message,bool * isValid)267 Result Socket::receiveRawUdp(uint16_t expectedPort,
268                              Message* message,
269                              bool* isValid) {
270     struct iphdr ip;
271     struct udphdr udp;
272 
273     struct iovec iov[3];
274     iov[0].iov_base = &ip;
275     iov[0].iov_len = sizeof(ip);
276     iov[1].iov_base = &udp;
277     iov[1].iov_len = sizeof(udp);
278     iov[2].iov_base = message->data();
279     iov[2].iov_len = message->capacity();
280 
281     ssize_t bytesRead = ::readv(mSocketFd, iov, 3);
282     if (bytesRead < 0) {
283         return Result::error("Unable to read from socket: %s", strerror(errno));
284     }
285     if (static_cast<size_t>(bytesRead) < sizeof(ip) + sizeof(udp)) {
286         // Not enough bytes to even cover IP and UDP headers
287         *isValid = false;
288         return Result::success();
289     }
290     *isValid = ip.version == IPVERSION &&
291                ip.ihl == (sizeof(ip) >> 2) &&
292                ip.protocol == IPPROTO_UDP &&
293                udp.dest == htons(expectedPort);
294 
295     message->setSize(bytesRead - sizeof(ip) - sizeof(udp));
296     return Result::success();
297 }
298 
enableOption(int level,int optionName)299 Result Socket::enableOption(int level, int optionName) {
300     if (mSocketFd == -1) {
301         return Result::error("Socket not open");
302     }
303 
304     int enabled = 1;
305     int status = ::setsockopt(mSocketFd,
306                               level,
307                               optionName,
308                               &enabled,
309                               sizeof(enabled));
310     if (status == -1) {
311         return Result::error("Failed to set socket option: %s",
312                              strerror(errno));
313     }
314     return Result::success();
315 }
316