1 /*
2  * Copyright (C) 2015 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.server.connectivity;
18 
19 import static android.system.OsConstants.*;
20 
21 import android.annotation.NonNull;
22 import android.annotation.Nullable;
23 import android.net.LinkAddress;
24 import android.net.LinkProperties;
25 import android.net.Network;
26 import android.net.NetworkUtils;
27 import android.net.RouteInfo;
28 import android.net.TrafficStats;
29 import android.net.shared.PrivateDnsConfig;
30 import android.net.util.NetworkConstants;
31 import android.os.SystemClock;
32 import android.system.ErrnoException;
33 import android.system.Os;
34 import android.system.StructTimeval;
35 import android.text.TextUtils;
36 import android.util.Pair;
37 
38 import com.android.internal.util.IndentingPrintWriter;
39 import com.android.internal.util.TrafficStatsConstants;
40 
41 import libcore.io.IoUtils;
42 
43 import java.io.Closeable;
44 import java.io.DataInputStream;
45 import java.io.DataOutputStream;
46 import java.io.FileDescriptor;
47 import java.io.IOException;
48 import java.io.InterruptedIOException;
49 import java.net.Inet4Address;
50 import java.net.Inet6Address;
51 import java.net.InetAddress;
52 import java.net.InetSocketAddress;
53 import java.net.NetworkInterface;
54 import java.net.SocketAddress;
55 import java.net.SocketException;
56 import java.net.UnknownHostException;
57 import java.nio.ByteBuffer;
58 import java.nio.charset.StandardCharsets;
59 import java.util.ArrayList;
60 import java.util.Collections;
61 import java.util.HashMap;
62 import java.util.List;
63 import java.util.Map;
64 import java.util.Random;
65 import java.util.concurrent.CountDownLatch;
66 import java.util.concurrent.TimeUnit;
67 
68 import javax.net.ssl.SNIHostName;
69 import javax.net.ssl.SNIServerName;
70 import javax.net.ssl.SSLParameters;
71 import javax.net.ssl.SSLSocket;
72 import javax.net.ssl.SSLSocketFactory;
73 
74 /**
75  * NetworkDiagnostics
76  *
77  * A simple class to diagnose network connectivity fundamentals.  Current
78  * checks performed are:
79  *     - ICMPv4/v6 echo requests for all routers
80  *     - ICMPv4/v6 echo requests for all DNS servers
81  *     - DNS UDP queries to all DNS servers
82  *
83  * Currently unimplemented checks include:
84  *     - report ARP/ND data about on-link neighbors
85  *     - DNS TCP queries to all DNS servers
86  *     - HTTP DIRECT and PROXY checks
87  *     - port 443 blocking/TLS intercept checks
88  *     - QUIC reachability checks
89  *     - MTU checks
90  *
91  * The supplied timeout bounds the entire diagnostic process.  Each specific
92  * check class must implement this upper bound on measurements in whichever
93  * manner is most appropriate and effective.
94  *
95  * @hide
96  */
97 public class NetworkDiagnostics {
98     private static final String TAG = "NetworkDiagnostics";
99 
100     private static final InetAddress TEST_DNS4 = NetworkUtils.numericToInetAddress("8.8.8.8");
101     private static final InetAddress TEST_DNS6 = NetworkUtils.numericToInetAddress(
102             "2001:4860:4860::8888");
103 
104     // For brevity elsewhere.
now()105     private static final long now() {
106         return SystemClock.elapsedRealtime();
107     }
108 
109     // Values from RFC 1035 section 4.1.1, names from <arpa/nameser.h>.
110     // Should be a member of DnsUdpCheck, but "compiler says no".
111     public static enum DnsResponseCode { NOERROR, FORMERR, SERVFAIL, NXDOMAIN, NOTIMP, REFUSED };
112 
113     private final Network mNetwork;
114     private final LinkProperties mLinkProperties;
115     private final PrivateDnsConfig mPrivateDnsCfg;
116     private final Integer mInterfaceIndex;
117 
118     private final long mTimeoutMs;
119     private final long mStartTime;
120     private final long mDeadlineTime;
121 
122     // A counter, initialized to the total number of measurements,
123     // so callers can wait for completion.
124     private final CountDownLatch mCountDownLatch;
125 
126     public class Measurement {
127         private static final String SUCCEEDED = "SUCCEEDED";
128         private static final String FAILED = "FAILED";
129 
130         private boolean succeeded;
131 
132         // Package private.  TODO: investigate better encapsulation.
133         String description = "";
134         long startTime;
135         long finishTime;
136         String result = "";
137         Thread thread;
138 
checkSucceeded()139         public boolean checkSucceeded() { return succeeded; }
140 
recordSuccess(String msg)141         void recordSuccess(String msg) {
142             maybeFixupTimes();
143             succeeded = true;
144             result = SUCCEEDED + ": " + msg;
145             if (mCountDownLatch != null) {
146                 mCountDownLatch.countDown();
147             }
148         }
149 
recordFailure(String msg)150         void recordFailure(String msg) {
151             maybeFixupTimes();
152             succeeded = false;
153             result = FAILED + ": " + msg;
154             if (mCountDownLatch != null) {
155                 mCountDownLatch.countDown();
156             }
157         }
158 
maybeFixupTimes()159         private void maybeFixupTimes() {
160             // Allows the caller to just set success/failure and not worry
161             // about also setting the correct finishing time.
162             if (finishTime == 0) { finishTime = now(); }
163 
164             // In cases where, for example, a failure has occurred before the
165             // measurement even began, fixup the start time to reflect as much.
166             if (startTime == 0) { startTime = finishTime; }
167         }
168 
169         @Override
toString()170         public String toString() {
171             return description + ": " + result + " (" + (finishTime - startTime) + "ms)";
172         }
173     }
174 
175     private final Map<InetAddress, Measurement> mIcmpChecks = new HashMap<>();
176     private final Map<Pair<InetAddress, InetAddress>, Measurement> mExplicitSourceIcmpChecks =
177             new HashMap<>();
178     private final Map<InetAddress, Measurement> mDnsUdpChecks = new HashMap<>();
179     private final Map<InetAddress, Measurement> mDnsTlsChecks = new HashMap<>();
180     private final String mDescription;
181 
182 
NetworkDiagnostics(Network network, LinkProperties lp, @NonNull PrivateDnsConfig privateDnsCfg, long timeoutMs)183     public NetworkDiagnostics(Network network, LinkProperties lp,
184             @NonNull PrivateDnsConfig privateDnsCfg, long timeoutMs) {
185         mNetwork = network;
186         mLinkProperties = lp;
187         mPrivateDnsCfg = privateDnsCfg;
188         mInterfaceIndex = getInterfaceIndex(mLinkProperties.getInterfaceName());
189         mTimeoutMs = timeoutMs;
190         mStartTime = now();
191         mDeadlineTime = mStartTime + mTimeoutMs;
192 
193         // Hardcode measurements to TEST_DNS4 and TEST_DNS6 in order to test off-link connectivity.
194         // We are free to modify mLinkProperties with impunity because ConnectivityService passes us
195         // a copy and not the original object. It's easier to do it this way because we don't need
196         // to check whether the LinkProperties already contains these DNS servers because
197         // LinkProperties#addDnsServer checks for duplicates.
198         if (mLinkProperties.isReachable(TEST_DNS4)) {
199             mLinkProperties.addDnsServer(TEST_DNS4);
200         }
201         // TODO: we could use mLinkProperties.isReachable(TEST_DNS6) here, because we won't set any
202         // DNS servers for which isReachable() is false, but since this is diagnostic code, be extra
203         // careful.
204         if (mLinkProperties.hasGlobalIpv6Address() || mLinkProperties.hasIpv6DefaultRoute()) {
205             mLinkProperties.addDnsServer(TEST_DNS6);
206         }
207 
208         for (RouteInfo route : mLinkProperties.getRoutes()) {
209             if (route.hasGateway()) {
210                 InetAddress gateway = route.getGateway();
211                 prepareIcmpMeasurement(gateway);
212                 if (route.isIPv6Default()) {
213                     prepareExplicitSourceIcmpMeasurements(gateway);
214                 }
215             }
216         }
217         for (InetAddress nameserver : mLinkProperties.getDnsServers()) {
218             prepareIcmpMeasurement(nameserver);
219             prepareDnsMeasurement(nameserver);
220 
221             // Unlike the DnsResolver which doesn't do certificate validation in opportunistic mode,
222             // DoT probes to the DNS servers will fail if certificate validation fails.
223             prepareDnsTlsMeasurement(null /* hostname */, nameserver);
224         }
225 
226         for (InetAddress tlsNameserver : mPrivateDnsCfg.ips) {
227             // Reachability check is necessary since when resolving the strict mode hostname,
228             // NetworkMonitor always queries for both A and AAAA records, even if the network
229             // is IPv4-only or IPv6-only.
230             if (mLinkProperties.isReachable(tlsNameserver)) {
231                 // If there are IPs, there must have been a name that resolved to them.
232                 prepareDnsTlsMeasurement(mPrivateDnsCfg.hostname, tlsNameserver);
233             }
234         }
235 
236         mCountDownLatch = new CountDownLatch(totalMeasurementCount());
237 
238         startMeasurements();
239 
240         mDescription = "ifaces{" + TextUtils.join(",", mLinkProperties.getAllInterfaceNames()) + "}"
241                 + " index{" + mInterfaceIndex + "}"
242                 + " network{" + mNetwork + "}"
243                 + " nethandle{" + mNetwork.getNetworkHandle() + "}";
244     }
245 
getInterfaceIndex(String ifname)246     private static Integer getInterfaceIndex(String ifname) {
247         try {
248             NetworkInterface ni = NetworkInterface.getByName(ifname);
249             return ni.getIndex();
250         } catch (NullPointerException | SocketException e) {
251             return null;
252         }
253     }
254 
socketAddressToString(@onNull SocketAddress sockAddr)255     private static String socketAddressToString(@NonNull SocketAddress sockAddr) {
256         // The default toString() implementation is not the prettiest.
257         InetSocketAddress inetSockAddr = (InetSocketAddress) sockAddr;
258         InetAddress localAddr = inetSockAddr.getAddress();
259         return String.format(
260                 (localAddr instanceof Inet6Address ? "[%s]:%d" : "%s:%d"),
261                 localAddr.getHostAddress(), inetSockAddr.getPort());
262     }
263 
prepareIcmpMeasurement(InetAddress target)264     private void prepareIcmpMeasurement(InetAddress target) {
265         if (!mIcmpChecks.containsKey(target)) {
266             Measurement measurement = new Measurement();
267             measurement.thread = new Thread(new IcmpCheck(target, measurement));
268             mIcmpChecks.put(target, measurement);
269         }
270     }
271 
prepareExplicitSourceIcmpMeasurements(InetAddress target)272     private void prepareExplicitSourceIcmpMeasurements(InetAddress target) {
273         for (LinkAddress l : mLinkProperties.getLinkAddresses()) {
274             InetAddress source = l.getAddress();
275             if (source instanceof Inet6Address && l.isGlobalPreferred()) {
276                 Pair<InetAddress, InetAddress> srcTarget = new Pair<>(source, target);
277                 if (!mExplicitSourceIcmpChecks.containsKey(srcTarget)) {
278                     Measurement measurement = new Measurement();
279                     measurement.thread = new Thread(new IcmpCheck(source, target, measurement));
280                     mExplicitSourceIcmpChecks.put(srcTarget, measurement);
281                 }
282             }
283         }
284     }
285 
prepareDnsMeasurement(InetAddress target)286     private void prepareDnsMeasurement(InetAddress target) {
287         if (!mDnsUdpChecks.containsKey(target)) {
288             Measurement measurement = new Measurement();
289             measurement.thread = new Thread(new DnsUdpCheck(target, measurement));
290             mDnsUdpChecks.put(target, measurement);
291         }
292     }
293 
prepareDnsTlsMeasurement(@ullable String hostname, @NonNull InetAddress target)294     private void prepareDnsTlsMeasurement(@Nullable String hostname, @NonNull InetAddress target) {
295         // This might overwrite an existing entry in mDnsTlsChecks, because |target| can be an IP
296         // address configured by the network as well as an IP address learned by resolving the
297         // strict mode DNS hostname. If the entry is overwritten, the overwritten measurement
298         // thread will not execute.
299         Measurement measurement = new Measurement();
300         measurement.thread = new Thread(new DnsTlsCheck(hostname, target, measurement));
301         mDnsTlsChecks.put(target, measurement);
302     }
303 
totalMeasurementCount()304     private int totalMeasurementCount() {
305         return mIcmpChecks.size() + mExplicitSourceIcmpChecks.size() + mDnsUdpChecks.size()
306                 + mDnsTlsChecks.size();
307     }
308 
startMeasurements()309     private void startMeasurements() {
310         for (Measurement measurement : mIcmpChecks.values()) {
311             measurement.thread.start();
312         }
313         for (Measurement measurement : mExplicitSourceIcmpChecks.values()) {
314             measurement.thread.start();
315         }
316         for (Measurement measurement : mDnsUdpChecks.values()) {
317             measurement.thread.start();
318         }
319         for (Measurement measurement : mDnsTlsChecks.values()) {
320             measurement.thread.start();
321         }
322     }
323 
waitForMeasurements()324     public void waitForMeasurements() {
325         try {
326             mCountDownLatch.await(mDeadlineTime - now(), TimeUnit.MILLISECONDS);
327         } catch (InterruptedException ignored) {}
328     }
329 
getMeasurements()330     public List<Measurement> getMeasurements() {
331         // TODO: Consider moving waitForMeasurements() in here to minimize the
332         // chance of caller errors.
333 
334         ArrayList<Measurement> measurements = new ArrayList(totalMeasurementCount());
335 
336         // Sort measurements IPv4 first.
337         for (Map.Entry<InetAddress, Measurement> entry : mIcmpChecks.entrySet()) {
338             if (entry.getKey() instanceof Inet4Address) {
339                 measurements.add(entry.getValue());
340             }
341         }
342         for (Map.Entry<Pair<InetAddress, InetAddress>, Measurement> entry :
343                 mExplicitSourceIcmpChecks.entrySet()) {
344             if (entry.getKey().first instanceof Inet4Address) {
345                 measurements.add(entry.getValue());
346             }
347         }
348         for (Map.Entry<InetAddress, Measurement> entry : mDnsUdpChecks.entrySet()) {
349             if (entry.getKey() instanceof Inet4Address) {
350                 measurements.add(entry.getValue());
351             }
352         }
353         for (Map.Entry<InetAddress, Measurement> entry : mDnsTlsChecks.entrySet()) {
354             if (entry.getKey() instanceof Inet4Address) {
355                 measurements.add(entry.getValue());
356             }
357         }
358 
359         // IPv6 measurements second.
360         for (Map.Entry<InetAddress, Measurement> entry : mIcmpChecks.entrySet()) {
361             if (entry.getKey() instanceof Inet6Address) {
362                 measurements.add(entry.getValue());
363             }
364         }
365         for (Map.Entry<Pair<InetAddress, InetAddress>, Measurement> entry :
366                 mExplicitSourceIcmpChecks.entrySet()) {
367             if (entry.getKey().first instanceof Inet6Address) {
368                 measurements.add(entry.getValue());
369             }
370         }
371         for (Map.Entry<InetAddress, Measurement> entry : mDnsUdpChecks.entrySet()) {
372             if (entry.getKey() instanceof Inet6Address) {
373                 measurements.add(entry.getValue());
374             }
375         }
376         for (Map.Entry<InetAddress, Measurement> entry : mDnsTlsChecks.entrySet()) {
377             if (entry.getKey() instanceof Inet6Address) {
378                 measurements.add(entry.getValue());
379             }
380         }
381 
382         return measurements;
383     }
384 
dump(IndentingPrintWriter pw)385     public void dump(IndentingPrintWriter pw) {
386         pw.println(TAG + ":" + mDescription);
387         final long unfinished = mCountDownLatch.getCount();
388         if (unfinished > 0) {
389             // This can't happen unless a caller forgets to call waitForMeasurements()
390             // or a measurement isn't implemented to correctly honor the timeout.
391             pw.println("WARNING: countdown wait incomplete: "
392                     + unfinished + " unfinished measurements");
393         }
394 
395         pw.increaseIndent();
396 
397         String prefix;
398         for (Measurement m : getMeasurements()) {
399             prefix = m.checkSucceeded() ? "." : "F";
400             pw.println(prefix + "  " + m.toString());
401         }
402 
403         pw.decreaseIndent();
404     }
405 
406 
407     private class SimpleSocketCheck implements Closeable {
408         protected final InetAddress mSource;  // Usually null.
409         protected final InetAddress mTarget;
410         protected final int mAddressFamily;
411         protected final Measurement mMeasurement;
412         protected FileDescriptor mFileDescriptor;
413         protected SocketAddress mSocketAddress;
414 
SimpleSocketCheck( InetAddress source, InetAddress target, Measurement measurement)415         protected SimpleSocketCheck(
416                 InetAddress source, InetAddress target, Measurement measurement) {
417             mMeasurement = measurement;
418 
419             if (target instanceof Inet6Address) {
420                 Inet6Address targetWithScopeId = null;
421                 if (target.isLinkLocalAddress() && mInterfaceIndex != null) {
422                     try {
423                         targetWithScopeId = Inet6Address.getByAddress(
424                                 null, target.getAddress(), mInterfaceIndex);
425                     } catch (UnknownHostException e) {
426                         mMeasurement.recordFailure(e.toString());
427                     }
428                 }
429                 mTarget = (targetWithScopeId != null) ? targetWithScopeId : target;
430                 mAddressFamily = AF_INET6;
431             } else {
432                 mTarget = target;
433                 mAddressFamily = AF_INET;
434             }
435 
436             // We don't need to check the scope ID here because we currently only do explicit-source
437             // measurements from global IPv6 addresses.
438             mSource = source;
439         }
440 
SimpleSocketCheck(InetAddress target, Measurement measurement)441         protected SimpleSocketCheck(InetAddress target, Measurement measurement) {
442             this(null, target, measurement);
443         }
444 
setupSocket( int sockType, int protocol, long writeTimeout, long readTimeout, int dstPort)445         protected void setupSocket(
446                 int sockType, int protocol, long writeTimeout, long readTimeout, int dstPort)
447                 throws ErrnoException, IOException {
448             final int oldTag = TrafficStats.getAndSetThreadStatsTag(
449                     TrafficStatsConstants.TAG_SYSTEM_PROBE);
450             try {
451                 mFileDescriptor = Os.socket(mAddressFamily, sockType, protocol);
452             } finally {
453                 // TODO: The tag should remain set until all traffic is sent and received.
454                 // Consider tagging the socket after the measurement thread is started.
455                 TrafficStats.setThreadStatsTag(oldTag);
456             }
457             // Setting SNDTIMEO is purely for defensive purposes.
458             Os.setsockoptTimeval(mFileDescriptor,
459                     SOL_SOCKET, SO_SNDTIMEO, StructTimeval.fromMillis(writeTimeout));
460             Os.setsockoptTimeval(mFileDescriptor,
461                     SOL_SOCKET, SO_RCVTIMEO, StructTimeval.fromMillis(readTimeout));
462             // TODO: Use IP_RECVERR/IPV6_RECVERR, pending OsContants availability.
463             mNetwork.bindSocket(mFileDescriptor);
464             if (mSource != null) {
465                 Os.bind(mFileDescriptor, mSource, 0);
466             }
467             Os.connect(mFileDescriptor, mTarget, dstPort);
468             mSocketAddress = Os.getsockname(mFileDescriptor);
469         }
470 
ensureMeasurementNecessary()471         protected boolean ensureMeasurementNecessary() {
472             if (mMeasurement.finishTime == 0) return false;
473 
474             // Countdown latch was not decremented when the measurement failed during setup.
475             mCountDownLatch.countDown();
476             return true;
477         }
478 
479         @Override
close()480         public void close() {
481             IoUtils.closeQuietly(mFileDescriptor);
482         }
483     }
484 
485 
486     private class IcmpCheck extends SimpleSocketCheck implements Runnable {
487         private static final int TIMEOUT_SEND = 100;
488         private static final int TIMEOUT_RECV = 300;
489         private static final int PACKET_BUFSIZE = 512;
490         private final int mProtocol;
491         private final int mIcmpType;
492 
IcmpCheck(InetAddress source, InetAddress target, Measurement measurement)493         public IcmpCheck(InetAddress source, InetAddress target, Measurement measurement) {
494             super(source, target, measurement);
495 
496             if (mAddressFamily == AF_INET6) {
497                 mProtocol = IPPROTO_ICMPV6;
498                 mIcmpType = NetworkConstants.ICMPV6_ECHO_REQUEST_TYPE;
499                 mMeasurement.description = "ICMPv6";
500             } else {
501                 mProtocol = IPPROTO_ICMP;
502                 mIcmpType = NetworkConstants.ICMPV4_ECHO_REQUEST_TYPE;
503                 mMeasurement.description = "ICMPv4";
504             }
505 
506             mMeasurement.description += " dst{" + mTarget.getHostAddress() + "}";
507         }
508 
IcmpCheck(InetAddress target, Measurement measurement)509         public IcmpCheck(InetAddress target, Measurement measurement) {
510             this(null, target, measurement);
511         }
512 
513         @Override
run()514         public void run() {
515             if (ensureMeasurementNecessary()) return;
516 
517             try {
518                 setupSocket(SOCK_DGRAM, mProtocol, TIMEOUT_SEND, TIMEOUT_RECV, 0);
519             } catch (ErrnoException | IOException e) {
520                 mMeasurement.recordFailure(e.toString());
521                 return;
522             }
523             mMeasurement.description += " src{" + socketAddressToString(mSocketAddress) + "}";
524 
525             // Build a trivial ICMP packet.
526             final byte[] icmpPacket = {
527                     (byte) mIcmpType, 0, 0, 0, 0, 0, 0, 0  // ICMP header
528             };
529 
530             int count = 0;
531             mMeasurement.startTime = now();
532             while (now() < mDeadlineTime - (TIMEOUT_SEND + TIMEOUT_RECV)) {
533                 count++;
534                 icmpPacket[icmpPacket.length - 1] = (byte) count;
535                 try {
536                     Os.write(mFileDescriptor, icmpPacket, 0, icmpPacket.length);
537                 } catch (ErrnoException | InterruptedIOException e) {
538                     mMeasurement.recordFailure(e.toString());
539                     break;
540                 }
541 
542                 try {
543                     ByteBuffer reply = ByteBuffer.allocate(PACKET_BUFSIZE);
544                     Os.read(mFileDescriptor, reply);
545                     // TODO: send a few pings back to back to guesstimate packet loss.
546                     mMeasurement.recordSuccess("1/" + count);
547                     break;
548                 } catch (ErrnoException | InterruptedIOException e) {
549                     continue;
550                 }
551             }
552             if (mMeasurement.finishTime == 0) {
553                 mMeasurement.recordFailure("0/" + count);
554             }
555 
556             close();
557         }
558     }
559 
560 
561     private class DnsUdpCheck extends SimpleSocketCheck implements Runnable {
562         private static final int TIMEOUT_SEND = 100;
563         private static final int TIMEOUT_RECV = 500;
564         private static final int RR_TYPE_A = 1;
565         private static final int RR_TYPE_AAAA = 28;
566         private static final int PACKET_BUFSIZE = 512;
567 
568         protected final Random mRandom = new Random();
569 
570         // Should be static, but the compiler mocks our puny, human attempts at reason.
responseCodeStr(int rcode)571         protected String responseCodeStr(int rcode) {
572             try {
573                 return DnsResponseCode.values()[rcode].toString();
574             } catch (IndexOutOfBoundsException e) {
575                 return String.valueOf(rcode);
576             }
577         }
578 
579         protected final int mQueryType;
580 
DnsUdpCheck(InetAddress target, Measurement measurement)581         public DnsUdpCheck(InetAddress target, Measurement measurement) {
582             super(target, measurement);
583 
584             // TODO: Ideally, query the target for both types regardless of address family.
585             if (mAddressFamily == AF_INET6) {
586                 mQueryType = RR_TYPE_AAAA;
587             } else {
588                 mQueryType = RR_TYPE_A;
589             }
590 
591             mMeasurement.description = "DNS UDP dst{" + mTarget.getHostAddress() + "}";
592         }
593 
594         @Override
run()595         public void run() {
596             if (ensureMeasurementNecessary()) return;
597 
598             try {
599                 setupSocket(SOCK_DGRAM, IPPROTO_UDP, TIMEOUT_SEND, TIMEOUT_RECV,
600                         NetworkConstants.DNS_SERVER_PORT);
601             } catch (ErrnoException | IOException e) {
602                 mMeasurement.recordFailure(e.toString());
603                 return;
604             }
605 
606             // This needs to be fixed length so it can be dropped into the pre-canned packet.
607             final String sixRandomDigits = String.valueOf(mRandom.nextInt(900000) + 100000);
608             appendDnsToMeasurementDescription(sixRandomDigits, mSocketAddress);
609 
610             // Build a trivial DNS packet.
611             final byte[] dnsPacket = getDnsQueryPacket(sixRandomDigits);
612 
613             int count = 0;
614             mMeasurement.startTime = now();
615             while (now() < mDeadlineTime - (TIMEOUT_RECV + TIMEOUT_RECV)) {
616                 count++;
617                 try {
618                     Os.write(mFileDescriptor, dnsPacket, 0, dnsPacket.length);
619                 } catch (ErrnoException | InterruptedIOException e) {
620                     mMeasurement.recordFailure(e.toString());
621                     break;
622                 }
623 
624                 try {
625                     ByteBuffer reply = ByteBuffer.allocate(PACKET_BUFSIZE);
626                     Os.read(mFileDescriptor, reply);
627                     // TODO: more correct and detailed evaluation of the response,
628                     // possibly adding the returned IP address(es) to the output.
629                     final String rcodeStr = (reply.limit() > 3)
630                             ? " " + responseCodeStr((int) (reply.get(3)) & 0x0f)
631                             : "";
632                     mMeasurement.recordSuccess("1/" + count + rcodeStr);
633                     break;
634                 } catch (ErrnoException | InterruptedIOException e) {
635                     continue;
636                 }
637             }
638             if (mMeasurement.finishTime == 0) {
639                 mMeasurement.recordFailure("0/" + count);
640             }
641 
642             close();
643         }
644 
getDnsQueryPacket(String sixRandomDigits)645         protected byte[] getDnsQueryPacket(String sixRandomDigits) {
646             byte[] rnd = sixRandomDigits.getBytes(StandardCharsets.US_ASCII);
647             return new byte[] {
648                 (byte) mRandom.nextInt(), (byte) mRandom.nextInt(),  // [0-1]   query ID
649                 1, 0,  // [2-3]   flags; byte[2] = 1 for recursion desired (RD).
650                 0, 1,  // [4-5]   QDCOUNT (number of queries)
651                 0, 0,  // [6-7]   ANCOUNT (number of answers)
652                 0, 0,  // [8-9]   NSCOUNT (number of name server records)
653                 0, 0,  // [10-11] ARCOUNT (number of additional records)
654                 17, rnd[0], rnd[1], rnd[2], rnd[3], rnd[4], rnd[5],
655                         '-', 'a', 'n', 'd', 'r', 'o', 'i', 'd', '-', 'd', 's',
656                 6, 'm', 'e', 't', 'r', 'i', 'c',
657                 7, 'g', 's', 't', 'a', 't', 'i', 'c',
658                 3, 'c', 'o', 'm',
659                 0,  // null terminator of FQDN (root TLD)
660                 0, (byte) mQueryType,  // QTYPE
661                 0, 1  // QCLASS, set to 1 = IN (Internet)
662             };
663         }
664 
appendDnsToMeasurementDescription( String sixRandomDigits, SocketAddress sockAddr)665         protected void appendDnsToMeasurementDescription(
666                 String sixRandomDigits, SocketAddress sockAddr) {
667             mMeasurement.description += " src{" + socketAddressToString(sockAddr) + "}"
668                     + " qtype{" + mQueryType + "}"
669                     + " qname{" + sixRandomDigits + "-android-ds.metric.gstatic.com}";
670         }
671     }
672 
673     // TODO: Have it inherited from SimpleSocketCheck, and separate common DNS helpers out of
674     // DnsUdpCheck.
675     private class DnsTlsCheck extends DnsUdpCheck {
676         private static final int TCP_CONNECT_TIMEOUT_MS = 2500;
677         private static final int TCP_TIMEOUT_MS = 2000;
678         private static final int DNS_TLS_PORT = 853;
679         private static final int DNS_HEADER_SIZE = 12;
680 
681         private final String mHostname;
682 
DnsTlsCheck(@ullable String hostname, @NonNull InetAddress target, @NonNull Measurement measurement)683         public DnsTlsCheck(@Nullable String hostname, @NonNull InetAddress target,
684                 @NonNull Measurement measurement) {
685             super(target, measurement);
686 
687             mHostname = hostname;
688             mMeasurement.description = "DNS TLS dst{" + mTarget.getHostAddress() + "} hostname{"
689                     + TextUtils.emptyIfNull(mHostname) + "}";
690         }
691 
setupSSLSocket()692         private SSLSocket setupSSLSocket() throws IOException {
693             // A TrustManager will be created and initialized with a KeyStore containing system
694             // CaCerts. During SSL handshake, it will be used to validate the certificates from
695             // the server.
696             SSLSocket sslSocket = (SSLSocket) SSLSocketFactory.getDefault().createSocket();
697             sslSocket.setSoTimeout(TCP_TIMEOUT_MS);
698 
699             if (!TextUtils.isEmpty(mHostname)) {
700                 // Set SNI.
701                 final List<SNIServerName> names =
702                         Collections.singletonList(new SNIHostName(mHostname));
703                 SSLParameters params = sslSocket.getSSLParameters();
704                 params.setServerNames(names);
705                 sslSocket.setSSLParameters(params);
706             }
707 
708             mNetwork.bindSocket(sslSocket);
709             return sslSocket;
710         }
711 
sendDoTProbe(@ullable SSLSocket sslSocket)712         private void sendDoTProbe(@Nullable SSLSocket sslSocket) throws IOException {
713             final String sixRandomDigits = String.valueOf(mRandom.nextInt(900000) + 100000);
714             final byte[] dnsPacket = getDnsQueryPacket(sixRandomDigits);
715 
716             mMeasurement.startTime = now();
717             sslSocket.connect(new InetSocketAddress(mTarget, DNS_TLS_PORT), TCP_CONNECT_TIMEOUT_MS);
718 
719             // Synchronous call waiting for the TLS handshake complete.
720             sslSocket.startHandshake();
721             appendDnsToMeasurementDescription(sixRandomDigits, sslSocket.getLocalSocketAddress());
722 
723             final DataOutputStream output = new DataOutputStream(sslSocket.getOutputStream());
724             output.writeShort(dnsPacket.length);
725             output.write(dnsPacket, 0, dnsPacket.length);
726 
727             final DataInputStream input = new DataInputStream(sslSocket.getInputStream());
728             final int replyLength = Short.toUnsignedInt(input.readShort());
729             final byte[] reply = new byte[replyLength];
730             int bytesRead = 0;
731             while (bytesRead < replyLength) {
732                 bytesRead += input.read(reply, bytesRead, replyLength - bytesRead);
733             }
734 
735             if (bytesRead > DNS_HEADER_SIZE && bytesRead == replyLength) {
736                 mMeasurement.recordSuccess("1/1 " + responseCodeStr((int) (reply[3]) & 0x0f));
737             } else {
738                 mMeasurement.recordFailure("1/1 Read " + bytesRead + " bytes while expected to be "
739                         + replyLength + " bytes");
740             }
741         }
742 
743         @Override
run()744         public void run() {
745             if (ensureMeasurementNecessary()) return;
746 
747             // No need to restore the tag, since this thread is only used for this measurement.
748             TrafficStats.getAndSetThreadStatsTag(TrafficStatsConstants.TAG_SYSTEM_PROBE);
749 
750             try (SSLSocket sslSocket = setupSSLSocket()) {
751                 sendDoTProbe(sslSocket);
752             } catch (IOException e) {
753                 mMeasurement.recordFailure(e.toString());
754             }
755         }
756     }
757 }
758