1 /*
2 * Copyright (C) 2019 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <https/RunLoop.h>
18
19 #include <https/Support.h>
20
21 #include <android-base/logging.h>
22
23 #include <cstring>
24 #include <fcntl.h>
25 #include <iostream>
26 #include <unistd.h>
27
28 #include <mutex>
29 #include <condition_variable>
30
operator <=(const QueueElem & other) const31 bool RunLoop::QueueElem::operator<=(const QueueElem &other) const {
32 if (mWhen) {
33 if (other.mWhen) {
34 return mWhen <= other.mWhen;
35 }
36
37 return false;
38 }
39
40 if (other.mWhen) {
41 return true;
42 }
43
44 // This ensures that two events posted without a trigger time are queued in
45 // the order they were post()ed in.
46 return true;
47 }
48
RunLoop()49 RunLoop::RunLoop()
50 : mDone(false),
51 mPThread(0),
52 mNextToken(1) {
53 int res = pipe(mControlFds);
54 CHECK_GE(res, 0);
55
56 makeFdNonblocking(mControlFds[0]);
57 }
58
RunLoop(std::string_view name)59 RunLoop::RunLoop(std::string_view name)
60 : RunLoop() {
61 mName = name;
62
63 mThread = std::thread([this]{ run(); });
64 }
65
~RunLoop()66 RunLoop::~RunLoop() {
67 stop();
68
69 close(mControlFds[1]);
70 mControlFds[1] = -1;
71
72 close(mControlFds[0]);
73 mControlFds[0] = -1;
74 }
75
stop()76 void RunLoop::stop() {
77 mDone = true;
78 interrupt();
79
80 if (mThread.joinable()) {
81 mThread.join();
82 }
83 }
84
post(AsyncFunction fn)85 RunLoop::Token RunLoop::post(AsyncFunction fn) {
86 CHECK(fn != nullptr);
87
88 auto token = mNextToken++;
89 insert({ std::nullopt, fn, token });
90
91 return token;
92 }
93
postAndAwait(AsyncFunction fn)94 bool RunLoop::postAndAwait(AsyncFunction fn) {
95 if (isCurrentThread()) {
96 // To wait from the runloop's thread would cause deadlock
97 post(fn);
98 return false;
99 }
100
101 std::mutex mtx;
102 bool ran = false;
103 std::condition_variable cond_var;
104
105 post([&cond_var, &mtx, &ran, fn](){
106 fn();
107 {
108 std::unique_lock<std::mutex> lock(mtx);
109 ran = true;
110 // Notify while holding the mutex, otherwise the condition variable
111 // could be destroyed before the call to notify_all.
112 cond_var.notify_all();
113 }
114 });
115
116 {
117 std::unique_lock<std::mutex> lock(mtx);
118 cond_var.wait(lock, [&ran](){ return ran;});
119 }
120 return ran;
121 }
122
postWithDelay(std::chrono::steady_clock::duration delay,AsyncFunction fn)123 RunLoop::Token RunLoop::postWithDelay(
124 std::chrono::steady_clock::duration delay, AsyncFunction fn) {
125 CHECK(fn != nullptr);
126
127 auto token = mNextToken++;
128 insert({ std::chrono::steady_clock::now() + delay, fn, token });
129
130 return token;
131 }
132
cancelToken(Token token)133 bool RunLoop::cancelToken(Token token) {
134 std::lock_guard<std::mutex> autoLock(mLock);
135
136 bool found = false;
137 for (auto it = mQueue.begin(); it != mQueue.end(); ++it) {
138 if (it->mToken == token) {
139 mQueue.erase(it);
140
141 if (it == mQueue.begin()) {
142 interrupt();
143 }
144
145 found = true;
146 break;
147 }
148 }
149
150 return found;
151 }
152
postSocketRecv(int sock,AsyncFunction fn)153 void RunLoop::postSocketRecv(int sock, AsyncFunction fn) {
154 CHECK_GE(sock, 0);
155 CHECK(fn != nullptr);
156
157 std::lock_guard<std::mutex> autoLock(mLock);
158 mAddInfos.push_back({ sock, InfoType::RECV, fn });
159 interrupt();
160 }
161
postSocketSend(int sock,AsyncFunction fn)162 void RunLoop::postSocketSend(int sock, AsyncFunction fn) {
163 CHECK_GE(sock, 0);
164 CHECK(fn != nullptr);
165
166 std::lock_guard<std::mutex> autoLock(mLock);
167 mAddInfos.push_back({ sock, InfoType::SEND, fn });
168 interrupt();
169 }
170
cancelSocket(int sock)171 void RunLoop::cancelSocket(int sock) {
172 CHECK_GE(sock, 0);
173
174 std::lock_guard<std::mutex> autoLock(mLock);
175 mAddInfos.push_back({ sock, InfoType::CANCEL, nullptr });
176 interrupt();
177 }
178
insert(const QueueElem & elem)179 void RunLoop::insert(const QueueElem &elem) {
180 std::lock_guard<std::mutex> autoLock(mLock);
181
182 auto it = mQueue.begin();
183 while (it != mQueue.end() && *it <= elem) {
184 ++it;
185 }
186
187 if (it == mQueue.begin()) {
188 interrupt();
189 }
190
191 mQueue.insert(it, elem);
192 }
193
run()194 void RunLoop::run() {
195 mPThread = pthread_self();
196
197 std::map<int, SocketCallbacks> socketCallbacksByFd;
198 std::vector<pollfd> pollFds;
199
200 auto removePollFdAt = [&socketCallbacksByFd, &pollFds](size_t i) {
201 if (i + 1 == pollFds.size()) {
202 pollFds.pop_back();
203 } else {
204 // Instead of leaving a hole in the middle of the
205 // pollFds vector, we copy the last item into
206 // that hole and reduce the size of the vector by 1,
207 // taking are of updating the corresponding callback
208 // with the correct, new index.
209 pollFds[i] = pollFds.back();
210 pollFds.pop_back();
211 socketCallbacksByFd[pollFds[i].fd].mPollFdIndex = i;
212 }
213 };
214
215 // The control channel's pollFd will always be at index 0.
216 pollFds.push_back({ mControlFds[0], POLLIN, 0 });
217
218 for (;;) {
219 int timeoutMs = -1; // wait Forever
220
221 {
222 std::lock_guard<std::mutex> autoLock(mLock);
223
224 if (mDone) {
225 break;
226 }
227
228 for (const auto &addInfo : mAddInfos) {
229 const int sock = addInfo.mSock;
230 const auto fn = addInfo.mFn;
231
232 auto it = socketCallbacksByFd.find(sock);
233
234 switch (addInfo.mType) {
235 case InfoType::RECV:
236 {
237 if (it == socketCallbacksByFd.end()) {
238 socketCallbacksByFd[sock] = { fn, nullptr, pollFds.size() };
239 pollFds.push_back({ sock, POLLIN, 0 });
240 } else {
241 // There's already a pollFd for this socket.
242 CHECK(it->second.mSendFn != nullptr);
243
244 CHECK(it->second.mRecvFn == nullptr);
245 it->second.mRecvFn = fn;
246
247 pollFds[it->second.mPollFdIndex].events |= POLLIN;
248 }
249 break;
250 }
251
252 case InfoType::SEND:
253 {
254 if (it == socketCallbacksByFd.end()) {
255 socketCallbacksByFd[sock] = { nullptr, fn, pollFds.size() };
256 pollFds.push_back({ sock, POLLOUT, 0 });
257 } else {
258 // There's already a pollFd for this socket.
259 if (it->second.mRecvFn == nullptr) {
260 LOG(ERROR)
261 << "There's an entry but no recvFn "
262 "notification for socket "
263 << sock;
264 }
265
266 CHECK(it->second.mRecvFn != nullptr);
267
268 if (it->second.mSendFn != nullptr) {
269 LOG(ERROR)
270 << "There's already a pending send "
271 "notification for socket "
272 << sock;
273 }
274 CHECK(it->second.mSendFn == nullptr);
275 it->second.mSendFn = fn;
276
277 pollFds[it->second.mPollFdIndex].events |= POLLOUT;
278 }
279 break;
280 }
281
282 case InfoType::CANCEL:
283 {
284 if (it != socketCallbacksByFd.end()) {
285 const size_t i = it->second.mPollFdIndex;
286
287 socketCallbacksByFd.erase(it);
288 removePollFdAt(i);
289 }
290 break;
291 }
292 }
293 }
294
295 mAddInfos.clear();
296
297 if (!mQueue.empty()) {
298 timeoutMs = 0;
299
300 if (mQueue.front().mWhen) {
301 auto duration =
302 *mQueue.front().mWhen - std::chrono::steady_clock::now();
303
304 auto durationMs =
305 std::chrono::duration_cast<std::chrono::milliseconds>(duration);
306
307 if (durationMs.count() > 0) {
308 timeoutMs = static_cast<int>(durationMs.count());
309 }
310 }
311 }
312 }
313
314 int pollRes = 0;
315 if (timeoutMs != 0) {
316 // NOTE: The inequality is on purpose, we'll want to execute this
317 // code if timeoutMs == -1 (infinite) or timeoutMs > 0, but not
318 // if it's 0.
319
320 pollRes = poll(
321 pollFds.data(),
322 static_cast<nfds_t>(pollFds.size()),
323 timeoutMs);
324 }
325
326 if (pollRes < 0) {
327 if (errno != EINTR) {
328 std::cerr
329 << "poll FAILED w/ "
330 << errno
331 << " ("
332 << strerror(errno)
333 << ")"
334 << std::endl;
335 }
336
337 CHECK_EQ(errno, EINTR);
338 continue;
339 }
340
341 std::vector<AsyncFunction> fnArray;
342
343 {
344 std::lock_guard<std::mutex> autoLock(mLock);
345
346 if (pollRes > 0) {
347 if (pollFds[0].revents & POLLIN) {
348 ssize_t res;
349 do {
350 uint8_t c[32];
351 while ((res = read(mControlFds[0], c, sizeof(c))) < 0
352 && errno == EINTR) {
353 }
354 } while (res > 0);
355 CHECK(res < 0 && errno == EWOULDBLOCK);
356
357 --pollRes;
358 }
359
360 // NOTE: Skip index 0, as we already handled it above.
361 // Also, bail early if we exhausted all actionable pollFds
362 // according to pollRes.
363 for (size_t i = pollFds.size(); pollRes && i-- > 1;) {
364 pollfd &pollFd = pollFds[i];
365 const short revents = pollFd.revents;
366
367 if (revents) {
368 --pollRes;
369 }
370
371 const bool readable = (revents & POLLIN);
372 const bool writable = (revents & POLLOUT);
373 const bool dead = (revents & POLLNVAL);
374
375 bool removeCallback = dead;
376
377 if (readable || writable || dead) {
378 const int sock = pollFd.fd;
379
380 const auto &it = socketCallbacksByFd.find(sock);
381 auto &cb = it->second;
382 CHECK_EQ(cb.mPollFdIndex, i);
383
384 if (readable) {
385 CHECK(cb.mRecvFn != nullptr);
386 fnArray.push_back(cb.mRecvFn);
387 cb.mRecvFn = nullptr;
388 pollFd.events &= ~POLLIN;
389
390 removeCallback |= (cb.mSendFn == nullptr);
391 }
392
393 if (writable) {
394 CHECK(cb.mSendFn != nullptr);
395 fnArray.push_back(cb.mSendFn);
396 cb.mSendFn = nullptr;
397 pollFd.events &= ~POLLOUT;
398
399 removeCallback |= (cb.mRecvFn == nullptr);
400 }
401
402 if (removeCallback) {
403 socketCallbacksByFd.erase(it);
404 removePollFdAt(i);
405 }
406 }
407 }
408 } else {
409 // No interrupt, no socket notifications.
410 fnArray.push_back(mQueue.front().mFn);
411 mQueue.pop_front();
412 }
413 }
414
415 for (const auto &fn : fnArray) {
416 fn();
417 }
418 }
419 }
420
interrupt()421 void RunLoop::interrupt() {
422 uint8_t c = 1;
423 ssize_t res;
424 while ((res = write(mControlFds[1], &c, 1)) < 0 && errno == EINTR) {
425 }
426
427 CHECK_EQ(res, 1);
428 }
429
430 struct MainRunLoop : public RunLoop {
431 };
432
433 static std::mutex gLock;
434 static std::shared_ptr<RunLoop> gMainRunLoop;
435
436 // static
main()437 std::shared_ptr<RunLoop> RunLoop::main() {
438 std::lock_guard<std::mutex> autoLock(gLock);
439 if (!gMainRunLoop) {
440 gMainRunLoop = std::make_shared<MainRunLoop>();
441 }
442 return gMainRunLoop;
443 }
444
isCurrentThread() const445 bool RunLoop::isCurrentThread() const {
446 return pthread_equal(pthread_self(), mPThread);
447 }
448
449