1 /*
2  * Copyright (C) 2020 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "host/frontend/webrtc/lib/streamer.h"
18 
19 #include <android-base/logging.h>
20 #include <json/json.h>
21 
22 #include <api/audio_codecs/audio_decoder_factory.h>
23 #include <api/audio_codecs/audio_encoder_factory.h>
24 #include <api/audio_codecs/builtin_audio_decoder_factory.h>
25 #include <api/audio_codecs/builtin_audio_encoder_factory.h>
26 #include <api/create_peerconnection_factory.h>
27 #include <api/peer_connection_interface.h>
28 #include <api/video_codecs/builtin_video_decoder_factory.h>
29 #include <api/video_codecs/builtin_video_encoder_factory.h>
30 #include <api/video_codecs/video_decoder_factory.h>
31 #include <api/video_codecs/video_encoder_factory.h>
32 #include <media/base/video_broadcaster.h>
33 #include <pc/video_track_source.h>
34 
35 #include "host/frontend/gcastv2/signaling_server/constants/signaling_constants.h"
36 #include "host/frontend/webrtc/lib/client_handler.h"
37 #include "host/frontend/webrtc/lib/port_range_socket_factory.h"
38 #include "host/frontend/webrtc/lib/video_track_source_impl.h"
39 #include "host/frontend/webrtc/lib/vp8only_encoder_factory.h"
40 
41 namespace cuttlefish {
42 namespace webrtc_streaming {
43 namespace {
44 
45 constexpr auto kStreamIdField = "stream_id";
46 constexpr auto kXResField = "x_res";
47 constexpr auto kYResField = "y_res";
48 constexpr auto kDpiField = "dpi";
49 constexpr auto kIsTouchField = "is_touch";
50 constexpr auto kDisplaysField = "displays";
51 
SendJson(WsConnection * ws_conn,const Json::Value & data)52 void SendJson(WsConnection* ws_conn, const Json::Value& data) {
53   Json::FastWriter json_writer;
54   auto data_str = json_writer.write(data);
55   ws_conn->Send(reinterpret_cast<const uint8_t*>(data_str.c_str()),
56                 data_str.size());
57 }
58 
ParseMessage(const uint8_t * data,size_t length,Json::Value * msg_out)59 bool ParseMessage(const uint8_t* data, size_t length, Json::Value* msg_out) {
60   Json::Reader json_reader;
61   auto str = reinterpret_cast<const char*>(data);
62   return json_reader.parse(str, str + length, *msg_out) >= 0;
63 }
64 
CreateAndStartThread(const std::string & name)65 std::unique_ptr<rtc::Thread> CreateAndStartThread(const std::string& name) {
66   auto thread = rtc::Thread::CreateWithSocketServer();
67   if (!thread) {
68     LOG(ERROR) << "Failed to create " << name << " thread";
69     return nullptr;
70   }
71   thread->SetName(name, nullptr);
72   if (!thread->Start()) {
73     LOG(ERROR) << "Failed to start " << name << " thread";
74     return nullptr;
75   }
76   return thread;
77 }
78 
79 class StreamerImpl : public Streamer {
80  public:
81   StreamerImpl(const StreamerConfig& cfg,
82                rtc::scoped_refptr<webrtc::PeerConnectionFactoryInterface>
83                    peer_connection_factory,
84                std::unique_ptr<rtc::Thread> network_thread,
85                std::unique_ptr<rtc::Thread> worker_thread,
86                std::unique_ptr<rtc::Thread> signal_thread,
87                std::shared_ptr<ConnectionObserverFactory> factory);
88   ~StreamerImpl() override = default;
89 
90   std::shared_ptr<VideoSink> AddDisplay(const std::string& label, int width,
91                                         int height, int dpi,
92                                         bool touch_enabled) override;
93   void AddAudio(const std::string& label) override;
94   void Register(std::weak_ptr<OperatorObserver> operator_observer) override;
95   void Unregister() override;
96 
97  private:
98   // This allows the websocket observer methods to be private in Streamer.
99   class WsObserver : public WsConnectionObserver {
100    public:
WsObserver(StreamerImpl * streamer)101     WsObserver(StreamerImpl* streamer) : streamer_(streamer) {}
102     ~WsObserver() override = default;
103 
OnOpen()104     void OnOpen() override { streamer_->OnOpen(); }
OnClose()105     void OnClose() override { streamer_->OnClose(); }
OnError(const std::string & error)106     void OnError(const std::string& error) override {
107       streamer_->OnError(error);
108     }
OnReceive(const uint8_t * msg,size_t length,bool is_binary)109     void OnReceive(const uint8_t* msg, size_t length, bool is_binary) override {
110       streamer_->OnReceive(msg, length, is_binary);
111     }
112 
113    private:
114     StreamerImpl* streamer_;
115   };
116   struct DisplayDescriptor {
117     int width;
118     int height;
119     int dpi;
120     bool touch_enabled;
121     rtc::scoped_refptr<webrtc::VideoTrackSourceInterface> source;
122   };
123   // TODO (jemoreira): move to a place in common with the signaling server
124   struct OperatorServerConfig {
125     std::vector<webrtc::PeerConnectionInterface::IceServer> servers;
126   };
127 
128   std::shared_ptr<ClientHandler> CreateClientHandler(int client_id);
129 
130   void SendMessageToClient(int client_id, const Json::Value& msg);
131   void DestroyClientHandler(int client_id);
132 
133   // For use by WsObserver
134   void OnOpen();
135   void OnClose();
136   void OnError(const std::string& error);
137   void OnReceive(const uint8_t* msg, size_t length, bool is_binary);
138 
139   void HandleConfigMessage(const Json::Value& msg);
140   void HandleClientMessage(const Json::Value& server_message);
141 
142   // All accesses to these variables happen from the signal_thread_, so there is
143   // no need for extra synchronization mechanisms (mutex)
144   StreamerConfig config_;
145   OperatorServerConfig operator_config_;
146   std::shared_ptr<WsConnection> server_connection_;
147   std::shared_ptr<ConnectionObserverFactory> connection_observer_factory_;
148   rtc::scoped_refptr<webrtc::PeerConnectionFactoryInterface>
149       peer_connection_factory_;
150   std::unique_ptr<rtc::Thread> network_thread_;
151   std::unique_ptr<rtc::Thread> worker_thread_;
152   std::unique_ptr<rtc::Thread> signal_thread_;
153   std::map<std::string, DisplayDescriptor> displays_;
154   std::map<int, std::shared_ptr<ClientHandler>> clients_;
155   std::shared_ptr<WsObserver> ws_observer_;
156   std::weak_ptr<OperatorObserver> operator_observer_;
157 };
158 
StreamerImpl(const StreamerConfig & cfg,rtc::scoped_refptr<webrtc::PeerConnectionFactoryInterface> peer_connection_factory,std::unique_ptr<rtc::Thread> network_thread,std::unique_ptr<rtc::Thread> worker_thread,std::unique_ptr<rtc::Thread> signal_thread,std::shared_ptr<ConnectionObserverFactory> connection_observer_factory)159 StreamerImpl::StreamerImpl(
160     const StreamerConfig& cfg,
161     rtc::scoped_refptr<webrtc::PeerConnectionFactoryInterface>
162         peer_connection_factory,
163     std::unique_ptr<rtc::Thread> network_thread,
164     std::unique_ptr<rtc::Thread> worker_thread,
165     std::unique_ptr<rtc::Thread> signal_thread,
166     std::shared_ptr<ConnectionObserverFactory> connection_observer_factory)
167     : config_(cfg),
168       connection_observer_factory_(connection_observer_factory),
169       peer_connection_factory_(peer_connection_factory),
170       network_thread_(std::move(network_thread)),
171       worker_thread_(std::move(worker_thread)),
172       signal_thread_(std::move(signal_thread)),
173       ws_observer_(new WsObserver(this)) {}
174 
AddDisplay(const std::string & label,int width,int height,int dpi,bool touch_enabled)175 std::shared_ptr<VideoSink> StreamerImpl::AddDisplay(const std::string& label,
176                                                 int width, int height, int dpi,
177                                                 bool touch_enabled) {
178   // Usually called from an application thread
179   return signal_thread_->Invoke<std::shared_ptr<VideoSink>>(
180       RTC_FROM_HERE,
181       [this, &label, width, height, dpi,
182        touch_enabled]() -> std::shared_ptr<VideoSink> {
183         if (displays_.count(label)) {
184           LOG(ERROR) << "Display with same label already exists: " << label;
185           return nullptr;
186         }
187         rtc::scoped_refptr<VideoTrackSourceImpl> source(
188             new rtc::RefCountedObject<VideoTrackSourceImpl>(width, height));
189         displays_[label] = {width, height, dpi, touch_enabled, source};
190         return std::shared_ptr<VideoSink>(
191             new VideoTrackSourceImplSinkWrapper(source));
192       });
193 }
194 
AddAudio(const std::string & label)195 void StreamerImpl::AddAudio(const std::string& label) {
196   // Usually called from an application thread
197   // TODO (b/128328845): audio support. Use signal_thread_->Invoke<>();
198 }
199 
Register(std::weak_ptr<OperatorObserver> observer)200 void StreamerImpl::Register(std::weak_ptr<OperatorObserver> observer) {
201   // Usually called from an application thread
202   // No need to block the calling thread on this, the observer will be notified
203   // when the connection is established.
204   signal_thread_->PostTask(RTC_FROM_HERE, [this, observer]() {
205     operator_observer_ = observer;
206     // This can be a local variable since the connection object will keep a
207     // reference to it.
208     auto ws_context = WsConnectionContext::Create();
209     CHECK(ws_context) << "Failed to create websocket context";
210     server_connection_ = ws_context->CreateConnection(
211         config_.operator_server.port, config_.operator_server.addr,
212         config_.operator_server.path, config_.operator_server.security,
213         ws_observer_);
214 
215     CHECK(server_connection_) << "Unable to create websocket connection object";
216 
217     server_connection_->Connect();
218   });
219 }
220 
Unregister()221 void StreamerImpl::Unregister() {
222   // Usually called from an application thread.
223   signal_thread_->PostTask(RTC_FROM_HERE,
224                            [this]() { server_connection_.reset(); });
225 }
226 
OnOpen()227 void StreamerImpl::OnOpen() {
228   // Called from the websocket thread.
229   // Connected to operator.
230   signal_thread_->PostTask(RTC_FROM_HERE, [this]() {
231     Json::Value register_obj;
232     register_obj[cuttlefish::webrtc_signaling::kTypeField] =
233         cuttlefish::webrtc_signaling::kRegisterType;
234     register_obj[cuttlefish::webrtc_signaling::kDeviceIdField] =
235         config_.device_id;
236 
237     Json::Value device_info;
238     Json::Value displays(Json::ValueType::arrayValue);
239     // No need to synchronize with other accesses to display_ because all
240     // happens on signal_thread.
241     for (auto& entry : displays_) {
242       Json::Value display;
243       display[kStreamIdField] = entry.first;
244       display[kXResField] = entry.second.width;
245       display[kYResField] = entry.second.height;
246       display[kDpiField] = entry.second.dpi;
247       display[kIsTouchField] = true;
248       displays.append(display);
249     }
250     device_info[kDisplaysField] = displays;
251     register_obj[cuttlefish::webrtc_signaling::kDeviceInfoField] = device_info;
252     SendJson(server_connection_.get(), register_obj);
253     // Do this last as OnRegistered() is user code and may take some time to
254     // complete (although it shouldn't...)
255     auto observer = operator_observer_.lock();
256     if (observer) {
257       observer->OnRegistered();
258     }
259   });
260 }
261 
OnClose()262 void StreamerImpl::OnClose() {
263   // Called from websocket thread
264   // The operator shouldn't close the connection with the client, it's up to the
265   // device to decide when to disconnect.
266   LOG(WARNING) << "Websocket closed unexpectedly";
267   signal_thread_->PostTask(RTC_FROM_HERE, [this]() {
268     auto observer = operator_observer_.lock();
269     if (observer) {
270       observer->OnClose();
271     }
272   });
273 }
274 
OnError(const std::string & error)275 void StreamerImpl::OnError(const std::string& error) {
276   // Called from websocket thread.
277   LOG(ERROR) << "Error on connection with the operator: " << error;
278   signal_thread_->PostTask(RTC_FROM_HERE, [this]() {
279     auto observer = operator_observer_.lock();
280     if (observer) {
281       observer->OnError();
282     }
283   });
284 }
285 
HandleConfigMessage(const Json::Value & server_message)286 void StreamerImpl::HandleConfigMessage(const Json::Value& server_message) {
287   CHECK(signal_thread_->IsCurrent())
288       << __FUNCTION__ << " called from the wrong thread";
289   if (server_message.isMember("ice_servers") &&
290       server_message["ice_servers"].isArray()) {
291     auto servers = server_message["ice_servers"];
292     operator_config_.servers.clear();
293     for (int server_idx = 0; server_idx < servers.size(); server_idx++) {
294       auto server = servers[server_idx];
295       webrtc::PeerConnectionInterface::IceServer ice_server;
296       if (!server.isMember("urls") || !server["urls"].isArray()) {
297         // The urls field is required
298         LOG(WARNING)
299             << "Invalid ICE server specification obtained from server: "
300             << server.toStyledString();
301         continue;
302       }
303       auto urls = server["urls"];
304       for (int url_idx = 0; url_idx < urls.size(); url_idx++) {
305         auto url = urls[url_idx];
306         if (!url.isString()) {
307           LOG(WARNING) << "Non string 'urls' field in ice server: "
308                        << url.toStyledString();
309           continue;
310         }
311         ice_server.urls.push_back(url.asString());
312         if (server.isMember("credential") && server["credential"].isString()) {
313           ice_server.password = server["credential"].asString();
314         }
315         if (server.isMember("username") && server["username"].isString()) {
316           ice_server.username = server["username"].asString();
317         }
318         operator_config_.servers.push_back(ice_server);
319       }
320     }
321   }
322 }
323 
HandleClientMessage(const Json::Value & server_message)324 void StreamerImpl::HandleClientMessage(const Json::Value& server_message) {
325   CHECK(signal_thread_->IsCurrent())
326       << __FUNCTION__ << " called from the wrong thread";
327   if (!server_message.isMember(cuttlefish::webrtc_signaling::kClientIdField) ||
328       !server_message[cuttlefish::webrtc_signaling::kClientIdField].isInt()) {
329     LOG(ERROR) << "Client message received without valid client id";
330     return;
331   }
332   auto client_id =
333       server_message[cuttlefish::webrtc_signaling::kClientIdField].asInt();
334   if (!server_message.isMember(cuttlefish::webrtc_signaling::kPayloadField)) {
335     LOG(WARNING) << "Received empty client message";
336     return;
337   }
338   auto client_message =
339       server_message[cuttlefish::webrtc_signaling::kPayloadField];
340   if (clients_.count(client_id) == 0) {
341     auto client_handler = CreateClientHandler(client_id);
342     if (!client_handler) {
343       LOG(ERROR) << "Failed to create a new client handler";
344       return;
345     }
346     clients_.emplace(client_id, client_handler);
347   }
348 
349   auto client_handler = clients_[client_id];
350 
351   client_handler->HandleMessage(client_message);
352 }
353 
OnReceive(const uint8_t * msg,size_t length,bool is_binary)354 void StreamerImpl::OnReceive(const uint8_t* msg, size_t length, bool is_binary) {
355   // Usually called from websocket thread.
356   Json::Value server_message;
357   // Once OnReceive returns the buffer can be destroyed/recycled at any time, so
358   // parse the data into a JSON object while still on the websocket thread.
359   if (is_binary || !ParseMessage(msg, length, &server_message)) {
360     LOG(ERROR) << "Received invalid JSON from server: '"
361                << (is_binary ? std::string("(binary_data)")
362                              : std::string(msg, msg + length))
363                << "'";
364     return;
365   }
366   // Transition to the signal thread before member variables are accessed.
367   signal_thread_->PostTask(RTC_FROM_HERE, [this, server_message]() {
368     if (!server_message.isMember(cuttlefish::webrtc_signaling::kTypeField) ||
369         !server_message[cuttlefish::webrtc_signaling::kTypeField].isString()) {
370       LOG(ERROR) << "No message_type field from server";
371       // Notify the caller
372       OnError(
373           "Invalid message received from operator: no message type field "
374           "present");
375       return;
376     }
377     auto type =
378         server_message[cuttlefish::webrtc_signaling::kTypeField].asString();
379     if (type == cuttlefish::webrtc_signaling::kConfigType) {
380       HandleConfigMessage(server_message);
381     } else if (type == cuttlefish::webrtc_signaling::kClientMessageType) {
382       HandleClientMessage(server_message);
383     } else {
384       LOG(ERROR) << "Unknown message type: " << type;
385       // Notify the caller
386       OnError("Invalid message received from operator: unknown message type");
387       return;
388     }
389   });
390 }
391 
CreateClientHandler(int client_id)392 std::shared_ptr<ClientHandler> StreamerImpl::CreateClientHandler(int client_id) {
393   CHECK(signal_thread_->IsCurrent())
394       << __FUNCTION__ << " called from the wrong thread";
395   auto observer = connection_observer_factory_->CreateObserver();
396 
397   auto client_handler = ClientHandler::Create(
398       client_id, observer,
399       [this, client_id](const Json::Value& msg) {
400         SendMessageToClient(client_id, msg);
401       },
402       [this, client_id] { DestroyClientHandler(client_id); });
403 
404   webrtc::PeerConnectionInterface::RTCConfiguration config;
405   config.sdp_semantics = webrtc::SdpSemantics::kUnifiedPlan;
406   config.enable_dtls_srtp = true;
407   config.servers.insert(config.servers.end(), operator_config_.servers.begin(),
408                         operator_config_.servers.end());
409   webrtc::PeerConnectionDependencies dependencies(client_handler.get());
410   // PortRangeSocketFactory's super class' constructor needs to be called on the
411   // network thread or have it as a parameter
412   dependencies.packet_socket_factory.reset(new PortRangeSocketFactory(
413       network_thread_.get(), config_.udp_port_range, config_.tcp_port_range));
414   auto peer_connection = peer_connection_factory_->CreatePeerConnection(
415       config, std::move(dependencies));
416 
417   if (!peer_connection) {
418     LOG(ERROR) << "Failed to create peer connection";
419     return nullptr;
420   }
421 
422   if (!client_handler->SetPeerConnection(std::move(peer_connection))) {
423     return nullptr;
424   }
425 
426   for (auto& entry : displays_) {
427     auto& label = entry.first;
428     auto& video_source = entry.second.source;
429 
430     auto video_track =
431         peer_connection_factory_->CreateVideoTrack(label, video_source.get());
432     client_handler->AddDisplay(video_track, label);
433   }
434 
435   return client_handler;
436 }
437 
SendMessageToClient(int client_id,const Json::Value & msg)438 void StreamerImpl::SendMessageToClient(int client_id, const Json::Value& msg) {
439   LOG(VERBOSE) << "Sending to client: " << msg.toStyledString();
440   CHECK(signal_thread_->IsCurrent())
441       << __FUNCTION__ << " called from the wrong thread";
442   Json::Value wrapper;
443   wrapper[cuttlefish::webrtc_signaling::kPayloadField] = msg;
444   wrapper[cuttlefish::webrtc_signaling::kTypeField] =
445       cuttlefish::webrtc_signaling::kForwardType;
446   wrapper[cuttlefish::webrtc_signaling::kClientIdField] = client_id;
447   // This is safe to call from the webrtc threads because
448   // WsConnection is thread safe
449   SendJson(server_connection_.get(), wrapper);
450 }
451 
DestroyClientHandler(int client_id)452 void StreamerImpl::DestroyClientHandler(int client_id) {
453   // Usually called from signal thread, could be called from websocket thread or
454   // an application thread.
455   signal_thread_->PostTask(RTC_FROM_HERE, [this, client_id]() {
456     // This needs to be 'posted' to the thread instead of 'invoked'
457     // immediately for two reasons:
458     // * The client handler is destroyed by this code, it's generally a
459     // bad idea (though not necessarily wrong) to return to a member
460     // function of a destroyed object.
461     // * The client handler may call this from within a peer connection
462     // observer callback, destroying the client handler there leads to a
463     // deadlock.
464     clients_.erase(client_id);
465   });
466 }
467 
468 }  // namespace
469 
470 /* static */
Create(const StreamerConfig & cfg,std::shared_ptr<ConnectionObserverFactory> connection_observer_factory)471 std::shared_ptr<Streamer> Streamer::Create(
472     const StreamerConfig& cfg,
473     std::shared_ptr<ConnectionObserverFactory> connection_observer_factory) {
474   auto network_thread = CreateAndStartThread("network-thread");
475   auto worker_thread = CreateAndStartThread("work-thread");
476   auto signal_thread = CreateAndStartThread("signal-thread");
477   if (!network_thread || !worker_thread || !signal_thread) {
478     return nullptr;
479   }
480 
481   auto pc_factory = webrtc::CreatePeerConnectionFactory(
482       network_thread.get(), worker_thread.get(), signal_thread.get(),
483       nullptr /* default_adm */, webrtc::CreateBuiltinAudioEncoderFactory(),
484       webrtc::CreateBuiltinAudioDecoderFactory(),
485       std::make_unique<VP8OnlyEncoderFactory>(
486           webrtc::CreateBuiltinVideoEncoderFactory()),
487       webrtc::CreateBuiltinVideoDecoderFactory(), nullptr /* audio_mixer */,
488       nullptr /* audio_processing */);
489 
490   if (!pc_factory) {
491     LOG(ERROR) << "Failed to create peer connection factory";
492     return nullptr;
493   }
494 
495   webrtc::PeerConnectionFactoryInterface::Options options;
496   // By default the loopback network is ignored, but generating candidates for
497   // it is useful when using TCP port forwarding.
498   options.network_ignore_mask = 0;
499   pc_factory->SetOptions(options);
500 
501   return std::shared_ptr<Streamer>(new StreamerImpl(
502       cfg, pc_factory, std::move(network_thread), std::move(worker_thread),
503       std::move(signal_thread), connection_observer_factory));
504 }
505 
506 }  // namespace webrtc_streaming
507 }  // namespace cuttlefish
508