1 /*
2  * Copyright (C) 2016 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "Netd"
18 
19 #include "SockDiag.h"
20 
21 #include <errno.h>
22 #include <linux/inet_diag.h>
23 #include <linux/netlink.h>
24 #include <linux/sock_diag.h>
25 #include <netdb.h>
26 #include <netinet/in.h>
27 #include <netinet/tcp.h>
28 #include <string.h>
29 #include <sys/socket.h>
30 #include <sys/uio.h>
31 
32 #include <cinttypes>
33 
34 #include <android-base/strings.h>
35 #include <log/log.h>
36 #include <netdutils/InternetAddresses.h>
37 #include <netdutils/Stopwatch.h>
38 
39 #include "Permission.h"
40 
41 #ifndef SOCK_DESTROY
42 #define SOCK_DESTROY 21
43 #endif
44 
45 #define INET_DIAG_BC_MARK_COND 10
46 
47 namespace android {
48 
49 using netdutils::ScopedAddrinfo;
50 using netdutils::Stopwatch;
51 
52 namespace net {
53 namespace {
54 
checkError(int fd)55 int checkError(int fd) {
56     struct {
57         nlmsghdr h;
58         nlmsgerr err;
59     } __attribute__((__packed__)) ack;
60     ssize_t bytesread = recv(fd, &ack, sizeof(ack), MSG_DONTWAIT | MSG_PEEK);
61     if (bytesread == -1) {
62        // Read failed (error), or nothing to read (good).
63        return (errno == EAGAIN) ? 0 : -errno;
64     } else if (bytesread == (ssize_t) sizeof(ack) && ack.h.nlmsg_type == NLMSG_ERROR) {
65         // We got an error. Consume it.
66         recv(fd, &ack, sizeof(ack), 0);
67         return ack.err.error;
68     } else {
69         // The kernel replied with something. Leave it to the caller.
70         return 0;
71     }
72 }
73 
74 }  // namespace
75 
open()76 bool SockDiag::open() {
77     if (hasSocks()) {
78         return false;
79     }
80 
81     mSock = socket(PF_NETLINK, SOCK_DGRAM | SOCK_CLOEXEC, NETLINK_INET_DIAG);
82     mWriteSock = socket(PF_NETLINK, SOCK_DGRAM | SOCK_CLOEXEC, NETLINK_INET_DIAG);
83     if (!hasSocks()) {
84         closeSocks();
85         return false;
86     }
87 
88     sockaddr_nl nl = { .nl_family = AF_NETLINK };
89     if ((connect(mSock, reinterpret_cast<sockaddr *>(&nl), sizeof(nl)) == -1) ||
90         (connect(mWriteSock, reinterpret_cast<sockaddr *>(&nl), sizeof(nl)) == -1)) {
91         closeSocks();
92         return false;
93     }
94 
95     return true;
96 }
97 
sendDumpRequest(uint8_t proto,uint8_t family,uint8_t extensions,uint32_t states,iovec * iov,int iovcnt)98 int SockDiag::sendDumpRequest(uint8_t proto, uint8_t family, uint8_t extensions, uint32_t states,
99                               iovec *iov, int iovcnt) {
100     struct {
101         nlmsghdr nlh;
102         inet_diag_req_v2 req;
103     } __attribute__((__packed__)) request = {
104         .nlh = {
105             .nlmsg_type = SOCK_DIAG_BY_FAMILY,
106             .nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP,
107         },
108         .req = {
109             .sdiag_family = family,
110             .sdiag_protocol = proto,
111             .idiag_ext = extensions,
112             .idiag_states = states,
113         },
114     };
115 
116     size_t len = 0;
117     iov[0].iov_base = &request;
118     iov[0].iov_len = sizeof(request);
119     for (int i = 0; i < iovcnt; i++) {
120         len += iov[i].iov_len;
121     }
122     request.nlh.nlmsg_len = len;
123 
124     if (writev(mSock, iov, iovcnt) != (ssize_t) len) {
125         return -errno;
126     }
127 
128     return checkError(mSock);
129 }
130 
sendDumpRequest(uint8_t proto,uint8_t family,uint32_t states)131 int SockDiag::sendDumpRequest(uint8_t proto, uint8_t family, uint32_t states) {
132     iovec iov[] = {
133         { nullptr, 0 },
134     };
135     return sendDumpRequest(proto, family, 0, states, iov, ARRAY_SIZE(iov));
136 }
137 
sendDumpRequest(uint8_t proto,uint8_t family,const char * addrstr)138 int SockDiag::sendDumpRequest(uint8_t proto, uint8_t family, const char *addrstr) {
139     addrinfo hints = { .ai_flags = AI_NUMERICHOST };
140     addrinfo *res;
141     in6_addr mapped = { .s6_addr32 = { 0, 0, htonl(0xffff), 0 } };
142 
143     // TODO: refactor the netlink parsing code out of system/core, bring it into netd, and stop
144     // doing string conversions when they're not necessary.
145     int ret = getaddrinfo(addrstr, nullptr, &hints, &res);
146     if (ret != 0) return -EINVAL;
147 
148     // So we don't have to call freeaddrinfo on every failure path.
149     ScopedAddrinfo resP(res);
150 
151     void *addr;
152     uint8_t addrlen;
153     if (res->ai_family == AF_INET && family == AF_INET) {
154         in_addr& ina = reinterpret_cast<sockaddr_in*>(res->ai_addr)->sin_addr;
155         addr = &ina;
156         addrlen = sizeof(ina);
157     } else if (res->ai_family == AF_INET && family == AF_INET6) {
158         in_addr& ina = reinterpret_cast<sockaddr_in*>(res->ai_addr)->sin_addr;
159         mapped.s6_addr32[3] = ina.s_addr;
160         addr = &mapped;
161         addrlen = sizeof(mapped);
162     } else if (res->ai_family == AF_INET6 && family == AF_INET6) {
163         in6_addr& in6a = reinterpret_cast<sockaddr_in6*>(res->ai_addr)->sin6_addr;
164         addr = &in6a;
165         addrlen = sizeof(in6a);
166     } else {
167         return -EAFNOSUPPORT;
168     }
169 
170     uint8_t prefixlen = addrlen * 8;
171     uint8_t yesjump = sizeof(inet_diag_bc_op) + sizeof(inet_diag_hostcond) + addrlen;
172     uint8_t nojump = yesjump + 4;
173 
174     struct {
175         nlattr nla;
176         inet_diag_bc_op op;
177         inet_diag_hostcond cond;
178     } __attribute__((__packed__)) attrs = {
179         .nla = {
180             .nla_type = INET_DIAG_REQ_BYTECODE,
181         },
182         .op = {
183             INET_DIAG_BC_S_COND,
184             yesjump,
185             nojump,
186         },
187         .cond = {
188             family,
189             prefixlen,
190             -1,
191             {}
192         },
193     };
194 
195     attrs.nla.nla_len = sizeof(attrs) + addrlen;
196 
197     iovec iov[] = {
198         { nullptr,           0 },
199         { &attrs,            sizeof(attrs) },
200         { addr,              addrlen },
201     };
202 
203     uint32_t states = ~(1 << TCP_TIME_WAIT);
204     return sendDumpRequest(proto, family, 0, states, iov, ARRAY_SIZE(iov));
205 }
206 
readDiagMsg(uint8_t proto,const SockDiag::DestroyFilter & shouldDestroy)207 int SockDiag::readDiagMsg(uint8_t proto, const SockDiag::DestroyFilter& shouldDestroy) {
208     NetlinkDumpCallback callback = [this, proto, shouldDestroy] (nlmsghdr *nlh) {
209         const inet_diag_msg *msg = reinterpret_cast<inet_diag_msg *>(NLMSG_DATA(nlh));
210         if (shouldDestroy(proto, msg)) {
211             sockDestroy(proto, msg);
212         }
213     };
214 
215     return processNetlinkDump(mSock, callback);
216 }
217 
readDiagMsgWithTcpInfo(const TcpInfoReader & tcpInfoReader)218 int SockDiag::readDiagMsgWithTcpInfo(const TcpInfoReader& tcpInfoReader) {
219     NetlinkDumpCallback callback = [tcpInfoReader] (nlmsghdr *nlh) {
220         if (nlh->nlmsg_type != SOCK_DIAG_BY_FAMILY) {
221             ALOGE("expected nlmsg_type=SOCK_DIAG_BY_FAMILY, got nlmsg_type=%d", nlh->nlmsg_type);
222             return;
223         }
224         Fwmark mark;
225         struct tcp_info *tcpinfo = nullptr;
226         uint32_t tcpinfoLength = 0;
227         inet_diag_msg *msg = reinterpret_cast<inet_diag_msg *>(NLMSG_DATA(nlh));
228         uint32_t attr_len = nlh->nlmsg_len - NLMSG_LENGTH(sizeof(*msg));
229         struct rtattr *attr = reinterpret_cast<struct rtattr*>(msg+1);
230         while (RTA_OK(attr, attr_len)) {
231             if (attr->rta_type == INET_DIAG_INFO) {
232                 tcpinfo = reinterpret_cast<struct tcp_info*>(RTA_DATA(attr));
233                 tcpinfoLength = RTA_PAYLOAD(attr);
234             }
235             if (attr->rta_type == INET_DIAG_MARK) {
236                 mark.intValue = *reinterpret_cast<uint32_t*>(RTA_DATA(attr));
237             }
238             attr = RTA_NEXT(attr, attr_len);
239         }
240 
241         tcpInfoReader(mark, msg, tcpinfo, tcpinfoLength);
242     };
243 
244     return processNetlinkDump(mSock, callback);
245 }
246 
247 // Determines whether a socket is a loopback socket. Does not check socket state.
isLoopbackSocket(const inet_diag_msg * msg)248 bool SockDiag::isLoopbackSocket(const inet_diag_msg *msg) {
249     switch (msg->idiag_family) {
250         case AF_INET:
251             // Old kernels only copy the IPv4 address and leave the other 12 bytes uninitialized.
252             return IN_LOOPBACK(htonl(msg->id.idiag_src[0])) ||
253                    IN_LOOPBACK(htonl(msg->id.idiag_dst[0])) ||
254                    msg->id.idiag_src[0] == msg->id.idiag_dst[0];
255 
256         case AF_INET6: {
257             const struct in6_addr *src = (const struct in6_addr *) &msg->id.idiag_src;
258             const struct in6_addr *dst = (const struct in6_addr *) &msg->id.idiag_dst;
259             return (IN6_IS_ADDR_V4MAPPED(src) && IN_LOOPBACK(src->s6_addr32[3])) ||
260                    (IN6_IS_ADDR_V4MAPPED(dst) && IN_LOOPBACK(dst->s6_addr32[3])) ||
261                    IN6_IS_ADDR_LOOPBACK(src) || IN6_IS_ADDR_LOOPBACK(dst) ||
262                    !memcmp(src, dst, sizeof(*src));
263         }
264         default:
265             return false;
266     }
267 }
268 
sockDestroy(uint8_t proto,const inet_diag_msg * msg)269 int SockDiag::sockDestroy(uint8_t proto, const inet_diag_msg *msg) {
270     if (msg == nullptr) {
271        return 0;
272     }
273 
274     DestroyRequest request = {
275         .nlh = {
276             .nlmsg_type = SOCK_DESTROY,
277             .nlmsg_flags = NLM_F_REQUEST,
278         },
279         .req = {
280             .sdiag_family = msg->idiag_family,
281             .sdiag_protocol = proto,
282             .idiag_states = (uint32_t) (1 << msg->idiag_state),
283             .id = msg->id,
284         },
285     };
286     request.nlh.nlmsg_len = sizeof(request);
287 
288     if (write(mWriteSock, &request, sizeof(request)) < (ssize_t) sizeof(request)) {
289         return -errno;
290     }
291 
292     int ret = checkError(mWriteSock);
293     if (!ret) mSocketsDestroyed++;
294     return ret;
295 }
296 
destroySockets(uint8_t proto,int family,const char * addrstr)297 int SockDiag::destroySockets(uint8_t proto, int family, const char *addrstr) {
298     if (!hasSocks()) {
299         return -EBADFD;
300     }
301 
302     if (int ret = sendDumpRequest(proto, family, addrstr)) {
303         return ret;
304     }
305 
306     auto destroyAll = [] (uint8_t, const inet_diag_msg*) { return true; };
307 
308     return readDiagMsg(proto, destroyAll);
309 }
310 
destroySockets(const char * addrstr)311 int SockDiag::destroySockets(const char *addrstr) {
312     Stopwatch s;
313     mSocketsDestroyed = 0;
314 
315     if (!strchr(addrstr, ':')) {
316         if (int ret = destroySockets(IPPROTO_TCP, AF_INET, addrstr)) {
317             ALOGE("Failed to destroy IPv4 sockets on %s: %s", addrstr, strerror(-ret));
318             return ret;
319         }
320     }
321     if (int ret = destroySockets(IPPROTO_TCP, AF_INET6, addrstr)) {
322         ALOGE("Failed to destroy IPv6 sockets on %s: %s", addrstr, strerror(-ret));
323         return ret;
324     }
325 
326     if (mSocketsDestroyed > 0) {
327         ALOGI("Destroyed %d sockets on %s in %" PRId64 "us", mSocketsDestroyed, addrstr,
328               s.timeTakenUs());
329     }
330 
331     return mSocketsDestroyed;
332 }
333 
destroyLiveSockets(const DestroyFilter & destroyFilter,const char * what,iovec * iov,int iovcnt)334 int SockDiag::destroyLiveSockets(const DestroyFilter& destroyFilter, const char *what,
335                                  iovec *iov, int iovcnt) {
336     const int proto = IPPROTO_TCP;
337     const uint32_t states = (1 << TCP_ESTABLISHED) | (1 << TCP_SYN_SENT) | (1 << TCP_SYN_RECV);
338 
339     for (const int family : {AF_INET, AF_INET6}) {
340         const char *familyName = (family == AF_INET) ? "IPv4" : "IPv6";
341         if (int ret = sendDumpRequest(proto, family, 0, states, iov, iovcnt)) {
342             ALOGE("Failed to dump %s sockets for %s: %s", familyName, what, strerror(-ret));
343             return ret;
344         }
345         if (int ret = readDiagMsg(proto, destroyFilter)) {
346             ALOGE("Failed to destroy %s sockets for %s: %s", familyName, what, strerror(-ret));
347             return ret;
348         }
349     }
350 
351     return 0;
352 }
353 
getLiveTcpInfos(const TcpInfoReader & tcpInfoReader)354 int SockDiag::getLiveTcpInfos(const TcpInfoReader& tcpInfoReader) {
355     const int proto = IPPROTO_TCP;
356     const uint32_t states = (1 << TCP_ESTABLISHED) | (1 << TCP_SYN_SENT) | (1 << TCP_SYN_RECV);
357     const uint8_t extensions = (1 << INET_DIAG_MEMINFO); // flag for dumping struct tcp_info.
358 
359     iovec iov[] = {
360         { nullptr, 0 },
361     };
362 
363     for (const int family : {AF_INET, AF_INET6}) {
364         const char *familyName = (family == AF_INET) ? "IPv4" : "IPv6";
365         if (int ret = sendDumpRequest(proto, family, extensions, states, iov, ARRAY_SIZE(iov))) {
366             ALOGE("Failed to dump %s sockets struct tcp_info: %s", familyName, strerror(-ret));
367             return ret;
368         }
369         if (int ret = readDiagMsgWithTcpInfo(tcpInfoReader)) {
370             ALOGE("Failed to read %s sockets struct tcp_info: %s", familyName, strerror(-ret));
371             return ret;
372         }
373     }
374 
375     return 0;
376 }
377 
destroySockets(uint8_t proto,const uid_t uid,bool excludeLoopback)378 int SockDiag::destroySockets(uint8_t proto, const uid_t uid, bool excludeLoopback) {
379     mSocketsDestroyed = 0;
380     Stopwatch s;
381 
382     auto shouldDestroy = [uid, excludeLoopback] (uint8_t, const inet_diag_msg *msg) {
383         return msg != nullptr &&
384                msg->idiag_uid == uid &&
385                !(excludeLoopback && isLoopbackSocket(msg));
386     };
387 
388     for (const int family : {AF_INET, AF_INET6}) {
389         const char *familyName = family == AF_INET ? "IPv4" : "IPv6";
390         uint32_t states = (1 << TCP_ESTABLISHED) | (1 << TCP_SYN_SENT) | (1 << TCP_SYN_RECV);
391         if (int ret = sendDumpRequest(proto, family, states)) {
392             ALOGE("Failed to dump %s sockets for UID: %s", familyName, strerror(-ret));
393             return ret;
394         }
395         if (int ret = readDiagMsg(proto, shouldDestroy)) {
396             ALOGE("Failed to destroy %s sockets for UID: %s", familyName, strerror(-ret));
397             return ret;
398         }
399     }
400 
401     if (mSocketsDestroyed > 0) {
402         ALOGI("Destroyed %d sockets for UID in %" PRId64 "us", mSocketsDestroyed, s.timeTakenUs());
403     }
404 
405     return 0;
406 }
407 
destroySockets(const UidRanges & uidRanges,const std::set<uid_t> & skipUids,bool excludeLoopback)408 int SockDiag::destroySockets(const UidRanges& uidRanges, const std::set<uid_t>& skipUids,
409                              bool excludeLoopback) {
410     mSocketsDestroyed = 0;
411     Stopwatch s;
412 
413     auto shouldDestroy = [&] (uint8_t, const inet_diag_msg *msg) {
414         return msg != nullptr &&
415                uidRanges.hasUid(msg->idiag_uid) &&
416                skipUids.find(msg->idiag_uid) == skipUids.end() &&
417                !(excludeLoopback && isLoopbackSocket(msg));
418     };
419 
420     iovec iov[] = {
421         { nullptr, 0 },
422     };
423 
424     if (int ret = destroyLiveSockets(shouldDestroy, "UID", iov, ARRAY_SIZE(iov))) {
425         return ret;
426     }
427 
428     if (mSocketsDestroyed > 0) {
429         ALOGI("Destroyed %d sockets for %s skip={%s} in %" PRId64 "us", mSocketsDestroyed,
430               uidRanges.toString().c_str(), android::base::Join(skipUids, " ").c_str(),
431               s.timeTakenUs());
432     }
433 
434     return 0;
435 }
436 
437 // Destroys all "live" (CONNECTED, SYN_SENT, SYN_RECV) TCP sockets on the specified netId where:
438 // 1. The opening app no longer has permission to use this network, or:
439 // 2. The opening app does have permission, but did not explicitly select this network.
440 //
441 // We destroy sockets without the explicit bit because we want to avoid the situation where a
442 // privileged app uses its privileges without knowing it is doing so. For example, a privileged app
443 // might have opened a socket on this network just because it was the default network at the
444 // time. If we don't kill these sockets, those apps could continue to use them without realizing
445 // that they are now sending and receiving traffic on a network that is now restricted.
destroySocketsLackingPermission(unsigned netId,Permission permission,bool excludeLoopback)446 int SockDiag::destroySocketsLackingPermission(unsigned netId, Permission permission,
447                                               bool excludeLoopback) {
448     struct markmatch {
449         inet_diag_bc_op op;
450         // TODO: switch to inet_diag_markcond
451         __u32 mark;
452         __u32 mask;
453     } __attribute__((packed));
454     constexpr uint8_t matchlen = sizeof(markmatch);
455 
456     Fwmark netIdMark, netIdMask;
457     netIdMark.netId = netId;
458     netIdMask.netId = 0xffff;
459 
460     Fwmark controlMark;
461     controlMark.explicitlySelected = true;
462     controlMark.permission = permission;
463 
464     // A SOCK_DIAG bytecode program that accepts the sockets we intend to destroy.
465     struct bytecode {
466         markmatch netIdMatch;
467         markmatch controlMatch;
468         inet_diag_bc_op controlJump;
469     } __attribute__((packed)) bytecode;
470 
471     // The length of the INET_DIAG_BC_JMP instruction.
472     constexpr uint8_t jmplen = sizeof(inet_diag_bc_op);
473     // Jump exactly this far past the end of the program to reject.
474     constexpr uint8_t rejectoffset = sizeof(inet_diag_bc_op);
475     // Total length of the program.
476     constexpr uint8_t bytecodelen = sizeof(bytecode);
477 
478     bytecode = (struct bytecode) {
479         // If netId matches, continue, otherwise, reject (i.e., leave socket alone).
480         { { INET_DIAG_BC_MARK_COND, matchlen, bytecodelen + rejectoffset },
481           netIdMark.intValue, netIdMask.intValue },
482 
483         // If explicit and permission bits match, go to the JMP below which rejects the socket
484         // (i.e., we leave it alone). Otherwise, jump to the end of the program, which accepts the
485         // socket (so we destroy it).
486         { { INET_DIAG_BC_MARK_COND, matchlen, matchlen + jmplen },
487           controlMark.intValue, controlMark.intValue },
488 
489         // This JMP unconditionally rejects the packet by jumping to the reject target. It is
490         // necessary to keep the kernel bytecode verifier happy. If we don't have a JMP the bytecode
491         // is invalid because the target of every no jump must always be reachable by yes jumps.
492         // Without this JMP, the accept target is not reachable by yes jumps and the program will
493         // be rejected by the validator.
494         { INET_DIAG_BC_JMP, jmplen, jmplen + rejectoffset },
495 
496         // We have reached the end of the program. Accept the socket, and destroy it below.
497     };
498 
499     struct nlattr nla = {
500             .nla_len = sizeof(struct nlattr) + bytecodelen,
501             .nla_type = INET_DIAG_REQ_BYTECODE,
502     };
503 
504     iovec iov[] = {
505         { nullptr,   0 },
506         { &nla,      sizeof(nla) },
507         { &bytecode, bytecodelen },
508     };
509 
510     mSocketsDestroyed = 0;
511     Stopwatch s;
512 
513     auto shouldDestroy = [&] (uint8_t, const inet_diag_msg *msg) {
514         return msg != nullptr && !(excludeLoopback && isLoopbackSocket(msg));
515     };
516 
517     if (int ret = destroyLiveSockets(shouldDestroy, "permission change", iov, ARRAY_SIZE(iov))) {
518         return ret;
519     }
520 
521     if (mSocketsDestroyed > 0) {
522         ALOGI("Destroyed %d sockets for netId %d permission=%d in %" PRId64 "us", mSocketsDestroyed,
523               netId, permission, s.timeTakenUs());
524     }
525 
526     return 0;
527 }
528 
529 }  // namespace net
530 }  // namespace android
531