1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.testutils
18 
19 import android.net.ConnectivityManager.NetworkCallback
20 import android.net.LinkProperties
21 import android.net.Network
22 import android.net.NetworkCapabilities
23 import android.net.NetworkCapabilities.NET_CAPABILITY_VALIDATED
24 import com.android.net.module.util.ArrayTrackRecord
25 import com.android.testutils.RecorderCallback.CallbackEntry.Available
26 import com.android.testutils.RecorderCallback.CallbackEntry.BlockedStatus
27 import com.android.testutils.RecorderCallback.CallbackEntry.CapabilitiesChanged
28 import com.android.testutils.RecorderCallback.CallbackEntry.LinkPropertiesChanged
29 import com.android.testutils.RecorderCallback.CallbackEntry.Losing
30 import com.android.testutils.RecorderCallback.CallbackEntry.Lost
31 import com.android.testutils.RecorderCallback.CallbackEntry.Resumed
32 import com.android.testutils.RecorderCallback.CallbackEntry.Suspended
33 import com.android.testutils.RecorderCallback.CallbackEntry.Unavailable
34 import kotlin.reflect.KClass
35 import kotlin.test.assertEquals
36 import kotlin.test.assertNotNull
37 import kotlin.test.assertTrue
38 import kotlin.test.fail
39 
40 object NULL_NETWORK : Network(-1)
41 object ANY_NETWORK : Network(-2)
42 
43 private val Int.capabilityName get() = NetworkCapabilities.capabilityNameOf(this)
44 
45 open class RecorderCallback private constructor(
46     private val backingRecord: ArrayTrackRecord<CallbackEntry>
47 ) : NetworkCallback() {
48     public constructor() : this(ArrayTrackRecord())
49     protected constructor(src: RecorderCallback?): this(src?.backingRecord ?: ArrayTrackRecord())
50 
51     sealed class CallbackEntry {
52         // To get equals(), hashcode(), componentN() etc for free, the child classes of
53         // this class are data classes. But while data classes can inherit from other classes,
54         // they may only have visible members in the constructors, so they couldn't declare
55         // a constructor with a non-val arg to pass to CallbackEntry. Instead, force all
56         // subclasses to implement a `network' property, which can be done in a data class
57         // constructor by specifying override.
58         abstract val network: Network
59 
60         data class Available(override val network: Network) : CallbackEntry()
61         data class CapabilitiesChanged(
62             override val network: Network,
63             val caps: NetworkCapabilities
64         ) : CallbackEntry()
65         data class LinkPropertiesChanged(
66             override val network: Network,
67             val lp: LinkProperties
68         ) : CallbackEntry()
69         data class Suspended(override val network: Network) : CallbackEntry()
70         data class Resumed(override val network: Network) : CallbackEntry()
71         data class Losing(override val network: Network, val maxMsToLive: Int) : CallbackEntry()
72         data class Lost(override val network: Network) : CallbackEntry()
73         data class Unavailable private constructor(
74             override val network: Network
75         ) : CallbackEntry() {
76             constructor() : this(NULL_NETWORK)
77         }
78         data class BlockedStatus(
79             override val network: Network,
80             val blocked: Boolean
81         ) : CallbackEntry()
82 
83         // Convenience constants for expecting a type
84         companion object {
85             @JvmField
86             val AVAILABLE = Available::class
87             @JvmField
88             val NETWORK_CAPS_UPDATED = CapabilitiesChanged::class
89             @JvmField
90             val LINK_PROPERTIES_CHANGED = LinkPropertiesChanged::class
91             @JvmField
92             val SUSPENDED = Suspended::class
93             @JvmField
94             val RESUMED = Resumed::class
95             @JvmField
96             val LOSING = Losing::class
97             @JvmField
98             val LOST = Lost::class
99             @JvmField
100             val UNAVAILABLE = Unavailable::class
101             @JvmField
102             val BLOCKED_STATUS = BlockedStatus::class
103         }
104     }
105 
106     val history = backingRecord.newReadHead()
107     val mark get() = history.mark
108 
onAvailablenull109     override fun onAvailable(network: Network) {
110         history.add(Available(network))
111     }
112 
113     // PreCheck is not used in the tests today. For backward compatibility with existing tests that
114     // expect the callbacks not to record this, do not listen to PreCheck here.
115 
onCapabilitiesChangednull116     override fun onCapabilitiesChanged(network: Network, caps: NetworkCapabilities) {
117         history.add(CapabilitiesChanged(network, caps))
118     }
119 
onLinkPropertiesChangednull120     override fun onLinkPropertiesChanged(network: Network, lp: LinkProperties) {
121         history.add(LinkPropertiesChanged(network, lp))
122     }
123 
onBlockedStatusChangednull124     override fun onBlockedStatusChanged(network: Network, blocked: Boolean) {
125         history.add(BlockedStatus(network, blocked))
126     }
127 
onNetworkSuspendednull128     override fun onNetworkSuspended(network: Network) {
129         history.add(Suspended(network))
130     }
131 
onNetworkResumednull132     override fun onNetworkResumed(network: Network) {
133         history.add(Resumed(network))
134     }
135 
onLosingnull136     override fun onLosing(network: Network, maxMsToLive: Int) {
137         history.add(Losing(network, maxMsToLive))
138     }
139 
onLostnull140     override fun onLost(network: Network) {
141         history.add(Lost(network))
142     }
143 
onUnavailablenull144     override fun onUnavailable() {
145         history.add(Unavailable())
146     }
147 }
148 
149 private const val DEFAULT_TIMEOUT = 200L // ms
150 
151 open class TestableNetworkCallback private constructor(
152     src: TestableNetworkCallback?,
153     val defaultTimeoutMs: Long = DEFAULT_TIMEOUT
154 ) : RecorderCallback(src) {
155     @JvmOverloads
156     constructor(timeoutMs: Long = DEFAULT_TIMEOUT): this(null, timeoutMs)
157 
createLinkedCopynull158     fun createLinkedCopy() = TestableNetworkCallback(this, defaultTimeoutMs)
159 
160     // The last available network, or null if any network was lost since the last call to
161     // onAvailable. TODO : fix this by fixing the tests that rely on this behavior
162     val lastAvailableNetwork: Network?
163         get() = when (val it = history.lastOrNull { it is Available || it is Lost }) {
164             is Available -> it.network
165             else -> null
166         }
167 
pollForNextCallbacknull168     fun pollForNextCallback(timeoutMs: Long = defaultTimeoutMs): CallbackEntry {
169         return history.poll(timeoutMs) ?: fail("Did not receive callback after ${timeoutMs}ms")
170     }
171 
172     // Make open for use in ConnectivityServiceTest which is the only one knowing its handlers.
173     @JvmOverloads
assertNoCallbacknull174     open fun assertNoCallback(timeoutMs: Long = defaultTimeoutMs) {
175         val cb = history.poll(timeoutMs)
176         if (null != cb) fail("Expected no callback but got $cb")
177     }
178 
179     // Expects a callback of the specified type on the specified network within the timeout.
180     // If no callback arrives, or a different callback arrives, fail. Returns the callback.
expectCallbacknull181     inline fun <reified T : CallbackEntry> expectCallback(
182         network: Network = ANY_NETWORK,
183         timeoutMs: Long = defaultTimeoutMs
184     ): T = pollForNextCallback(timeoutMs).let {
185         if (it !is T || (ANY_NETWORK !== network && it.network != network)) {
186             fail("Unexpected callback : $it, expected ${T::class} with Network[$network]")
187         } else {
188             it
189         }
190     }
191 
192     // Expects a callback of the specified type matching the predicate within the timeout.
193     // Any callback that doesn't match the predicate will be skipped. Fails only if
194     // no matching callback is received within the timeout.
eventuallyExpectnull195     inline fun <reified T : CallbackEntry> eventuallyExpect(
196         timeoutMs: Long = defaultTimeoutMs,
197         from: Int = mark,
198         crossinline predicate: (T) -> Boolean = { true }
<lambda>null199     ): T = eventuallyExpectOrNull(timeoutMs, from, predicate).also {
200         assertNotNull(it, "Callback ${T::class} not received within ${timeoutMs}ms")
201     } as T
202 
203     // TODO (b/157405399) straighten and unify the method names
eventuallyExpectOrNullnull204     inline fun <reified T : CallbackEntry> eventuallyExpectOrNull(
205         timeoutMs: Long = defaultTimeoutMs,
206         from: Int = mark,
207         crossinline predicate: (T) -> Boolean = { true }
<lambda>null208     ) = history.poll(timeoutMs, from) { it is T && predicate(it) } as T?
209 
expectCallbackThatnull210     fun expectCallbackThat(
211         timeoutMs: Long = defaultTimeoutMs,
212         valid: (CallbackEntry) -> Boolean
213     ) = pollForNextCallback(timeoutMs).also { assertTrue(valid(it), "Unexpected callback : $it") }
214 
expectCapabilitiesThatnull215     fun expectCapabilitiesThat(
216         net: Network,
217         tmt: Long = defaultTimeoutMs,
218         valid: (NetworkCapabilities) -> Boolean
219     ): CapabilitiesChanged {
220         return expectCallback<CapabilitiesChanged>(net, tmt).also {
221             assertTrue(valid(it.caps), "Capabilities don't match expectations ${it.caps}")
222         }
223     }
224 
expectLinkPropertiesThatnull225     fun expectLinkPropertiesThat(
226         net: Network,
227         tmt: Long = defaultTimeoutMs,
228         valid: (LinkProperties) -> Boolean
229     ): LinkPropertiesChanged {
230         return expectCallback<LinkPropertiesChanged>(net, tmt).also {
231             assertTrue(valid(it.lp), "LinkProperties don't match expectations ${it.lp}")
232         }
233     }
234 
235     // Expects onAvailable and the callbacks that follow it. These are:
236     // - onSuspended, iff the network was suspended when the callbacks fire.
237     // - onCapabilitiesChanged.
238     // - onLinkPropertiesChanged.
239     // - onBlockedStatusChanged.
240     //
241     // @param network the network to expect the callbacks on.
242     // @param suspended whether to expect a SUSPENDED callback.
243     // @param validated the expected value of the VALIDATED capability in the
244     //        onCapabilitiesChanged callback.
245     // @param tmt how long to wait for the callbacks.
expectAvailableCallbacksnull246     fun expectAvailableCallbacks(
247         net: Network,
248         suspended: Boolean = false,
249         validated: Boolean = true,
250         blocked: Boolean = false,
251         tmt: Long = defaultTimeoutMs
252     ) {
253         expectCallback<Available>(net, tmt)
254         if (suspended) {
255             expectCallback<Suspended>(net, tmt)
256         }
257         expectCapabilitiesThat(net, tmt) { validated == it.hasCapability(NET_CAPABILITY_VALIDATED) }
258         expectCallback<LinkPropertiesChanged>(net, tmt)
259         expectBlockedStatusCallback(blocked, net)
260     }
261 
262     // Backward compatibility for existing Java code. Use named arguments instead and remove all
263     // these when there is no user left.
expectAvailableAndSuspendedCallbacksnull264     fun expectAvailableAndSuspendedCallbacks(
265         net: Network,
266         validated: Boolean,
267         tmt: Long = defaultTimeoutMs
268     ) = expectAvailableCallbacks(net, suspended = true, validated = validated, tmt = tmt)
269 
270     fun expectBlockedStatusCallback(blocked: Boolean, net: Network, tmt: Long = defaultTimeoutMs) {
271         expectCallback<BlockedStatus>(net, tmt).also {
272             assertEquals(it.blocked, blocked, "Unexpected blocked status ${it.blocked}")
273         }
274     }
275 
276     // Expects the available callbacks (where the onCapabilitiesChanged must contain the
277     // VALIDATED capability), plus another onCapabilitiesChanged which is identical to the
278     // one we just sent.
279     // TODO: this is likely a bug. Fix it and remove this method.
expectAvailableDoubleValidatedCallbacksnull280     fun expectAvailableDoubleValidatedCallbacks(net: Network, tmt: Long = defaultTimeoutMs) {
281         val mark = history.mark
282         expectAvailableCallbacks(net, tmt = tmt)
283         val firstCaps = history.poll(tmt, mark) { it is CapabilitiesChanged }
284         assertEquals(firstCaps, expectCallback<CapabilitiesChanged>(net, tmt))
285     }
286 
287     // Expects the available callbacks where the onCapabilitiesChanged must not have validated,
288     // then expects another onCapabilitiesChanged that has the validated bit set. This is used
289     // when a network connects and satisfies a callback, and then immediately validates.
expectAvailableThenValidatedCallbacksnull290     fun expectAvailableThenValidatedCallbacks(net: Network, tmt: Long = defaultTimeoutMs) {
291         expectAvailableCallbacks(net, validated = false, tmt = tmt)
292         expectCapabilitiesThat(net, tmt) { it.hasCapability(NET_CAPABILITY_VALIDATED) }
293     }
294 
295     // Temporary Java compat measure : have MockNetworkAgent implement this so that all existing
296     // calls with networkAgent can be routed through here without moving MockNetworkAgent.
297     // TODO: clean this up, remove this method.
298     interface HasNetwork {
299         val network: Network
300     }
301 
302     @JvmOverloads
expectCallbacknull303     open fun <T : CallbackEntry> expectCallback(
304         type: KClass<T>,
305         n: Network?,
306         timeoutMs: Long = defaultTimeoutMs
307     ) = pollForNextCallback(timeoutMs).also {
308         val network = n ?: NULL_NETWORK
309         // TODO : remove this .java access if the tests ever use kotlin-reflect. At the time of
310         // this writing this would be the only use of this library in the tests.
311         assertTrue(type.java.isInstance(it) && it.network == network,
312                 "Unexpected callback : $it, expected ${type.java} with Network[$network]")
313     } as T
314 
315     @JvmOverloads
expectCallbacknull316     open fun <T : CallbackEntry> expectCallback(
317         type: KClass<T>,
318         n: HasNetwork?,
319         timeoutMs: Long = defaultTimeoutMs
320     ) = expectCallback(type, n?.network, timeoutMs)
321 
322     fun expectAvailableCallbacks(
323         n: HasNetwork,
324         suspended: Boolean,
325         validated: Boolean,
326         blocked: Boolean,
327         timeoutMs: Long
328     ) = expectAvailableCallbacks(n.network, suspended, validated, blocked, timeoutMs)
329 
330     fun expectAvailableAndSuspendedCallbacks(n: HasNetwork, expectValidated: Boolean) {
331         expectAvailableAndSuspendedCallbacks(n.network, expectValidated)
332     }
333 
expectAvailableCallbacksValidatednull334     fun expectAvailableCallbacksValidated(n: HasNetwork) {
335         expectAvailableCallbacks(n.network)
336     }
337 
expectAvailableCallbacksValidatedAndBlockednull338     fun expectAvailableCallbacksValidatedAndBlocked(n: HasNetwork) {
339         expectAvailableCallbacks(n.network, blocked = true)
340     }
341 
expectAvailableCallbacksUnvalidatednull342     fun expectAvailableCallbacksUnvalidated(n: HasNetwork) {
343         expectAvailableCallbacks(n.network, validated = false)
344     }
345 
expectAvailableCallbacksUnvalidatedAndBlockednull346     fun expectAvailableCallbacksUnvalidatedAndBlocked(n: HasNetwork) {
347         expectAvailableCallbacks(n.network, validated = false, blocked = true)
348     }
349 
expectAvailableDoubleValidatedCallbacksnull350     fun expectAvailableDoubleValidatedCallbacks(n: HasNetwork) {
351         expectAvailableDoubleValidatedCallbacks(n.network, defaultTimeoutMs)
352     }
353 
expectAvailableThenValidatedCallbacksnull354     fun expectAvailableThenValidatedCallbacks(n: HasNetwork) {
355         expectAvailableThenValidatedCallbacks(n.network, defaultTimeoutMs)
356     }
357 
358     @JvmOverloads
expectLinkPropertiesThatnull359     fun expectLinkPropertiesThat(
360         n: HasNetwork,
361         tmt: Long = defaultTimeoutMs,
362         valid: (LinkProperties) -> Boolean
363     ) = expectLinkPropertiesThat(n.network, tmt, valid)
364 
365     @JvmOverloads
366     fun expectCapabilitiesThat(
367         n: HasNetwork,
368         tmt: Long = defaultTimeoutMs,
369         valid: (NetworkCapabilities) -> Boolean
370     ) = expectCapabilitiesThat(n.network, tmt, valid)
371 
372     @JvmOverloads
373     fun expectCapabilitiesWith(
374         capability: Int,
375         n: HasNetwork,
376         timeoutMs: Long = defaultTimeoutMs
377     ): NetworkCapabilities {
378         return expectCapabilitiesThat(n.network, timeoutMs) { it.hasCapability(capability) }.caps
379     }
380 
381     @JvmOverloads
expectCapabilitiesWithoutnull382     fun expectCapabilitiesWithout(
383         capability: Int,
384         n: HasNetwork,
385         timeoutMs: Long = defaultTimeoutMs
386     ): NetworkCapabilities {
387         return expectCapabilitiesThat(n.network, timeoutMs) { !it.hasCapability(capability) }.caps
388     }
389 
expectBlockedStatusCallbacknull390     fun expectBlockedStatusCallback(expectBlocked: Boolean, n: HasNetwork) {
391         expectBlockedStatusCallback(expectBlocked, n.network, defaultTimeoutMs)
392     }
393 }
394