1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "dns_tls_frontend.h"
18 
19 #include <arpa/inet.h>
20 #include <netdb.h>
21 #include <openssl/err.h>
22 #include <openssl/evp.h>
23 #include <openssl/ssl.h>
24 #include <openssl/x509.h>
25 #include <sys/eventfd.h>
26 #include <sys/poll.h>
27 #include <sys/socket.h>
28 #include <sys/types.h>
29 #include <unistd.h>
30 
31 #define LOG_TAG "DnsTlsFrontend"
32 #include <android-base/logging.h>
33 #include <netdutils/InternetAddresses.h>
34 #include <netdutils/SocketOption.h>
35 #include "dns_tls_certificate.h"
36 
37 using android::netdutils::enableSockopt;
38 using android::netdutils::ScopedAddrinfo;
39 
40 namespace {
stringToX509Certs(const char * certs)41 static bssl::UniquePtr<X509> stringToX509Certs(const char* certs) {
42     bssl::UniquePtr<BIO> bio(BIO_new_mem_buf(certs, strlen(certs)));
43     return bssl::UniquePtr<X509>(PEM_read_bio_X509(bio.get(), nullptr, nullptr, nullptr));
44 }
45 
46 // Convert a string buffer containing an RSA Private Key into an OpenSSL RSA struct.
stringToRSAPrivateKey(const char * key)47 static bssl::UniquePtr<RSA> stringToRSAPrivateKey(const char* key) {
48     bssl::UniquePtr<BIO> bio(BIO_new_mem_buf(key, strlen(key)));
49     return bssl::UniquePtr<RSA>(PEM_read_bio_RSAPrivateKey(bio.get(), nullptr, nullptr, nullptr));
50 }
51 
addr2str(const sockaddr * sa,socklen_t sa_len)52 std::string addr2str(const sockaddr* sa, socklen_t sa_len) {
53     char host_str[NI_MAXHOST] = {0};
54     int rv = getnameinfo(sa, sa_len, host_str, sizeof(host_str), nullptr, 0, NI_NUMERICHOST);
55     if (rv == 0) return std::string(host_str);
56     return std::string();
57 }
58 
59 }  // namespace
60 
61 namespace test {
62 
startServer()63 bool DnsTlsFrontend::startServer() {
64     OpenSSL_add_ssl_algorithms();
65 
66     // reset queries_ to 0 every time startServer called
67     // which would help us easy to check queries_ via calling waitForQueries
68     queries_ = 0;
69 
70     ctx_.reset(SSL_CTX_new(TLS_server_method()));
71     if (!ctx_) {
72         LOG(ERROR) << "SSL context creation failed";
73         return false;
74     }
75 
76     SSL_CTX_set_ecdh_auto(ctx_.get(), 1);
77 
78     bssl::UniquePtr<X509> ca_certs(stringToX509Certs(kCertificate));
79     if (!ca_certs) {
80         LOG(ERROR) << "StringToX509Certs failed";
81         return false;
82     }
83 
84     if (SSL_CTX_use_certificate(ctx_.get(), ca_certs.get()) <= 0) {
85         LOG(ERROR) << "SSL_CTX_use_certificate failed";
86         return false;
87     }
88 
89     bssl::UniquePtr<RSA> private_key(stringToRSAPrivateKey(kPrivatekey));
90     if (SSL_CTX_use_RSAPrivateKey(ctx_.get(), private_key.get()) <= 0) {
91         LOG(ERROR) << "Error loading client RSA Private Key data.";
92         return false;
93     }
94 
95     // Set up TCP server socket for clients.
96     addrinfo frontend_ai_hints{
97             .ai_flags = AI_PASSIVE,
98             .ai_family = AF_UNSPEC,
99             .ai_socktype = SOCK_STREAM,
100     };
101     addrinfo* frontend_ai_res = nullptr;
102     int rv = getaddrinfo(listen_address_.c_str(), listen_service_.c_str(), &frontend_ai_hints,
103                          &frontend_ai_res);
104     ScopedAddrinfo frontend_ai_res_cleanup(frontend_ai_res);
105     if (rv) {
106         LOG(ERROR) << "frontend getaddrinfo(" << listen_address_.c_str() << ", "
107                    << listen_service_.c_str() << ") failed: " << gai_strerror(rv);
108         return false;
109     }
110 
111     for (const addrinfo* ai = frontend_ai_res; ai; ai = ai->ai_next) {
112         android::base::unique_fd s(socket(ai->ai_family, ai->ai_socktype, ai->ai_protocol));
113         if (s.get() < 0) {
114             PLOG(INFO) << "ignore creating socket failed " << s.get();
115             continue;
116         }
117         enableSockopt(s.get(), SOL_SOCKET, SO_REUSEPORT).ignoreError();
118         enableSockopt(s.get(), SOL_SOCKET, SO_REUSEADDR).ignoreError();
119         std::string host_str = addr2str(ai->ai_addr, ai->ai_addrlen);
120         if (bind(s.get(), ai->ai_addr, ai->ai_addrlen)) {
121             PLOG(INFO) << "failed to bind TCP " << host_str.c_str() << ":"
122                        << listen_service_.c_str();
123             continue;
124         }
125         LOG(INFO) << "bound to TCP " << host_str.c_str() << ":" << listen_service_.c_str();
126         socket_ = std::move(s);
127         break;
128     }
129 
130     if (listen(socket_.get(), 1) < 0) {
131         PLOG(INFO) << "failed to listen socket " << socket_.get();
132         return false;
133     }
134 
135     // Set up UDP client socket to backend.
136     addrinfo backend_ai_hints{.ai_family = AF_UNSPEC, .ai_socktype = SOCK_DGRAM};
137     addrinfo* backend_ai_res = nullptr;
138     rv = getaddrinfo(backend_address_.c_str(), backend_service_.c_str(), &backend_ai_hints,
139                      &backend_ai_res);
140     ScopedAddrinfo backend_ai_res_cleanup(backend_ai_res);
141     if (rv) {
142         LOG(ERROR) << "backend getaddrinfo(" << listen_address_.c_str() << ", "
143                    << listen_service_.c_str() << ") failed: " << gai_strerror(rv);
144         return false;
145     }
146     backend_socket_.reset(socket(backend_ai_res->ai_family, backend_ai_res->ai_socktype,
147                                  backend_ai_res->ai_protocol));
148     if (backend_socket_.get() < 0) {
149         PLOG(INFO) << "backend socket " << backend_socket_.get() << " creation failed";
150         return false;
151     }
152 
153     // connect() always fails in the test DnsTlsSocketTest.SlowDestructor because of
154     // no backend server. Don't check it.
155     connect(backend_socket_.get(), backend_ai_res->ai_addr, backend_ai_res->ai_addrlen);
156 
157     // Set up eventfd socket.
158     event_fd_.reset(eventfd(0, EFD_NONBLOCK | EFD_CLOEXEC));
159     if (event_fd_.get() == -1) {
160         PLOG(INFO) << "failed to create eventfd " << event_fd_.get();
161         return false;
162     }
163 
164     {
165         std::lock_guard lock(update_mutex_);
166         handler_thread_ = std::thread(&DnsTlsFrontend::requestHandler, this);
167     }
168     LOG(INFO) << "server started successfully";
169     return true;
170 }
171 
requestHandler()172 void DnsTlsFrontend::requestHandler() {
173     LOG(DEBUG) << "Request handler started";
174     enum { EVENT_FD = 0, LISTEN_FD = 1 };
175     pollfd fds[2] = {{.fd = event_fd_.get(), .events = POLLIN},
176                      {.fd = socket_.get(), .events = POLLIN}};
177     android::base::unique_fd clientFd;
178 
179     while (true) {
180         int poll_code = poll(fds, std::size(fds), -1);
181         if (poll_code <= 0) {
182             PLOG(WARNING) << "Poll failed with error " << poll_code;
183             break;
184         }
185 
186         if (fds[EVENT_FD].revents & (POLLIN | POLLERR)) {
187             handleEventFd();
188             break;
189         }
190         if (fds[LISTEN_FD].revents & (POLLIN | POLLERR)) {
191             sockaddr_storage addr;
192             socklen_t len = sizeof(addr);
193 
194             LOG(DEBUG) << "Trying to accept a client";
195             android::base::unique_fd client(
196                     accept4(socket_.get(), reinterpret_cast<sockaddr*>(&addr), &len, SOCK_CLOEXEC));
197             if (client.get() < 0) {
198                 // Stop
199                 PLOG(INFO) << "failed to accept client socket " << client.get();
200                 break;
201             }
202 
203             accept_connection_count_++;
204             if (hangOnHandshake_) {
205                 LOG(DEBUG) << "TEST ONLY: unresponsive to SSL handshake";
206 
207                 // The previous fd already stored in clientFd will be closed automatically.
208                 clientFd = std::move(client);
209                 continue;
210             }
211 
212             bssl::UniquePtr<SSL> ssl(SSL_new(ctx_.get()));
213             SSL_set_fd(ssl.get(), client.get());
214 
215             LOG(DEBUG) << "Doing SSL handshake";
216             if (SSL_accept(ssl.get()) <= 0) {
217                 LOG(INFO) << "SSL negotiation failure";
218             } else {
219                 LOG(DEBUG) << "SSL handshake complete";
220                 // Increment queries_ as late as possible, because it represents
221                 // a query that is fully processed, and the response returned to the
222                 // client, including cleanup actions.
223                 queries_ += handleRequests(ssl.get(), client.get());
224             }
225         }
226     }
227     LOG(DEBUG) << "Ending loop";
228 }
229 
handleRequests(SSL * ssl,int clientFd)230 int DnsTlsFrontend::handleRequests(SSL* ssl, int clientFd) {
231     int queryCounts = 0;
232     pollfd fds = {.fd = clientFd, .events = POLLIN};
233     do {
234         uint8_t queryHeader[2];
235         if (SSL_read(ssl, &queryHeader, 2) != 2) {
236             LOG(INFO) << "Not enough header bytes";
237             return queryCounts;
238         }
239         const uint16_t qlen = (queryHeader[0] << 8) | queryHeader[1];
240         uint8_t query[qlen];
241         size_t qbytes = 0;
242         while (qbytes < qlen) {
243             int ret = SSL_read(ssl, query + qbytes, qlen - qbytes);
244             if (ret <= 0) {
245                 LOG(INFO) << "Error while reading query";
246                 return queryCounts;
247             }
248             qbytes += ret;
249         }
250         int sent = send(backend_socket_.get(), query, qlen, 0);
251         if (sent != qlen) {
252             LOG(INFO) << "Failed to send query";
253             return queryCounts;
254         }
255         const int max_size = 4096;
256         uint8_t recv_buffer[max_size];
257         int rlen = recv(backend_socket_.get(), recv_buffer, max_size, 0);
258         if (rlen <= 0) {
259             LOG(INFO) << "Failed to receive response";
260             return queryCounts;
261         }
262         uint8_t responseHeader[2];
263         responseHeader[0] = rlen >> 8;
264         responseHeader[1] = rlen;
265         if (SSL_write(ssl, responseHeader, 2) != 2) {
266             LOG(INFO) << "Failed to write response header";
267             return queryCounts;
268         }
269         if (SSL_write(ssl, recv_buffer, rlen) != rlen) {
270             LOG(INFO) << "Failed to write response body";
271             return queryCounts;
272         }
273         ++queryCounts;
274     } while (poll(&fds, 1, 1) > 0);
275 
276     LOG(DEBUG) << __func__ << " return: " << queryCounts;
277     return queryCounts;
278 }
279 
stopServer()280 bool DnsTlsFrontend::stopServer() {
281     std::lock_guard lock(update_mutex_);
282     if (!running()) {
283         LOG(INFO) << "server not running";
284         return false;
285     }
286 
287     LOG(INFO) << "stopping frontend";
288     if (!sendToEventFd()) {
289         return false;
290     }
291     handler_thread_.join();
292     socket_.reset();
293     backend_socket_.reset();
294     event_fd_.reset();
295     ctx_.reset();
296     LOG(INFO) << "frontend stopped successfully";
297     return true;
298 }
299 
300 // TODO: use a condition variable instead of polling
301 // TODO: also clear queries_ to eliminate potential race conditions
waitForQueries(int expected_count) const302 bool DnsTlsFrontend::waitForQueries(int expected_count) const {
303     constexpr int intervalMs = 20;
304     constexpr int timeoutMs = 5000;
305     int limit = timeoutMs / intervalMs;
306     for (int count = 0; count <= limit; ++count) {
307         bool done = queries_ >= expected_count;
308         // Always sleep at least one more interval after we are done, to wait for
309         // any immediate post-query actions that the client may take (such as
310         // marking this server as reachable during validation).
311         usleep(intervalMs * 1000);
312         if (done) {
313             // For ensuring that calls have sufficient headroom for slow machines
314             LOG(DEBUG) << "Query arrived in " << count << "/" << limit << " of allotted time";
315             return true;
316         }
317     }
318     return false;
319 }
320 
sendToEventFd()321 bool DnsTlsFrontend::sendToEventFd() {
322     const uint64_t data = 1;
323     if (const ssize_t rt = write(event_fd_.get(), &data, sizeof(data)); rt != sizeof(data)) {
324         PLOG(INFO) << "failed to write eventfd, rt=" << rt;
325         return false;
326     }
327     return true;
328 }
329 
handleEventFd()330 void DnsTlsFrontend::handleEventFd() {
331     int64_t data;
332     if (const ssize_t rt = read(event_fd_.get(), &data, sizeof(data)); rt != sizeof(data)) {
333         PLOG(INFO) << "ignore reading eventfd failed, rt=" << rt;
334     }
335 }
336 
337 }  // namespace test
338