1 // Copyright (C) 2019 The Android Open Source Project
2 // Copyright (C) 2019 Google Inc.
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 #include "android/base/threads/AndroidWorkPool.h"
16
17 #include "android/base/threads/AndroidFunctorThread.h"
18 #include "android/base/synchronization/AndroidLock.h"
19 #include "android/base/synchronization/AndroidConditionVariable.h"
20 #include "android/base/synchronization/AndroidMessageChannel.h"
21
22 #include <atomic>
23 #include <memory>
24 #include <unordered_map>
25 #include <sys/time.h>
26
27 using android::base::guest::AutoLock;
28 using android::base::guest::ConditionVariable;
29 using android::base::guest::FunctorThread;
30 using android::base::guest::Lock;
31 using android::base::guest::MessageChannel;
32
33 namespace android {
34 namespace base {
35 namespace guest {
36
37 class WaitGroup { // intrusive refcounted
38 public:
39
WaitGroup(int numTasksRemaining)40 WaitGroup(int numTasksRemaining) :
41 mNumTasksInitial(numTasksRemaining),
42 mNumTasksRemaining(numTasksRemaining) { }
43
44 ~WaitGroup() = default;
45
getLock()46 android::base::guest::Lock& getLock() { return mLock; }
47
acquire()48 void acquire() {
49 if (0 == mRefCount.fetch_add(1, std::memory_order_seq_cst)) {
50 ALOGE("%s: goofed, refcount0 acquire\n", __func__);
51 abort();
52 }
53 }
54
release()55 bool release() {
56 if (0 == mRefCount) {
57 ALOGE("%s: goofed, refcount0 release\n", __func__);
58 abort();
59 }
60 if (1 == mRefCount.fetch_sub(1, std::memory_order_seq_cst)) {
61 std::atomic_thread_fence(std::memory_order_acquire);
62 delete this;
63 return true;
64 }
65 return false;
66 }
67
68 // wait on all of or any of the associated tasks to complete.
waitAllLocked(WorkPool::TimeoutUs timeout)69 bool waitAllLocked(WorkPool::TimeoutUs timeout) {
70 return conditionalTimeoutLocked(
71 [this] { return mNumTasksRemaining > 0; },
72 timeout);
73 }
74
waitAnyLocked(WorkPool::TimeoutUs timeout)75 bool waitAnyLocked(WorkPool::TimeoutUs timeout) {
76 return conditionalTimeoutLocked(
77 [this] { return mNumTasksRemaining == mNumTasksInitial; },
78 timeout);
79 }
80
81 // broadcasts to all waiters that there has been a new job that has completed
decrementBroadcast()82 bool decrementBroadcast() {
83 AutoLock lock(mLock);
84 bool done =
85 (1 == mNumTasksRemaining.fetch_sub(1, std::memory_order_seq_cst));
86 std::atomic_thread_fence(std::memory_order_acquire);
87 mCv.broadcast();
88 return done;
89 }
90
91 private:
92
doWait(WorkPool::TimeoutUs timeout)93 bool doWait(WorkPool::TimeoutUs timeout) {
94 if (timeout == ~0ULL) {
95 ALOGV("%s: uncond wait\n", __func__);
96 mCv.wait(&mLock);
97 return true;
98 } else {
99 return mCv.timedWait(&mLock, getDeadline(timeout));
100 }
101 }
102
getDeadline(WorkPool::TimeoutUs relative)103 struct timespec getDeadline(WorkPool::TimeoutUs relative) {
104 struct timeval deadlineUs;
105 struct timespec deadlineNs;
106 gettimeofday(&deadlineUs, 0);
107
108 auto prevDeadlineUs = deadlineUs.tv_usec;
109
110 deadlineUs.tv_usec += relative;
111
112 // Wrap around
113 if (prevDeadlineUs > deadlineUs.tv_usec) {
114 ++deadlineUs.tv_sec;
115 }
116
117 deadlineNs.tv_sec = deadlineUs.tv_sec;
118 deadlineNs.tv_nsec = deadlineUs.tv_usec * 1000LL;
119 return deadlineNs;
120 }
121
currTimeUs()122 uint64_t currTimeUs() {
123 struct timeval tv;
124 gettimeofday(&tv, 0);
125 return (uint64_t)(tv.tv_sec * 1000000LL + tv.tv_usec);
126 }
127
conditionalTimeoutLocked(std::function<bool ()> conditionFunc,WorkPool::TimeoutUs timeout)128 bool conditionalTimeoutLocked(std::function<bool()> conditionFunc, WorkPool::TimeoutUs timeout) {
129 uint64_t currTime = currTimeUs();
130 WorkPool::TimeoutUs currTimeout = timeout;
131
132 while (conditionFunc()) {
133 doWait(currTimeout);
134 if (!conditionFunc()) {
135 // Decrement timeout for wakeups
136 uint64_t nextTime = currTimeUs();
137 WorkPool::TimeoutUs waited =
138 nextTime - currTime;
139 currTime = nextTime;
140
141 if (currTimeout > waited) {
142 currTimeout -= waited;
143 } else {
144 return conditionFunc();
145 }
146 }
147 }
148
149 return true;
150 }
151
152 std::atomic<int> mRefCount = { 1 };
153 int mNumTasksInitial;
154 std::atomic<int> mNumTasksRemaining;
155
156 Lock mLock;
157 ConditionVariable mCv;
158 };
159
160 class WorkPoolThread {
161 public:
162 // State diagram for each work pool thread
163 //
164 // Unacquired: (Start state) When no one else has claimed the thread.
165 // Acquired: When the thread has been claimed for work,
166 // but work has not been issued to it yet.
167 // Scheduled: When the thread is running tasks from the acquirer.
168 // Exiting: cleanup
169 //
170 // Messages:
171 //
172 // Acquire
173 // Run
174 // Exit
175 //
176 // Transitions:
177 //
178 // Note: While task is being run, messages will come back with a failure value.
179 //
180 // Unacquired:
181 // message Acquire -> Acquired. effect: return success value
182 // message Run -> Unacquired. effect: return failure value
183 // message Exit -> Exiting. effect: return success value
184 //
185 // Acquired:
186 // message Acquire -> Acquired. effect: return failure value
187 // message Run -> Scheduled. effect: run the task, return success
188 // message Exit -> Exiting. effect: return success value
189 //
190 // Scheduled:
191 // implicit effect: after task is run, transition back to Unacquired.
192 // message Acquire -> Scheduled. effect: return failure value
193 // message Run -> Scheduled. effect: return failure value
194 // message Exit -> queue up exit message, then transition to Exiting after that is done.
195 // effect: return success value
196 //
197 enum State {
198 Unacquired = 0,
199 Acquired = 1,
200 Scheduled = 2,
201 Exiting = 3,
202 };
203
WorkPoolThread()204 WorkPoolThread() : mThread([this] { threadFunc(); }) {
205 mThread.start();
206 }
207
~WorkPoolThread()208 ~WorkPoolThread() {
209 exit();
210 mThread.wait();
211 }
212
acquire()213 bool acquire() {
214 AutoLock lock(mLock);
215 switch (mState) {
216 case State::Unacquired:
217 mState = State::Acquired;
218 return true;
219 case State::Acquired:
220 case State::Scheduled:
221 case State::Exiting:
222 return false;
223 }
224 }
225
run(WorkPool::WaitGroupHandle waitGroupHandle,WaitGroup * waitGroup,WorkPool::Task task)226 bool run(WorkPool::WaitGroupHandle waitGroupHandle, WaitGroup* waitGroup, WorkPool::Task task) {
227 AutoLock lock(mLock);
228 switch (mState) {
229 case State::Unacquired:
230 return false;
231 case State::Acquired: {
232 mState = State::Scheduled;
233 mToCleanupWaitGroupHandle = waitGroupHandle;
234 waitGroup->acquire();
235 mToCleanupWaitGroup = waitGroup;
236 mShouldCleanupWaitGroup = false;
237 TaskInfo msg = {
238 Command::Run,
239 waitGroup, task,
240 };
241 mRunMessages.send(msg);
242 return true;
243 }
244 case State::Scheduled:
245 case State::Exiting:
246 return false;
247 }
248 }
249
shouldCleanupWaitGroup(WorkPool::WaitGroupHandle * waitGroupHandle,WaitGroup ** waitGroup)250 bool shouldCleanupWaitGroup(WorkPool::WaitGroupHandle* waitGroupHandle, WaitGroup** waitGroup) {
251 AutoLock lock(mLock);
252 bool res = mShouldCleanupWaitGroup;
253 *waitGroupHandle = mToCleanupWaitGroupHandle;
254 *waitGroup = mToCleanupWaitGroup;
255 mShouldCleanupWaitGroup = false;
256 return res;
257 }
258
259 private:
260 enum Command {
261 Run = 0,
262 Exit = 1,
263 };
264
265 struct TaskInfo {
266 Command cmd;
267 WaitGroup* waitGroup = nullptr;
268 WorkPool::Task task = {};
269 };
270
exit()271 bool exit() {
272 AutoLock lock(mLock);
273 TaskInfo msg { Command::Exit, };
274 mRunMessages.send(msg);
275 return true;
276 }
277
threadFunc()278 void threadFunc() {
279 TaskInfo taskInfo;
280 bool done = false;
281
282 while (!done) {
283 mRunMessages.receive(&taskInfo);
284 switch (taskInfo.cmd) {
285 case Command::Run:
286 doRun(taskInfo);
287 break;
288 case Command::Exit: {
289 AutoLock lock(mLock);
290 mState = State::Exiting;
291 break;
292 }
293 }
294 AutoLock lock(mLock);
295 done = mState == State::Exiting;
296 }
297 }
298
299 // Assumption: the wait group refcount is >= 1 when entering
300 // this function (before decrement)..
301 // at least it doesn't get to 0
doRun(TaskInfo & msg)302 void doRun(TaskInfo& msg) {
303 WaitGroup* waitGroup = msg.waitGroup;
304
305 if (msg.task) msg.task();
306
307 bool lastTask =
308 waitGroup->decrementBroadcast();
309
310 AutoLock lock(mLock);
311 mState = State::Unacquired;
312
313 if (lastTask) {
314 mShouldCleanupWaitGroup = true;
315 }
316
317 waitGroup->release();
318 }
319
320 FunctorThread mThread;
321 Lock mLock;
322 State mState = State::Unacquired;
323 MessageChannel<TaskInfo, 4> mRunMessages;
324 WorkPool::WaitGroupHandle mToCleanupWaitGroupHandle = 0;
325 WaitGroup* mToCleanupWaitGroup = nullptr;
326 bool mShouldCleanupWaitGroup = false;
327 };
328
329 class WorkPool::Impl {
330 public:
Impl(int numInitialThreads)331 Impl(int numInitialThreads) : mThreads(numInitialThreads) {
332 for (size_t i = 0; i < mThreads.size(); ++i) {
333 mThreads[i].reset(new WorkPoolThread);
334 }
335 }
336
337 ~Impl() = default;
338
schedule(const std::vector<WorkPool::Task> & tasks)339 WorkPool::WaitGroupHandle schedule(const std::vector<WorkPool::Task>& tasks) {
340
341 if (tasks.empty()) abort();
342
343 AutoLock lock(mLock);
344
345 // Sweep old wait groups
346 for (size_t i = 0; i < mThreads.size(); ++i) {
347 WaitGroupHandle handle;
348 WaitGroup* waitGroup;
349 bool cleanup = mThreads[i]->shouldCleanupWaitGroup(&handle, &waitGroup);
350 if (cleanup) {
351 mWaitGroups.erase(handle);
352 waitGroup->release();
353 }
354 }
355
356 WorkPool::WaitGroupHandle resHandle = genWaitGroupHandleLocked();
357 WaitGroup* waitGroup =
358 new WaitGroup(tasks.size());
359
360 mWaitGroups[resHandle] = waitGroup;
361
362 std::vector<size_t> threadIndices;
363
364 while (threadIndices.size() < tasks.size()) {
365 for (size_t i = 0; i < mThreads.size(); ++i) {
366 if (!mThreads[i]->acquire()) continue;
367 threadIndices.push_back(i);
368 if (threadIndices.size() == tasks.size()) break;
369 }
370 if (threadIndices.size() < tasks.size()) {
371 mThreads.resize(mThreads.size() + 1);
372 mThreads[mThreads.size() - 1].reset(new WorkPoolThread);
373 }
374 }
375
376 // every thread here is acquired
377 for (size_t i = 0; i < threadIndices.size(); ++i) {
378 mThreads[threadIndices[i]]->run(resHandle, waitGroup, tasks[i]);
379 }
380
381 return resHandle;
382 }
383
waitAny(WorkPool::WaitGroupHandle waitGroupHandle,WorkPool::TimeoutUs timeout)384 bool waitAny(WorkPool::WaitGroupHandle waitGroupHandle, WorkPool::TimeoutUs timeout) {
385 AutoLock lock(mLock);
386 auto it = mWaitGroups.find(waitGroupHandle);
387 if (it == mWaitGroups.end()) return true;
388
389 auto waitGroup = it->second;
390 waitGroup->acquire();
391 lock.unlock();
392
393 bool waitRes = false;
394
395 {
396 AutoLock waitGroupLock(waitGroup->getLock());
397 waitRes = waitGroup->waitAnyLocked(timeout);
398 }
399
400 waitGroup->release();
401
402 return waitRes;
403 }
404
waitAll(WorkPool::WaitGroupHandle waitGroupHandle,WorkPool::TimeoutUs timeout)405 bool waitAll(WorkPool::WaitGroupHandle waitGroupHandle, WorkPool::TimeoutUs timeout) {
406 auto waitGroup = acquireWaitGroupFromHandle(waitGroupHandle);
407 if (!waitGroup) return true;
408
409 bool waitRes = false;
410
411 {
412 AutoLock waitGroupLock(waitGroup->getLock());
413 waitRes = waitGroup->waitAllLocked(timeout);
414 }
415
416 waitGroup->release();
417
418 return waitRes;
419 }
420
421 private:
422 // Increments wait group refcount by 1.
acquireWaitGroupFromHandle(WorkPool::WaitGroupHandle waitGroupHandle)423 WaitGroup* acquireWaitGroupFromHandle(WorkPool::WaitGroupHandle waitGroupHandle) {
424 AutoLock lock(mLock);
425 auto it = mWaitGroups.find(waitGroupHandle);
426 if (it == mWaitGroups.end()) return nullptr;
427
428 auto waitGroup = it->second;
429 waitGroup->acquire();
430
431 return waitGroup;
432 }
433
434 using WaitGroupStore = std::unordered_map<WorkPool::WaitGroupHandle, WaitGroup*>;
435
genWaitGroupHandleLocked()436 WorkPool::WaitGroupHandle genWaitGroupHandleLocked() {
437 WorkPool::WaitGroupHandle res = mNextWaitGroupHandle;
438 ++mNextWaitGroupHandle;
439 return res;
440 }
441
442 Lock mLock;
443 uint64_t mNextWaitGroupHandle = 0;
444 WaitGroupStore mWaitGroups;
445 std::vector<std::unique_ptr<WorkPoolThread>> mThreads;
446 };
447
WorkPool(int numInitialThreads)448 WorkPool::WorkPool(int numInitialThreads) : mImpl(new WorkPool::Impl(numInitialThreads)) { }
449 WorkPool::~WorkPool() = default;
450
schedule(const std::vector<WorkPool::Task> & tasks)451 WorkPool::WaitGroupHandle WorkPool::schedule(const std::vector<WorkPool::Task>& tasks) {
452 return mImpl->schedule(tasks);
453 }
454
waitAny(WorkPool::WaitGroupHandle waitGroup,WorkPool::TimeoutUs timeout)455 bool WorkPool::waitAny(WorkPool::WaitGroupHandle waitGroup, WorkPool::TimeoutUs timeout) {
456 return mImpl->waitAny(waitGroup, timeout);
457 }
458
waitAll(WorkPool::WaitGroupHandle waitGroup,WorkPool::TimeoutUs timeout)459 bool WorkPool::waitAll(WorkPool::WaitGroupHandle waitGroup, WorkPool::TimeoutUs timeout) {
460 return mImpl->waitAll(waitGroup, timeout);
461 }
462
463 } // namespace guest
464 } // namespace base
465 } // namespace android
466