1 /*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "dns_tls_frontend.h"
18
19 #include <arpa/inet.h>
20 #include <netdb.h>
21 #include <openssl/err.h>
22 #include <openssl/evp.h>
23 #include <openssl/ssl.h>
24 #include <openssl/x509.h>
25 #include <sys/eventfd.h>
26 #include <sys/poll.h>
27 #include <sys/socket.h>
28 #include <sys/types.h>
29 #include <unistd.h>
30
31 #define LOG_TAG "DnsTlsFrontend"
32 #include <android-base/logging.h>
33 #include <netdutils/InternetAddresses.h>
34 #include <netdutils/SocketOption.h>
35 #include "dns_tls_certificate.h"
36
37 using android::netdutils::enableSockopt;
38 using android::netdutils::ScopedAddrinfo;
39
40 namespace {
stringToX509Certs(const char * certs)41 static bssl::UniquePtr<X509> stringToX509Certs(const char* certs) {
42 bssl::UniquePtr<BIO> bio(BIO_new_mem_buf(certs, strlen(certs)));
43 return bssl::UniquePtr<X509>(PEM_read_bio_X509(bio.get(), nullptr, nullptr, nullptr));
44 }
45
46 // Convert a string buffer containing an RSA Private Key into an OpenSSL RSA struct.
stringToRSAPrivateKey(const char * key)47 static bssl::UniquePtr<RSA> stringToRSAPrivateKey(const char* key) {
48 bssl::UniquePtr<BIO> bio(BIO_new_mem_buf(key, strlen(key)));
49 return bssl::UniquePtr<RSA>(PEM_read_bio_RSAPrivateKey(bio.get(), nullptr, nullptr, nullptr));
50 }
51
addr2str(const sockaddr * sa,socklen_t sa_len)52 std::string addr2str(const sockaddr* sa, socklen_t sa_len) {
53 char host_str[NI_MAXHOST] = {0};
54 int rv = getnameinfo(sa, sa_len, host_str, sizeof(host_str), nullptr, 0, NI_NUMERICHOST);
55 if (rv == 0) return std::string(host_str);
56 return std::string();
57 }
58
59 } // namespace
60
61 namespace test {
62
startServer()63 bool DnsTlsFrontend::startServer() {
64 OpenSSL_add_ssl_algorithms();
65
66 // reset queries_ to 0 every time startServer called
67 // which would help us easy to check queries_ via calling waitForQueries
68 queries_ = 0;
69
70 ctx_.reset(SSL_CTX_new(TLS_server_method()));
71 if (!ctx_) {
72 LOG(ERROR) << "SSL context creation failed";
73 return false;
74 }
75
76 SSL_CTX_set_ecdh_auto(ctx_.get(), 1);
77
78 bssl::UniquePtr<X509> ca_certs(stringToX509Certs(kCertificate));
79 if (!ca_certs) {
80 LOG(ERROR) << "StringToX509Certs failed";
81 return false;
82 }
83
84 if (SSL_CTX_use_certificate(ctx_.get(), ca_certs.get()) <= 0) {
85 LOG(ERROR) << "SSL_CTX_use_certificate failed";
86 return false;
87 }
88
89 bssl::UniquePtr<RSA> private_key(stringToRSAPrivateKey(kPrivatekey));
90 if (SSL_CTX_use_RSAPrivateKey(ctx_.get(), private_key.get()) <= 0) {
91 LOG(ERROR) << "Error loading client RSA Private Key data.";
92 return false;
93 }
94
95 // Set up TCP server socket for clients.
96 addrinfo frontend_ai_hints{
97 .ai_flags = AI_PASSIVE,
98 .ai_family = AF_UNSPEC,
99 .ai_socktype = SOCK_STREAM,
100 };
101 addrinfo* frontend_ai_res = nullptr;
102 int rv = getaddrinfo(listen_address_.c_str(), listen_service_.c_str(), &frontend_ai_hints,
103 &frontend_ai_res);
104 ScopedAddrinfo frontend_ai_res_cleanup(frontend_ai_res);
105 if (rv) {
106 LOG(ERROR) << "frontend getaddrinfo(" << listen_address_.c_str() << ", "
107 << listen_service_.c_str() << ") failed: " << gai_strerror(rv);
108 return false;
109 }
110
111 for (const addrinfo* ai = frontend_ai_res; ai; ai = ai->ai_next) {
112 android::base::unique_fd s(socket(ai->ai_family, ai->ai_socktype, ai->ai_protocol));
113 if (s.get() < 0) {
114 PLOG(INFO) << "ignore creating socket failed " << s.get();
115 continue;
116 }
117 enableSockopt(s.get(), SOL_SOCKET, SO_REUSEPORT).ignoreError();
118 enableSockopt(s.get(), SOL_SOCKET, SO_REUSEADDR).ignoreError();
119 std::string host_str = addr2str(ai->ai_addr, ai->ai_addrlen);
120 if (bind(s.get(), ai->ai_addr, ai->ai_addrlen)) {
121 PLOG(INFO) << "failed to bind TCP " << host_str.c_str() << ":"
122 << listen_service_.c_str();
123 continue;
124 }
125 LOG(INFO) << "bound to TCP " << host_str.c_str() << ":" << listen_service_.c_str();
126 socket_ = std::move(s);
127 break;
128 }
129
130 if (listen(socket_.get(), 1) < 0) {
131 PLOG(INFO) << "failed to listen socket " << socket_.get();
132 return false;
133 }
134
135 // Set up UDP client socket to backend.
136 addrinfo backend_ai_hints{.ai_family = AF_UNSPEC, .ai_socktype = SOCK_DGRAM};
137 addrinfo* backend_ai_res = nullptr;
138 rv = getaddrinfo(backend_address_.c_str(), backend_service_.c_str(), &backend_ai_hints,
139 &backend_ai_res);
140 ScopedAddrinfo backend_ai_res_cleanup(backend_ai_res);
141 if (rv) {
142 LOG(ERROR) << "backend getaddrinfo(" << listen_address_.c_str() << ", "
143 << listen_service_.c_str() << ") failed: " << gai_strerror(rv);
144 return false;
145 }
146 backend_socket_.reset(socket(backend_ai_res->ai_family, backend_ai_res->ai_socktype,
147 backend_ai_res->ai_protocol));
148 if (backend_socket_.get() < 0) {
149 PLOG(INFO) << "backend socket " << backend_socket_.get() << " creation failed";
150 return false;
151 }
152
153 // connect() always fails in the test DnsTlsSocketTest.SlowDestructor because of
154 // no backend server. Don't check it.
155 connect(backend_socket_.get(), backend_ai_res->ai_addr, backend_ai_res->ai_addrlen);
156
157 // Set up eventfd socket.
158 event_fd_.reset(eventfd(0, EFD_NONBLOCK | EFD_CLOEXEC));
159 if (event_fd_.get() == -1) {
160 PLOG(INFO) << "failed to create eventfd " << event_fd_.get();
161 return false;
162 }
163
164 {
165 std::lock_guard lock(update_mutex_);
166 handler_thread_ = std::thread(&DnsTlsFrontend::requestHandler, this);
167 }
168 LOG(INFO) << "server started successfully";
169 return true;
170 }
171
requestHandler()172 void DnsTlsFrontend::requestHandler() {
173 LOG(DEBUG) << "Request handler started";
174 enum { EVENT_FD = 0, LISTEN_FD = 1 };
175 pollfd fds[2] = {{.fd = event_fd_.get(), .events = POLLIN},
176 {.fd = socket_.get(), .events = POLLIN}};
177 android::base::unique_fd clientFd;
178
179 while (true) {
180 int poll_code = poll(fds, std::size(fds), -1);
181 if (poll_code <= 0) {
182 PLOG(WARNING) << "Poll failed with error " << poll_code;
183 break;
184 }
185
186 if (fds[EVENT_FD].revents & (POLLIN | POLLERR)) {
187 handleEventFd();
188 break;
189 }
190 if (fds[LISTEN_FD].revents & (POLLIN | POLLERR)) {
191 sockaddr_storage addr;
192 socklen_t len = sizeof(addr);
193
194 LOG(DEBUG) << "Trying to accept a client";
195 android::base::unique_fd client(
196 accept4(socket_.get(), reinterpret_cast<sockaddr*>(&addr), &len, SOCK_CLOEXEC));
197 if (client.get() < 0) {
198 // Stop
199 PLOG(INFO) << "failed to accept client socket " << client.get();
200 break;
201 }
202
203 accept_connection_count_++;
204 if (hangOnHandshake_) {
205 LOG(DEBUG) << "TEST ONLY: unresponsive to SSL handshake";
206
207 // The previous fd already stored in clientFd will be closed automatically.
208 clientFd = std::move(client);
209 continue;
210 }
211
212 bssl::UniquePtr<SSL> ssl(SSL_new(ctx_.get()));
213 SSL_set_fd(ssl.get(), client.get());
214
215 LOG(DEBUG) << "Doing SSL handshake";
216 if (SSL_accept(ssl.get()) <= 0) {
217 LOG(INFO) << "SSL negotiation failure";
218 } else {
219 LOG(DEBUG) << "SSL handshake complete";
220 // Increment queries_ as late as possible, because it represents
221 // a query that is fully processed, and the response returned to the
222 // client, including cleanup actions.
223 queries_ += handleRequests(ssl.get(), client.get());
224 }
225 }
226 }
227 LOG(DEBUG) << "Ending loop";
228 }
229
handleRequests(SSL * ssl,int clientFd)230 int DnsTlsFrontend::handleRequests(SSL* ssl, int clientFd) {
231 int queryCounts = 0;
232 pollfd fds = {.fd = clientFd, .events = POLLIN};
233 do {
234 uint8_t queryHeader[2];
235 if (SSL_read(ssl, &queryHeader, 2) != 2) {
236 LOG(INFO) << "Not enough header bytes";
237 return queryCounts;
238 }
239 const uint16_t qlen = (queryHeader[0] << 8) | queryHeader[1];
240 uint8_t query[qlen];
241 size_t qbytes = 0;
242 while (qbytes < qlen) {
243 int ret = SSL_read(ssl, query + qbytes, qlen - qbytes);
244 if (ret <= 0) {
245 LOG(INFO) << "Error while reading query";
246 return queryCounts;
247 }
248 qbytes += ret;
249 }
250 int sent = send(backend_socket_.get(), query, qlen, 0);
251 if (sent != qlen) {
252 LOG(INFO) << "Failed to send query";
253 return queryCounts;
254 }
255 const int max_size = 4096;
256 uint8_t recv_buffer[max_size];
257 int rlen = recv(backend_socket_.get(), recv_buffer, max_size, 0);
258 if (rlen <= 0) {
259 LOG(INFO) << "Failed to receive response";
260 return queryCounts;
261 }
262 uint8_t responseHeader[2];
263 responseHeader[0] = rlen >> 8;
264 responseHeader[1] = rlen;
265 if (SSL_write(ssl, responseHeader, 2) != 2) {
266 LOG(INFO) << "Failed to write response header";
267 return queryCounts;
268 }
269 if (SSL_write(ssl, recv_buffer, rlen) != rlen) {
270 LOG(INFO) << "Failed to write response body";
271 return queryCounts;
272 }
273 ++queryCounts;
274 } while (poll(&fds, 1, 1) > 0);
275
276 LOG(DEBUG) << __func__ << " return: " << queryCounts;
277 return queryCounts;
278 }
279
stopServer()280 bool DnsTlsFrontend::stopServer() {
281 std::lock_guard lock(update_mutex_);
282 if (!running()) {
283 LOG(INFO) << "server not running";
284 return false;
285 }
286
287 LOG(INFO) << "stopping frontend";
288 if (!sendToEventFd()) {
289 return false;
290 }
291 handler_thread_.join();
292 socket_.reset();
293 backend_socket_.reset();
294 event_fd_.reset();
295 ctx_.reset();
296 LOG(INFO) << "frontend stopped successfully";
297 return true;
298 }
299
300 // TODO: use a condition variable instead of polling
301 // TODO: also clear queries_ to eliminate potential race conditions
waitForQueries(int expected_count) const302 bool DnsTlsFrontend::waitForQueries(int expected_count) const {
303 constexpr int intervalMs = 20;
304 constexpr int timeoutMs = 5000;
305 int limit = timeoutMs / intervalMs;
306 for (int count = 0; count <= limit; ++count) {
307 bool done = queries_ >= expected_count;
308 // Always sleep at least one more interval after we are done, to wait for
309 // any immediate post-query actions that the client may take (such as
310 // marking this server as reachable during validation).
311 usleep(intervalMs * 1000);
312 if (done) {
313 // For ensuring that calls have sufficient headroom for slow machines
314 LOG(DEBUG) << "Query arrived in " << count << "/" << limit << " of allotted time";
315 return true;
316 }
317 }
318 return false;
319 }
320
sendToEventFd()321 bool DnsTlsFrontend::sendToEventFd() {
322 const uint64_t data = 1;
323 if (const ssize_t rt = write(event_fd_.get(), &data, sizeof(data)); rt != sizeof(data)) {
324 PLOG(INFO) << "failed to write eventfd, rt=" << rt;
325 return false;
326 }
327 return true;
328 }
329
handleEventFd()330 void DnsTlsFrontend::handleEventFd() {
331 int64_t data;
332 if (const ssize_t rt = read(event_fd_.get(), &data, sizeof(data)); rt != sizeof(data)) {
333 PLOG(INFO) << "ignore reading eventfd failed, rt=" << rt;
334 }
335 }
336
337 } // namespace test
338