1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.net.module.util
18 
19 import java.util.concurrent.TimeUnit
20 import java.util.concurrent.locks.Condition
21 import java.util.concurrent.locks.ReentrantLock
22 import kotlin.concurrent.withLock
23 
24 /**
25  * A List that additionally offers the ability to append via the add() method, and to retrieve
26  * an element by its index optionally waiting for it to become available.
27  */
28 interface TrackRecord<E> : List<E> {
29     /**
30      * Adds an element to this queue, waking up threads waiting for one. Returns true, as
31      * per the contract for List.
32      */
addnull33     fun add(e: E): Boolean
34 
35     /**
36      * Returns the first element after {@param pos}, possibly blocking until one is available, or
37      * null if no such element can be found within the timeout.
38      * If a predicate is given, only elements matching the predicate are returned.
39      *
40      * @param timeoutMs how long, in milliseconds, to wait at most (best effort approximation).
41      * @param pos the position at which to start polling.
42      * @param predicate an optional predicate to filter elements to be returned.
43      * @return an element matching the predicate, or null if timeout.
44      */
45     fun poll(timeoutMs: Long, pos: Int, predicate: (E) -> Boolean = { true }): E?
46 }
47 
48 /**
49  * A thread-safe implementation of TrackRecord that is backed by an ArrayList.
50  *
51  * This class also supports the creation of a read-head for easier single-thread access.
52  * Refer to the documentation of {@link ArrayTrackRecord.ReadHead}.
53  */
54 class ArrayTrackRecord<E> : TrackRecord<E> {
55     private val lock = ReentrantLock()
56     private val condition = lock.newCondition()
57     // Backing store. This stores the elements in this ArrayTrackRecord.
58     private val elements = ArrayList<E>()
59 
60     // The list iterator for RecordingQueue iterates over a snapshot of the collection at the
61     // time the operator is created. Because TrackRecord is only ever mutated by appending,
62     // that makes this iterator thread-safe as it sees an effectively immutable List.
63     class ArrayTrackRecordIterator<E>(
64         private val list: ArrayList<E>,
65         start: Int,
66         private val end: Int
67     ) : ListIterator<E> {
68         var index = start
hasNextnull69         override fun hasNext() = index < end
70         override fun next() = list[index++]
71         override fun hasPrevious() = index > 0
72         override fun nextIndex() = index + 1
73         override fun previous() = list[--index]
74         override fun previousIndex() = index - 1
75     }
76 
77     // List<E> implementation
78     override val size get() = lock.withLock { elements.size }
<lambda>null79     override fun contains(element: E) = lock.withLock { elements.contains(element) }
<lambda>null80     override fun containsAll(elements: Collection<E>) = lock.withLock {
81         this.elements.containsAll(elements)
82     }
<lambda>null83     override operator fun get(index: Int) = lock.withLock { elements[index] }
<lambda>null84     override fun indexOf(element: E): Int = lock.withLock { elements.indexOf(element) }
<lambda>null85     override fun lastIndexOf(element: E): Int = lock.withLock { elements.lastIndexOf(element) }
<lambda>null86     override fun isEmpty() = lock.withLock { elements.isEmpty() }
listIteratornull87     override fun listIterator(index: Int) = ArrayTrackRecordIterator(elements, index, size)
88     override fun listIterator() = listIterator(0)
89     override fun iterator() = listIterator()
90     override fun subList(fromIndex: Int, toIndex: Int): List<E> = lock.withLock {
91         elements.subList(fromIndex, toIndex)
92     }
93 
94     // TrackRecord<E> implementation
addnull95     override fun add(e: E): Boolean {
96         lock.withLock {
97             elements.add(e)
98             condition.signalAll()
99         }
100         return true
101     }
<lambda>null102     override fun poll(timeoutMs: Long, pos: Int, predicate: (E) -> Boolean) = lock.withLock {
103         elements.getOrNull(pollForIndexReadLocked(timeoutMs, pos, predicate))
104     }
105 
106     // For convenience
<lambda>null107     fun getOrNull(pos: Int, predicate: (E) -> Boolean) = lock.withLock {
108         if (pos < 0 || pos > size) null else elements.subList(pos, size).find(predicate)
109     }
110 
111     // Returns the index of the next element whose position is >= pos matching the predicate, if
112     // necessary waiting until such a time that such an element is available, with a timeout.
113     // If no such element is found within the timeout -1 is returned.
pollForIndexReadLockednull114     private fun pollForIndexReadLocked(timeoutMs: Long, pos: Int, predicate: (E) -> Boolean): Int {
115         val deadline = System.currentTimeMillis() + timeoutMs
116         var index = pos
117         do {
118             while (index < elements.size) {
119                 if (predicate(elements[index])) return index
120                 ++index
121             }
122         } while (condition.await(deadline - System.currentTimeMillis()))
123         return -1
124     }
125 
126     /**
127      * Returns a ReadHead over this ArrayTrackRecord. The returned ReadHead is tied to the
128      * current thread.
129      */
newReadHeadnull130     fun newReadHead() = ReadHead()
131 
132     /**
133      * ReadHead is an object that helps users of ArrayTrackRecord keep track of how far
134      * it has read this far in the ArrayTrackRecord. A ReadHead is always associated with
135      * a single instance of ArrayTrackRecord. Multiple ReadHeads can be created and used
136      * on the same instance of ArrayTrackRecord concurrently, and the ArrayTrackRecord
137      * instance can also be used concurrently. ReadHead maintains the current index that is
138      * the next to be read, and calls this the "mark".
139      *
140      * A ReadHead delegates all TrackRecord methods to its associated ArrayTrackRecord, and
141      * inherits its thread-safe properties. However, the additional methods that ReadHead
142      * offers on top of TrackRecord do not share these properties and can only be used by
143      * the thread that created the ReadHead. This is because by construction it does not
144      * make sense to use a ReadHead on multiple threads concurrently (see below for details).
145      *
146      * In a ReadHead, {@link poll(Long, (E) -> Boolean)} works similarly to a LinkedBlockingQueue.
147      * It can be called repeatedly and will return the elements as they arrive.
148      *
149      * Intended usage looks something like this :
150      * val TrackRecord<MyObject> record = ArrayTrackRecord().newReadHead()
151      * Thread().start {
152      *   // do stuff
153      *   record.add(something)
154      *   // do stuff
155      * }
156      *
157      * val obj1 = record.poll(timeout)
158      * // do something with obj1
159      * val obj2 = record.poll(timeout)
160      * // do something with obj2
161      *
162      * The point is that the caller does not have to track the mark like it would have to if
163      * it was using ArrayTrackRecord directly.
164      *
165      * Note that if multiple threads were using poll() concurrently on the same ReadHead, what
166      * happens to the mark and the return values could be well defined, but it could not
167      * be useful because there is no way to provide either a guarantee not to skip objects nor
168      * a guarantee about the mark position at the exit of poll(). This is even more true in the
169      * presence of a predicate to filter returned elements, because one thread might be
170      * filtering out the events the other is interested in.
171      * Instead, this use case is supported by creating multiple ReadHeads on the same instance
172      * of ArrayTrackRecord. Each ReadHead is then guaranteed to see all events always and
173      * guarantees are made on the value of the mark upon return. {@see poll(Long, (E) -> Boolean)}
174      * for details. Be careful to create each ReadHead on the thread it is meant to be used on.
175      *
176      * Users of a ReadHead can ask for the current position of the mark at any time. This mark
177      * can be used later to replay the history of events either on this ReadHead, on the associated
178      * ArrayTrackRecord or on another ReadHead associated with the same ArrayTrackRecord. It
179      * might look like this in the reader thread :
180      *
181      * val markAtStart = record.mark
182      * // Start processing interesting events
183      * while (val element = record.poll(timeout) { it.isInteresting() }) {
184      *   // Do something with element
185      * }
186      * // Look for stuff that happened while searching for interesting events
187      * val firstElementReceived = record.getOrNull(markAtStart)
188      * val firstSpecialElement = record.getOrNull(markAtStart) { it.isSpecial() }
189      * // Get the first special element since markAtStart, possibly blocking until one is available
190      * val specialElement = record.poll(timeout, markAtStart) { it.isSpecial() }
191      */
192     inner class ReadHead : TrackRecord<E> by this@ArrayTrackRecord {
193         private val owningThread = Thread.currentThread()
194         private var readHead = 0
195 
196         /**
197          * @return the current value of the mark.
198          */
199         var mark
200             get() = readHead.also { checkThread() }
201             set(v: Int) = rewind(v)
202         fun rewind(v: Int) {
203             checkThread()
204             readHead = v
205         }
206 
207         private fun checkThread() = check(Thread.currentThread() == owningThread) {
208             "Must be called by the thread that created this object"
209         }
210 
211         /**
212          * Returns the first element after the mark, optionally blocking until one is available, or
213          * null if no such element can be found within the timeout.
214          * If a predicate is given, only elements matching the predicate are returned.
215          *
216          * Upon return the mark will be set to immediately after the returned element, or after
217          * the last element in the queue if null is returned. This means this method will always
218          * skip elements that do not match the predicate, even if it returns null.
219          *
220          * This method can only be used by the thread that created this ManagedRecordingQueue.
221          * If used on another thread, this throws IllegalStateException.
222          *
223          * @param timeoutMs how long, in milliseconds, to wait at most (best effort approximation).
224          * @param predicate an optional predicate to filter elements to be returned.
225          * @return an element matching the predicate, or null if timeout.
226          */
227         fun poll(timeoutMs: Long, predicate: (E) -> Boolean = { true }): E? {
228             checkThread()
229             lock.withLock {
230                 val index = pollForIndexReadLocked(timeoutMs, readHead, predicate)
231                 readHead = if (index < 0) size else index + 1
232                 return getOrNull(index)
233             }
234         }
235 
236         /**
237          * Returns the first element after the mark or null. This never blocks.
238          *
239          * This method can only be used by the thread that created this ManagedRecordingQueue.
240          * If used on another thread, this throws IllegalStateException.
241          */
242         fun peek(): E? = getOrNull(readHead).also { checkThread() }
243     }
244 }
245 
246 // Private helper
Conditionnull247 private fun Condition.await(timeoutMs: Long) = this.await(timeoutMs, TimeUnit.MILLISECONDS)
248