1 /*
2  * Copyright (C) 2020 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.captiveportallogin;
18 
19 import static android.Manifest.permission.MANAGE_TEST_NETWORKS;
20 import static android.app.Activity.RESULT_OK;
21 import static android.content.Intent.ACTION_CREATE_DOCUMENT;
22 import static android.net.ConnectivityManager.ACTION_CAPTIVE_PORTAL_SIGN_IN;
23 import static android.net.ConnectivityManager.EXTRA_CAPTIVE_PORTAL;
24 import static android.net.ConnectivityManager.EXTRA_CAPTIVE_PORTAL_URL;
25 import static android.net.ConnectivityManager.EXTRA_CAPTIVE_PORTAL_USER_AGENT;
26 import static android.net.ConnectivityManager.EXTRA_NETWORK;
27 import static android.net.NetworkCapabilities.NET_CAPABILITY_VALIDATED;
28 import static android.provider.DeviceConfig.NAMESPACE_CONNECTIVITY;
29 
30 import static androidx.test.espresso.intent.Intents.intending;
31 import static androidx.test.espresso.intent.matcher.IntentMatchers.hasAction;
32 import static androidx.test.espresso.intent.matcher.IntentMatchers.isInternal;
33 import static androidx.test.espresso.web.sugar.Web.onWebView;
34 import static androidx.test.espresso.web.webdriver.DriverAtoms.findElement;
35 import static androidx.test.espresso.web.webdriver.DriverAtoms.webClick;
36 import static androidx.test.platform.app.InstrumentationRegistry.getInstrumentation;
37 
38 import static com.android.dx.mockito.inline.extended.ExtendedMockito.doReturn;
39 import static com.android.dx.mockito.inline.extended.ExtendedMockito.mockitoSession;
40 import static com.android.dx.mockito.inline.extended.ExtendedMockito.spyOn;
41 import static com.android.testutils.TestNetworkTrackerKt.initTestNetwork;
42 
43 import static junit.framework.Assert.assertEquals;
44 
45 import static org.hamcrest.CoreMatchers.not;
46 import static org.junit.Assert.assertFalse;
47 import static org.junit.Assert.assertNotNull;
48 import static org.junit.Assert.assertTrue;
49 import static org.mockito.ArgumentMatchers.argThat;
50 import static org.mockito.Mockito.any;
51 import static org.mockito.Mockito.mock;
52 import static org.mockito.Mockito.spy;
53 import static org.mockito.Mockito.verify;
54 
55 import android.app.Instrumentation.ActivityResult;
56 import android.app.KeyguardManager;
57 import android.app.UiAutomation;
58 import android.app.admin.DevicePolicyManager;
59 import android.content.ComponentName;
60 import android.content.Context;
61 import android.content.Intent;
62 import android.net.CaptivePortal;
63 import android.net.ConnectivityManager;
64 import android.net.InetAddresses;
65 import android.net.LinkAddress;
66 import android.net.Network;
67 import android.net.NetworkCapabilities;
68 import android.net.Uri;
69 import android.os.Parcel;
70 import android.os.Parcelable;
71 import android.provider.DeviceConfig;
72 
73 import androidx.test.InstrumentationRegistry;
74 import androidx.test.core.app.ActivityScenario;
75 import androidx.test.espresso.intent.Intents;
76 import androidx.test.espresso.intent.rule.IntentsTestRule;
77 import androidx.test.espresso.web.webdriver.Locator;
78 import androidx.test.ext.junit.runners.AndroidJUnit4;
79 import androidx.test.filters.SmallTest;
80 
81 import com.android.testutils.TestNetworkTracker;
82 
83 import org.junit.After;
84 import org.junit.Before;
85 import org.junit.Rule;
86 import org.junit.Test;
87 import org.junit.runner.RunWith;
88 import org.mockito.ArgumentCaptor;
89 import org.mockito.MockitoAnnotations;
90 import org.mockito.MockitoSession;
91 import org.mockito.quality.Strictness;
92 
93 import java.io.IOException;
94 import java.net.ServerSocket;
95 import java.util.Collections;
96 import java.util.HashMap;
97 import java.util.Map;
98 import java.util.function.BooleanSupplier;
99 
100 import fi.iki.elonen.NanoHTTPD;
101 
102 @RunWith(AndroidJUnit4.class)
103 @SmallTest
104 public class CaptivePortalLoginActivityTest {
105     private static final String TEST_URL = "http://android.test.com";
106     private static final int TEST_NETID = 1234;
107     private static final String TEST_URL_QUERY = "testquery";
108     private static final long TEST_TIMEOUT_MS = 10_000L;
109     private static final LinkAddress TEST_LINKADDR = new LinkAddress(
110             InetAddresses.parseNumericAddress("2001:db8::8"), 64);
111     private static final String TEST_USERAGENT = "Test/42.0 Unit-test";
112     private CaptivePortalLoginActivity mActivity;
113     private MockitoSession mSession;
114     private Network mNetwork = new Network(TEST_NETID);
115     private TestNetworkTracker mTestNetworkTracker;
116 
117     private static ConnectivityManager sConnectivityManager;
118     private static DevicePolicyManager sMockDevicePolicyManager;
119 
120     public static class InstrumentedCaptivePortalLoginActivity extends CaptivePortalLoginActivity {
121         @Override
getSystemService(String name)122         public Object getSystemService(String name) {
123             if (Context.CONNECTIVITY_SERVICE.equals(name)) return sConnectivityManager;
124             if (Context.DEVICE_POLICY_SERVICE.equals(name)) return sMockDevicePolicyManager;
125             return super.getSystemService(name);
126         }
127     }
128 
129     /** Class to replace CaptivePortal to prevent mock object is updated and replaced by parcel. */
130     public static class MockCaptivePortal extends CaptivePortal {
131         int mDismissTimes;
132         int mIgnoreTimes;
133         int mUseTimes;
134 
MockCaptivePortal()135         private MockCaptivePortal() {
136             this(0, 0, 0);
137         }
MockCaptivePortal(int dismissTimes, int ignoreTimes, int useTimes)138         private MockCaptivePortal(int dismissTimes, int ignoreTimes, int useTimes) {
139             super(null);
140             mDismissTimes = dismissTimes;
141             mIgnoreTimes = ignoreTimes;
142             mUseTimes = useTimes;
143         }
144         @Override
reportCaptivePortalDismissed()145         public void reportCaptivePortalDismissed() {
146             mDismissTimes++;
147         }
148 
149         @Override
ignoreNetwork()150         public void ignoreNetwork() {
151             mIgnoreTimes++;
152         }
153 
154         @Override
useNetwork()155         public void useNetwork() {
156             mUseTimes++;
157         }
158 
159         @Override
logEvent(int eventId, String packageName)160         public void logEvent(int eventId, String packageName) {
161             // Do nothing
162         }
163 
164         @Override
writeToParcel(Parcel out, int flags)165         public void writeToParcel(Parcel out, int flags) {
166             out.writeInt(mDismissTimes);
167             out.writeInt(mIgnoreTimes);
168             out.writeInt(mUseTimes);
169         }
170 
171         public static final Parcelable.Creator<MockCaptivePortal> CREATOR =
172                 new Parcelable.Creator<MockCaptivePortal>() {
173                 @Override
174                 public MockCaptivePortal createFromParcel(Parcel in) {
175                     return new MockCaptivePortal(in.readInt(), in.readInt(), in.readInt());
176                 }
177 
178                 @Override
179                 public MockCaptivePortal[] newArray(int size) {
180                     return new MockCaptivePortal[size];
181                 }
182         };
183     }
184 
185     @Rule
186     public final IntentsTestRule mActivityRule =
187             new IntentsTestRule<>(InstrumentedCaptivePortalLoginActivity.class,
188                     false /* initialTouchMode */, false  /* launchActivity */);
189 
190     @Before
setUp()191     public void setUp() throws Exception {
192         final Context context = getInstrumentation().getContext();
193         sConnectivityManager = spy(context.getSystemService(ConnectivityManager.class));
194         sMockDevicePolicyManager = mock(DevicePolicyManager.class);
195         MockitoAnnotations.initMocks(this);
196         mSession = mockitoSession()
197                 .spyStatic(DeviceConfig.class)
198                 .strictness(Strictness.WARN)
199                 .startMocking();
200         setDismissPortalInValidatedNetwork(true);
201         // Use a real (but test) network for the application. The application will pass this
202         // network to ConnectivityManager#bindProcessToNetwork, so it needs to be a real, existing
203         // network on the device but otherwise has no functional use at all. The http server set up
204         // by this test will run on the loopback interface and will not use this test network.
205         final UiAutomation automation = InstrumentationRegistry.getInstrumentation()
206                 .getUiAutomation();
207         automation.adoptShellPermissionIdentity(MANAGE_TEST_NETWORKS);
208         try {
209             mTestNetworkTracker = initTestNetwork(
210                     getInstrumentation().getContext(), TEST_LINKADDR, TEST_TIMEOUT_MS);
211         } finally {
212             automation.dropShellPermissionIdentity();
213         }
214         mNetwork = mTestNetworkTracker.getNetwork();
215     }
216 
217     @After
tearDown()218     public void tearDown() throws Exception {
219         mSession.finishMocking();
220         mActivityRule.finishActivity();
221         getInstrumentation().getContext().getSystemService(ConnectivityManager.class)
222                 .bindProcessToNetwork(null);
223         mTestNetworkTracker.teardown();
224     }
225 
initActivity(String url)226     private void initActivity(String url) {
227         // onCreate will be triggered in launchActivity(). Handle mock objects after
228         // launchActivity() if any new mock objects. Activity launching flow will be
229         //  1. launchActivity()
230         //  2. onCreate()
231         //  3. end of launchActivity()
232         mActivity = (InstrumentedCaptivePortalLoginActivity) mActivityRule.launchActivity(
233             new Intent(ACTION_CAPTIVE_PORTAL_SIGN_IN)
234                 .putExtra(EXTRA_CAPTIVE_PORTAL_URL, url)
235                 .putExtra(EXTRA_NETWORK, mNetwork)
236                 .putExtra(EXTRA_CAPTIVE_PORTAL_USER_AGENT, TEST_USERAGENT)
237                 .putExtra(EXTRA_CAPTIVE_PORTAL, new MockCaptivePortal())
238         );
239         // Verify activity created successfully.
240         assertNotNull(mActivity);
241         getInstrumentation().getContext().getSystemService(KeyguardManager.class)
242                 .requestDismissKeyguard(mActivity, null);
243         // Dismiss dialogs or notification shade, so that the test can interact with the activity.
244         mActivity.sendBroadcast(new Intent(Intent.ACTION_CLOSE_SYSTEM_DIALOGS));
245     }
246 
getCaptivePortal()247     private MockCaptivePortal getCaptivePortal() {
248         return (MockCaptivePortal) mActivity.mCaptivePortal;
249     }
250 
configNonVpnNetwork()251     private void configNonVpnNetwork() {
252         final Network[] networks = new Network[] {new Network(mNetwork)};
253         doReturn(networks).when(sConnectivityManager).getAllNetworks();
254         final NetworkCapabilities nonVpnCapabilities = new NetworkCapabilities()
255                 .addTransportType(NetworkCapabilities.TRANSPORT_WIFI);
256         doReturn(nonVpnCapabilities).when(sConnectivityManager).getNetworkCapabilities(
257                 mNetwork);
258     }
259 
configVpnNetwork()260     private void configVpnNetwork() {
261         final Network network1 = new Network(TEST_NETID + 1);
262         final Network network2 = new Network(TEST_NETID + 2);
263         final Network[] networks = new Network[] {network1, network2};
264         doReturn(networks).when(sConnectivityManager).getAllNetworks();
265         final NetworkCapabilities underlyingCapabilities = new NetworkCapabilities()
266                 .addTransportType(NetworkCapabilities.TRANSPORT_WIFI);
267         final NetworkCapabilities vpnCapabilities = new NetworkCapabilities(underlyingCapabilities)
268                 .addTransportType(NetworkCapabilities.TRANSPORT_VPN);
269         doReturn(underlyingCapabilities).when(sConnectivityManager).getNetworkCapabilities(
270                 network1);
271         doReturn(vpnCapabilities).when(sConnectivityManager).getNetworkCapabilities(network2);
272     }
273 
274     @Test
testHasVpnNetwork()275     public void testHasVpnNetwork() throws Exception {
276         initActivity(TEST_URL);
277         // Test non-vpn case.
278         configNonVpnNetwork();
279         assertFalse(mActivity.hasVpnNetwork());
280         // Test vpn case.
281         configVpnNetwork();
282         assertTrue(mActivity.hasVpnNetwork());
283     }
284 
285     @Test
testIsAlwaysOnVpnEnabled()286     public void testIsAlwaysOnVpnEnabled() throws Exception {
287         initActivity(TEST_URL);
288         doReturn(false).when(sMockDevicePolicyManager).isAlwaysOnVpnLockdownEnabled(any());
289         assertFalse(mActivity.isAlwaysOnVpnEnabled());
290         doReturn(true).when(sMockDevicePolicyManager).isAlwaysOnVpnLockdownEnabled(any());
291         assertTrue(mActivity.isAlwaysOnVpnEnabled());
292     }
293 
294     @Test
testVpnMsgOrLinkToBrowser()295     public void testVpnMsgOrLinkToBrowser() throws Exception {
296         initActivity(TEST_URL);
297         // Test non-vpn case.
298         configNonVpnNetwork();
299         doReturn(false).when(sMockDevicePolicyManager).isAlwaysOnVpnLockdownEnabled(any());
300         final String linkMatcher = ".*<a[^>]+href.*";
301         assertTrue(mActivity.getWebViewClient().getVpnMsgOrLinkToBrowser().matches(linkMatcher));
302 
303         // Test has vpn case.
304         configVpnNetwork();
305         final String vpnMatcher = ".*<div.*vpnwarning.*";
306         assertTrue(mActivity.getWebViewClient().getVpnMsgOrLinkToBrowser().matches(vpnMatcher));
307 
308         // Test always-on vpn case.
309         configNonVpnNetwork();
310         doReturn(true).when(sMockDevicePolicyManager).isAlwaysOnVpnLockdownEnabled(any());
311         assertTrue(mActivity.getWebViewClient().getVpnMsgOrLinkToBrowser().matches(vpnMatcher));
312     }
313 
notifyCapabilitiesChanged(final NetworkCapabilities nc)314     private void notifyCapabilitiesChanged(final NetworkCapabilities nc) {
315         mActivity.handleCapabilitiesChanged(mNetwork, nc);
316         InstrumentationRegistry.getInstrumentation().waitForIdleSync();
317     }
318 
verifyDismissed()319     private void verifyDismissed() {
320         final MockCaptivePortal cp = getCaptivePortal();
321         assertEquals(cp.mDismissTimes, 1);
322         assertEquals(cp.mIgnoreTimes, 0);
323         assertEquals(cp.mUseTimes, 0);
324     }
325 
notifyValidatedChangedAndDismissed(final NetworkCapabilities nc)326     private void notifyValidatedChangedAndDismissed(final NetworkCapabilities nc) {
327         notifyCapabilitiesChanged(nc);
328         verifyDismissed();
329     }
330 
verifyNotDone()331     private void verifyNotDone() {
332         final MockCaptivePortal cp = getCaptivePortal();
333         assertEquals(cp.mDismissTimes, 0);
334         assertEquals(cp.mIgnoreTimes, 0);
335         assertEquals(cp.mUseTimes, 0);
336     }
337 
notifyValidatedChangedNotDone(final NetworkCapabilities nc)338     private void notifyValidatedChangedNotDone(final NetworkCapabilities nc) {
339         notifyCapabilitiesChanged(nc);
340         verifyNotDone();
341     }
342 
verifyUseAsIs()343     private void verifyUseAsIs() {
344         final MockCaptivePortal cp = getCaptivePortal();
345         assertEquals(cp.mDismissTimes, 0);
346         assertEquals(cp.mIgnoreTimes, 0);
347         assertEquals(cp.mUseTimes, 1);
348     }
349 
setDismissPortalInValidatedNetwork(final boolean enable)350     private void setDismissPortalInValidatedNetwork(final boolean enable) {
351         // Feature is enabled if the package version greater than configuration. Instead of reading
352         // the package version, use Long.MAX_VALUE to replace disable configuration and 1 for
353         // enabling.
354         doReturn(enable ? 1 : Long.MAX_VALUE).when(() -> DeviceConfig.getLong(
355                 NAMESPACE_CONNECTIVITY,
356                 CaptivePortalLoginActivity.DISMISS_PORTAL_IN_VALIDATED_NETWORK, 0 /* default */));
357     }
358 
359     @Test
testNetworkCapabilitiesUpdate()360     public void testNetworkCapabilitiesUpdate() throws Exception {
361         initActivity(TEST_URL);
362         // NetworkCapabilities updates w/o NET_CAPABILITY_VALIDATED.
363         final NetworkCapabilities nc = new NetworkCapabilities();
364         notifyValidatedChangedNotDone(nc);
365 
366         // NetworkCapabilities updates w/ NET_CAPABILITY_VALIDATED.
367         nc.setCapability(NET_CAPABILITY_VALIDATED, true);
368         notifyValidatedChangedAndDismissed(nc);
369     }
370 
371     @Test
testNetworkCapabilitiesUpdateWithFlag()372     public void testNetworkCapabilitiesUpdateWithFlag() throws Exception {
373         initActivity(TEST_URL);
374         final NetworkCapabilities nc = new NetworkCapabilities();
375         nc.setCapability(NET_CAPABILITY_VALIDATED, true);
376         // Disable flag. Auto-dismiss should not happen.
377         setDismissPortalInValidatedNetwork(false);
378         notifyValidatedChangedNotDone(nc);
379 
380         // Enable flag. Auto-dismissed.
381         setDismissPortalInValidatedNetwork(true);
382         notifyValidatedChangedAndDismissed(nc);
383     }
384 
runCustomSchemeTest(String linkUri)385     private HttpServer runCustomSchemeTest(String linkUri) throws Exception {
386         final HttpServer server = new HttpServer();
387         server.setResponseBody(TEST_URL_QUERY,
388                 "<a id='tst_link' href='" + linkUri + "'>Test link</a>");
389 
390         server.start();
391         ActivityScenario.launch(RequestDismissKeyguardActivity.class);
392         initActivity(server.makeUrl(TEST_URL_QUERY));
393         // Mock all external intents
394         intending(not(isInternal())).respondWith(new ActivityResult(RESULT_OK, null));
395 
396         onWebView().withElement(findElement(Locator.ID, "tst_link")).perform(webClick());
397         getInstrumentation().waitForIdleSync();
398         return server;
399     }
400 
401     @Test
testTelScheme()402     public void testTelScheme() throws Exception {
403         final String telUri = "tel:0123456789";
404         final HttpServer server = runCustomSchemeTest(telUri);
405 
406         final Intent sentIntent = Intents.getIntents().get(0);
407         assertEquals(Intent.ACTION_DIAL, sentIntent.getAction());
408         assertEquals(Uri.parse(telUri), sentIntent.getData());
409 
410         server.stop();
411     }
412 
413     @Test
testSmsScheme()414     public void testSmsScheme() throws Exception {
415         final String telUri = "sms:0123456789";
416         final HttpServer server = runCustomSchemeTest(telUri);
417 
418         final Intent sentIntent = Intents.getIntents().get(0);
419         assertEquals(Intent.ACTION_SENDTO, sentIntent.getAction());
420         assertEquals(Uri.parse(telUri), sentIntent.getData());
421 
422         server.stop();
423     }
424 
425     @Test
testUnsupportedScheme()426     public void testUnsupportedScheme() throws Exception {
427         final HttpServer server = runCustomSchemeTest("mailto:test@example.com");
428         assertEquals(0, Intents.getIntents().size());
429 
430         onWebView().withElement(findElement(Locator.ID, "continue_link"))
431                 .perform(webClick());
432 
433         // The intent is sent in onDestroy(); there is no way to wait for that event, so poll
434         // until the intent is found.
435         assertTrue(isEventually(() -> Intents.getIntents().size() == 1, TEST_TIMEOUT_MS));
436         verifyUseAsIs();
437         final Intent sentIntent = Intents.getIntents().get(0);
438         assertEquals(Intent.ACTION_VIEW, sentIntent.getAction());
439         assertEquals(Uri.parse(server.makeUrl(TEST_URL_QUERY)), sentIntent.getData());
440 
441         server.stop();
442     }
443 
444     @Test
testDownload()445     public void testDownload() throws Exception {
446         // Setup the server with a single link on the portal page, leading to a download
447         final HttpServer server = new HttpServer();
448         final String linkIdDownload = "download";
449         final String downloadQuery = "dl";
450         final String filename = "testfile.png";
451         final String mimetype = "image/png";
452         server.setResponseBody(TEST_URL_QUERY,
453                 "<a id='" + linkIdDownload + "' href='?" + downloadQuery + "'>Download</a>");
454         server.setResponse(downloadQuery, "This is a test file", mimetype, Collections.singletonMap(
455                 "Content-Disposition", "attachment; filename=\"" + filename + "\""));
456         server.start();
457 
458         ActivityScenario.launch(RequestDismissKeyguardActivity.class);
459         initActivity(server.makeUrl(TEST_URL_QUERY));
460 
461         // Create a mock file to be returned when mocking the file chooser
462         final Context ctx = mActivity.getApplicationContext();
463         final Intent mockFileResponse = new Intent();
464         final Uri mockFile = Uri.parse("content://mockdata");
465         mockFileResponse.setData(mockFile);
466 
467         // Mock file chooser and DownloadService intents
468         intending(hasAction(ACTION_CREATE_DOCUMENT)).respondWith(
469                 new ActivityResult(RESULT_OK, mockFileResponse));
470         // mockito-intents does not support mocking service starts (only startActivity), and the
471         // activity is created by the framework from the activity start intent. Use extended mockito
472         // to inject a mock on startForegroundService.
473         spyOn(mActivity);
474         final ComponentName downloadComponent = new ComponentName(ctx, DownloadService.class);
475         doReturn(downloadComponent).when(mActivity).startForegroundService(argThat(intent ->
476                 downloadComponent.equals(intent.getComponent())));
477         // No intent fired yet
478         assertEquals(0, Intents.getIntents().size());
479 
480         onWebView().withElement(findElement(Locator.ID, linkIdDownload))
481                 .perform(webClick());
482 
483         // The create file intent should be fired when the download starts
484         assertTrue("Create file intent not received within timeout",
485                 isEventually(() -> Intents.getIntents().size() == 1, TEST_TIMEOUT_MS));
486 
487         final Intent fileIntent = Intents.getIntents().get(0);
488         assertEquals(ACTION_CREATE_DOCUMENT, fileIntent.getAction());
489         assertEquals(mimetype, fileIntent.getType());
490         assertEquals(filename, fileIntent.getStringExtra(Intent.EXTRA_TITLE));
491 
492         // The download intent should be fired after the create file result is received
493         final ArgumentCaptor<Intent> intentCaptor = ArgumentCaptor.forClass(Intent.class);
494         verify(mActivity).startForegroundService(intentCaptor.capture());
495         final Intent dlIntent = intentCaptor.getValue();
496 
497         assertEquals(downloadComponent, dlIntent.getComponent());
498         assertEquals(mNetwork, dlIntent.getParcelableExtra(DownloadService.ARG_NETWORK));
499         assertEquals(TEST_USERAGENT, dlIntent.getStringExtra(DownloadService.ARG_USERAGENT));
500         final String expectedUrl = server.makeUrl(downloadQuery);
501         assertEquals(expectedUrl, dlIntent.getStringExtra(DownloadService.ARG_URL));
502         assertEquals(filename, dlIntent.getStringExtra(DownloadService.ARG_DISPLAY_NAME));
503         assertEquals(mockFile, dlIntent.getParcelableExtra(DownloadService.ARG_OUTFILE));
504 
505         server.stop();
506     }
507 
isEventually(BooleanSupplier condition, long timeout)508     private static boolean isEventually(BooleanSupplier condition, long timeout)
509             throws InterruptedException {
510         final long start = System.currentTimeMillis();
511         do {
512             if (condition.getAsBoolean()) return true;
513             Thread.sleep(10);
514         } while ((System.currentTimeMillis() - start) < timeout);
515 
516         return false;
517     }
518 
519     private static class HttpServer extends NanoHTTPD {
520         private final ServerSocket mSocket;
521         // Responses per URL query
522         private final HashMap<String, MockResponse> mResponses = new HashMap<>();
523 
524         private static final class MockResponse {
525             private final String mBody;
526             private final String mMimetype;
527             private final Map<String, String> mHeaders;
528 
MockResponse(String body, String mimetype, Map<String, String> headers)529             MockResponse(String body, String mimetype, Map<String, String> headers) {
530                 this.mBody = body;
531                 this.mMimetype = mimetype;
532                 this.mHeaders = Collections.unmodifiableMap(new HashMap<>(headers));
533             }
534         }
535 
HttpServer()536         HttpServer() throws IOException {
537             this(new ServerSocket());
538         }
539 
HttpServer(ServerSocket socket)540         private HttpServer(ServerSocket socket) {
541             // 0 as port for picking a port automatically
542             super("localhost", 0);
543             mSocket = socket;
544         }
545 
546         @Override
getServerSocketFactory()547         public ServerSocketFactory getServerSocketFactory() {
548             return () -> mSocket;
549         }
550 
makeUrl(String query)551         private String makeUrl(String query) {
552             return new Uri.Builder()
553                     .scheme("http")
554                     .encodedAuthority("localhost:" + mSocket.getLocalPort())
555                     // Explicitly specify an empty path to match the format of URLs returned by
556                     // WebView (for example in onDownloadStart)
557                     .path("/")
558                     .query(query)
559                     .build()
560                     .toString();
561         }
562 
setResponseBody(String query, String body)563         private void setResponseBody(String query, String body) {
564             setResponse(query, body, NanoHTTPD.MIME_HTML, Collections.emptyMap());
565         }
566 
setResponse(String query, String body, String mimetype, Map<String, String> headers)567         private void setResponse(String query, String body, String mimetype,
568                 Map<String, String> headers) {
569             mResponses.put(query, new MockResponse(body, mimetype, headers));
570         }
571 
572         @Override
serve(IHTTPSession session)573         public Response serve(IHTTPSession session) {
574             final MockResponse mockResponse = mResponses.get(session.getQueryParameterString());
575             if (mockResponse == null) {
576                 // Default response is a 404
577                 return super.serve(session);
578             }
579 
580             final Response response = newFixedLengthResponse(Response.Status.OK,
581                     mockResponse.mMimetype,
582                     "<!doctype html>"
583                     + "<html>"
584                     + "<head><title>Test portal</title></head>"
585                     + "<body>" + mockResponse.mBody + "</body>"
586                     + "</html>");
587             mockResponse.mHeaders.forEach(response::addHeader);
588             return response;
589         }
590     }
591 }
592