1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except
6  * in compliance with the License. You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "adb/tls/tls_connection.h"
18 
19 #include <algorithm>
20 #include <vector>
21 
22 #include <android-base/logging.h>
23 #include <android-base/strings.h>
24 #include <openssl/err.h>
25 #include <openssl/ssl.h>
26 
27 using android::base::borrowed_fd;
28 
29 namespace adb {
30 namespace tls {
31 
32 namespace {
33 
34 static constexpr char kExportedKeyLabel[] = "adb-label";
35 
36 class TlsConnectionImpl : public TlsConnection {
37   public:
38     explicit TlsConnectionImpl(Role role, std::string_view cert, std::string_view priv_key,
39                                borrowed_fd fd);
40     ~TlsConnectionImpl() override;
41 
42     bool AddTrustedCertificate(std::string_view cert) override;
43     void SetCertVerifyCallback(CertVerifyCb cb) override;
44     void SetCertificateCallback(SetCertCb cb) override;
45     void SetClientCAList(STACK_OF(X509_NAME) * ca_list) override;
46     std::vector<uint8_t> ExportKeyingMaterial(size_t length) override;
47     void EnableClientPostHandshakeCheck(bool enable) override;
48     TlsError DoHandshake() override;
49     std::vector<uint8_t> ReadFully(size_t size) override;
50     bool ReadFully(void* buf, size_t size) override;
51     bool WriteFully(std::string_view data) override;
52 
53     static bssl::UniquePtr<EVP_PKEY> EvpPkeyFromPEM(std::string_view pem);
54     static bssl::UniquePtr<CRYPTO_BUFFER> BufferFromPEM(std::string_view pem);
55 
56   private:
57     static int SSLSetCertVerifyCb(X509_STORE_CTX* ctx, void* opaque);
58     static int SSLSetCertCb(SSL* ssl, void* opaque);
59 
60     static bssl::UniquePtr<X509> X509FromBuffer(bssl::UniquePtr<CRYPTO_BUFFER> buffer);
61     static const char* SSLErrorString();
62     void Invalidate();
63     TlsError GetFailureReason(int err);
RoleToString()64     const char* RoleToString() { return role_ == Role::Server ? kServerRoleStr : kClientRoleStr; }
65 
66     Role role_;
67     bssl::UniquePtr<EVP_PKEY> priv_key_;
68     bssl::UniquePtr<CRYPTO_BUFFER> cert_;
69 
70     bssl::UniquePtr<STACK_OF(X509_NAME)> ca_list_;
71     bssl::UniquePtr<SSL_CTX> ssl_ctx_;
72     bssl::UniquePtr<SSL> ssl_;
73     std::vector<bssl::UniquePtr<X509>> known_certificates_;
74     bool client_verify_post_handshake_ = false;
75 
76     CertVerifyCb cert_verify_cb_;
77     SetCertCb set_cert_cb_;
78     borrowed_fd fd_;
79     static constexpr char kClientRoleStr[] = "[client]: ";
80     static constexpr char kServerRoleStr[] = "[server]: ";
81 };  // TlsConnectionImpl
82 
TlsConnectionImpl(Role role,std::string_view cert,std::string_view priv_key,borrowed_fd fd)83 TlsConnectionImpl::TlsConnectionImpl(Role role, std::string_view cert, std::string_view priv_key,
84                                      borrowed_fd fd)
85     : role_(role), fd_(fd) {
86     CHECK(!cert.empty() && !priv_key.empty());
87     LOG(INFO) << RoleToString() << "Initializing adbwifi TlsConnection";
88     cert_ = BufferFromPEM(cert);
89     CHECK(cert_);
90     priv_key_ = EvpPkeyFromPEM(priv_key);
91     CHECK(priv_key_);
92 }
93 
~TlsConnectionImpl()94 TlsConnectionImpl::~TlsConnectionImpl() {
95     // shutdown the SSL connection
96     if (ssl_ != nullptr) {
97         SSL_shutdown(ssl_.get());
98     }
99 }
100 
101 // static
SSLErrorString()102 const char* TlsConnectionImpl::SSLErrorString() {
103     auto sslerr = ERR_peek_last_error();
104     return ERR_reason_error_string(sslerr);
105 }
106 
107 // static
EvpPkeyFromPEM(std::string_view pem)108 bssl::UniquePtr<EVP_PKEY> TlsConnectionImpl::EvpPkeyFromPEM(std::string_view pem) {
109     bssl::UniquePtr<BIO> bio(BIO_new_mem_buf(pem.data(), pem.size()));
110     return bssl::UniquePtr<EVP_PKEY>(PEM_read_bio_PrivateKey(bio.get(), nullptr, nullptr, nullptr));
111 }
112 
113 // static
BufferFromPEM(std::string_view pem)114 bssl::UniquePtr<CRYPTO_BUFFER> TlsConnectionImpl::BufferFromPEM(std::string_view pem) {
115     bssl::UniquePtr<BIO> bio(BIO_new_mem_buf(pem.data(), pem.size()));
116     char* name = nullptr;
117     char* header = nullptr;
118     uint8_t* data = nullptr;
119     long data_len = 0;
120 
121     if (!PEM_read_bio(bio.get(), &name, &header, &data, &data_len)) {
122         LOG(ERROR) << "Failed to read certificate";
123         return nullptr;
124     }
125     OPENSSL_free(name);
126     OPENSSL_free(header);
127 
128     auto ret = bssl::UniquePtr<CRYPTO_BUFFER>(CRYPTO_BUFFER_new(data, data_len, nullptr));
129     OPENSSL_free(data);
130     return ret;
131 }
132 
133 // static
X509FromBuffer(bssl::UniquePtr<CRYPTO_BUFFER> buffer)134 bssl::UniquePtr<X509> TlsConnectionImpl::X509FromBuffer(bssl::UniquePtr<CRYPTO_BUFFER> buffer) {
135     if (!buffer) {
136         return nullptr;
137     }
138     return bssl::UniquePtr<X509>(X509_parse_from_buffer(buffer.get()));
139 }
140 
141 // static
SSLSetCertVerifyCb(X509_STORE_CTX * ctx,void * opaque)142 int TlsConnectionImpl::SSLSetCertVerifyCb(X509_STORE_CTX* ctx, void* opaque) {
143     auto* p = reinterpret_cast<TlsConnectionImpl*>(opaque);
144     return p->cert_verify_cb_(ctx);
145 }
146 
147 // static
SSLSetCertCb(SSL * ssl,void * opaque)148 int TlsConnectionImpl::SSLSetCertCb(SSL* ssl, void* opaque) {
149     auto* p = reinterpret_cast<TlsConnectionImpl*>(opaque);
150     return p->set_cert_cb_(ssl);
151 }
152 
AddTrustedCertificate(std::string_view cert)153 bool TlsConnectionImpl::AddTrustedCertificate(std::string_view cert) {
154     // Create X509 buffer from the certificate string
155     auto buf = X509FromBuffer(BufferFromPEM(cert));
156     if (buf == nullptr) {
157         LOG(ERROR) << RoleToString() << "Failed to create a X509 buffer for the certificate.";
158         return false;
159     }
160     known_certificates_.push_back(std::move(buf));
161     return true;
162 }
163 
SetCertVerifyCallback(CertVerifyCb cb)164 void TlsConnectionImpl::SetCertVerifyCallback(CertVerifyCb cb) {
165     cert_verify_cb_ = cb;
166 }
167 
SetCertificateCallback(SetCertCb cb)168 void TlsConnectionImpl::SetCertificateCallback(SetCertCb cb) {
169     set_cert_cb_ = cb;
170 }
171 
SetClientCAList(STACK_OF (X509_NAME)* ca_list)172 void TlsConnectionImpl::SetClientCAList(STACK_OF(X509_NAME) * ca_list) {
173     CHECK(role_ == Role::Server);
174     ca_list_.reset(ca_list != nullptr ? SSL_dup_CA_list(ca_list) : nullptr);
175 }
176 
ExportKeyingMaterial(size_t length)177 std::vector<uint8_t> TlsConnectionImpl::ExportKeyingMaterial(size_t length) {
178     if (ssl_.get() == nullptr) {
179         return {};
180     }
181 
182     std::vector<uint8_t> out(length);
183     if (SSL_export_keying_material(ssl_.get(), out.data(), out.size(), kExportedKeyLabel,
184                                    sizeof(kExportedKeyLabel), nullptr, 0, false) == 0) {
185         return {};
186     }
187     return out;
188 }
189 
EnableClientPostHandshakeCheck(bool enable)190 void TlsConnectionImpl::EnableClientPostHandshakeCheck(bool enable) {
191     client_verify_post_handshake_ = enable;
192 }
193 
GetFailureReason(int err)194 TlsConnection::TlsError TlsConnectionImpl::GetFailureReason(int err) {
195     switch (ERR_GET_REASON(err)) {
196         case SSL_R_SSLV3_ALERT_BAD_CERTIFICATE:
197         case SSL_R_SSLV3_ALERT_UNSUPPORTED_CERTIFICATE:
198         case SSL_R_SSLV3_ALERT_CERTIFICATE_REVOKED:
199         case SSL_R_SSLV3_ALERT_CERTIFICATE_EXPIRED:
200         case SSL_R_SSLV3_ALERT_CERTIFICATE_UNKNOWN:
201         case SSL_R_TLSV1_ALERT_ACCESS_DENIED:
202         case SSL_R_TLSV1_ALERT_UNKNOWN_CA:
203         case SSL_R_TLSV1_CERTIFICATE_REQUIRED:
204             return TlsError::PeerRejectedCertificate;
205         case SSL_R_CERTIFICATE_VERIFY_FAILED:
206             return TlsError::CertificateRejected;
207         default:
208             return TlsError::UnknownFailure;
209     }
210 }
211 
DoHandshake()212 TlsConnection::TlsError TlsConnectionImpl::DoHandshake() {
213     LOG(INFO) << RoleToString() << "Starting adbwifi tls handshake";
214     ssl_ctx_.reset(SSL_CTX_new(TLS_method()));
215     // TODO: Remove set_max_proto_version() once external/boringssl is updated
216     // past
217     // https://boringssl.googlesource.com/boringssl/+/58d56f4c59969a23e5f52014e2651c76fea2f877
218     if (ssl_ctx_.get() == nullptr ||
219         !SSL_CTX_set_min_proto_version(ssl_ctx_.get(), TLS1_3_VERSION) ||
220         !SSL_CTX_set_max_proto_version(ssl_ctx_.get(), TLS1_3_VERSION)) {
221         LOG(ERROR) << RoleToString() << "Failed to create SSL context";
222         return TlsError::UnknownFailure;
223     }
224 
225     // Register user-supplied known certificates
226     for (auto const& cert : known_certificates_) {
227         if (X509_STORE_add_cert(SSL_CTX_get_cert_store(ssl_ctx_.get()), cert.get()) == 0) {
228             LOG(ERROR) << RoleToString() << "Unable to add certificates into the X509_STORE";
229             return TlsError::UnknownFailure;
230         }
231     }
232 
233     // Custom certificate verification
234     if (cert_verify_cb_) {
235         SSL_CTX_set_cert_verify_callback(ssl_ctx_.get(), SSLSetCertVerifyCb, this);
236     }
237 
238     // set select certificate callback, if any.
239     if (set_cert_cb_) {
240         SSL_CTX_set_cert_cb(ssl_ctx_.get(), SSLSetCertCb, this);
241     }
242 
243     // Server-allowed client CA list
244     if (ca_list_ != nullptr) {
245         bssl::UniquePtr<STACK_OF(X509_NAME)> names(SSL_dup_CA_list(ca_list_.get()));
246         SSL_CTX_set_client_CA_list(ssl_ctx_.get(), names.release());
247     }
248 
249     // Register our certificate and private key.
250     std::vector<CRYPTO_BUFFER*> cert_chain = {
251             cert_.get(),
252     };
253     if (!SSL_CTX_set_chain_and_key(ssl_ctx_.get(), cert_chain.data(), cert_chain.size(),
254                                    priv_key_.get(), nullptr)) {
255         LOG(ERROR) << RoleToString()
256                    << "Unable to register the certificate chain file and private key ["
257                    << SSLErrorString() << "]";
258         Invalidate();
259         return TlsError::UnknownFailure;
260     }
261 
262     SSL_CTX_set_verify(ssl_ctx_.get(), SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, nullptr);
263 
264     // Okay! Let's try to do the handshake!
265     ssl_.reset(SSL_new(ssl_ctx_.get()));
266     if (!SSL_set_fd(ssl_.get(), fd_.get())) {
267         LOG(ERROR) << RoleToString() << "SSL_set_fd failed. [" << SSLErrorString() << "]";
268         return TlsError::UnknownFailure;
269     }
270 
271     switch (role_) {
272         case Role::Server:
273             SSL_set_accept_state(ssl_.get());
274             break;
275         case Role::Client:
276             SSL_set_connect_state(ssl_.get());
277             break;
278     }
279     if (SSL_do_handshake(ssl_.get()) != 1) {
280         LOG(ERROR) << RoleToString() << "Handshake failed in SSL_accept/SSL_connect ["
281                    << SSLErrorString() << "]";
282         auto sslerr = ERR_get_error();
283         Invalidate();
284         return GetFailureReason(sslerr);
285     }
286 
287     if (client_verify_post_handshake_ && role_ == Role::Client) {
288         uint8_t check;
289         // Try to peek one byte for any failures. This assumes on success that
290         // the server actually sends something.
291         if (SSL_peek(ssl_.get(), &check, 1) <= 0) {
292             LOG(ERROR) << RoleToString() << "Post-handshake SSL_peek failed [" << SSLErrorString()
293                        << "]";
294             auto sslerr = ERR_get_error();
295             Invalidate();
296             return GetFailureReason(sslerr);
297         }
298     }
299 
300     LOG(INFO) << RoleToString() << "Handshake succeeded.";
301     return TlsError::Success;
302 }
303 
Invalidate()304 void TlsConnectionImpl::Invalidate() {
305     ssl_.reset();
306     ssl_ctx_.reset();
307 }
308 
ReadFully(size_t size)309 std::vector<uint8_t> TlsConnectionImpl::ReadFully(size_t size) {
310     std::vector<uint8_t> buf(size);
311     if (!ReadFully(buf.data(), buf.size())) {
312         return {};
313     }
314 
315     return buf;
316 }
317 
ReadFully(void * buf,size_t size)318 bool TlsConnectionImpl::ReadFully(void* buf, size_t size) {
319     CHECK_GT(size, 0U);
320     if (!ssl_) {
321         LOG(ERROR) << RoleToString() << "Tried to read on a null SSL connection";
322         return false;
323     }
324 
325     size_t offset = 0;
326     uint8_t* p8 = reinterpret_cast<uint8_t*>(buf);
327     while (size > 0) {
328         int bytes_read =
329                 SSL_read(ssl_.get(), p8 + offset, std::min(static_cast<size_t>(INT_MAX), size));
330         if (bytes_read <= 0) {
331             LOG(ERROR) << RoleToString() << "SSL_read failed [" << SSLErrorString() << "]";
332             return false;
333         }
334         size -= bytes_read;
335         offset += bytes_read;
336     }
337     return true;
338 }
339 
WriteFully(std::string_view data)340 bool TlsConnectionImpl::WriteFully(std::string_view data) {
341     CHECK(!data.empty());
342     if (!ssl_) {
343         LOG(ERROR) << RoleToString() << "Tried to read on a null SSL connection";
344         return false;
345     }
346 
347     while (!data.empty()) {
348         int bytes_out = SSL_write(ssl_.get(), data.data(),
349                                   std::min(static_cast<size_t>(INT_MAX), data.size()));
350         if (bytes_out <= 0) {
351             LOG(ERROR) << RoleToString() << "SSL_write failed [" << SSLErrorString() << "]";
352             return false;
353         }
354         data = data.substr(bytes_out);
355     }
356     return true;
357 }
358 }  // namespace
359 
360 // static
Create(TlsConnection::Role role,std::string_view cert,std::string_view priv_key,borrowed_fd fd)361 std::unique_ptr<TlsConnection> TlsConnection::Create(TlsConnection::Role role,
362                                                      std::string_view cert,
363                                                      std::string_view priv_key, borrowed_fd fd) {
364     CHECK(!cert.empty());
365     CHECK(!priv_key.empty());
366 
367     return std::make_unique<TlsConnectionImpl>(role, cert, priv_key, fd);
368 }
369 
370 // static
SetCertAndKey(SSL * ssl,std::string_view cert,std::string_view priv_key)371 bool TlsConnection::SetCertAndKey(SSL* ssl, std::string_view cert, std::string_view priv_key) {
372     CHECK(ssl);
373     // Note: declaring these in local scope is okay because
374     // SSL_set_chain_and_key will increase the refcount (bssl::UpRef).
375     auto x509_cert = TlsConnectionImpl::BufferFromPEM(cert);
376     auto evp_pkey = TlsConnectionImpl::EvpPkeyFromPEM(priv_key);
377     if (x509_cert == nullptr || evp_pkey == nullptr) {
378         return false;
379     }
380 
381     std::vector<CRYPTO_BUFFER*> cert_chain = {
382             x509_cert.get(),
383     };
384     if (!SSL_set_chain_and_key(ssl, cert_chain.data(), cert_chain.size(), evp_pkey.get(),
385                                nullptr)) {
386         LOG(ERROR) << "SSL_set_chain_and_key failed";
387         return false;
388     }
389 
390     return true;
391 }
392 
393 }  // namespace tls
394 }  // namespace adb
395