1 //
2 // Copyright (C) 2020 The Android Open Source Project
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 //
15
16 #include "host/frontend/webrtc/lib/ws_connection.h"
17
18 #include <android-base/logging.h>
19 #include <libwebsockets.h>
20
21 class WsConnectionContextImpl;
22
23 class WsConnectionImpl : public WsConnection,
24 public std::enable_shared_from_this<WsConnectionImpl> {
25 public:
26 struct CreateConnectionSul {
27 lws_sorted_usec_list_t sul = {};
28 std::weak_ptr<WsConnectionImpl> weak_this;
29 };
30
31 WsConnectionImpl(int port, const std::string& addr, const std::string& path,
32 Security secure,
33 std::weak_ptr<WsConnectionObserver> observer,
34 std::shared_ptr<WsConnectionContextImpl> context);
35
36 ~WsConnectionImpl() override;
37
38 void Connect() override;
39 void ConnectInner();
40
41 bool Send(const uint8_t* data, size_t len, bool binary = false) override;
42
43 void OnError(const std::string& error);
44 void OnReceive(const uint8_t* data, size_t len, bool is_binary);
45 void OnOpen();
46 void OnClose();
47 void OnWriteable();
48
49 private:
50 struct WsBuffer {
51 WsBuffer() = default;
WsBufferWsConnectionImpl::WsBuffer52 WsBuffer(const uint8_t* data, size_t len, bool binary)
53 : buffer_(LWS_PRE + len), is_binary_(binary) {
54 memcpy(&buffer_[LWS_PRE], data, len);
55 }
56
dataWsConnectionImpl::WsBuffer57 uint8_t* data() { return &buffer_[LWS_PRE]; }
is_binaryWsConnectionImpl::WsBuffer58 bool is_binary() const { return is_binary_; }
sizeWsConnectionImpl::WsBuffer59 size_t size() const { return buffer_.size() - LWS_PRE; }
60
61 private:
62 std::vector<uint8_t> buffer_;
63 bool is_binary_;
64 };
65
66 CreateConnectionSul extended_sul_;
67 struct lws* wsi_;
68 const int port_;
69 const std::string addr_;
70 const std::string path_;
71 const Security security_;
72
73 std::weak_ptr<WsConnectionObserver> observer_;
74
75 // each element contains the data to be sent and whether it's binary or not
76 std::deque<WsBuffer> write_queue_;
77 std::mutex write_queue_mutex_;
78 // The connection object should not outlive the context object. This reference
79 // guarantees it.
80 std::shared_ptr<WsConnectionContextImpl> context_;
81 };
82
83 class WsConnectionContextImpl
84 : public WsConnectionContext,
85 public std::enable_shared_from_this<WsConnectionContextImpl> {
86 public:
87 WsConnectionContextImpl(struct lws_context* lws_ctx);
88 ~WsConnectionContextImpl() override;
89
90 std::shared_ptr<WsConnection> CreateConnection(
91 int port, const std::string& addr, const std::string& path,
92 WsConnection::Security secure,
93 std::weak_ptr<WsConnectionObserver> observer) override;
94
95 void RememberConnection(void*, std::weak_ptr<WsConnectionImpl>);
96 void ForgetConnection(void*);
97 std::shared_ptr<WsConnectionImpl> GetConnection(void*);
98
lws_context()99 struct lws_context* lws_context() {
100 return lws_context_;
101 }
102
103 private:
104 void Start();
105
106 std::map<void*, std::weak_ptr<WsConnectionImpl>> weak_by_ptr_;
107 std::mutex map_mutex_;
108 struct lws_context* lws_context_;
109 std::thread message_loop_;
110 };
111
112 int LwsCallback(struct lws* wsi, enum lws_callback_reasons reason, void* user,
113 void* in, size_t len);
114 void CreateConnectionCallback(lws_sorted_usec_list_t* sul);
115
116 namespace {
117
118 constexpr char kProtocolName[] = "lws-websocket-protocol";
119 constexpr int kBufferSize = 65536;
120
121 const uint32_t backoff_ms[] = {1000, 2000, 3000, 4000, 5000};
122
123 const lws_retry_bo_t kRetry = {
124 .retry_ms_table = backoff_ms,
125 .retry_ms_table_count = LWS_ARRAY_SIZE(backoff_ms),
126 .conceal_count = LWS_ARRAY_SIZE(backoff_ms),
127
128 .secs_since_valid_ping = 3, /* force PINGs after secs idle */
129 .secs_since_valid_hangup = 10, /* hangup after secs idle */
130
131 .jitter_percent = 20,
132 };
133
134 const struct lws_protocols kProtocols[2] = {
135 {kProtocolName, LwsCallback, 0, kBufferSize, 0, NULL, 0},
136 {NULL, NULL, 0, 0, 0, NULL, 0}};
137
138 } // namespace
139
Create()140 std::shared_ptr<WsConnectionContext> WsConnectionContext::Create() {
141 struct lws_context_creation_info context_info = {};
142 context_info.port = CONTEXT_PORT_NO_LISTEN;
143 context_info.options = LWS_SERVER_OPTION_DO_SSL_GLOBAL_INIT;
144 context_info.protocols = kProtocols;
145 struct lws_context* lws_ctx = lws_create_context(&context_info);
146 if (!lws_ctx) {
147 return nullptr;
148 }
149 return std::shared_ptr<WsConnectionContext>(
150 new WsConnectionContextImpl(lws_ctx));
151 }
152
WsConnectionContextImpl(struct lws_context * lws_ctx)153 WsConnectionContextImpl::WsConnectionContextImpl(struct lws_context* lws_ctx)
154 : lws_context_(lws_ctx) {
155 Start();
156 }
157
~WsConnectionContextImpl()158 WsConnectionContextImpl::~WsConnectionContextImpl() {
159 lws_context_destroy(lws_context_);
160 if (message_loop_.joinable()) {
161 message_loop_.join();
162 }
163 }
164
Start()165 void WsConnectionContextImpl::Start() {
166 message_loop_ = std::thread([this]() {
167 for (;;) {
168 if (lws_service(lws_context_, 0) < 0) {
169 break;
170 }
171 }
172 });
173 }
174
CreateConnection(int port,const std::string & addr,const std::string & path,WsConnection::Security security,std::weak_ptr<WsConnectionObserver> observer)175 std::shared_ptr<WsConnection> WsConnectionContextImpl::CreateConnection(
176 int port, const std::string& addr, const std::string& path,
177 WsConnection::Security security,
178 std::weak_ptr<WsConnectionObserver> observer) {
179 return std::shared_ptr<WsConnection>(new WsConnectionImpl(
180 port, addr, path, security, observer, shared_from_this()));
181 }
182
GetConnection(void * raw)183 std::shared_ptr<WsConnectionImpl> WsConnectionContextImpl::GetConnection(
184 void* raw) {
185 std::shared_ptr<WsConnectionImpl> connection;
186 {
187 std::lock_guard<std::mutex> lock(map_mutex_);
188 if (weak_by_ptr_.count(raw) == 0) {
189 return nullptr;
190 }
191 connection = weak_by_ptr_[raw].lock();
192 if (!connection) {
193 weak_by_ptr_.erase(raw);
194 }
195 }
196 return connection;
197 }
198
RememberConnection(void * raw,std::weak_ptr<WsConnectionImpl> conn)199 void WsConnectionContextImpl::RememberConnection(
200 void* raw, std::weak_ptr<WsConnectionImpl> conn) {
201 std::lock_guard<std::mutex> lock(map_mutex_);
202 weak_by_ptr_.emplace(
203 std::pair<void*, std::weak_ptr<WsConnectionImpl>>(raw, conn));
204 }
205
ForgetConnection(void * raw)206 void WsConnectionContextImpl::ForgetConnection(void* raw) {
207 std::lock_guard<std::mutex> lock(map_mutex_);
208 weak_by_ptr_.erase(raw);
209 }
210
WsConnectionImpl(int port,const std::string & addr,const std::string & path,Security security,std::weak_ptr<WsConnectionObserver> observer,std::shared_ptr<WsConnectionContextImpl> context)211 WsConnectionImpl::WsConnectionImpl(
212 int port, const std::string& addr, const std::string& path,
213 Security security, std::weak_ptr<WsConnectionObserver> observer,
214 std::shared_ptr<WsConnectionContextImpl> context)
215 : port_(port),
216 addr_(addr),
217 path_(path),
218 security_(security),
219 observer_(observer),
220 context_(context) {}
221
~WsConnectionImpl()222 WsConnectionImpl::~WsConnectionImpl() {
223 context_->ForgetConnection(this);
224 // This will cause the callback to be called which will drop the connection
225 // after seeing the context doesn't remember this object
226 lws_callback_on_writable(wsi_);
227 }
228
Connect()229 void WsConnectionImpl::Connect() {
230 memset(&extended_sul_.sul, 0, sizeof(extended_sul_.sul));
231 extended_sul_.weak_this = weak_from_this();
232 lws_sul_schedule(context_->lws_context(), 0, &extended_sul_.sul,
233 CreateConnectionCallback, 1);
234 }
235
OnError(const std::string & error)236 void WsConnectionImpl::OnError(const std::string& error) {
237 auto observer = observer_.lock();
238 if (observer) {
239 observer->OnError(error);
240 }
241 }
OnReceive(const uint8_t * data,size_t len,bool is_binary)242 void WsConnectionImpl::OnReceive(const uint8_t* data, size_t len,
243 bool is_binary) {
244 auto observer = observer_.lock();
245 if (observer) {
246 observer->OnReceive(data, len, is_binary);
247 }
248 }
OnOpen()249 void WsConnectionImpl::OnOpen() {
250 auto observer = observer_.lock();
251 if (observer) {
252 observer->OnOpen();
253 }
254 }
OnClose()255 void WsConnectionImpl::OnClose() {
256 auto observer = observer_.lock();
257 if (observer) {
258 observer->OnClose();
259 }
260 }
261
OnWriteable()262 void WsConnectionImpl::OnWriteable() {
263 WsBuffer buffer;
264 {
265 std::lock_guard<std::mutex> lock(write_queue_mutex_);
266 if (write_queue_.size() == 0) {
267 return;
268 }
269 buffer = std::move(write_queue_.front());
270 write_queue_.pop_front();
271 }
272 auto flags = lws_write_ws_flags(
273 buffer.is_binary() ? LWS_WRITE_BINARY : LWS_WRITE_TEXT, true, true);
274 auto res = lws_write(wsi_, buffer.data(), buffer.size(),
275 (enum lws_write_protocol)flags);
276 if (res != buffer.size()) {
277 LOG(WARNING) << "Unable to send the entire message!";
278 }
279 }
280
Send(const uint8_t * data,size_t len,bool binary)281 bool WsConnectionImpl::Send(const uint8_t* data, size_t len, bool binary) {
282 if (!wsi_) {
283 LOG(WARNING) << "Send called on an uninitialized connection!!";
284 return false;
285 }
286 WsBuffer buffer(data, len, binary);
287 {
288 std::lock_guard<std::mutex> lock(write_queue_mutex_);
289 write_queue_.emplace_back(std::move(buffer));
290 }
291
292 lws_callback_on_writable(wsi_);
293 return true;
294 }
295
LwsCallback(struct lws * wsi,enum lws_callback_reasons reason,void * user,void * in,size_t len)296 int LwsCallback(struct lws* wsi, enum lws_callback_reasons reason, void* user,
297 void* in, size_t len) {
298 constexpr int DROP = -1;
299 constexpr int OK = 0;
300
301 // For some values of `reason`, `user` doesn't point to the value provided
302 // when the connection was created. This function object should be used with
303 // care.
304 auto with_connection =
305 [wsi, user](std::function<void(std::shared_ptr<WsConnectionImpl>)> cb) {
306 auto context = reinterpret_cast<WsConnectionContextImpl*>(user);
307 auto connection = context->GetConnection(wsi);
308 if (!connection) {
309 return DROP;
310 }
311 cb(connection);
312 return OK;
313 };
314
315 switch (reason) {
316 case LWS_CALLBACK_CLIENT_CONNECTION_ERROR:
317 return with_connection(
318 [in](std::shared_ptr<WsConnectionImpl> connection) {
319 connection->OnError(in ? (char*)in : "(null)");
320 });
321
322 case LWS_CALLBACK_CLIENT_RECEIVE:
323 return with_connection(
324 [in, len, wsi](std::shared_ptr<WsConnectionImpl> connection) {
325 connection->OnReceive((const uint8_t*)in, len,
326 lws_frame_is_binary(wsi));
327 });
328
329 case LWS_CALLBACK_CLIENT_ESTABLISHED:
330 return with_connection([](std::shared_ptr<WsConnectionImpl> connection) {
331 connection->OnOpen();
332 });
333
334 case LWS_CALLBACK_CLIENT_CLOSED:
335 return with_connection([](std::shared_ptr<WsConnectionImpl> connection) {
336 connection->OnClose();
337 });
338
339 case LWS_CALLBACK_CLIENT_WRITEABLE:
340 return with_connection([](std::shared_ptr<WsConnectionImpl> connection) {
341 connection->OnWriteable();
342 });
343
344 default:
345 LOG(VERBOSE) << "Unhandled value: " << reason;
346 return lws_callback_http_dummy(wsi, reason, user, in, len);
347 }
348 }
349
CreateConnectionCallback(lws_sorted_usec_list_t * sul)350 void CreateConnectionCallback(lws_sorted_usec_list_t* sul) {
351 std::shared_ptr<WsConnectionImpl> connection =
352 reinterpret_cast<WsConnectionImpl::CreateConnectionSul*>(sul)
353 ->weak_this.lock();
354 if (!connection) {
355 LOG(WARNING) << "The object was already destroyed by the time of the first "
356 << "connection attempt. That's unusual.";
357 return;
358 }
359 connection->ConnectInner();
360 }
361
ConnectInner()362 void WsConnectionImpl::ConnectInner() {
363 struct lws_client_connect_info connect_info;
364
365 memset(&connect_info, 0, sizeof(connect_info));
366
367 connect_info.context = context_->lws_context();
368 connect_info.port = port_;
369 connect_info.address = addr_.c_str();
370 connect_info.path = path_.c_str();
371 connect_info.host = connect_info.address;
372 connect_info.origin = connect_info.address;
373 switch (security_) {
374 case Security::kAllowSelfSigned:
375 connect_info.ssl_connection = LCCSCF_ALLOW_SELFSIGNED |
376 LCCSCF_SKIP_SERVER_CERT_HOSTNAME_CHECK |
377 LCCSCF_USE_SSL;
378 break;
379 case Security::kStrict:
380 connect_info.ssl_connection = LCCSCF_USE_SSL;
381 break;
382 case Security::kInsecure:
383 connect_info.ssl_connection = 0;
384 break;
385 }
386 connect_info.protocol = "UNNUSED";
387 connect_info.local_protocol_name = kProtocolName;
388 connect_info.pwsi = &wsi_;
389 connect_info.retry_and_idle_policy = &kRetry;
390 // There is no guarantee the connection object still exists when the callback
391 // is called. Put the context instead as the user data which is guaranteed to
392 // still exist and holds a weak ptr to the connection.
393 connect_info.userdata = context_.get();
394
395 if (lws_client_connect_via_info(&connect_info)) {
396 // wsi_ is not initialized until after the call to
397 // lws_client_connect_via_info(). Luckily, this is guaranteed to run before
398 // the protocol callback is called because it runs in the same loop.
399 context_->RememberConnection(wsi_, weak_from_this());
400 } else {
401 LOG(ERROR) << "Connection failed!";
402 }
403 }
404