1 // Copyright (C) 2019 The Android Open Source Project
2 // Copyright (C) 2019 Google Inc.
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 #include "android/base/threads/AndroidWorkPool.h"
16 
17 #include "android/base/threads/AndroidFunctorThread.h"
18 #include "android/base/synchronization/AndroidLock.h"
19 #include "android/base/synchronization/AndroidConditionVariable.h"
20 #include "android/base/synchronization/AndroidMessageChannel.h"
21 
22 #include <atomic>
23 #include <memory>
24 #include <unordered_map>
25 #include <sys/time.h>
26 
27 using android::base::guest::AutoLock;
28 using android::base::guest::ConditionVariable;
29 using android::base::guest::FunctorThread;
30 using android::base::guest::Lock;
31 using android::base::guest::MessageChannel;
32 
33 namespace android {
34 namespace base {
35 namespace guest {
36 
37 class WaitGroup { // intrusive refcounted
38 public:
39 
WaitGroup(int numTasksRemaining)40     WaitGroup(int numTasksRemaining) :
41         mNumTasksInitial(numTasksRemaining),
42         mNumTasksRemaining(numTasksRemaining) { }
43 
44     ~WaitGroup() = default;
45 
getLock()46     android::base::guest::Lock& getLock() { return mLock; }
47 
acquire()48     void acquire() {
49         if (0 == mRefCount.fetch_add(1, std::memory_order_seq_cst)) {
50             ALOGE("%s: goofed, refcount0 acquire\n", __func__);
51             abort();
52         }
53     }
54 
release()55     bool release() {
56         if (0 == mRefCount) {
57             ALOGE("%s: goofed, refcount0 release\n", __func__);
58             abort();
59         }
60         if (1 == mRefCount.fetch_sub(1, std::memory_order_seq_cst)) {
61             std::atomic_thread_fence(std::memory_order_acquire);
62             delete this;
63             return true;
64         }
65         return false;
66     }
67 
68     // wait on all of or any of the associated tasks to complete.
waitAllLocked(WorkPool::TimeoutUs timeout)69     bool waitAllLocked(WorkPool::TimeoutUs timeout) {
70         return conditionalTimeoutLocked(
71             [this] { return mNumTasksRemaining > 0; },
72             timeout);
73     }
74 
waitAnyLocked(WorkPool::TimeoutUs timeout)75     bool waitAnyLocked(WorkPool::TimeoutUs timeout) {
76         return conditionalTimeoutLocked(
77             [this] { return mNumTasksRemaining == mNumTasksInitial; },
78             timeout);
79     }
80 
81     // broadcasts to all waiters that there has been a new job that has completed
decrementBroadcast()82     bool decrementBroadcast() {
83         AutoLock lock(mLock);
84         bool done =
85             (1 == mNumTasksRemaining.fetch_sub(1, std::memory_order_seq_cst));
86         std::atomic_thread_fence(std::memory_order_acquire);
87         mCv.broadcast();
88         return done;
89     }
90 
91 private:
92 
doWait(WorkPool::TimeoutUs timeout)93     bool doWait(WorkPool::TimeoutUs timeout) {
94         if (timeout == ~0ULL) {
95             ALOGV("%s: uncond wait\n", __func__);
96             mCv.wait(&mLock);
97             return true;
98         } else {
99             return mCv.timedWait(&mLock, getDeadline(timeout));
100         }
101     }
102 
getDeadline(WorkPool::TimeoutUs relative)103     struct timespec getDeadline(WorkPool::TimeoutUs relative) {
104         struct timeval deadlineUs;
105         struct timespec deadlineNs;
106         gettimeofday(&deadlineUs, 0);
107 
108         auto prevDeadlineUs = deadlineUs.tv_usec;
109 
110         deadlineUs.tv_usec += relative;
111 
112         // Wrap around
113         if (prevDeadlineUs > deadlineUs.tv_usec) {
114             ++deadlineUs.tv_sec;
115         }
116 
117         deadlineNs.tv_sec = deadlineUs.tv_sec;
118         deadlineNs.tv_nsec = deadlineUs.tv_usec * 1000LL;
119         return deadlineNs;
120     }
121 
currTimeUs()122     uint64_t currTimeUs() {
123         struct timeval tv;
124         gettimeofday(&tv, 0);
125         return (uint64_t)(tv.tv_sec * 1000000LL + tv.tv_usec);
126     }
127 
conditionalTimeoutLocked(std::function<bool ()> conditionFunc,WorkPool::TimeoutUs timeout)128     bool conditionalTimeoutLocked(std::function<bool()> conditionFunc, WorkPool::TimeoutUs timeout) {
129         uint64_t currTime = currTimeUs();
130         WorkPool::TimeoutUs currTimeout = timeout;
131 
132         while (conditionFunc()) {
133             doWait(currTimeout);
134             if (!conditionFunc()) {
135                 // Decrement timeout for wakeups
136                 uint64_t nextTime = currTimeUs();
137                 WorkPool::TimeoutUs waited =
138                     nextTime - currTime;
139                 currTime = nextTime;
140 
141                 if (currTimeout > waited) {
142                     currTimeout -= waited;
143                 } else {
144                     return conditionFunc();
145                 }
146             }
147         }
148 
149         return true;
150     }
151 
152     std::atomic<int> mRefCount = { 1 };
153     int mNumTasksInitial;
154     std::atomic<int> mNumTasksRemaining;
155 
156     Lock mLock;
157     ConditionVariable mCv;
158 };
159 
160 class WorkPoolThread {
161 public:
162     // State diagram for each work pool thread
163     //
164     // Unacquired: (Start state) When no one else has claimed the thread.
165     // Acquired: When the thread has been claimed for work,
166     // but work has not been issued to it yet.
167     // Scheduled: When the thread is running tasks from the acquirer.
168     // Exiting: cleanup
169     //
170     // Messages:
171     //
172     // Acquire
173     // Run
174     // Exit
175     //
176     // Transitions:
177     //
178     // Note: While task is being run, messages will come back with a failure value.
179     //
180     // Unacquired:
181     //     message Acquire -> Acquired. effect: return success value
182     //     message Run -> Unacquired. effect: return failure value
183     //     message Exit -> Exiting. effect: return success value
184     //
185     // Acquired:
186     //     message Acquire -> Acquired. effect: return failure value
187     //     message Run -> Scheduled. effect: run the task, return success
188     //     message Exit -> Exiting. effect: return success value
189     //
190     // Scheduled:
191     //     implicit effect: after task is run, transition back to Unacquired.
192     //     message Acquire -> Scheduled. effect: return failure value
193     //     message Run -> Scheduled. effect: return failure value
194     //     message Exit -> queue up exit message, then transition to Exiting after that is done.
195     //         effect: return success value
196     //
197     enum State {
198         Unacquired = 0,
199         Acquired = 1,
200         Scheduled = 2,
201         Exiting = 3,
202     };
203 
WorkPoolThread()204     WorkPoolThread() : mThread([this] { threadFunc(); }) {
205         mThread.start();
206     }
207 
~WorkPoolThread()208     ~WorkPoolThread() {
209         exit();
210         mThread.wait();
211     }
212 
acquire()213     bool acquire() {
214         AutoLock lock(mLock);
215         switch (mState) {
216             case State::Unacquired:
217                 mState = State::Acquired;
218                 return true;
219             case State::Acquired:
220             case State::Scheduled:
221             case State::Exiting:
222                 return false;
223         }
224     }
225 
run(WorkPool::WaitGroupHandle waitGroupHandle,WaitGroup * waitGroup,WorkPool::Task task)226     bool run(WorkPool::WaitGroupHandle waitGroupHandle, WaitGroup* waitGroup, WorkPool::Task task) {
227         AutoLock lock(mLock);
228         switch (mState) {
229             case State::Unacquired:
230                 return false;
231             case State::Acquired: {
232                 mState = State::Scheduled;
233                 mToCleanupWaitGroupHandle = waitGroupHandle;
234                 waitGroup->acquire();
235                 mToCleanupWaitGroup = waitGroup;
236                 mShouldCleanupWaitGroup = false;
237                 TaskInfo msg = {
238                     Command::Run,
239                     waitGroup, task,
240                 };
241                 mRunMessages.send(msg);
242                 return true;
243             }
244             case State::Scheduled:
245             case State::Exiting:
246                 return false;
247         }
248     }
249 
shouldCleanupWaitGroup(WorkPool::WaitGroupHandle * waitGroupHandle,WaitGroup ** waitGroup)250     bool shouldCleanupWaitGroup(WorkPool::WaitGroupHandle* waitGroupHandle, WaitGroup** waitGroup) {
251         AutoLock lock(mLock);
252         bool res = mShouldCleanupWaitGroup;
253         *waitGroupHandle = mToCleanupWaitGroupHandle;
254         *waitGroup = mToCleanupWaitGroup;
255         mShouldCleanupWaitGroup = false;
256         return res;
257     }
258 
259 private:
260     enum Command {
261         Run = 0,
262         Exit = 1,
263     };
264 
265     struct TaskInfo {
266         Command cmd;
267         WaitGroup* waitGroup = nullptr;
268         WorkPool::Task task = {};
269     };
270 
exit()271     bool exit() {
272         AutoLock lock(mLock);
273         TaskInfo msg { Command::Exit, };
274         mRunMessages.send(msg);
275         return true;
276     }
277 
threadFunc()278     void threadFunc() {
279         TaskInfo taskInfo;
280         bool done = false;
281 
282         while (!done) {
283             mRunMessages.receive(&taskInfo);
284             switch (taskInfo.cmd) {
285                 case Command::Run:
286                     doRun(taskInfo);
287                     break;
288                 case Command::Exit: {
289                     AutoLock lock(mLock);
290                     mState = State::Exiting;
291                     break;
292                 }
293             }
294             AutoLock lock(mLock);
295             done = mState == State::Exiting;
296         }
297     }
298 
299     // Assumption: the wait group refcount is >= 1 when entering
300     // this function (before decrement)..
301     // at least it doesn't get to 0
doRun(TaskInfo & msg)302     void doRun(TaskInfo& msg) {
303         WaitGroup* waitGroup = msg.waitGroup;
304 
305         if (msg.task) msg.task();
306 
307         bool lastTask =
308             waitGroup->decrementBroadcast();
309 
310         AutoLock lock(mLock);
311         mState = State::Unacquired;
312 
313         if (lastTask) {
314             mShouldCleanupWaitGroup = true;
315         }
316 
317         waitGroup->release();
318     }
319 
320     FunctorThread mThread;
321     Lock mLock;
322     State mState = State::Unacquired;
323     MessageChannel<TaskInfo, 4> mRunMessages;
324     WorkPool::WaitGroupHandle mToCleanupWaitGroupHandle = 0;
325     WaitGroup* mToCleanupWaitGroup = nullptr;
326     bool mShouldCleanupWaitGroup = false;
327 };
328 
329 class WorkPool::Impl {
330 public:
Impl(int numInitialThreads)331     Impl(int numInitialThreads) : mThreads(numInitialThreads) {
332         for (size_t i = 0; i < mThreads.size(); ++i) {
333             mThreads[i].reset(new WorkPoolThread);
334         }
335     }
336 
337     ~Impl() = default;
338 
schedule(const std::vector<WorkPool::Task> & tasks)339     WorkPool::WaitGroupHandle schedule(const std::vector<WorkPool::Task>& tasks) {
340 
341         if (tasks.empty()) abort();
342 
343         AutoLock lock(mLock);
344 
345         // Sweep old wait groups
346         for (size_t i = 0; i < mThreads.size(); ++i) {
347             WaitGroupHandle handle;
348             WaitGroup* waitGroup;
349             bool cleanup = mThreads[i]->shouldCleanupWaitGroup(&handle, &waitGroup);
350             if (cleanup) {
351                 mWaitGroups.erase(handle);
352                 waitGroup->release();
353             }
354         }
355 
356         WorkPool::WaitGroupHandle resHandle = genWaitGroupHandleLocked();
357         WaitGroup* waitGroup =
358             new WaitGroup(tasks.size());
359 
360         mWaitGroups[resHandle] = waitGroup;
361 
362         std::vector<size_t> threadIndices;
363 
364         while (threadIndices.size() < tasks.size()) {
365             for (size_t i = 0; i < mThreads.size(); ++i) {
366                 if (!mThreads[i]->acquire()) continue;
367                 threadIndices.push_back(i);
368                 if (threadIndices.size() == tasks.size()) break;
369             }
370             if (threadIndices.size() < tasks.size()) {
371                 mThreads.resize(mThreads.size() + 1);
372                 mThreads[mThreads.size() - 1].reset(new WorkPoolThread);
373             }
374         }
375 
376         // every thread here is acquired
377         for (size_t i = 0; i < threadIndices.size(); ++i) {
378             mThreads[threadIndices[i]]->run(resHandle, waitGroup, tasks[i]);
379         }
380 
381         return resHandle;
382     }
383 
waitAny(WorkPool::WaitGroupHandle waitGroupHandle,WorkPool::TimeoutUs timeout)384     bool waitAny(WorkPool::WaitGroupHandle waitGroupHandle, WorkPool::TimeoutUs timeout) {
385         AutoLock lock(mLock);
386         auto it = mWaitGroups.find(waitGroupHandle);
387         if (it == mWaitGroups.end()) return true;
388 
389         auto waitGroup = it->second;
390         waitGroup->acquire();
391         lock.unlock();
392 
393         bool waitRes = false;
394 
395         {
396             AutoLock waitGroupLock(waitGroup->getLock());
397             waitRes = waitGroup->waitAnyLocked(timeout);
398         }
399 
400         waitGroup->release();
401 
402         return waitRes;
403     }
404 
waitAll(WorkPool::WaitGroupHandle waitGroupHandle,WorkPool::TimeoutUs timeout)405     bool waitAll(WorkPool::WaitGroupHandle waitGroupHandle, WorkPool::TimeoutUs timeout) {
406         auto waitGroup = acquireWaitGroupFromHandle(waitGroupHandle);
407         if (!waitGroup) return true;
408 
409         bool waitRes = false;
410 
411         {
412             AutoLock waitGroupLock(waitGroup->getLock());
413             waitRes = waitGroup->waitAllLocked(timeout);
414         }
415 
416         waitGroup->release();
417 
418         return waitRes;
419     }
420 
421 private:
422     // Increments wait group refcount by 1.
acquireWaitGroupFromHandle(WorkPool::WaitGroupHandle waitGroupHandle)423     WaitGroup* acquireWaitGroupFromHandle(WorkPool::WaitGroupHandle waitGroupHandle) {
424         AutoLock lock(mLock);
425         auto it = mWaitGroups.find(waitGroupHandle);
426         if (it == mWaitGroups.end()) return nullptr;
427 
428         auto waitGroup = it->second;
429         waitGroup->acquire();
430 
431         return waitGroup;
432     }
433 
434     using WaitGroupStore = std::unordered_map<WorkPool::WaitGroupHandle, WaitGroup*>;
435 
genWaitGroupHandleLocked()436     WorkPool::WaitGroupHandle genWaitGroupHandleLocked() {
437         WorkPool::WaitGroupHandle res = mNextWaitGroupHandle;
438         ++mNextWaitGroupHandle;
439         return res;
440     }
441 
442     Lock mLock;
443     uint64_t mNextWaitGroupHandle = 0;
444     WaitGroupStore mWaitGroups;
445     std::vector<std::unique_ptr<WorkPoolThread>> mThreads;
446 };
447 
WorkPool(int numInitialThreads)448 WorkPool::WorkPool(int numInitialThreads) : mImpl(new WorkPool::Impl(numInitialThreads)) { }
449 WorkPool::~WorkPool() = default;
450 
schedule(const std::vector<WorkPool::Task> & tasks)451 WorkPool::WaitGroupHandle WorkPool::schedule(const std::vector<WorkPool::Task>& tasks) {
452     return mImpl->schedule(tasks);
453 }
454 
waitAny(WorkPool::WaitGroupHandle waitGroup,WorkPool::TimeoutUs timeout)455 bool WorkPool::waitAny(WorkPool::WaitGroupHandle waitGroup, WorkPool::TimeoutUs timeout) {
456     return mImpl->waitAny(waitGroup, timeout);
457 }
458 
waitAll(WorkPool::WaitGroupHandle waitGroup,WorkPool::TimeoutUs timeout)459 bool WorkPool::waitAll(WorkPool::WaitGroupHandle waitGroup, WorkPool::TimeoutUs timeout) {
460     return mImpl->waitAll(waitGroup, timeout);
461 }
462 
463 } // namespace guest
464 } // namespace base
465 } // namespace android
466