1 //
2 // Copyright (C) 2020 The Android Open Source Project
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //      http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 //
15 
16 #include "host/frontend/webrtc/lib/ws_connection.h"
17 
18 #include <android-base/logging.h>
19 #include <libwebsockets.h>
20 
21 class WsConnectionContextImpl;
22 
23 class WsConnectionImpl : public WsConnection,
24                          public std::enable_shared_from_this<WsConnectionImpl> {
25  public:
26   struct CreateConnectionSul {
27     lws_sorted_usec_list_t sul = {};
28     std::weak_ptr<WsConnectionImpl> weak_this;
29   };
30 
31   WsConnectionImpl(int port, const std::string& addr, const std::string& path,
32                    Security secure,
33                    std::weak_ptr<WsConnectionObserver> observer,
34                    std::shared_ptr<WsConnectionContextImpl> context);
35 
36   ~WsConnectionImpl() override;
37 
38   void Connect() override;
39   void ConnectInner();
40 
41   bool Send(const uint8_t* data, size_t len, bool binary = false) override;
42 
43   void OnError(const std::string& error);
44   void OnReceive(const uint8_t* data, size_t len, bool is_binary);
45   void OnOpen();
46   void OnClose();
47   void OnWriteable();
48 
49  private:
50   struct WsBuffer {
51     WsBuffer() = default;
WsBufferWsConnectionImpl::WsBuffer52     WsBuffer(const uint8_t* data, size_t len, bool binary)
53         : buffer_(LWS_PRE + len), is_binary_(binary) {
54       memcpy(&buffer_[LWS_PRE], data, len);
55     }
56 
dataWsConnectionImpl::WsBuffer57     uint8_t* data() { return &buffer_[LWS_PRE]; }
is_binaryWsConnectionImpl::WsBuffer58     bool is_binary() const { return is_binary_; }
sizeWsConnectionImpl::WsBuffer59     size_t size() const { return buffer_.size() - LWS_PRE; }
60 
61    private:
62     std::vector<uint8_t> buffer_;
63     bool is_binary_;
64   };
65 
66   CreateConnectionSul extended_sul_;
67   struct lws* wsi_;
68   const int port_;
69   const std::string addr_;
70   const std::string path_;
71   const Security security_;
72 
73   std::weak_ptr<WsConnectionObserver> observer_;
74 
75   // each element contains the data to be sent and whether it's binary or not
76   std::deque<WsBuffer> write_queue_;
77   std::mutex write_queue_mutex_;
78   // The connection object should not outlive the context object. This reference
79   // guarantees it.
80   std::shared_ptr<WsConnectionContextImpl> context_;
81 };
82 
83 class WsConnectionContextImpl
84     : public WsConnectionContext,
85       public std::enable_shared_from_this<WsConnectionContextImpl> {
86  public:
87   WsConnectionContextImpl(struct lws_context* lws_ctx);
88   ~WsConnectionContextImpl() override;
89 
90   std::shared_ptr<WsConnection> CreateConnection(
91       int port, const std::string& addr, const std::string& path,
92       WsConnection::Security secure,
93       std::weak_ptr<WsConnectionObserver> observer) override;
94 
95   void RememberConnection(void*, std::weak_ptr<WsConnectionImpl>);
96   void ForgetConnection(void*);
97   std::shared_ptr<WsConnectionImpl> GetConnection(void*);
98 
lws_context()99   struct lws_context* lws_context() {
100     return lws_context_;
101   }
102 
103  private:
104   void Start();
105 
106   std::map<void*, std::weak_ptr<WsConnectionImpl>> weak_by_ptr_;
107   std::mutex map_mutex_;
108   struct lws_context* lws_context_;
109   std::thread message_loop_;
110 };
111 
112 int LwsCallback(struct lws* wsi, enum lws_callback_reasons reason, void* user,
113                 void* in, size_t len);
114 void CreateConnectionCallback(lws_sorted_usec_list_t* sul);
115 
116 namespace {
117 
118 constexpr char kProtocolName[] = "lws-websocket-protocol";
119 constexpr int kBufferSize = 65536;
120 
121 const uint32_t backoff_ms[] = {1000, 2000, 3000, 4000, 5000};
122 
123 const lws_retry_bo_t kRetry = {
124     .retry_ms_table = backoff_ms,
125     .retry_ms_table_count = LWS_ARRAY_SIZE(backoff_ms),
126     .conceal_count = LWS_ARRAY_SIZE(backoff_ms),
127 
128     .secs_since_valid_ping = 3,    /* force PINGs after secs idle */
129     .secs_since_valid_hangup = 10, /* hangup after secs idle */
130 
131     .jitter_percent = 20,
132 };
133 
134 const struct lws_protocols kProtocols[2] = {
135     {kProtocolName, LwsCallback, 0, kBufferSize, 0, NULL, 0},
136     {NULL, NULL, 0, 0, 0, NULL, 0}};
137 
138 }  // namespace
139 
Create()140 std::shared_ptr<WsConnectionContext> WsConnectionContext::Create() {
141   struct lws_context_creation_info context_info = {};
142   context_info.port = CONTEXT_PORT_NO_LISTEN;
143   context_info.options = LWS_SERVER_OPTION_DO_SSL_GLOBAL_INIT;
144   context_info.protocols = kProtocols;
145   struct lws_context* lws_ctx = lws_create_context(&context_info);
146   if (!lws_ctx) {
147     return nullptr;
148   }
149   return std::shared_ptr<WsConnectionContext>(
150       new WsConnectionContextImpl(lws_ctx));
151 }
152 
WsConnectionContextImpl(struct lws_context * lws_ctx)153 WsConnectionContextImpl::WsConnectionContextImpl(struct lws_context* lws_ctx)
154     : lws_context_(lws_ctx) {
155   Start();
156 }
157 
~WsConnectionContextImpl()158 WsConnectionContextImpl::~WsConnectionContextImpl() {
159   lws_context_destroy(lws_context_);
160   if (message_loop_.joinable()) {
161     message_loop_.join();
162   }
163 }
164 
Start()165 void WsConnectionContextImpl::Start() {
166   message_loop_ = std::thread([this]() {
167     for (;;) {
168       if (lws_service(lws_context_, 0) < 0) {
169         break;
170       }
171     }
172   });
173 }
174 
CreateConnection(int port,const std::string & addr,const std::string & path,WsConnection::Security security,std::weak_ptr<WsConnectionObserver> observer)175 std::shared_ptr<WsConnection> WsConnectionContextImpl::CreateConnection(
176     int port, const std::string& addr, const std::string& path,
177     WsConnection::Security security,
178     std::weak_ptr<WsConnectionObserver> observer) {
179   return std::shared_ptr<WsConnection>(new WsConnectionImpl(
180       port, addr, path, security, observer, shared_from_this()));
181 }
182 
GetConnection(void * raw)183 std::shared_ptr<WsConnectionImpl> WsConnectionContextImpl::GetConnection(
184     void* raw) {
185   std::shared_ptr<WsConnectionImpl> connection;
186   {
187     std::lock_guard<std::mutex> lock(map_mutex_);
188     if (weak_by_ptr_.count(raw) == 0) {
189       return nullptr;
190     }
191     connection = weak_by_ptr_[raw].lock();
192     if (!connection) {
193       weak_by_ptr_.erase(raw);
194     }
195   }
196   return connection;
197 }
198 
RememberConnection(void * raw,std::weak_ptr<WsConnectionImpl> conn)199 void WsConnectionContextImpl::RememberConnection(
200     void* raw, std::weak_ptr<WsConnectionImpl> conn) {
201   std::lock_guard<std::mutex> lock(map_mutex_);
202   weak_by_ptr_.emplace(
203       std::pair<void*, std::weak_ptr<WsConnectionImpl>>(raw, conn));
204 }
205 
ForgetConnection(void * raw)206 void WsConnectionContextImpl::ForgetConnection(void* raw) {
207   std::lock_guard<std::mutex> lock(map_mutex_);
208   weak_by_ptr_.erase(raw);
209 }
210 
WsConnectionImpl(int port,const std::string & addr,const std::string & path,Security security,std::weak_ptr<WsConnectionObserver> observer,std::shared_ptr<WsConnectionContextImpl> context)211 WsConnectionImpl::WsConnectionImpl(
212     int port, const std::string& addr, const std::string& path,
213     Security security, std::weak_ptr<WsConnectionObserver> observer,
214     std::shared_ptr<WsConnectionContextImpl> context)
215     : port_(port),
216       addr_(addr),
217       path_(path),
218       security_(security),
219       observer_(observer),
220       context_(context) {}
221 
~WsConnectionImpl()222 WsConnectionImpl::~WsConnectionImpl() {
223   context_->ForgetConnection(this);
224   // This will cause the callback to be called which will drop the connection
225   // after seeing the context doesn't remember this object
226   lws_callback_on_writable(wsi_);
227 }
228 
Connect()229 void WsConnectionImpl::Connect() {
230   memset(&extended_sul_.sul, 0, sizeof(extended_sul_.sul));
231   extended_sul_.weak_this = weak_from_this();
232   lws_sul_schedule(context_->lws_context(), 0, &extended_sul_.sul,
233                    CreateConnectionCallback, 1);
234 }
235 
OnError(const std::string & error)236 void WsConnectionImpl::OnError(const std::string& error) {
237   auto observer = observer_.lock();
238   if (observer) {
239     observer->OnError(error);
240   }
241 }
OnReceive(const uint8_t * data,size_t len,bool is_binary)242 void WsConnectionImpl::OnReceive(const uint8_t* data, size_t len,
243                                  bool is_binary) {
244   auto observer = observer_.lock();
245   if (observer) {
246     observer->OnReceive(data, len, is_binary);
247   }
248 }
OnOpen()249 void WsConnectionImpl::OnOpen() {
250   auto observer = observer_.lock();
251   if (observer) {
252     observer->OnOpen();
253   }
254 }
OnClose()255 void WsConnectionImpl::OnClose() {
256   auto observer = observer_.lock();
257   if (observer) {
258     observer->OnClose();
259   }
260 }
261 
OnWriteable()262 void WsConnectionImpl::OnWriteable() {
263   WsBuffer buffer;
264   {
265     std::lock_guard<std::mutex> lock(write_queue_mutex_);
266     if (write_queue_.size() == 0) {
267       return;
268     }
269     buffer = std::move(write_queue_.front());
270     write_queue_.pop_front();
271   }
272   auto flags = lws_write_ws_flags(
273       buffer.is_binary() ? LWS_WRITE_BINARY : LWS_WRITE_TEXT, true, true);
274   auto res = lws_write(wsi_, buffer.data(), buffer.size(),
275                        (enum lws_write_protocol)flags);
276   if (res != buffer.size()) {
277     LOG(WARNING) << "Unable to send the entire message!";
278   }
279 }
280 
Send(const uint8_t * data,size_t len,bool binary)281 bool WsConnectionImpl::Send(const uint8_t* data, size_t len, bool binary) {
282   if (!wsi_) {
283     LOG(WARNING) << "Send called on an uninitialized connection!!";
284     return false;
285   }
286   WsBuffer buffer(data, len, binary);
287   {
288     std::lock_guard<std::mutex> lock(write_queue_mutex_);
289     write_queue_.emplace_back(std::move(buffer));
290   }
291 
292   lws_callback_on_writable(wsi_);
293   return true;
294 }
295 
LwsCallback(struct lws * wsi,enum lws_callback_reasons reason,void * user,void * in,size_t len)296 int LwsCallback(struct lws* wsi, enum lws_callback_reasons reason, void* user,
297                 void* in, size_t len) {
298   constexpr int DROP = -1;
299   constexpr int OK = 0;
300 
301   // For some values of `reason`, `user` doesn't point to the value provided
302   // when the connection was created. This function object should be used with
303   // care.
304   auto with_connection =
305       [wsi, user](std::function<void(std::shared_ptr<WsConnectionImpl>)> cb) {
306         auto context = reinterpret_cast<WsConnectionContextImpl*>(user);
307         auto connection = context->GetConnection(wsi);
308         if (!connection) {
309           return DROP;
310         }
311         cb(connection);
312         return OK;
313       };
314 
315   switch (reason) {
316     case LWS_CALLBACK_CLIENT_CONNECTION_ERROR:
317       return with_connection(
318           [in](std::shared_ptr<WsConnectionImpl> connection) {
319             connection->OnError(in ? (char*)in : "(null)");
320           });
321 
322     case LWS_CALLBACK_CLIENT_RECEIVE:
323       return with_connection(
324           [in, len, wsi](std::shared_ptr<WsConnectionImpl> connection) {
325             connection->OnReceive((const uint8_t*)in, len,
326                                   lws_frame_is_binary(wsi));
327           });
328 
329     case LWS_CALLBACK_CLIENT_ESTABLISHED:
330       return with_connection([](std::shared_ptr<WsConnectionImpl> connection) {
331         connection->OnOpen();
332       });
333 
334     case LWS_CALLBACK_CLIENT_CLOSED:
335       return with_connection([](std::shared_ptr<WsConnectionImpl> connection) {
336         connection->OnClose();
337       });
338 
339     case LWS_CALLBACK_CLIENT_WRITEABLE:
340       return with_connection([](std::shared_ptr<WsConnectionImpl> connection) {
341         connection->OnWriteable();
342       });
343 
344     default:
345       LOG(VERBOSE) << "Unhandled value: " << reason;
346       return lws_callback_http_dummy(wsi, reason, user, in, len);
347   }
348 }
349 
CreateConnectionCallback(lws_sorted_usec_list_t * sul)350 void CreateConnectionCallback(lws_sorted_usec_list_t* sul) {
351   std::shared_ptr<WsConnectionImpl> connection =
352       reinterpret_cast<WsConnectionImpl::CreateConnectionSul*>(sul)
353           ->weak_this.lock();
354   if (!connection) {
355     LOG(WARNING) << "The object was already destroyed by the time of the first "
356                  << "connection attempt. That's unusual.";
357     return;
358   }
359   connection->ConnectInner();
360 }
361 
ConnectInner()362 void WsConnectionImpl::ConnectInner() {
363   struct lws_client_connect_info connect_info;
364 
365   memset(&connect_info, 0, sizeof(connect_info));
366 
367   connect_info.context = context_->lws_context();
368   connect_info.port = port_;
369   connect_info.address = addr_.c_str();
370   connect_info.path = path_.c_str();
371   connect_info.host = connect_info.address;
372   connect_info.origin = connect_info.address;
373   switch (security_) {
374     case Security::kAllowSelfSigned:
375       connect_info.ssl_connection = LCCSCF_ALLOW_SELFSIGNED |
376                                     LCCSCF_SKIP_SERVER_CERT_HOSTNAME_CHECK |
377                                     LCCSCF_USE_SSL;
378       break;
379     case Security::kStrict:
380       connect_info.ssl_connection = LCCSCF_USE_SSL;
381       break;
382     case Security::kInsecure:
383       connect_info.ssl_connection = 0;
384       break;
385   }
386   connect_info.protocol = "UNNUSED";
387   connect_info.local_protocol_name = kProtocolName;
388   connect_info.pwsi = &wsi_;
389   connect_info.retry_and_idle_policy = &kRetry;
390   // There is no guarantee the connection object still exists when the callback
391   // is called. Put the context instead as the user data which is guaranteed to
392   // still exist and holds a weak ptr to the connection.
393   connect_info.userdata = context_.get();
394 
395   if (lws_client_connect_via_info(&connect_info)) {
396     // wsi_ is not initialized until after the call to
397     // lws_client_connect_via_info(). Luckily, this is guaranteed to run before
398     // the protocol callback is called because it runs in the same loop.
399     context_->RememberConnection(wsi_, weak_from_this());
400   } else {
401     LOG(ERROR) << "Connection failed!";
402   }
403 }
404