1 /*
2  * Copyright (C) 2008 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "SocketListener"
18 
19 #include <errno.h>
20 #include <stdio.h>
21 #include <stdlib.h>
22 #include <sys/poll.h>
23 #include <sys/socket.h>
24 #include <sys/time.h>
25 #include <sys/types.h>
26 #include <sys/un.h>
27 #include <unistd.h>
28 
29 #include <vector>
30 
31 #include <cutils/sockets.h>
32 #include <log/log.h>
33 #include <sysutils/SocketListener.h>
34 #include <sysutils/SocketClient.h>
35 
36 #define CtrlPipe_Shutdown 0
37 #define CtrlPipe_Wakeup   1
38 
SocketListener(const char * socketName,bool listen)39 SocketListener::SocketListener(const char *socketName, bool listen) {
40     init(socketName, -1, listen, false);
41 }
42 
SocketListener(int socketFd,bool listen)43 SocketListener::SocketListener(int socketFd, bool listen) {
44     init(nullptr, socketFd, listen, false);
45 }
46 
SocketListener(const char * socketName,bool listen,bool useCmdNum)47 SocketListener::SocketListener(const char *socketName, bool listen, bool useCmdNum) {
48     init(socketName, -1, listen, useCmdNum);
49 }
50 
init(const char * socketName,int socketFd,bool listen,bool useCmdNum)51 void SocketListener::init(const char *socketName, int socketFd, bool listen, bool useCmdNum) {
52     mListen = listen;
53     mSocketName = socketName;
54     mSock = socketFd;
55     mUseCmdNum = useCmdNum;
56     pthread_mutex_init(&mClientsLock, nullptr);
57 }
58 
~SocketListener()59 SocketListener::~SocketListener() {
60     if (mSocketName && mSock > -1)
61         close(mSock);
62 
63     if (mCtrlPipe[0] != -1) {
64         close(mCtrlPipe[0]);
65         close(mCtrlPipe[1]);
66     }
67     for (auto pair : mClients) {
68         pair.second->decRef();
69     }
70 }
71 
startListener()72 int SocketListener::startListener() {
73     return startListener(4);
74 }
75 
startListener(int backlog)76 int SocketListener::startListener(int backlog) {
77 
78     if (!mSocketName && mSock == -1) {
79         SLOGE("Failed to start unbound listener");
80         errno = EINVAL;
81         return -1;
82     } else if (mSocketName) {
83         if ((mSock = android_get_control_socket(mSocketName)) < 0) {
84             SLOGE("Obtaining file descriptor socket '%s' failed: %s",
85                  mSocketName, strerror(errno));
86             return -1;
87         }
88         SLOGV("got mSock = %d for %s", mSock, mSocketName);
89         fcntl(mSock, F_SETFD, FD_CLOEXEC);
90     }
91 
92     if (mListen && listen(mSock, backlog) < 0) {
93         SLOGE("Unable to listen on socket (%s)", strerror(errno));
94         return -1;
95     } else if (!mListen)
96         mClients[mSock] = new SocketClient(mSock, false, mUseCmdNum);
97 
98     if (pipe2(mCtrlPipe, O_CLOEXEC)) {
99         SLOGE("pipe failed (%s)", strerror(errno));
100         return -1;
101     }
102 
103     if (pthread_create(&mThread, nullptr, SocketListener::threadStart, this)) {
104         SLOGE("pthread_create (%s)", strerror(errno));
105         return -1;
106     }
107 
108     return 0;
109 }
110 
stopListener()111 int SocketListener::stopListener() {
112     char c = CtrlPipe_Shutdown;
113     int  rc;
114 
115     rc = TEMP_FAILURE_RETRY(write(mCtrlPipe[1], &c, 1));
116     if (rc != 1) {
117         SLOGE("Error writing to control pipe (%s)", strerror(errno));
118         return -1;
119     }
120 
121     void *ret;
122     if (pthread_join(mThread, &ret)) {
123         SLOGE("Error joining to listener thread (%s)", strerror(errno));
124         return -1;
125     }
126     close(mCtrlPipe[0]);
127     close(mCtrlPipe[1]);
128     mCtrlPipe[0] = -1;
129     mCtrlPipe[1] = -1;
130 
131     if (mSocketName && mSock > -1) {
132         close(mSock);
133         mSock = -1;
134     }
135 
136     for (auto pair : mClients) {
137         delete pair.second;
138     }
139     mClients.clear();
140     return 0;
141 }
142 
threadStart(void * obj)143 void *SocketListener::threadStart(void *obj) {
144     SocketListener *me = reinterpret_cast<SocketListener *>(obj);
145 
146     me->runListener();
147     pthread_exit(nullptr);
148     return nullptr;
149 }
150 
runListener()151 void SocketListener::runListener() {
152     while (true) {
153         std::vector<pollfd> fds;
154 
155         pthread_mutex_lock(&mClientsLock);
156         fds.reserve(2 + mClients.size());
157         fds.push_back({.fd = mCtrlPipe[0], .events = POLLIN});
158         if (mListen) fds.push_back({.fd = mSock, .events = POLLIN});
159         for (auto pair : mClients) {
160             // NB: calling out to an other object with mClientsLock held (safe)
161             const int fd = pair.second->getSocket();
162             if (fd != pair.first) SLOGE("fd mismatch: %d != %d", fd, pair.first);
163             fds.push_back({.fd = fd, .events = POLLIN});
164         }
165         pthread_mutex_unlock(&mClientsLock);
166 
167         SLOGV("mListen=%d, mSocketName=%s", mListen, mSocketName);
168         int rc = TEMP_FAILURE_RETRY(poll(fds.data(), fds.size(), -1));
169         if (rc < 0) {
170             SLOGE("poll failed (%s) mListen=%d", strerror(errno), mListen);
171             sleep(1);
172             continue;
173         }
174 
175         if (fds[0].revents & (POLLIN | POLLERR)) {
176             char c = CtrlPipe_Shutdown;
177             TEMP_FAILURE_RETRY(read(mCtrlPipe[0], &c, 1));
178             if (c == CtrlPipe_Shutdown) {
179                 break;
180             }
181             continue;
182         }
183         if (mListen && (fds[1].revents & (POLLIN | POLLERR))) {
184             int c = TEMP_FAILURE_RETRY(accept4(mSock, nullptr, nullptr, SOCK_CLOEXEC));
185             if (c < 0) {
186                 SLOGE("accept failed (%s)", strerror(errno));
187                 sleep(1);
188                 continue;
189             }
190             pthread_mutex_lock(&mClientsLock);
191             mClients[c] = new SocketClient(c, true, mUseCmdNum);
192             pthread_mutex_unlock(&mClientsLock);
193         }
194 
195         // Add all active clients to the pending list first, so we can release
196         // the lock before invoking the callbacks.
197         std::vector<SocketClient*> pending;
198         pthread_mutex_lock(&mClientsLock);
199         const int size = fds.size();
200         for (int i = mListen ? 2 : 1; i < size; ++i) {
201             const struct pollfd& p = fds[i];
202             if (p.revents & (POLLIN | POLLERR)) {
203                 auto it = mClients.find(p.fd);
204                 if (it == mClients.end()) {
205                     SLOGE("fd vanished: %d", p.fd);
206                     continue;
207                 }
208                 SocketClient* c = it->second;
209                 pending.push_back(c);
210                 c->incRef();
211             }
212         }
213         pthread_mutex_unlock(&mClientsLock);
214 
215         for (SocketClient* c : pending) {
216             // Process it, if false is returned, remove from the map
217             SLOGV("processing fd %d", c->getSocket());
218             if (!onDataAvailable(c)) {
219                 release(c, false);
220             }
221             c->decRef();
222         }
223     }
224 }
225 
release(SocketClient * c,bool wakeup)226 bool SocketListener::release(SocketClient* c, bool wakeup) {
227     bool ret = false;
228     /* if our sockets are connection-based, remove and destroy it */
229     if (mListen && c) {
230         /* Remove the client from our map */
231         SLOGV("going to zap %d for %s", c->getSocket(), mSocketName);
232         pthread_mutex_lock(&mClientsLock);
233         ret = (mClients.erase(c->getSocket()) != 0);
234         pthread_mutex_unlock(&mClientsLock);
235         if (ret) {
236             ret = c->decRef();
237             if (wakeup) {
238                 char b = CtrlPipe_Wakeup;
239                 TEMP_FAILURE_RETRY(write(mCtrlPipe[1], &b, 1));
240             }
241         }
242     }
243     return ret;
244 }
245 
snapshotClients()246 std::vector<SocketClient*> SocketListener::snapshotClients() {
247     std::vector<SocketClient*> clients;
248     pthread_mutex_lock(&mClientsLock);
249     clients.reserve(mClients.size());
250     for (auto pair : mClients) {
251         SocketClient* c = pair.second;
252         c->incRef();
253         clients.push_back(c);
254     }
255     pthread_mutex_unlock(&mClientsLock);
256 
257     return clients;
258 }
259 
sendBroadcast(int code,const char * msg,bool addErrno)260 void SocketListener::sendBroadcast(int code, const char *msg, bool addErrno) {
261     for (SocketClient* c : snapshotClients()) {
262         // broadcasts are unsolicited and should not include a cmd number
263         if (c->sendMsg(code, msg, addErrno, false)) {
264             SLOGW("Error sending broadcast (%s)", strerror(errno));
265         }
266         c->decRef();
267     }
268 }
269 
runOnEachSocket(SocketClientCommand * command)270 void SocketListener::runOnEachSocket(SocketClientCommand *command) {
271     for (SocketClient* c : snapshotClients()) {
272         command->runSocketCommand(c);
273         c->decRef();
274     }
275 }
276