1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <https/RunLoop.h>
18 
19 #include <https/Support.h>
20 
21 #include <android-base/logging.h>
22 
23 #include <cstring>
24 #include <fcntl.h>
25 #include <iostream>
26 #include <unistd.h>
27 
28 #include <mutex>
29 #include <condition_variable>
30 
operator <=(const QueueElem & other) const31 bool RunLoop::QueueElem::operator<=(const QueueElem &other) const {
32     if (mWhen) {
33         if (other.mWhen) {
34             return mWhen <= other.mWhen;
35         }
36 
37         return false;
38     }
39 
40     if (other.mWhen) {
41         return true;
42     }
43 
44     // This ensures that two events posted without a trigger time are queued in
45     // the order they were post()ed in.
46     return true;
47 }
48 
RunLoop()49 RunLoop::RunLoop()
50     : mDone(false),
51       mPThread(0),
52       mNextToken(1) {
53     int res = pipe(mControlFds);
54     CHECK_GE(res, 0);
55 
56     makeFdNonblocking(mControlFds[0]);
57 }
58 
RunLoop(std::string_view name)59 RunLoop::RunLoop(std::string_view name)
60     : RunLoop() {
61     mName = name;
62 
63     mThread = std::thread([this]{ run(); });
64 }
65 
~RunLoop()66 RunLoop::~RunLoop() {
67     stop();
68 
69     close(mControlFds[1]);
70     mControlFds[1] = -1;
71 
72     close(mControlFds[0]);
73     mControlFds[0] = -1;
74 }
75 
stop()76 void RunLoop::stop() {
77     mDone = true;
78     interrupt();
79 
80     if (mThread.joinable()) {
81         mThread.join();
82     }
83 }
84 
post(AsyncFunction fn)85 RunLoop::Token RunLoop::post(AsyncFunction fn) {
86     CHECK(fn != nullptr);
87 
88     auto token = mNextToken++;
89     insert({ std::nullopt, fn, token });
90 
91     return token;
92 }
93 
postAndAwait(AsyncFunction fn)94 bool RunLoop::postAndAwait(AsyncFunction fn) {
95     if (isCurrentThread()) {
96         // To wait from the runloop's thread would cause deadlock
97         post(fn);
98         return false;
99     }
100 
101     std::mutex mtx;
102     bool ran = false;
103     std::condition_variable cond_var;
104 
105     post([&cond_var, &mtx, &ran, fn](){
106         fn();
107         {
108             std::unique_lock<std::mutex> lock(mtx);
109             ran = true;
110             // Notify while holding the mutex, otherwise the condition variable
111             // could be destroyed before the call to notify_all.
112             cond_var.notify_all();
113         }
114     });
115 
116     {
117         std::unique_lock<std::mutex> lock(mtx);
118         cond_var.wait(lock, [&ran](){ return ran;});
119     }
120     return ran;
121 }
122 
postWithDelay(std::chrono::steady_clock::duration delay,AsyncFunction fn)123 RunLoop::Token RunLoop::postWithDelay(
124         std::chrono::steady_clock::duration delay, AsyncFunction fn) {
125     CHECK(fn != nullptr);
126 
127     auto token = mNextToken++;
128     insert({ std::chrono::steady_clock::now() + delay, fn, token });
129 
130     return token;
131 }
132 
cancelToken(Token token)133 bool RunLoop::cancelToken(Token token) {
134     std::lock_guard<std::mutex> autoLock(mLock);
135 
136     bool found = false;
137     for (auto it = mQueue.begin(); it != mQueue.end(); ++it) {
138         if (it->mToken == token) {
139             mQueue.erase(it);
140 
141             if (it == mQueue.begin()) {
142                 interrupt();
143             }
144 
145             found = true;
146             break;
147         }
148     }
149 
150     return found;
151 }
152 
postSocketRecv(int sock,AsyncFunction fn)153 void RunLoop::postSocketRecv(int sock, AsyncFunction fn) {
154     CHECK_GE(sock, 0);
155     CHECK(fn != nullptr);
156 
157     std::lock_guard<std::mutex> autoLock(mLock);
158     mAddInfos.push_back({ sock, InfoType::RECV, fn });
159     interrupt();
160 }
161 
postSocketSend(int sock,AsyncFunction fn)162 void RunLoop::postSocketSend(int sock, AsyncFunction fn) {
163     CHECK_GE(sock, 0);
164     CHECK(fn != nullptr);
165 
166     std::lock_guard<std::mutex> autoLock(mLock);
167     mAddInfos.push_back({ sock, InfoType::SEND, fn });
168     interrupt();
169 }
170 
cancelSocket(int sock)171 void RunLoop::cancelSocket(int sock) {
172     CHECK_GE(sock, 0);
173 
174     std::lock_guard<std::mutex> autoLock(mLock);
175     mAddInfos.push_back({ sock, InfoType::CANCEL, nullptr });
176     interrupt();
177 }
178 
insert(const QueueElem & elem)179 void RunLoop::insert(const QueueElem &elem) {
180     std::lock_guard<std::mutex> autoLock(mLock);
181 
182     auto it = mQueue.begin();
183     while (it != mQueue.end() && *it <= elem) {
184         ++it;
185     }
186 
187     if (it == mQueue.begin()) {
188         interrupt();
189     }
190 
191     mQueue.insert(it, elem);
192 }
193 
run()194 void RunLoop::run() {
195     mPThread = pthread_self();
196 
197     std::map<int, SocketCallbacks> socketCallbacksByFd;
198     std::vector<pollfd> pollFds;
199 
200     auto removePollFdAt = [&socketCallbacksByFd, &pollFds](size_t i) {
201         if (i + 1 == pollFds.size()) {
202             pollFds.pop_back();
203         } else {
204             // Instead of leaving a hole in the middle of the
205             // pollFds vector, we copy the last item into
206             // that hole and reduce the size of the vector by 1,
207             // taking are of updating the corresponding callback
208             // with the correct, new index.
209             pollFds[i] = pollFds.back();
210             pollFds.pop_back();
211             socketCallbacksByFd[pollFds[i].fd].mPollFdIndex = i;
212         }
213     };
214 
215     // The control channel's pollFd will always be at index 0.
216     pollFds.push_back({ mControlFds[0], POLLIN, 0 });
217 
218     for (;;) {
219         int timeoutMs = -1;  // wait Forever
220 
221         {
222             std::lock_guard<std::mutex> autoLock(mLock);
223 
224             if (mDone) {
225                 break;
226             }
227 
228             for (const auto &addInfo : mAddInfos) {
229                 const int sock = addInfo.mSock;
230                 const auto fn = addInfo.mFn;
231 
232                 auto it = socketCallbacksByFd.find(sock);
233 
234                 switch (addInfo.mType) {
235                     case InfoType::RECV:
236                     {
237                         if (it == socketCallbacksByFd.end()) {
238                             socketCallbacksByFd[sock] = { fn, nullptr, pollFds.size() };
239                             pollFds.push_back({ sock, POLLIN, 0 });
240                         } else {
241                             // There's already a pollFd for this socket.
242                             CHECK(it->second.mSendFn != nullptr);
243 
244                             CHECK(it->second.mRecvFn == nullptr);
245                             it->second.mRecvFn = fn;
246 
247                             pollFds[it->second.mPollFdIndex].events |= POLLIN;
248                         }
249                         break;
250                     }
251 
252                     case InfoType::SEND:
253                     {
254                         if (it == socketCallbacksByFd.end()) {
255                             socketCallbacksByFd[sock] = { nullptr, fn, pollFds.size() };
256                             pollFds.push_back({ sock, POLLOUT, 0 });
257                         } else {
258                             // There's already a pollFd for this socket.
259                             if (it->second.mRecvFn == nullptr) {
260                                 LOG(ERROR)
261                                     << "There's an entry but no recvFn "
262                                        "notification for socket "
263                                     << sock;
264                             }
265 
266                             CHECK(it->second.mRecvFn != nullptr);
267 
268                             if (it->second.mSendFn != nullptr) {
269                                 LOG(ERROR)
270                                     << "There's already a pending send "
271                                        "notification for socket "
272                                     << sock;
273                             }
274                             CHECK(it->second.mSendFn == nullptr);
275                             it->second.mSendFn = fn;
276 
277                             pollFds[it->second.mPollFdIndex].events |= POLLOUT;
278                         }
279                         break;
280                     }
281 
282                     case InfoType::CANCEL:
283                     {
284                         if (it != socketCallbacksByFd.end()) {
285                             const size_t i = it->second.mPollFdIndex;
286 
287                             socketCallbacksByFd.erase(it);
288                             removePollFdAt(i);
289                         }
290                         break;
291                     }
292                 }
293             }
294 
295             mAddInfos.clear();
296 
297             if (!mQueue.empty()) {
298                 timeoutMs = 0;
299 
300                 if (mQueue.front().mWhen) {
301                     auto duration =
302                         *mQueue.front().mWhen - std::chrono::steady_clock::now();
303 
304                     auto durationMs =
305                         std::chrono::duration_cast<std::chrono::milliseconds>(duration);
306 
307                     if (durationMs.count() > 0) {
308                         timeoutMs = static_cast<int>(durationMs.count());
309                     }
310                 }
311             }
312         }
313 
314         int pollRes = 0;
315         if (timeoutMs != 0) {
316             // NOTE: The inequality is on purpose, we'll want to execute this
317             // code if timeoutMs == -1 (infinite) or timeoutMs > 0, but not
318             // if it's 0.
319 
320             pollRes = poll(
321                     pollFds.data(),
322                     static_cast<nfds_t>(pollFds.size()),
323                     timeoutMs);
324         }
325 
326         if (pollRes < 0) {
327             if (errno != EINTR) {
328                 std::cerr
329                     << "poll FAILED w/ "
330                     << errno
331                     << " ("
332                     << strerror(errno)
333                     << ")"
334                     << std::endl;
335             }
336 
337             CHECK_EQ(errno, EINTR);
338             continue;
339         }
340 
341         std::vector<AsyncFunction> fnArray;
342 
343         {
344             std::lock_guard<std::mutex> autoLock(mLock);
345 
346             if (pollRes > 0) {
347                 if (pollFds[0].revents & POLLIN) {
348                     ssize_t res;
349                     do {
350                         uint8_t c[32];
351                         while ((res = read(mControlFds[0], c, sizeof(c))) < 0
352                                 && errno == EINTR) {
353                         }
354                     } while (res > 0);
355                     CHECK(res < 0 && errno == EWOULDBLOCK);
356 
357                     --pollRes;
358                 }
359 
360                 // NOTE: Skip index 0, as we already handled it above.
361                 // Also, bail early if we exhausted all actionable pollFds
362                 // according to pollRes.
363                 for (size_t i = pollFds.size(); pollRes && i-- > 1;) {
364                     pollfd &pollFd = pollFds[i];
365                     const short revents = pollFd.revents;
366 
367                     if (revents) {
368                         --pollRes;
369                     }
370 
371                     const bool readable = (revents & POLLIN);
372                     const bool writable = (revents & POLLOUT);
373                     const bool dead = (revents & POLLNVAL);
374 
375                     bool removeCallback = dead;
376 
377                     if (readable || writable || dead) {
378                         const int sock = pollFd.fd;
379 
380                         const auto &it = socketCallbacksByFd.find(sock);
381                         auto &cb = it->second;
382                         CHECK_EQ(cb.mPollFdIndex, i);
383 
384                         if (readable) {
385                             CHECK(cb.mRecvFn != nullptr);
386                             fnArray.push_back(cb.mRecvFn);
387                             cb.mRecvFn = nullptr;
388                             pollFd.events &= ~POLLIN;
389 
390                             removeCallback |= (cb.mSendFn == nullptr);
391                         }
392 
393                         if (writable) {
394                             CHECK(cb.mSendFn != nullptr);
395                             fnArray.push_back(cb.mSendFn);
396                             cb.mSendFn = nullptr;
397                             pollFd.events &= ~POLLOUT;
398 
399                             removeCallback |= (cb.mRecvFn == nullptr);
400                         }
401 
402                         if (removeCallback) {
403                             socketCallbacksByFd.erase(it);
404                             removePollFdAt(i);
405                         }
406                     }
407                 }
408             } else {
409                 // No interrupt, no socket notifications.
410                 fnArray.push_back(mQueue.front().mFn);
411                 mQueue.pop_front();
412             }
413         }
414 
415         for (const auto &fn : fnArray) {
416             fn();
417         }
418     }
419 }
420 
interrupt()421 void RunLoop::interrupt() {
422     uint8_t c = 1;
423     ssize_t res;
424     while ((res = write(mControlFds[1], &c, 1)) < 0 && errno == EINTR) {
425     }
426 
427     CHECK_EQ(res, 1);
428 }
429 
430 struct MainRunLoop : public RunLoop {
431 };
432 
433 static std::mutex gLock;
434 static std::shared_ptr<RunLoop> gMainRunLoop;
435 
436 // static
main()437 std::shared_ptr<RunLoop> RunLoop::main() {
438     std::lock_guard<std::mutex> autoLock(gLock);
439     if (!gMainRunLoop) {
440         gMainRunLoop = std::make_shared<MainRunLoop>();
441     }
442     return gMainRunLoop;
443 }
444 
isCurrentThread() const445 bool RunLoop::isCurrentThread() const {
446     return pthread_equal(pthread_self(), mPThread);
447 }
448 
449