1 /*
2 * Copyright (C) 2020 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "adb/pairing/pairing_connection.h"
18
19 #include <stddef.h>
20 #include <stdint.h>
21
22 #include <functional>
23 #include <memory>
24 #include <string_view>
25 #include <thread>
26 #include <vector>
27
28 #include <adb/pairing/pairing_auth.h>
29 #include <adb/tls/tls_connection.h>
30 #include <android-base/endian.h>
31 #include <android-base/logging.h>
32 #include <android-base/macros.h>
33 #include <android-base/unique_fd.h>
34
35 #include "pairing.pb.h"
36
37 using namespace adb;
38 using android::base::unique_fd;
39 using TlsError = tls::TlsConnection::TlsError;
40
41 const uint8_t kCurrentKeyHeaderVersion = 1;
42 const uint8_t kMinSupportedKeyHeaderVersion = 1;
43 const uint8_t kMaxSupportedKeyHeaderVersion = 1;
44 const uint32_t kMaxPayloadSize = kMaxPeerInfoSize * 2;
45
46 struct PairingPacketHeader {
47 uint8_t version; // PairingPacket version
48 uint8_t type; // the type of packet (PairingPacket.Type)
49 uint32_t payload; // Size of the payload in bytes
50 } __attribute__((packed));
51
52 struct PairingAuthDeleter {
operator ()PairingAuthDeleter53 void operator()(PairingAuthCtx* p) { pairing_auth_destroy(p); }
54 }; // PairingAuthDeleter
55 using PairingAuthPtr = std::unique_ptr<PairingAuthCtx, PairingAuthDeleter>;
56
57 // PairingConnectionCtx encapsulates the protocol to authenticate two peers with
58 // each other. This class will open the tcp sockets and handle the pairing
59 // process. On completion, both sides will have each other's public key
60 // (certificate) if successful, otherwise, the pairing failed. The tcp port
61 // number is hardcoded (see pairing_connection.cpp).
62 //
63 // Each PairingConnectionCtx instance represents a different device trying to
64 // pair. So for the device, we can have multiple PairingConnectionCtxs while the
65 // host may have only one (unless host has a PairingServer).
66 //
67 // See pairing_connection_test.cpp for example usage.
68 //
69 struct PairingConnectionCtx {
70 public:
71 using Data = std::vector<uint8_t>;
72 using ResultCallback = pairing_result_cb;
73 enum class Role {
74 Client,
75 Server,
76 };
77
78 explicit PairingConnectionCtx(Role role, const Data& pswd, const PeerInfo& peer_info,
79 const Data& certificate, const Data& priv_key);
80 virtual ~PairingConnectionCtx();
81
82 // Starts the pairing connection on a separate thread.
83 // Upon completion, if the pairing was successful,
84 // |cb| will be called with the peer information and certificate.
85 // Otherwise, |cb| will be called with empty data. |fd| should already
86 // be opened. PairingConnectionCtx will take ownership of the |fd|.
87 //
88 // Pairing is successful if both server/client uses the same non-empty
89 // |pswd|, and they are able to exchange the information. |pswd| and
90 // |certificate| must be non-empty. Start() can only be called once in the
91 // lifetime of this object.
92 //
93 // Returns true if the thread was successfully started, false otherwise.
94 bool Start(int fd, ResultCallback cb, void* opaque);
95
96 private:
97 // Setup the tls connection.
98 bool SetupTlsConnection();
99
100 /************ PairingPacketHeader methods ****************/
101 // Tries to write out the header and payload.
102 bool WriteHeader(const PairingPacketHeader* header, std::string_view payload);
103 // Tries to parse incoming data into the |header|. Returns true if header
104 // is valid and header version is supported. |header| is filled on success.
105 // |header| may contain garbage if unsuccessful.
106 bool ReadHeader(PairingPacketHeader* header);
107 // Creates a PairingPacketHeader.
108 void CreateHeader(PairingPacketHeader* header, adb::proto::PairingPacket::Type type,
109 uint32_t payload_size);
110 // Checks if actual matches expected.
111 bool CheckHeaderType(adb::proto::PairingPacket::Type expected, uint8_t actual);
112
113 /*********** State related methods **************/
114 // Handles the State::ExchangingMsgs state.
115 bool DoExchangeMsgs();
116 // Handles the State::ExchangingPeerInfo state.
117 bool DoExchangePeerInfo();
118
119 // The background task to do the pairing.
120 void StartWorker();
121
122 // Calls |cb_| and sets the state to Stopped.
123 void NotifyResult(const PeerInfo* p);
124
125 static PairingAuthPtr CreatePairingAuthPtr(Role role, const Data& pswd);
126
127 enum class State {
128 Ready,
129 ExchangingMsgs,
130 ExchangingPeerInfo,
131 Stopped,
132 };
133
134 std::atomic<State> state_{State::Ready};
135 Role role_;
136 Data pswd_;
137 PeerInfo peer_info_;
138 Data cert_;
139 Data priv_key_;
140
141 // Peer's info
142 PeerInfo their_info_;
143
144 ResultCallback cb_;
145 void* opaque_ = nullptr;
146 std::unique_ptr<tls::TlsConnection> tls_;
147 PairingAuthPtr auth_;
148 unique_fd fd_;
149 std::thread thread_;
150 static constexpr size_t kExportedKeySize = 64;
151 }; // PairingConnectionCtx
152
PairingConnectionCtx(Role role,const Data & pswd,const PeerInfo & peer_info,const Data & cert,const Data & priv_key)153 PairingConnectionCtx::PairingConnectionCtx(Role role, const Data& pswd, const PeerInfo& peer_info,
154 const Data& cert, const Data& priv_key)
155 : role_(role), pswd_(pswd), peer_info_(peer_info), cert_(cert), priv_key_(priv_key) {
156 CHECK(!pswd_.empty() && !cert_.empty() && !priv_key_.empty());
157 }
158
~PairingConnectionCtx()159 PairingConnectionCtx::~PairingConnectionCtx() {
160 // Force close the fd and wait for the worker thread to finish.
161 fd_.reset();
162 if (thread_.joinable()) {
163 thread_.join();
164 }
165 }
166
SetupTlsConnection()167 bool PairingConnectionCtx::SetupTlsConnection() {
168 tls_ = tls::TlsConnection::Create(
169 role_ == Role::Server ? tls::TlsConnection::Role::Server
170 : tls::TlsConnection::Role::Client,
171 std::string_view(reinterpret_cast<const char*>(cert_.data()), cert_.size()),
172 std::string_view(reinterpret_cast<const char*>(priv_key_.data()), priv_key_.size()),
173 fd_);
174
175 if (tls_ == nullptr) {
176 LOG(ERROR) << "Unable to start TlsConnection. Unable to pair fd=" << fd_.get();
177 return false;
178 }
179
180 // Allow any peer certificate
181 tls_->SetCertVerifyCallback([](X509_STORE_CTX*) { return 1; });
182
183 // SSL doesn't seem to behave correctly with fdevents so just do a blocking
184 // read for the pairing data.
185 if (tls_->DoHandshake() != TlsError::Success) {
186 LOG(ERROR) << "Failed to handshake with the peer fd=" << fd_.get();
187 return false;
188 }
189
190 // To ensure the connection is not stolen while we do the PAKE, append the
191 // exported key material from the tls connection to the password.
192 std::vector<uint8_t> exportedKeyMaterial = tls_->ExportKeyingMaterial(kExportedKeySize);
193 if (exportedKeyMaterial.empty()) {
194 LOG(ERROR) << "Failed to export key material";
195 return false;
196 }
197 pswd_.insert(pswd_.end(), std::make_move_iterator(exportedKeyMaterial.begin()),
198 std::make_move_iterator(exportedKeyMaterial.end()));
199 auth_ = CreatePairingAuthPtr(role_, pswd_);
200
201 return true;
202 }
203
WriteHeader(const PairingPacketHeader * header,std::string_view payload)204 bool PairingConnectionCtx::WriteHeader(const PairingPacketHeader* header,
205 std::string_view payload) {
206 PairingPacketHeader network_header = *header;
207 network_header.payload = htonl(network_header.payload);
208 if (!tls_->WriteFully(std::string_view(reinterpret_cast<const char*>(&network_header),
209 sizeof(PairingPacketHeader))) ||
210 !tls_->WriteFully(payload)) {
211 LOG(ERROR) << "Failed to write out PairingPacketHeader";
212 state_ = State::Stopped;
213 return false;
214 }
215 return true;
216 }
217
ReadHeader(PairingPacketHeader * header)218 bool PairingConnectionCtx::ReadHeader(PairingPacketHeader* header) {
219 auto data = tls_->ReadFully(sizeof(PairingPacketHeader));
220 if (data.empty()) {
221 return false;
222 }
223
224 uint8_t* p = data.data();
225 // First byte is always PairingPacketHeader version
226 header->version = *p;
227 ++p;
228 if (header->version < kMinSupportedKeyHeaderVersion ||
229 header->version > kMaxSupportedKeyHeaderVersion) {
230 LOG(ERROR) << "PairingPacketHeader version mismatch (us=" << kCurrentKeyHeaderVersion
231 << " them=" << header->version << ")";
232 return false;
233 }
234 // Next byte is the PairingPacket::Type
235 if (!adb::proto::PairingPacket::Type_IsValid(*p)) {
236 LOG(ERROR) << "Unknown PairingPacket type=" << static_cast<uint32_t>(*p);
237 return false;
238 }
239 header->type = *p;
240 ++p;
241 // Last, the payload size
242 header->payload = ntohl(*(reinterpret_cast<uint32_t*>(p)));
243 if (header->payload == 0 || header->payload > kMaxPayloadSize) {
244 LOG(ERROR) << "header payload not within a safe payload size (size=" << header->payload
245 << ")";
246 return false;
247 }
248
249 return true;
250 }
251
CreateHeader(PairingPacketHeader * header,adb::proto::PairingPacket::Type type,uint32_t payload_size)252 void PairingConnectionCtx::CreateHeader(PairingPacketHeader* header,
253 adb::proto::PairingPacket::Type type,
254 uint32_t payload_size) {
255 header->version = kCurrentKeyHeaderVersion;
256 uint8_t type8 = static_cast<uint8_t>(static_cast<int>(type));
257 header->type = type8;
258 header->payload = payload_size;
259 }
260
CheckHeaderType(adb::proto::PairingPacket::Type expected_type,uint8_t actual)261 bool PairingConnectionCtx::CheckHeaderType(adb::proto::PairingPacket::Type expected_type,
262 uint8_t actual) {
263 uint8_t expected = *reinterpret_cast<uint8_t*>(&expected_type);
264 if (actual != expected) {
265 LOG(ERROR) << "Unexpected header type (expected=" << static_cast<uint32_t>(expected)
266 << " actual=" << static_cast<uint32_t>(actual) << ")";
267 return false;
268 }
269 return true;
270 }
271
NotifyResult(const PeerInfo * p)272 void PairingConnectionCtx::NotifyResult(const PeerInfo* p) {
273 cb_(p, fd_.get(), opaque_);
274 state_ = State::Stopped;
275 }
276
Start(int fd,ResultCallback cb,void * opaque)277 bool PairingConnectionCtx::Start(int fd, ResultCallback cb, void* opaque) {
278 if (fd < 0) {
279 return false;
280 }
281 fd_.reset(fd);
282
283 State expected = State::Ready;
284 if (!state_.compare_exchange_strong(expected, State::ExchangingMsgs)) {
285 return false;
286 }
287
288 cb_ = cb;
289 opaque_ = opaque;
290
291 thread_ = std::thread([this] { StartWorker(); });
292 return true;
293 }
294
DoExchangeMsgs()295 bool PairingConnectionCtx::DoExchangeMsgs() {
296 uint32_t payload = pairing_auth_msg_size(auth_.get());
297 std::vector<uint8_t> msg(payload);
298 pairing_auth_get_spake2_msg(auth_.get(), msg.data());
299
300 PairingPacketHeader header;
301 CreateHeader(&header, adb::proto::PairingPacket::SPAKE2_MSG, payload);
302
303 // Write our SPAKE2 msg
304 if (!WriteHeader(&header,
305 std::string_view(reinterpret_cast<const char*>(msg.data()), msg.size()))) {
306 LOG(ERROR) << "Failed to write SPAKE2 msg.";
307 return false;
308 }
309
310 // Read the peer's SPAKE2 msg header
311 if (!ReadHeader(&header)) {
312 LOG(ERROR) << "Invalid PairingPacketHeader.";
313 return false;
314 }
315 if (!CheckHeaderType(adb::proto::PairingPacket::SPAKE2_MSG, header.type)) {
316 return false;
317 }
318
319 // Read the SPAKE2 msg payload and initialize the cipher for
320 // encrypting the PeerInfo and certificate.
321 auto their_msg = tls_->ReadFully(header.payload);
322 if (their_msg.empty() ||
323 !pairing_auth_init_cipher(auth_.get(), their_msg.data(), their_msg.size())) {
324 LOG(ERROR) << "Unable to initialize pairing cipher [their_msg.size=" << their_msg.size()
325 << "]";
326 return false;
327 }
328
329 return true;
330 }
331
DoExchangePeerInfo()332 bool PairingConnectionCtx::DoExchangePeerInfo() {
333 // Encrypt PeerInfo
334 std::vector<uint8_t> buf;
335 uint8_t* p = reinterpret_cast<uint8_t*>(&peer_info_);
336 buf.assign(p, p + sizeof(peer_info_));
337 std::vector<uint8_t> outbuf(pairing_auth_safe_encrypted_size(auth_.get(), buf.size()));
338 CHECK(!outbuf.empty());
339 size_t outsize;
340 if (!pairing_auth_encrypt(auth_.get(), buf.data(), buf.size(), outbuf.data(), &outsize)) {
341 LOG(ERROR) << "Failed to encrypt peer info";
342 return false;
343 }
344 outbuf.resize(outsize);
345
346 // Write out the packet header
347 PairingPacketHeader out_header;
348 out_header.version = kCurrentKeyHeaderVersion;
349 out_header.type = static_cast<uint8_t>(static_cast<int>(adb::proto::PairingPacket::PEER_INFO));
350 out_header.payload = htonl(outbuf.size());
351 if (!tls_->WriteFully(
352 std::string_view(reinterpret_cast<const char*>(&out_header), sizeof(out_header)))) {
353 LOG(ERROR) << "Unable to write PairingPacketHeader";
354 return false;
355 }
356
357 // Write out the encrypted payload
358 if (!tls_->WriteFully(
359 std::string_view(reinterpret_cast<const char*>(outbuf.data()), outbuf.size()))) {
360 LOG(ERROR) << "Unable to write encrypted peer info";
361 return false;
362 }
363
364 // Read in the peer's packet header
365 PairingPacketHeader header;
366 if (!ReadHeader(&header)) {
367 LOG(ERROR) << "Invalid PairingPacketHeader.";
368 return false;
369 }
370
371 if (!CheckHeaderType(adb::proto::PairingPacket::PEER_INFO, header.type)) {
372 return false;
373 }
374
375 // Read in the encrypted peer certificate
376 buf = tls_->ReadFully(header.payload);
377 if (buf.empty()) {
378 return false;
379 }
380
381 // Try to decrypt the certificate
382 outbuf.resize(pairing_auth_safe_decrypted_size(auth_.get(), buf.data(), buf.size()));
383 if (outbuf.empty()) {
384 LOG(ERROR) << "Unsupported payload while decrypting peer info.";
385 return false;
386 }
387
388 if (!pairing_auth_decrypt(auth_.get(), buf.data(), buf.size(), outbuf.data(), &outsize)) {
389 LOG(ERROR) << "Failed to decrypt";
390 return false;
391 }
392 outbuf.resize(outsize);
393
394 // The decrypted message should contain the PeerInfo.
395 if (outbuf.size() != sizeof(PeerInfo)) {
396 LOG(ERROR) << "Got size=" << outbuf.size() << "PeerInfo.size=" << sizeof(PeerInfo);
397 return false;
398 }
399
400 p = outbuf.data();
401 ::memcpy(&their_info_, p, sizeof(PeerInfo));
402 p += sizeof(PeerInfo);
403
404 return true;
405 }
406
StartWorker()407 void PairingConnectionCtx::StartWorker() {
408 // Setup the secure transport
409 if (!SetupTlsConnection()) {
410 NotifyResult(nullptr);
411 return;
412 }
413
414 for (;;) {
415 switch (state_) {
416 case State::ExchangingMsgs:
417 if (!DoExchangeMsgs()) {
418 NotifyResult(nullptr);
419 return;
420 }
421 state_ = State::ExchangingPeerInfo;
422 break;
423 case State::ExchangingPeerInfo:
424 if (!DoExchangePeerInfo()) {
425 NotifyResult(nullptr);
426 return;
427 }
428 NotifyResult(&their_info_);
429 return;
430 case State::Ready:
431 case State::Stopped:
432 LOG(FATAL) << __func__ << ": Got invalid state";
433 return;
434 }
435 }
436 }
437
438 // static
CreatePairingAuthPtr(Role role,const Data & pswd)439 PairingAuthPtr PairingConnectionCtx::CreatePairingAuthPtr(Role role, const Data& pswd) {
440 switch (role) {
441 case Role::Client:
442 return PairingAuthPtr(pairing_auth_client_new(pswd.data(), pswd.size()));
443 break;
444 case Role::Server:
445 return PairingAuthPtr(pairing_auth_server_new(pswd.data(), pswd.size()));
446 break;
447 }
448 }
449
CreateConnection(PairingConnectionCtx::Role role,const uint8_t * pswd,size_t pswd_len,const PeerInfo * peer_info,const uint8_t * x509_cert_pem,size_t x509_size,const uint8_t * priv_key_pem,size_t priv_size)450 static PairingConnectionCtx* CreateConnection(PairingConnectionCtx::Role role, const uint8_t* pswd,
451 size_t pswd_len, const PeerInfo* peer_info,
452 const uint8_t* x509_cert_pem, size_t x509_size,
453 const uint8_t* priv_key_pem, size_t priv_size) {
454 CHECK(pswd);
455 CHECK_GT(pswd_len, 0U);
456 CHECK(x509_cert_pem);
457 CHECK_GT(x509_size, 0U);
458 CHECK(priv_key_pem);
459 CHECK_GT(priv_size, 0U);
460 CHECK(peer_info);
461 std::vector<uint8_t> vec_pswd(pswd, pswd + pswd_len);
462 std::vector<uint8_t> vec_x509_cert(x509_cert_pem, x509_cert_pem + x509_size);
463 std::vector<uint8_t> vec_priv_key(priv_key_pem, priv_key_pem + priv_size);
464 return new PairingConnectionCtx(role, vec_pswd, *peer_info, vec_x509_cert, vec_priv_key);
465 }
466
pairing_connection_client_new(const uint8_t * pswd,size_t pswd_len,const PeerInfo * peer_info,const uint8_t * x509_cert_pem,size_t x509_size,const uint8_t * priv_key_pem,size_t priv_size)467 PairingConnectionCtx* pairing_connection_client_new(const uint8_t* pswd, size_t pswd_len,
468 const PeerInfo* peer_info,
469 const uint8_t* x509_cert_pem, size_t x509_size,
470 const uint8_t* priv_key_pem, size_t priv_size) {
471 return CreateConnection(PairingConnectionCtx::Role::Client, pswd, pswd_len, peer_info,
472 x509_cert_pem, x509_size, priv_key_pem, priv_size);
473 }
474
pairing_connection_server_new(const uint8_t * pswd,size_t pswd_len,const PeerInfo * peer_info,const uint8_t * x509_cert_pem,size_t x509_size,const uint8_t * priv_key_pem,size_t priv_size)475 PairingConnectionCtx* pairing_connection_server_new(const uint8_t* pswd, size_t pswd_len,
476 const PeerInfo* peer_info,
477 const uint8_t* x509_cert_pem, size_t x509_size,
478 const uint8_t* priv_key_pem, size_t priv_size) {
479 return CreateConnection(PairingConnectionCtx::Role::Server, pswd, pswd_len, peer_info,
480 x509_cert_pem, x509_size, priv_key_pem, priv_size);
481 }
482
pairing_connection_destroy(PairingConnectionCtx * ctx)483 void pairing_connection_destroy(PairingConnectionCtx* ctx) {
484 CHECK(ctx);
485 delete ctx;
486 }
487
pairing_connection_start(PairingConnectionCtx * ctx,int fd,pairing_result_cb cb,void * opaque)488 bool pairing_connection_start(PairingConnectionCtx* ctx, int fd, pairing_result_cb cb,
489 void* opaque) {
490 return ctx->Start(fd, cb, opaque);
491 }
492