1 /*
2 * Copyright (C) 2020 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "adb/pairing/pairing_server.h"
18
19 #include <sys/epoll.h>
20 #include <sys/eventfd.h>
21
22 #include <atomic>
23 #include <deque>
24 #include <iomanip>
25 #include <mutex>
26 #include <sstream>
27 #include <thread>
28 #include <tuple>
29 #include <unordered_map>
30 #include <variant>
31 #include <vector>
32
33 #include <adb/crypto/rsa_2048_key.h>
34 #include <adb/crypto/x509_generator.h>
35 #include <adb/pairing/pairing_connection.h>
36 #include <android-base/logging.h>
37 #include <android-base/parsenetaddress.h>
38 #include <android-base/thread_annotations.h>
39 #include <android-base/unique_fd.h>
40 #include <cutils/sockets.h>
41
42 #include "internal/constants.h"
43
44 using android::base::ScopedLockAssertion;
45 using android::base::unique_fd;
46 using namespace adb::crypto;
47 using namespace adb::pairing;
48
49 // The implementation has two background threads running: one to handle and
50 // accept any new pairing connection requests (socket accept), and the other to
51 // handle connection events (connection started, connection finished).
52 struct PairingServerCtx {
53 public:
54 using Data = std::vector<uint8_t>;
55
56 virtual ~PairingServerCtx();
57
58 // All parameters must be non-empty.
59 explicit PairingServerCtx(const Data& pswd, const PeerInfo& peer_info, const Data& cert,
60 const Data& priv_key, uint16_t port);
61
62 // Starts the pairing server. This call is non-blocking. Upon completion,
63 // if the pairing was successful, then |cb| will be called with the PublicKeyHeader
64 // containing the info of the trusted peer. Otherwise, |cb| will be
65 // called with an empty value. Start can only be called once in the lifetime
66 // of this object.
67 //
68 // Returns the port number if PairingServerCtx was successfully started. Otherwise,
69 // returns 0.
70 uint16_t Start(pairing_server_result_cb cb, void* opaque);
71
72 private:
73 // Setup the server socket to accept incoming connections. Returns the
74 // server port number (> 0 on success).
75 uint16_t SetupServer();
76 // Force stop the server thread.
77 void StopServer();
78
79 // handles a new pairing client connection
80 bool HandleNewClientConnection(int fd) EXCLUDES(conn_mutex_);
81
82 // ======== connection events thread =============
83 std::mutex conn_mutex_;
84 std::condition_variable conn_cv_;
85
86 using FdVal = int;
87 struct ConnectionDeleter {
operator ()PairingServerCtx::ConnectionDeleter88 void operator()(PairingConnectionCtx* p) { pairing_connection_destroy(p); }
89 };
90 using ConnectionPtr = std::unique_ptr<PairingConnectionCtx, ConnectionDeleter>;
91 static ConnectionPtr CreatePairingConnection(const Data& pswd, const PeerInfo& peer_info,
92 const Data& cert, const Data& priv_key);
93 using NewConnectionEvent = std::tuple<unique_fd, ConnectionPtr>;
94 // <fd, PeerInfo.type, PeerInfo.data>
95 using ConnectionFinishedEvent = std::tuple<FdVal, uint8_t, std::optional<std::string>>;
96 using ConnectionEvent = std::variant<NewConnectionEvent, ConnectionFinishedEvent>;
97 // Queue for connections to write into. We have a separate queue to read
98 // from, in order to minimize the time the server thread is blocked.
99 std::deque<ConnectionEvent> conn_write_queue_ GUARDED_BY(conn_mutex_);
100 std::deque<ConnectionEvent> conn_read_queue_;
101 // Map of fds to their PairingConnections currently running.
102 std::unordered_map<FdVal, ConnectionPtr> connections_;
103
104 // Two threads launched when starting the pairing server:
105 // 1) A server thread that waits for incoming client connections, and
106 // 2) A connection events thread that synchonizes events from all of the
107 // clients, since each PairingConnection is running in it's own thread.
108 void StartConnectionEventsThread();
109 void StartServerThread();
110
111 static void PairingConnectionCallback(const PeerInfo* peer_info, int fd, void* opaque);
112
113 std::thread conn_events_thread_;
114 void ConnectionEventsWorker();
115 std::thread server_thread_;
116 void ServerWorker();
117 bool is_terminate_ GUARDED_BY(conn_mutex_) = false;
118
119 enum class State {
120 Ready,
121 Running,
122 Stopped,
123 };
124 State state_ = State::Ready;
125 Data pswd_;
126 PeerInfo peer_info_;
127 Data cert_;
128 Data priv_key_;
129 uint16_t port_;
130
131 pairing_server_result_cb cb_;
132 void* opaque_ = nullptr;
133 bool got_valid_pairing_ = false;
134
135 static const int kEpollConstSocket = 0;
136 // Used to break the server thread from epoll_wait
137 static const int kEpollConstEventFd = 1;
138 unique_fd epoll_fd_;
139 unique_fd server_fd_;
140 unique_fd event_fd_;
141 }; // PairingServerCtx
142
143 // static
CreatePairingConnection(const Data & pswd,const PeerInfo & peer_info,const Data & cert,const Data & priv_key)144 PairingServerCtx::ConnectionPtr PairingServerCtx::CreatePairingConnection(const Data& pswd,
145 const PeerInfo& peer_info,
146 const Data& cert,
147 const Data& priv_key) {
148 return ConnectionPtr(pairing_connection_server_new(pswd.data(), pswd.size(), &peer_info,
149 cert.data(), cert.size(), priv_key.data(),
150 priv_key.size()));
151 }
152
PairingServerCtx(const Data & pswd,const PeerInfo & peer_info,const Data & cert,const Data & priv_key,uint16_t port)153 PairingServerCtx::PairingServerCtx(const Data& pswd, const PeerInfo& peer_info, const Data& cert,
154 const Data& priv_key, uint16_t port)
155 : pswd_(pswd), peer_info_(peer_info), cert_(cert), priv_key_(priv_key), port_(port) {
156 CHECK(!pswd_.empty() && !cert_.empty() && !priv_key_.empty());
157 }
158
~PairingServerCtx()159 PairingServerCtx::~PairingServerCtx() {
160 // Since these connections have references to us, let's make sure they
161 // destruct before us.
162 if (server_thread_.joinable()) {
163 StopServer();
164 server_thread_.join();
165 }
166
167 {
168 std::lock_guard<std::mutex> lock(conn_mutex_);
169 is_terminate_ = true;
170 }
171 conn_cv_.notify_one();
172 if (conn_events_thread_.joinable()) {
173 conn_events_thread_.join();
174 }
175
176 // Notify the cb_ if it hasn't already.
177 if (!got_valid_pairing_ && cb_ != nullptr) {
178 cb_(nullptr, opaque_);
179 }
180 }
181
Start(pairing_server_result_cb cb,void * opaque)182 uint16_t PairingServerCtx::Start(pairing_server_result_cb cb, void* opaque) {
183 cb_ = cb;
184 opaque_ = opaque;
185
186 if (state_ != State::Ready) {
187 LOG(ERROR) << "PairingServerCtx already running or stopped";
188 return 0;
189 }
190
191 port_ = SetupServer();
192 if (port_ == 0) {
193 LOG(ERROR) << "Unable to start PairingServer";
194 state_ = State::Stopped;
195 return 0;
196 }
197 LOG(INFO) << "Pairing server started on port " << port_;
198
199 state_ = State::Running;
200 return port_;
201 }
202
StopServer()203 void PairingServerCtx::StopServer() {
204 if (event_fd_.get() == -1) {
205 return;
206 }
207 uint64_t value = 1;
208 ssize_t rc = write(event_fd_.get(), &value, sizeof(value));
209 if (rc == -1) {
210 // This can happen if the server didn't start.
211 PLOG(ERROR) << "write to eventfd failed";
212 } else if (rc != sizeof(value)) {
213 LOG(FATAL) << "write to event returned short (" << rc << ")";
214 }
215 }
216
SetupServer()217 uint16_t PairingServerCtx::SetupServer() {
218 epoll_fd_.reset(epoll_create1(EPOLL_CLOEXEC));
219 if (epoll_fd_ == -1) {
220 PLOG(ERROR) << "failed to create epoll fd";
221 return 0;
222 }
223
224 event_fd_.reset(eventfd(0, EFD_CLOEXEC | EFD_NONBLOCK));
225 if (event_fd_ == -1) {
226 PLOG(ERROR) << "failed to create eventfd";
227 return 0;
228 }
229
230 server_fd_.reset(socket_inaddr_any_server(port_, SOCK_STREAM));
231 if (server_fd_.get() == -1) {
232 PLOG(ERROR) << "Failed to start pairing connection server";
233 return 0;
234 } else if (fcntl(server_fd_.get(), F_SETFD, FD_CLOEXEC) != 0) {
235 PLOG(ERROR) << "Failed to make server socket cloexec";
236 return 0;
237 } else if (fcntl(server_fd_.get(), F_SETFD, O_NONBLOCK) != 0) {
238 PLOG(ERROR) << "Failed to make server socket nonblocking";
239 return 0;
240 }
241
242 StartConnectionEventsThread();
243 StartServerThread();
244 int port = socket_get_local_port(server_fd_.get());
245 return (port <= 0 ? 0 : port);
246 }
247
StartServerThread()248 void PairingServerCtx::StartServerThread() {
249 server_thread_ = std::thread([this]() { ServerWorker(); });
250 }
251
StartConnectionEventsThread()252 void PairingServerCtx::StartConnectionEventsThread() {
253 conn_events_thread_ = std::thread([this]() { ConnectionEventsWorker(); });
254 }
255
ServerWorker()256 void PairingServerCtx::ServerWorker() {
257 {
258 struct epoll_event event;
259 event.events = EPOLLIN;
260 event.data.u64 = kEpollConstSocket;
261 CHECK_EQ(0, epoll_ctl(epoll_fd_.get(), EPOLL_CTL_ADD, server_fd_.get(), &event));
262 }
263
264 {
265 struct epoll_event event;
266 event.events = EPOLLIN;
267 event.data.u64 = kEpollConstEventFd;
268 CHECK_EQ(0, epoll_ctl(epoll_fd_.get(), EPOLL_CTL_ADD, event_fd_.get(), &event));
269 }
270
271 while (true) {
272 struct epoll_event events[2];
273 int rc = TEMP_FAILURE_RETRY(epoll_wait(epoll_fd_.get(), events, 2, -1));
274 if (rc == -1) {
275 PLOG(ERROR) << "epoll_wait failed";
276 return;
277 } else if (rc == 0) {
278 LOG(ERROR) << "epoll_wait returned 0";
279 return;
280 }
281
282 for (int i = 0; i < rc; ++i) {
283 struct epoll_event& event = events[i];
284 switch (event.data.u64) {
285 case kEpollConstSocket:
286 HandleNewClientConnection(server_fd_.get());
287 break;
288 case kEpollConstEventFd:
289 uint64_t dummy;
290 int rc = TEMP_FAILURE_RETRY(read(event_fd_.get(), &dummy, sizeof(dummy)));
291 if (rc != sizeof(dummy)) {
292 PLOG(FATAL) << "failed to read from eventfd (rc=" << rc << ")";
293 }
294 return;
295 }
296 }
297 }
298 }
299
300 // static
PairingConnectionCallback(const PeerInfo * peer_info,int fd,void * opaque)301 void PairingServerCtx::PairingConnectionCallback(const PeerInfo* peer_info, int fd, void* opaque) {
302 auto* p = reinterpret_cast<PairingServerCtx*>(opaque);
303
304 ConnectionFinishedEvent event;
305 if (peer_info != nullptr) {
306 if (peer_info->type == ADB_RSA_PUB_KEY) {
307 event = std::make_tuple(fd, peer_info->type,
308 std::string(reinterpret_cast<const char*>(peer_info->data)));
309 } else {
310 LOG(WARNING) << "Ignoring successful pairing because of unknown "
311 << "PeerInfo type=" << peer_info->type;
312 }
313 } else {
314 event = std::make_tuple(fd, 0, std::nullopt);
315 }
316 {
317 std::lock_guard<std::mutex> lock(p->conn_mutex_);
318 p->conn_write_queue_.push_back(std::move(event));
319 }
320 p->conn_cv_.notify_one();
321 }
322
ConnectionEventsWorker()323 void PairingServerCtx::ConnectionEventsWorker() {
324 uint8_t num_tries = 0;
325 for (;;) {
326 // Transfer the write queue to the read queue.
327 {
328 std::unique_lock<std::mutex> lock(conn_mutex_);
329 ScopedLockAssertion assume_locked(conn_mutex_);
330
331 if (is_terminate_) {
332 // We check |is_terminate_| twice because condition_variable's
333 // notify() only wakes up a thread if it is in the wait state
334 // prior to notify(). Furthermore, we aren't holding the mutex
335 // when processing the events in |conn_read_queue_|.
336 return;
337 }
338 if (conn_write_queue_.empty()) {
339 // We need to wait for new events, or the termination signal.
340 conn_cv_.wait(lock, [this]() REQUIRES(conn_mutex_) {
341 return (is_terminate_ || !conn_write_queue_.empty());
342 });
343 }
344 if (is_terminate_) {
345 // We're done.
346 return;
347 }
348 // Move all events into the read queue.
349 conn_read_queue_ = std::move(conn_write_queue_);
350 conn_write_queue_.clear();
351 }
352
353 // Process all events in the read queue.
354 while (conn_read_queue_.size() > 0) {
355 auto& event = conn_read_queue_.front();
356 if (auto* p = std::get_if<NewConnectionEvent>(&event)) {
357 // Ignore if we are already at the max number of connections
358 if (connections_.size() >= internal::kMaxConnections) {
359 conn_read_queue_.pop_front();
360 continue;
361 }
362 auto [ufd, connection] = std::move(*p);
363 int fd = ufd.release();
364 bool started = pairing_connection_start(connection.get(), fd,
365 PairingConnectionCallback, this);
366 if (!started) {
367 LOG(ERROR) << "PairingServer unable to start a PairingConnection fd=" << fd;
368 ufd.reset(fd);
369 } else {
370 connections_[fd] = std::move(connection);
371 }
372 } else if (auto* p = std::get_if<ConnectionFinishedEvent>(&event)) {
373 auto [fd, info_type, public_key] = std::move(*p);
374 if (public_key.has_value() && !public_key->empty()) {
375 // Valid pairing. Let's shutdown the server and close any
376 // pairing connections in progress.
377 StopServer();
378 connections_.clear();
379
380 PeerInfo info = {};
381 info.type = info_type;
382 strncpy(reinterpret_cast<char*>(info.data), public_key->data(),
383 public_key->size());
384
385 cb_(&info, opaque_);
386
387 got_valid_pairing_ = true;
388 return;
389 }
390 // Invalid pairing. Close the invalid connection.
391 if (connections_.find(fd) != connections_.end()) {
392 connections_.erase(fd);
393 }
394
395 if (++num_tries >= internal::kMaxPairingAttempts) {
396 cb_(nullptr, opaque_);
397 // To prevent the destructor from calling it again.
398 cb_ = nullptr;
399 return;
400 }
401 }
402 conn_read_queue_.pop_front();
403 }
404 }
405 }
406
HandleNewClientConnection(int fd)407 bool PairingServerCtx::HandleNewClientConnection(int fd) {
408 unique_fd ufd(TEMP_FAILURE_RETRY(accept4(fd, nullptr, nullptr, SOCK_CLOEXEC)));
409 if (ufd == -1) {
410 PLOG(WARNING) << "adb_socket_accept failed fd=" << fd;
411 return false;
412 }
413 auto connection = CreatePairingConnection(pswd_, peer_info_, cert_, priv_key_);
414 if (connection == nullptr) {
415 LOG(ERROR) << "PairingServer unable to create a PairingConnection fd=" << fd;
416 return false;
417 }
418 // send the new connection to the connection thread for further processing
419 NewConnectionEvent event = std::make_tuple(std::move(ufd), std::move(connection));
420 {
421 std::lock_guard<std::mutex> lock(conn_mutex_);
422 conn_write_queue_.push_back(std::move(event));
423 }
424 conn_cv_.notify_one();
425
426 return true;
427 }
428
pairing_server_start(PairingServerCtx * ctx,pairing_server_result_cb cb,void * opaque)429 uint16_t pairing_server_start(PairingServerCtx* ctx, pairing_server_result_cb cb, void* opaque) {
430 return ctx->Start(cb, opaque);
431 }
432
pairing_server_new(const uint8_t * pswd,size_t pswd_len,const PeerInfo * peer_info,const uint8_t * x509_cert_pem,size_t x509_size,const uint8_t * priv_key_pem,size_t priv_size,uint16_t port)433 PairingServerCtx* pairing_server_new(const uint8_t* pswd, size_t pswd_len,
434 const PeerInfo* peer_info, const uint8_t* x509_cert_pem,
435 size_t x509_size, const uint8_t* priv_key_pem,
436 size_t priv_size, uint16_t port) {
437 CHECK(pswd);
438 CHECK_GT(pswd_len, 0U);
439 CHECK(x509_cert_pem);
440 CHECK_GT(x509_size, 0U);
441 CHECK(priv_key_pem);
442 CHECK_GT(priv_size, 0U);
443 CHECK(peer_info);
444 std::vector<uint8_t> vec_pswd(pswd, pswd + pswd_len);
445 std::vector<uint8_t> vec_x509_cert(x509_cert_pem, x509_cert_pem + x509_size);
446 std::vector<uint8_t> vec_priv_key(priv_key_pem, priv_key_pem + priv_size);
447 return new PairingServerCtx(vec_pswd, *peer_info, vec_x509_cert, vec_priv_key, port);
448 }
449
pairing_server_new_no_cert(const uint8_t * pswd,size_t pswd_len,const PeerInfo * peer_info,uint16_t port)450 PairingServerCtx* pairing_server_new_no_cert(const uint8_t* pswd, size_t pswd_len,
451 const PeerInfo* peer_info, uint16_t port) {
452 auto rsa_2048 = CreateRSA2048Key();
453 auto x509_cert = GenerateX509Certificate(rsa_2048->GetEvpPkey());
454 std::string pkey_pem = Key::ToPEMString(rsa_2048->GetEvpPkey());
455 std::string cert_pem = X509ToPEMString(x509_cert.get());
456
457 return pairing_server_new(pswd, pswd_len, peer_info,
458 reinterpret_cast<const uint8_t*>(cert_pem.data()), cert_pem.size(),
459 reinterpret_cast<const uint8_t*>(pkey_pem.data()), pkey_pem.size(),
460 port);
461 }
462
pairing_server_destroy(PairingServerCtx * ctx)463 void pairing_server_destroy(PairingServerCtx* ctx) {
464 CHECK(ctx);
465 delete ctx;
466 }
467