1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "resolv"
18 
19 #include "DnsTlsTransport.h"
20 
21 #include <android-base/logging.h>
22 #include <android-base/stringprintf.h>
23 #include <arpa/inet.h>
24 #include <arpa/nameser.h>
25 #include <netdutils/ThreadUtil.h>
26 
27 #include "DnsTlsSocketFactory.h"
28 #include "IDnsTlsSocketFactory.h"
29 
30 using android::base::StringPrintf;
31 using android::netdutils::setThreadName;
32 
33 namespace android {
34 namespace net {
35 
query(const netdutils::Slice query)36 std::future<DnsTlsTransport::Result> DnsTlsTransport::query(const netdutils::Slice query) {
37     std::lock_guard guard(mLock);
38 
39     auto record = mQueries.recordQuery(query);
40     if (!record) {
41         return std::async(std::launch::deferred, []{
42             return (Result) { .code = Response::internal_error };
43         });
44     }
45 
46     if (!mSocket) {
47         LOG(DEBUG) << "No socket for query.  Opening socket and sending.";
48         doConnect();
49     } else {
50         sendQuery(record->query);
51     }
52 
53     return std::move(record->result);
54 }
55 
getConnectCounter() const56 int DnsTlsTransport::getConnectCounter() const {
57     std::lock_guard guard(mLock);
58     return mConnectCounter;
59 }
60 
sendQuery(const DnsTlsQueryMap::Query & q)61 bool DnsTlsTransport::sendQuery(const DnsTlsQueryMap::Query& q) {
62     // Strip off the ID number and send the new ID instead.
63     const bool sent = mSocket->query(q.newId, netdutils::drop(netdutils::makeSlice(q.query), 2));
64     if (sent) {
65         mQueries.markTried(q.newId);
66     }
67     return sent;
68 }
69 
doConnect()70 void DnsTlsTransport::doConnect() {
71     LOG(DEBUG) << "Constructing new socket";
72     mSocket = mFactory->createDnsTlsSocket(mServer, mMark, this, &mCache);
73     mConnectCounter++;
74 
75     if (mSocket) {
76         auto queries = mQueries.getAll();
77         LOG(DEBUG) << "Initialization succeeded.  Reissuing " << queries.size() << " queries.";
78         for(auto& q : queries) {
79             if (!sendQuery(q)) {
80                 break;
81             }
82         }
83     } else {
84         LOG(DEBUG) << "Initialization failed.";
85         mSocket.reset();
86         LOG(DEBUG) << "Failing all pending queries.";
87         mQueries.clear();
88     }
89 }
90 
onResponse(std::vector<uint8_t> response)91 void DnsTlsTransport::onResponse(std::vector<uint8_t> response) {
92     mQueries.onResponse(std::move(response));
93 }
94 
onClosed()95 void DnsTlsTransport::onClosed() {
96     std::lock_guard guard(mLock);
97     if (mClosing) {
98         return;
99     }
100     // Move remaining operations to a new thread.
101     // This is necessary because
102     // 1. onClosed is currently running on a thread that blocks mSocket's destructor
103     // 2. doReconnect will call that destructor
104     if (mReconnectThread) {
105         // Complete cleanup of a previous reconnect thread, if present.
106         mReconnectThread->join();
107         // Joining a thread that is trying to acquire mLock, while holding mLock,
108         // looks like it risks a deadlock.  However, a deadlock will not occur because
109         // once onClosed is called, it cannot be called again until after doReconnect
110         // acquires mLock.
111     }
112     mReconnectThread.reset(new std::thread(&DnsTlsTransport::doReconnect, this));
113 }
114 
doReconnect()115 void DnsTlsTransport::doReconnect() {
116     std::lock_guard guard(mLock);
117     setThreadName(StringPrintf("TlsReconn_%u", mMark & 0xffff).c_str());
118     if (mClosing) {
119         return;
120     }
121     mQueries.cleanup();
122     if (!mQueries.empty()) {
123         LOG(DEBUG) << "Fast reconnect to retry remaining queries";
124         doConnect();
125     } else {
126         LOG(DEBUG) << "No pending queries.  Going idle.";
127         mSocket.reset();
128     }
129 }
130 
~DnsTlsTransport()131 DnsTlsTransport::~DnsTlsTransport() {
132     LOG(DEBUG) << "Destructor";
133     {
134         std::lock_guard guard(mLock);
135         LOG(DEBUG) << "Locked destruction procedure";
136         mQueries.clear();
137         mClosing = true;
138     }
139     // It's possible that a reconnect thread was spawned and waiting for mLock.
140     // It's safe for that thread to run now because mClosing is true (and mQueries is empty),
141     // but we need to wait for it to finish before allowing destruction to proceed.
142     if (mReconnectThread) {
143         LOG(DEBUG) << "Waiting for reconnect thread to terminate";
144         mReconnectThread->join();
145         mReconnectThread.reset();
146     }
147     // Ensure that the socket is destroyed, and can clean up its callback threads,
148     // before any of this object's fields become invalid.
149     mSocket.reset();
150     LOG(DEBUG) << "Destructor completed";
151 }
152 
153 // static
154 // TODO: Use this function to preheat the session cache.
155 // That may require moving it to DnsTlsDispatcher.
validate(const DnsTlsServer & server,unsigned netid,uint32_t mark)156 bool DnsTlsTransport::validate(const DnsTlsServer& server, unsigned netid, uint32_t mark) {
157     LOG(DEBUG) << "Beginning validation on " << netid;
158     // Generate "<random>-dnsotls-ds.metric.gstatic.com", which we will lookup through |ss| in
159     // order to prove that it is actually a working DNS over TLS server.
160     static const char kDnsSafeChars[] =
161             "abcdefhijklmnopqrstuvwxyz"
162             "ABCDEFHIJKLMNOPQRSTUVWXYZ"
163             "0123456789";
164     const auto c = [](uint8_t rnd) -> uint8_t {
165         return kDnsSafeChars[(rnd % std::size(kDnsSafeChars))];
166     };
167     uint8_t rnd[8];
168     arc4random_buf(rnd, std::size(rnd));
169     // We could try to use res_mkquery() here, but it's basically the same.
170     uint8_t query[] = {
171         rnd[6], rnd[7],  // [0-1]   query ID
172         1, 0,  // [2-3]   flags; query[2] = 1 for recursion desired (RD).
173         0, 1,  // [4-5]   QDCOUNT (number of queries)
174         0, 0,  // [6-7]   ANCOUNT (number of answers)
175         0, 0,  // [8-9]   NSCOUNT (number of name server records)
176         0, 0,  // [10-11] ARCOUNT (number of additional records)
177         17, c(rnd[0]), c(rnd[1]), c(rnd[2]), c(rnd[3]), c(rnd[4]), c(rnd[5]),
178             '-', 'd', 'n', 's', 'o', 't', 'l', 's', '-', 'd', 's',
179         6, 'm', 'e', 't', 'r', 'i', 'c',
180         7, 'g', 's', 't', 'a', 't', 'i', 'c',
181         3, 'c', 'o', 'm',
182         0,  // null terminator of FQDN (root TLD)
183         0, ns_t_aaaa,  // QTYPE
184         0, ns_c_in     // QCLASS
185     };
186     const int qlen = std::size(query);
187 
188     int replylen = 0;
189     DnsTlsSocketFactory factory;
190     DnsTlsTransport transport(server, mark, &factory);
191     auto r = transport.query(netdutils::Slice(query, qlen)).get();
192     if (r.code != Response::success) {
193         LOG(DEBUG) << "query failed";
194         return false;
195     }
196 
197     const std::vector<uint8_t>& recvbuf = r.response;
198     if (recvbuf.size() < NS_HFIXEDSZ) {
199         LOG(WARNING) << "short response: " << replylen;
200         return false;
201     }
202 
203     const int qdcount = (recvbuf[4] << 8) | recvbuf[5];
204     if (qdcount != 1) {
205         LOG(WARNING) << "reply query count != 1: " << qdcount;
206         return false;
207     }
208 
209     const int ancount = (recvbuf[6] << 8) | recvbuf[7];
210     LOG(DEBUG) << netid << " answer count: " << ancount;
211 
212     // TODO: Further validate the response contents (check for valid AAAA record, ...).
213     // Note that currently, integration tests rely on this function accepting a
214     // response with zero records.
215 
216     return true;
217 }
218 
219 }  // end of namespace net
220 }  // end of namespace android
221