1 /*
2  * Copyright (C) 2007 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define TRACE_TAG SOCKETS
18 
19 #include "sysdeps.h"
20 
21 #include <ctype.h>
22 #include <errno.h>
23 #include <stdio.h>
24 #include <stdlib.h>
25 #include <string.h>
26 #include <unistd.h>
27 
28 #include <algorithm>
29 #include <chrono>
30 #include <mutex>
31 #include <string>
32 #include <vector>
33 
34 #include <android-base/strings.h>
35 
36 #if !ADB_HOST
37 #include <android-base/properties.h>
38 #include <log/log_properties.h>
39 #endif
40 
41 #include "adb.h"
42 #include "adb_io.h"
43 #include "adb_utils.h"
44 #include "transport.h"
45 #include "types.h"
46 
47 using namespace std::chrono_literals;
48 
49 static std::recursive_mutex& local_socket_list_lock = *new std::recursive_mutex();
50 static unsigned local_socket_next_id = 1;
51 
52 static auto& local_socket_list = *new std::vector<asocket*>();
53 
54 /* the the list of currently closing local sockets.
55 ** these have no peer anymore, but still packets to
56 ** write to their fd.
57 */
58 static auto& local_socket_closing_list = *new std::vector<asocket*>();
59 
60 // Parse the global list of sockets to find one with id |local_id|.
61 // If |peer_id| is not 0, also check that it is connected to a peer
62 // with id |peer_id|. Returns an asocket handle on success, NULL on failure.
find_local_socket(unsigned local_id,unsigned peer_id)63 asocket* find_local_socket(unsigned local_id, unsigned peer_id) {
64     asocket* result = nullptr;
65 
66     std::lock_guard<std::recursive_mutex> lock(local_socket_list_lock);
67     for (asocket* s : local_socket_list) {
68         if (s->id != local_id) {
69             continue;
70         }
71         if (peer_id == 0 || (s->peer && s->peer->id == peer_id)) {
72             result = s;
73         }
74         break;
75     }
76 
77     return result;
78 }
79 
install_local_socket(asocket * s)80 void install_local_socket(asocket* s) {
81     std::lock_guard<std::recursive_mutex> lock(local_socket_list_lock);
82 
83     s->id = local_socket_next_id++;
84 
85     // Socket ids should never be 0.
86     if (local_socket_next_id == 0) {
87         LOG(FATAL) << "local socket id overflow";
88     }
89 
90     local_socket_list.push_back(s);
91 }
92 
remove_socket(asocket * s)93 void remove_socket(asocket* s) {
94     std::lock_guard<std::recursive_mutex> lock(local_socket_list_lock);
95     for (auto list : { &local_socket_list, &local_socket_closing_list }) {
96         list->erase(std::remove_if(list->begin(), list->end(), [s](asocket* x) { return x == s; }),
97                     list->end());
98     }
99 }
100 
close_all_sockets(atransport * t)101 void close_all_sockets(atransport* t) {
102     /* this is a little gross, but since s->close() *will* modify
103     ** the list out from under you, your options are limited.
104     */
105     std::lock_guard<std::recursive_mutex> lock(local_socket_list_lock);
106 restart:
107     for (asocket* s : local_socket_list) {
108         if (s->transport == t || (s->peer && s->peer->transport == t)) {
109             s->close(s);
110             goto restart;
111         }
112     }
113 }
114 
115 enum class SocketFlushResult {
116     Destroyed,
117     TryAgain,
118     Completed,
119 };
120 
local_socket_flush_incoming(asocket * s)121 static SocketFlushResult local_socket_flush_incoming(asocket* s) {
122     if (!s->packet_queue.empty()) {
123         std::vector<adb_iovec> iov = s->packet_queue.iovecs();
124         ssize_t rc = adb_writev(s->fd, iov.data(), iov.size());
125         if (rc > 0 && static_cast<size_t>(rc) == s->packet_queue.size()) {
126             s->packet_queue.clear();
127         } else if (rc > 0) {
128             s->packet_queue.drop_front(rc);
129             fdevent_add(s->fde, FDE_WRITE);
130             return SocketFlushResult::TryAgain;
131         } else if (rc == -1 && errno == EAGAIN) {
132             fdevent_add(s->fde, FDE_WRITE);
133             return SocketFlushResult::TryAgain;
134         } else {
135             // We failed to write, but it's possible that we can still read from the socket.
136             // Give that a try before giving up.
137             s->has_write_error = true;
138         }
139     }
140 
141     // If we sent the last packet of a closing socket, we can now destroy it.
142     if (s->closing) {
143         s->close(s);
144         return SocketFlushResult::Destroyed;
145     }
146 
147     fdevent_del(s->fde, FDE_WRITE);
148     return SocketFlushResult::Completed;
149 }
150 
151 // Returns false if the socket has been closed and destroyed as a side-effect of this function.
local_socket_flush_outgoing(asocket * s)152 static bool local_socket_flush_outgoing(asocket* s) {
153     const size_t max_payload = s->get_max_payload();
154     apacket::payload_type data;
155     data.resize(max_payload);
156     char* x = &data[0];
157     size_t avail = max_payload;
158     int r = 0;
159     int is_eof = 0;
160 
161     while (avail > 0) {
162         r = adb_read(s->fd, x, avail);
163         D("LS(%d): post adb_read(fd=%d,...) r=%d (errno=%d) avail=%zu", s->id, s->fd, r,
164           r < 0 ? errno : 0, avail);
165         if (r == -1) {
166             if (errno == EAGAIN) {
167                 break;
168             }
169         } else if (r > 0) {
170             avail -= r;
171             x += r;
172             continue;
173         }
174 
175         /* r = 0 or unhandled error */
176         is_eof = 1;
177         break;
178     }
179     D("LS(%d): fd=%d post avail loop. r=%d is_eof=%d forced_eof=%d", s->id, s->fd, r, is_eof,
180       s->fde->force_eof);
181 
182     if (avail != max_payload && s->peer) {
183         data.resize(max_payload - avail);
184 
185         // s->peer->enqueue() may call s->close() and free s,
186         // so save variables for debug printing below.
187         unsigned saved_id = s->id;
188         int saved_fd = s->fd;
189         r = s->peer->enqueue(s->peer, std::move(data));
190         D("LS(%u): fd=%d post peer->enqueue(). r=%d", saved_id, saved_fd, r);
191 
192         if (r < 0) {
193             // Error return means they closed us as a side-effect and we must
194             // return immediately.
195             //
196             // Note that if we still have buffered packets, the socket will be
197             // placed on the closing socket list. This handler function will be
198             // called again to process FDE_WRITE events.
199             return false;
200         }
201 
202         if (r > 0) {
203             /* if the remote cannot accept further events,
204             ** we disable notification of READs.  They'll
205             ** be enabled again when we get a call to ready()
206             */
207             fdevent_del(s->fde, FDE_READ);
208         }
209     }
210 
211     // Don't allow a forced eof if data is still there.
212     if ((s->fde->force_eof && !r) || is_eof) {
213         D(" closing because is_eof=%d r=%d s->fde.force_eof=%d", is_eof, r, s->fde->force_eof);
214         s->close(s);
215         return false;
216     }
217 
218     return true;
219 }
220 
local_socket_enqueue(asocket * s,apacket::payload_type data)221 static int local_socket_enqueue(asocket* s, apacket::payload_type data) {
222     D("LS(%d): enqueue %zu", s->id, data.size());
223 
224     s->packet_queue.append(std::move(data));
225     switch (local_socket_flush_incoming(s)) {
226         case SocketFlushResult::Destroyed:
227             return -1;
228 
229         case SocketFlushResult::TryAgain:
230             return 1;
231 
232         case SocketFlushResult::Completed:
233             return 0;
234     }
235 
236     return !s->packet_queue.empty();
237 }
238 
local_socket_ready(asocket * s)239 static void local_socket_ready(asocket* s) {
240     /* far side is ready for data, pay attention to
241        readable events */
242     fdevent_add(s->fde, FDE_READ);
243 }
244 
245 struct ClosingSocket {
246     std::chrono::steady_clock::time_point begin;
247 };
248 
249 // The standard (RFC 1122 - 4.2.2.13) says that if we call close on a
250 // socket while we have pending data, a TCP RST should be sent to the
251 // other end to notify it that we didn't read all of its data. However,
252 // this can result in data that we've successfully written out to be dropped
253 // on the other end. To avoid this, instead of immediately closing a
254 // socket, call shutdown on it instead, and then read from the file
255 // descriptor until we hit EOF or an error before closing.
deferred_close(unique_fd fd)256 static void deferred_close(unique_fd fd) {
257     // Shutdown the socket in the outgoing direction only, so that
258     // we don't have the same problem on the opposite end.
259     adb_shutdown(fd.get(), SHUT_WR);
260     auto callback = [](fdevent* fde, unsigned event, void* arg) {
261         auto socket_info = static_cast<ClosingSocket*>(arg);
262         if (event & FDE_READ) {
263             ssize_t rc;
264             char buf[BUFSIZ];
265             while ((rc = adb_read(fde->fd.get(), buf, sizeof(buf))) > 0) {
266                 continue;
267             }
268 
269             if (rc == -1 && errno == EAGAIN) {
270                 // There's potentially more data to read.
271                 auto duration = std::chrono::steady_clock::now() - socket_info->begin;
272                 if (duration > 1s) {
273                     LOG(WARNING) << "timeout expired while flushing socket, closing";
274                 } else {
275                     return;
276                 }
277             }
278         } else if (event & FDE_TIMEOUT) {
279             LOG(WARNING) << "timeout expired while flushing socket, closing";
280         }
281 
282         // Either there was an error, we hit the end of the socket, or our timeout expired.
283         fdevent_destroy(fde);
284         delete socket_info;
285     };
286 
287     ClosingSocket* socket_info = new ClosingSocket{
288             .begin = std::chrono::steady_clock::now(),
289     };
290 
291     fdevent* fde = fdevent_create(fd.release(), callback, socket_info);
292     fdevent_add(fde, FDE_READ);
293     fdevent_set_timeout(fde, 1s);
294 }
295 
296 // be sure to hold the socket list lock when calling this
local_socket_destroy(asocket * s)297 static void local_socket_destroy(asocket* s) {
298     int exit_on_close = s->exit_on_close;
299 
300     D("LS(%d): destroying fde.fd=%d", s->id, s->fd);
301 
302     deferred_close(fdevent_release(s->fde));
303 
304     remove_socket(s);
305     delete s;
306 
307     if (exit_on_close) {
308         D("local_socket_destroy: exiting");
309         exit(1);
310     }
311 }
312 
local_socket_close(asocket * s)313 static void local_socket_close(asocket* s) {
314     D("entered local_socket_close. LS(%d) fd=%d", s->id, s->fd);
315     std::lock_guard<std::recursive_mutex> lock(local_socket_list_lock);
316     if (s->peer) {
317         D("LS(%d): closing peer. peer->id=%d peer->fd=%d", s->id, s->peer->id, s->peer->fd);
318         /* Note: it's important to call shutdown before disconnecting from
319          * the peer, this ensures that remote sockets can still get the id
320          * of the local socket they're connected to, to send a CLOSE()
321          * protocol event. */
322         if (s->peer->shutdown) {
323             s->peer->shutdown(s->peer);
324         }
325         s->peer->peer = nullptr;
326         s->peer->close(s->peer);
327         s->peer = nullptr;
328     }
329 
330     /* If we are already closing, or if there are no
331     ** pending packets, destroy immediately
332     */
333     if (s->closing || s->has_write_error || s->packet_queue.empty()) {
334         int id = s->id;
335         local_socket_destroy(s);
336         D("LS(%d): closed", id);
337         return;
338     }
339 
340     /* otherwise, put on the closing list
341     */
342     D("LS(%d): closing", s->id);
343     s->closing = 1;
344     fdevent_del(s->fde, FDE_READ);
345     remove_socket(s);
346     D("LS(%d): put on socket_closing_list fd=%d", s->id, s->fd);
347     local_socket_closing_list.push_back(s);
348     CHECK_EQ(FDE_WRITE, s->fde->state & FDE_WRITE);
349 }
350 
local_socket_event_func(int fd,unsigned ev,void * _s)351 static void local_socket_event_func(int fd, unsigned ev, void* _s) {
352     asocket* s = reinterpret_cast<asocket*>(_s);
353     D("LS(%d): event_func(fd=%d(==%d), ev=%04x)", s->id, s->fd, fd, ev);
354 
355     /* put the FDE_WRITE processing before the FDE_READ
356     ** in order to simplify the code.
357     */
358     if (ev & FDE_WRITE) {
359         switch (local_socket_flush_incoming(s)) {
360             case SocketFlushResult::Destroyed:
361                 return;
362 
363             case SocketFlushResult::TryAgain:
364                 break;
365 
366             case SocketFlushResult::Completed:
367                 s->peer->ready(s->peer);
368                 break;
369         }
370     }
371 
372     if (ev & FDE_READ) {
373         if (!local_socket_flush_outgoing(s)) {
374             return;
375         }
376     }
377 
378     if (ev & FDE_ERROR) {
379         /* this should be caught be the next read or write
380         ** catching it here means we may skip the last few
381         ** bytes of readable data.
382         */
383         D("LS(%d): FDE_ERROR (fd=%d)", s->id, s->fd);
384         return;
385     }
386 }
387 
create_local_socket(unique_fd ufd)388 asocket* create_local_socket(unique_fd ufd) {
389     int fd = ufd.release();
390     asocket* s = new asocket();
391     s->fd = fd;
392     s->enqueue = local_socket_enqueue;
393     s->ready = local_socket_ready;
394     s->shutdown = nullptr;
395     s->close = local_socket_close;
396     install_local_socket(s);
397 
398     s->fde = fdevent_create(fd, local_socket_event_func, s);
399     D("LS(%d): created (fd=%d)", s->id, s->fd);
400     return s;
401 }
402 
create_local_service_socket(std::string_view name,atransport * transport)403 asocket* create_local_service_socket(std::string_view name, atransport* transport) {
404 #if !ADB_HOST
405     if (asocket* s = daemon_service_to_socket(name); s) {
406         return s;
407     }
408 #endif
409     unique_fd fd = service_to_fd(name, transport);
410     if (fd < 0) {
411         return nullptr;
412     }
413 
414     int fd_value = fd.get();
415     asocket* s = create_local_socket(std::move(fd));
416     LOG(VERBOSE) << "LS(" << s->id << "): bound to '" << name << "' via " << fd_value;
417 
418 #if !ADB_HOST
419     if ((name.starts_with("root:") && getuid() != 0 && __android_log_is_debuggable()) ||
420         (name.starts_with("unroot:") && getuid() == 0) || name.starts_with("usb:") ||
421         name.starts_with("tcpip:")) {
422         D("LS(%d): enabling exit_on_close", s->id);
423         s->exit_on_close = 1;
424     }
425 #endif
426 
427     return s;
428 }
429 
remote_socket_enqueue(asocket * s,apacket::payload_type data)430 static int remote_socket_enqueue(asocket* s, apacket::payload_type data) {
431     D("entered remote_socket_enqueue RS(%d) WRITE fd=%d peer.fd=%d", s->id, s->fd, s->peer->fd);
432     apacket* p = get_apacket();
433 
434     p->msg.command = A_WRTE;
435     p->msg.arg0 = s->peer->id;
436     p->msg.arg1 = s->id;
437 
438     if (data.size() > MAX_PAYLOAD) {
439         put_apacket(p);
440         return -1;
441     }
442 
443     p->payload = std::move(data);
444     p->msg.data_length = p->payload.size();
445 
446     send_packet(p, s->transport);
447     return 1;
448 }
449 
remote_socket_ready(asocket * s)450 static void remote_socket_ready(asocket* s) {
451     D("entered remote_socket_ready RS(%d) OKAY fd=%d peer.fd=%d", s->id, s->fd, s->peer->fd);
452     apacket* p = get_apacket();
453     p->msg.command = A_OKAY;
454     p->msg.arg0 = s->peer->id;
455     p->msg.arg1 = s->id;
456     send_packet(p, s->transport);
457 }
458 
remote_socket_shutdown(asocket * s)459 static void remote_socket_shutdown(asocket* s) {
460     D("entered remote_socket_shutdown RS(%d) CLOSE fd=%d peer->fd=%d", s->id, s->fd,
461       s->peer ? s->peer->fd : -1);
462     apacket* p = get_apacket();
463     p->msg.command = A_CLSE;
464     if (s->peer) {
465         p->msg.arg0 = s->peer->id;
466     }
467     p->msg.arg1 = s->id;
468     send_packet(p, s->transport);
469 }
470 
remote_socket_close(asocket * s)471 static void remote_socket_close(asocket* s) {
472     if (s->peer) {
473         s->peer->peer = nullptr;
474         D("RS(%d) peer->close()ing peer->id=%d peer->fd=%d", s->id, s->peer->id, s->peer->fd);
475         s->peer->close(s->peer);
476     }
477     D("entered remote_socket_close RS(%d) CLOSE fd=%d peer->fd=%d", s->id, s->fd,
478       s->peer ? s->peer->fd : -1);
479     D("RS(%d): closed", s->id);
480     delete s;
481 }
482 
483 // Create a remote socket to exchange packets with a remote service through transport
484 // |t|. Where |id| is the socket id of the corresponding service on the other
485 //  side of the transport (it is allocated by the remote side and _cannot_ be 0).
486 // Returns a new non-NULL asocket handle.
create_remote_socket(unsigned id,atransport * t)487 asocket* create_remote_socket(unsigned id, atransport* t) {
488     if (id == 0) {
489         LOG(FATAL) << "invalid remote socket id (0)";
490     }
491     asocket* s = new asocket();
492     s->id = id;
493     s->enqueue = remote_socket_enqueue;
494     s->ready = remote_socket_ready;
495     s->shutdown = remote_socket_shutdown;
496     s->close = remote_socket_close;
497     s->transport = t;
498 
499     D("RS(%d): created", s->id);
500     return s;
501 }
502 
connect_to_remote(asocket * s,std::string_view destination)503 void connect_to_remote(asocket* s, std::string_view destination) {
504     D("Connect_to_remote call RS(%d) fd=%d", s->id, s->fd);
505     apacket* p = get_apacket();
506 
507     LOG(VERBOSE) << "LS(" << s->id << ": connect(" << destination << ")";
508     p->msg.command = A_OPEN;
509     p->msg.arg0 = s->id;
510 
511     // adbd used to expect a null-terminated string.
512     // Keep doing so to maintain backward compatibility.
513     p->payload.resize(destination.size() + 1);
514     memcpy(p->payload.data(), destination.data(), destination.size());
515     p->payload[destination.size()] = '\0';
516     p->msg.data_length = p->payload.size();
517 
518     CHECK_LE(p->msg.data_length, s->get_max_payload());
519 
520     send_packet(p, s->transport);
521 }
522 
523 #if ADB_HOST
524 /* this is used by magic sockets to rig local sockets to
525    send the go-ahead message when they connect */
local_socket_ready_notify(asocket * s)526 static void local_socket_ready_notify(asocket* s) {
527     s->ready = local_socket_ready;
528     s->shutdown = nullptr;
529     s->close = local_socket_close;
530     SendOkay(s->fd);
531     s->ready(s);
532 }
533 
534 /* this is used by magic sockets to rig local sockets to
535    send the failure message if they are closed before
536    connected (to avoid closing them without a status message) */
local_socket_close_notify(asocket * s)537 static void local_socket_close_notify(asocket* s) {
538     s->ready = local_socket_ready;
539     s->shutdown = nullptr;
540     s->close = local_socket_close;
541     SendFail(s->fd, "closed");
542     s->close(s);
543 }
544 
unhex(const char * s,int len)545 static unsigned unhex(const char* s, int len) {
546     unsigned n = 0, c;
547 
548     while (len-- > 0) {
549         switch ((c = *s++)) {
550             case '0':
551             case '1':
552             case '2':
553             case '3':
554             case '4':
555             case '5':
556             case '6':
557             case '7':
558             case '8':
559             case '9':
560                 c -= '0';
561                 break;
562             case 'a':
563             case 'b':
564             case 'c':
565             case 'd':
566             case 'e':
567             case 'f':
568                 c = c - 'a' + 10;
569                 break;
570             case 'A':
571             case 'B':
572             case 'C':
573             case 'D':
574             case 'E':
575             case 'F':
576                 c = c - 'A' + 10;
577                 break;
578             default:
579                 return 0xffffffff;
580         }
581 
582         n = (n << 4) | c;
583     }
584 
585     return n;
586 }
587 
588 namespace internal {
589 
590 // Parses a host service string of the following format:
591 //   * [tcp:|udp:]<serial>[:<port>]:<command>
592 //   * <prefix>:<serial>:<command>
593 // Where <port> must be a base-10 number and <prefix> may be any of {usb,product,model,device}.
parse_host_service(std::string_view * out_serial,std::string_view * out_command,std::string_view full_service)594 bool parse_host_service(std::string_view* out_serial, std::string_view* out_command,
595                         std::string_view full_service) {
596     if (full_service.empty()) {
597         return false;
598     }
599 
600     std::string_view serial;
601     std::string_view command = full_service;
602     // Remove |count| bytes from the beginning of command and add them to |serial|.
603     auto consume = [&full_service, &serial, &command](size_t count) {
604         CHECK_LE(count, command.size());
605         if (!serial.empty()) {
606             CHECK_EQ(serial.data() + serial.size(), command.data());
607         }
608 
609         serial = full_service.substr(0, serial.size() + count);
610         command.remove_prefix(count);
611     };
612 
613     // Remove the trailing : from serial, and assign the values to the output parameters.
614     auto finish = [out_serial, out_command, &serial, &command] {
615         if (serial.empty() || command.empty()) {
616             return false;
617         }
618 
619         CHECK_EQ(':', serial.back());
620         serial.remove_suffix(1);
621 
622         *out_serial = serial;
623         *out_command = command;
624         return true;
625     };
626 
627     static constexpr std::string_view prefixes[] = {
628             "usb:", "product:", "model:", "device:", "localfilesystem:"};
629     for (std::string_view prefix : prefixes) {
630         if (command.starts_with(prefix)) {
631             consume(prefix.size());
632 
633             size_t offset = command.find_first_of(':');
634             if (offset == std::string::npos) {
635                 return false;
636             }
637             consume(offset + 1);
638             return finish();
639         }
640     }
641 
642     // For fastboot compatibility, ignore protocol prefixes.
643     if (command.starts_with("tcp:") || command.starts_with("udp:")) {
644         consume(4);
645         if (command.empty()) {
646             return false;
647         }
648     }
649     if (command.starts_with("vsock:")) {
650         // vsock serials are vsock:cid:port, which have an extra colon compared to tcp.
651         size_t next_colon = command.find(':');
652         if (next_colon == std::string::npos) {
653             return false;
654         }
655         consume(next_colon + 1);
656     }
657 
658     bool found_address = false;
659     if (command[0] == '[') {
660         // Read an IPv6 address. `adb connect` creates the serial number from the canonical
661         // network address so it will always have the [] delimiters.
662         size_t ipv6_end = command.find_first_of(']');
663         if (ipv6_end != std::string::npos) {
664             consume(ipv6_end + 1);
665             if (command.empty()) {
666                 // Nothing after the IPv6 address.
667                 return false;
668             } else if (command[0] != ':') {
669                 // Garbage after the IPv6 address.
670                 return false;
671             }
672             consume(1);
673             found_address = true;
674         }
675     }
676 
677     if (!found_address) {
678         // Scan ahead to the next colon.
679         size_t offset = command.find_first_of(':');
680         if (offset == std::string::npos) {
681             return false;
682         }
683         consume(offset + 1);
684     }
685 
686     // We're either at the beginning of a port, or the command itself.
687     // Look for a port in between colons.
688     size_t next_colon = command.find_first_of(':');
689     if (next_colon == std::string::npos) {
690         // No colon, we must be at the command.
691         return finish();
692     }
693 
694     bool port_valid = true;
695     if (command.size() <= next_colon) {
696         return false;
697     }
698 
699     std::string_view port = command.substr(0, next_colon);
700     for (auto digit : port) {
701         if (!isdigit(digit)) {
702             // Port isn't a number.
703             port_valid = false;
704             break;
705         }
706     }
707 
708     if (port_valid) {
709         consume(next_colon + 1);
710     }
711     return finish();
712 }
713 
714 }  // namespace internal
715 
smart_socket_enqueue(asocket * s,apacket::payload_type data)716 static int smart_socket_enqueue(asocket* s, apacket::payload_type data) {
717     std::string_view service;
718     std::string_view serial;
719     TransportId transport_id = 0;
720     TransportType type = kTransportAny;
721 
722     D("SS(%d): enqueue %zu", s->id, data.size());
723 
724     if (s->smart_socket_data.empty()) {
725         // TODO: Make this an IOVector?
726         s->smart_socket_data.assign(data.begin(), data.end());
727     } else {
728         std::copy(data.begin(), data.end(), std::back_inserter(s->smart_socket_data));
729     }
730 
731     /* don't bother if we can't decode the length */
732     if (s->smart_socket_data.size() < 4) {
733         return 0;
734     }
735 
736     uint32_t len = unhex(s->smart_socket_data.data(), 4);
737     if (len == 0 || len > MAX_PAYLOAD) {
738         D("SS(%d): bad size (%u)", s->id, len);
739         goto fail;
740     }
741 
742     D("SS(%d): len is %u", s->id, len);
743     /* can't do anything until we have the full header */
744     if ((len + 4) > s->smart_socket_data.size()) {
745         D("SS(%d): waiting for %zu more bytes", s->id, len + 4 - s->smart_socket_data.size());
746         return 0;
747     }
748 
749     s->smart_socket_data[len + 4] = 0;
750 
751     D("SS(%d): '%s'", s->id, (char*)(s->smart_socket_data.data() + 4));
752 
753     service = std::string_view(s->smart_socket_data).substr(4);
754 
755     // TODO: These should be handled in handle_host_request.
756     if (android::base::ConsumePrefix(&service, "host-serial:")) {
757         // serial number should follow "host:" and could be a host:port string.
758         if (!internal::parse_host_service(&serial, &service, service)) {
759             LOG(ERROR) << "SS(" << s->id << "): failed to parse host service: " << service;
760             goto fail;
761         }
762     } else if (android::base::ConsumePrefix(&service, "host-transport-id:")) {
763         if (!ParseUint(&transport_id, service, &service)) {
764             LOG(ERROR) << "SS(" << s->id << "): failed to parse host transport id: " << service;
765             return -1;
766         }
767         if (!android::base::ConsumePrefix(&service, ":")) {
768             LOG(ERROR) << "SS(" << s->id << "): host-transport-id without command";
769             return -1;
770         }
771     } else if (android::base::ConsumePrefix(&service, "host-usb:")) {
772         type = kTransportUsb;
773     } else if (android::base::ConsumePrefix(&service, "host-local:")) {
774         type = kTransportLocal;
775     } else if (android::base::ConsumePrefix(&service, "host:")) {
776         type = kTransportAny;
777     } else {
778         service = std::string_view{};
779     }
780 
781     if (!service.empty()) {
782         asocket* s2;
783 
784         // Some requests are handled immediately -- in that case the handle_host_request() routine
785         // has sent the OKAY or FAIL message and all we have to do is clean up.
786         auto host_request_result = handle_host_request(
787                 service, type, serial.empty() ? nullptr : std::string(serial).c_str(), transport_id,
788                 s->peer->fd, s);
789 
790         switch (host_request_result) {
791             case HostRequestResult::Handled:
792                 LOG(VERBOSE) << "SS(" << s->id << "): handled host service '" << service << "'";
793                 goto fail;
794 
795             case HostRequestResult::SwitchedTransport:
796                 D("SS(%d): okay transport", s->id);
797                 s->smart_socket_data.clear();
798                 return 0;
799 
800             case HostRequestResult::Unhandled:
801                 break;
802         }
803 
804         /* try to find a local service with this name.
805         ** if no such service exists, we'll fail out
806         ** and tear down here.
807         */
808         // TODO: Convert to string_view.
809         s2 = host_service_to_socket(service, serial, transport_id);
810         if (s2 == nullptr) {
811             LOG(VERBOSE) << "SS(" << s->id << "): couldn't create host service '" << service << "'";
812             SendFail(s->peer->fd, "unknown host service");
813             goto fail;
814         }
815 
816         /* we've connected to a local host service,
817         ** so we make our peer back into a regular
818         ** local socket and bind it to the new local
819         ** service socket, acknowledge the successful
820         ** connection, and close this smart socket now
821         ** that its work is done.
822         */
823         SendOkay(s->peer->fd);
824 
825         s->peer->ready = local_socket_ready;
826         s->peer->shutdown = nullptr;
827         s->peer->close = local_socket_close;
828         s->peer->peer = s2;
829         s2->peer = s->peer;
830         s->peer = nullptr;
831         D("SS(%d): okay", s->id);
832         s->close(s);
833 
834         /* initial state is "ready" */
835         s2->ready(s2);
836         return 0;
837     }
838 
839     if (!s->transport) {
840         SendFail(s->peer->fd, "device offline (no transport)");
841         goto fail;
842     } else if (!ConnectionStateIsOnline(s->transport->GetConnectionState())) {
843         /* if there's no remote we fail the connection
844          ** right here and terminate it
845          */
846         SendFail(s->peer->fd, "device offline (transport offline)");
847         goto fail;
848     }
849 
850     /* instrument our peer to pass the success or fail
851     ** message back once it connects or closes, then
852     ** detach from it, request the connection, and
853     ** tear down
854     */
855     s->peer->ready = local_socket_ready_notify;
856     s->peer->shutdown = nullptr;
857     s->peer->close = local_socket_close_notify;
858     s->peer->peer = nullptr;
859     /* give him our transport and upref it */
860     s->peer->transport = s->transport;
861 
862     connect_to_remote(s->peer, std::string_view(s->smart_socket_data).substr(4));
863     s->peer = nullptr;
864     s->close(s);
865     return 1;
866 
867 fail:
868     /* we're going to close our peer as a side-effect, so
869     ** return -1 to signal that state to the local socket
870     ** who is enqueueing against us
871     */
872     s->close(s);
873     return -1;
874 }
875 
smart_socket_ready(asocket * s)876 static void smart_socket_ready(asocket* s) {
877     D("SS(%d): ready", s->id);
878 }
879 
smart_socket_close(asocket * s)880 static void smart_socket_close(asocket* s) {
881     D("SS(%d): closed", s->id);
882     if (s->peer) {
883         s->peer->peer = nullptr;
884         s->peer->close(s->peer);
885         s->peer = nullptr;
886     }
887     delete s;
888 }
889 
create_smart_socket(void)890 static asocket* create_smart_socket(void) {
891     D("Creating smart socket");
892     asocket* s = new asocket();
893     s->enqueue = smart_socket_enqueue;
894     s->ready = smart_socket_ready;
895     s->shutdown = nullptr;
896     s->close = smart_socket_close;
897 
898     D("SS(%d)", s->id);
899     return s;
900 }
901 
connect_to_smartsocket(asocket * s)902 void connect_to_smartsocket(asocket* s) {
903     D("Connecting to smart socket");
904     asocket* ss = create_smart_socket();
905     s->peer = ss;
906     ss->peer = s;
907     s->ready(s);
908 }
909 #endif
910 
get_max_payload() const911 size_t asocket::get_max_payload() const {
912     size_t max_payload = MAX_PAYLOAD;
913     if (transport) {
914         max_payload = std::min(max_payload, transport->get_max_payload());
915     }
916     if (peer && peer->transport) {
917         max_payload = std::min(max_payload, peer->transport->get_max_payload());
918     }
919     return max_payload;
920 }
921