1 /*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define LOG_TAG "resolv"
18
19 #include "DnsTlsTransport.h"
20
21 #include <android-base/logging.h>
22 #include <android-base/stringprintf.h>
23 #include <arpa/inet.h>
24 #include <arpa/nameser.h>
25 #include <netdutils/ThreadUtil.h>
26
27 #include "DnsTlsSocketFactory.h"
28 #include "IDnsTlsSocketFactory.h"
29
30 using android::base::StringPrintf;
31 using android::netdutils::setThreadName;
32
33 namespace android {
34 namespace net {
35
query(const netdutils::Slice query)36 std::future<DnsTlsTransport::Result> DnsTlsTransport::query(const netdutils::Slice query) {
37 std::lock_guard guard(mLock);
38
39 auto record = mQueries.recordQuery(query);
40 if (!record) {
41 return std::async(std::launch::deferred, []{
42 return (Result) { .code = Response::internal_error };
43 });
44 }
45
46 if (!mSocket) {
47 LOG(DEBUG) << "No socket for query. Opening socket and sending.";
48 doConnect();
49 } else {
50 sendQuery(record->query);
51 }
52
53 return std::move(record->result);
54 }
55
getConnectCounter() const56 int DnsTlsTransport::getConnectCounter() const {
57 std::lock_guard guard(mLock);
58 return mConnectCounter;
59 }
60
sendQuery(const DnsTlsQueryMap::Query & q)61 bool DnsTlsTransport::sendQuery(const DnsTlsQueryMap::Query& q) {
62 // Strip off the ID number and send the new ID instead.
63 const bool sent = mSocket->query(q.newId, netdutils::drop(netdutils::makeSlice(q.query), 2));
64 if (sent) {
65 mQueries.markTried(q.newId);
66 }
67 return sent;
68 }
69
doConnect()70 void DnsTlsTransport::doConnect() {
71 LOG(DEBUG) << "Constructing new socket";
72 mSocket = mFactory->createDnsTlsSocket(mServer, mMark, this, &mCache);
73 mConnectCounter++;
74
75 if (mSocket) {
76 auto queries = mQueries.getAll();
77 LOG(DEBUG) << "Initialization succeeded. Reissuing " << queries.size() << " queries.";
78 for(auto& q : queries) {
79 if (!sendQuery(q)) {
80 break;
81 }
82 }
83 } else {
84 LOG(DEBUG) << "Initialization failed.";
85 mSocket.reset();
86 LOG(DEBUG) << "Failing all pending queries.";
87 mQueries.clear();
88 }
89 }
90
onResponse(std::vector<uint8_t> response)91 void DnsTlsTransport::onResponse(std::vector<uint8_t> response) {
92 mQueries.onResponse(std::move(response));
93 }
94
onClosed()95 void DnsTlsTransport::onClosed() {
96 std::lock_guard guard(mLock);
97 if (mClosing) {
98 return;
99 }
100 // Move remaining operations to a new thread.
101 // This is necessary because
102 // 1. onClosed is currently running on a thread that blocks mSocket's destructor
103 // 2. doReconnect will call that destructor
104 if (mReconnectThread) {
105 // Complete cleanup of a previous reconnect thread, if present.
106 mReconnectThread->join();
107 // Joining a thread that is trying to acquire mLock, while holding mLock,
108 // looks like it risks a deadlock. However, a deadlock will not occur because
109 // once onClosed is called, it cannot be called again until after doReconnect
110 // acquires mLock.
111 }
112 mReconnectThread.reset(new std::thread(&DnsTlsTransport::doReconnect, this));
113 }
114
doReconnect()115 void DnsTlsTransport::doReconnect() {
116 std::lock_guard guard(mLock);
117 setThreadName(StringPrintf("TlsReconn_%u", mMark & 0xffff).c_str());
118 if (mClosing) {
119 return;
120 }
121 mQueries.cleanup();
122 if (!mQueries.empty()) {
123 LOG(DEBUG) << "Fast reconnect to retry remaining queries";
124 doConnect();
125 } else {
126 LOG(DEBUG) << "No pending queries. Going idle.";
127 mSocket.reset();
128 }
129 }
130
~DnsTlsTransport()131 DnsTlsTransport::~DnsTlsTransport() {
132 LOG(DEBUG) << "Destructor";
133 {
134 std::lock_guard guard(mLock);
135 LOG(DEBUG) << "Locked destruction procedure";
136 mQueries.clear();
137 mClosing = true;
138 }
139 // It's possible that a reconnect thread was spawned and waiting for mLock.
140 // It's safe for that thread to run now because mClosing is true (and mQueries is empty),
141 // but we need to wait for it to finish before allowing destruction to proceed.
142 if (mReconnectThread) {
143 LOG(DEBUG) << "Waiting for reconnect thread to terminate";
144 mReconnectThread->join();
145 mReconnectThread.reset();
146 }
147 // Ensure that the socket is destroyed, and can clean up its callback threads,
148 // before any of this object's fields become invalid.
149 mSocket.reset();
150 LOG(DEBUG) << "Destructor completed";
151 }
152
153 // static
154 // TODO: Use this function to preheat the session cache.
155 // That may require moving it to DnsTlsDispatcher.
validate(const DnsTlsServer & server,unsigned netid,uint32_t mark)156 bool DnsTlsTransport::validate(const DnsTlsServer& server, unsigned netid, uint32_t mark) {
157 LOG(DEBUG) << "Beginning validation on " << netid;
158 // Generate "<random>-dnsotls-ds.metric.gstatic.com", which we will lookup through |ss| in
159 // order to prove that it is actually a working DNS over TLS server.
160 static const char kDnsSafeChars[] =
161 "abcdefhijklmnopqrstuvwxyz"
162 "ABCDEFHIJKLMNOPQRSTUVWXYZ"
163 "0123456789";
164 const auto c = [](uint8_t rnd) -> uint8_t {
165 return kDnsSafeChars[(rnd % std::size(kDnsSafeChars))];
166 };
167 uint8_t rnd[8];
168 arc4random_buf(rnd, std::size(rnd));
169 // We could try to use res_mkquery() here, but it's basically the same.
170 uint8_t query[] = {
171 rnd[6], rnd[7], // [0-1] query ID
172 1, 0, // [2-3] flags; query[2] = 1 for recursion desired (RD).
173 0, 1, // [4-5] QDCOUNT (number of queries)
174 0, 0, // [6-7] ANCOUNT (number of answers)
175 0, 0, // [8-9] NSCOUNT (number of name server records)
176 0, 0, // [10-11] ARCOUNT (number of additional records)
177 17, c(rnd[0]), c(rnd[1]), c(rnd[2]), c(rnd[3]), c(rnd[4]), c(rnd[5]),
178 '-', 'd', 'n', 's', 'o', 't', 'l', 's', '-', 'd', 's',
179 6, 'm', 'e', 't', 'r', 'i', 'c',
180 7, 'g', 's', 't', 'a', 't', 'i', 'c',
181 3, 'c', 'o', 'm',
182 0, // null terminator of FQDN (root TLD)
183 0, ns_t_aaaa, // QTYPE
184 0, ns_c_in // QCLASS
185 };
186 const int qlen = std::size(query);
187
188 int replylen = 0;
189 DnsTlsSocketFactory factory;
190 DnsTlsTransport transport(server, mark, &factory);
191 auto r = transport.query(netdutils::Slice(query, qlen)).get();
192 if (r.code != Response::success) {
193 LOG(DEBUG) << "query failed";
194 return false;
195 }
196
197 const std::vector<uint8_t>& recvbuf = r.response;
198 if (recvbuf.size() < NS_HFIXEDSZ) {
199 LOG(WARNING) << "short response: " << replylen;
200 return false;
201 }
202
203 const int qdcount = (recvbuf[4] << 8) | recvbuf[5];
204 if (qdcount != 1) {
205 LOG(WARNING) << "reply query count != 1: " << qdcount;
206 return false;
207 }
208
209 const int ancount = (recvbuf[6] << 8) | recvbuf[7];
210 LOG(DEBUG) << netid << " answer count: " << ancount;
211
212 // TODO: Further validate the response contents (check for valid AAAA record, ...).
213 // Note that currently, integration tests rely on this function accepting a
214 // response with zero records.
215
216 return true;
217 }
218
219 } // end of namespace net
220 } // end of namespace android
221