1 /*
2  * Copyright (C) 2020 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "tcp_client.h"
18 #include "constants.h"
19 
20 #include <android-base/errors.h>
21 #include <android-base/logging.h>
22 #include <android-base/parseint.h>
23 #include <android-base/properties.h>
24 #include <android-base/stringprintf.h>
25 #include <android-base/strings.h>
26 
27 static constexpr int kDefaultPort = 5554;
28 static constexpr int kProtocolVersion = 1;
29 static constexpr int kHandshakeTimeoutMs = 2000;
30 static constexpr size_t kHandshakeLength = 4;
31 
32 // Extract the big-endian 8-byte message length into a 64-bit number.
ExtractMessageLength(const void * buffer)33 static uint64_t ExtractMessageLength(const void* buffer) {
34     uint64_t ret = 0;
35     for (int i = 0; i < 8; ++i) {
36         ret |= uint64_t{reinterpret_cast<const uint8_t*>(buffer)[i]} << (56 - i * 8);
37     }
38     return ret;
39 }
40 
41 // Encode the 64-bit number into a big-endian 8-byte message length.
EncodeMessageLength(uint64_t length,void * buffer)42 static void EncodeMessageLength(uint64_t length, void* buffer) {
43     for (int i = 0; i < 8; ++i) {
44         reinterpret_cast<uint8_t*>(buffer)[i] = length >> (56 - i * 8);
45     }
46 }
47 
ClientTcpTransport()48 ClientTcpTransport::ClientTcpTransport() {
49     service_ = Socket::NewServer(Socket::Protocol::kTcp, kDefaultPort);
50 
51     // A workaround to notify recovery to continue its work.
52     android::base::SetProperty("sys.usb.ffs.ready", "1");
53 }
54 
Read(void * data,size_t len)55 ssize_t ClientTcpTransport::Read(void* data, size_t len) {
56     if (len > SSIZE_MAX) {
57         return -1;
58     }
59 
60     size_t total_read = 0;
61     do {
62         // Read a new message
63         while (message_bytes_left_ == 0) {
64             if (socket_ == nullptr) {
65                 ListenFastbootSocket();
66             }
67 
68             char buffer[8];
69             if (socket_->ReceiveAll(buffer, 8, 0) == 8) {
70                 message_bytes_left_ = ExtractMessageLength(buffer);
71             } else {
72                 // If connection is closed by host, Receive will return 0 immediately.
73                 socket_.reset(nullptr);
74                 // In DATA phase, return error.
75                 if (downloading_) {
76                     return -1;
77                 }
78             }
79         }
80 
81         size_t read_length = len - total_read;
82         if (read_length > message_bytes_left_) {
83             read_length = message_bytes_left_;
84         }
85         ssize_t bytes_read =
86                 socket_->ReceiveAll(reinterpret_cast<char*>(data) + total_read, read_length, 0);
87         if (bytes_read == -1) {
88             socket_.reset(nullptr);
89             return -1;
90         } else {
91             message_bytes_left_ -= bytes_read;
92             total_read += bytes_read;
93         }
94     // There are more than one DATA phases if the downloading buffer is too
95     // large, like a very big system image. All of data phases should be
96     // received until the whole buffer is filled in that case.
97     } while (downloading_ && total_read < len);
98 
99     return total_read;
100 }
101 
Write(const void * data,size_t len)102 ssize_t ClientTcpTransport::Write(const void* data, size_t len) {
103     if (socket_ == nullptr || len > SSIZE_MAX) {
104         return -1;
105     }
106 
107     // Use multi-buffer writes for better performance.
108     char header[8];
109     EncodeMessageLength(len, header);
110 
111     if (!socket_->Send(std::vector<cutils_socket_buffer_t>{{header, 8}, {data, len}})) {
112         socket_.reset(nullptr);
113         return -1;
114     }
115 
116     // In DATA phase
117     if (android::base::StartsWith(reinterpret_cast<const char*>(data), RESPONSE_DATA)) {
118         downloading_ = true;
119     } else {
120         downloading_ = false;
121     }
122 
123     return len;
124 }
125 
Close()126 int ClientTcpTransport::Close() {
127     if (socket_ == nullptr) {
128         return -1;
129     }
130     socket_.reset(nullptr);
131 
132     return 0;
133 }
134 
Reset()135 int ClientTcpTransport::Reset() {
136     return Close();
137 }
138 
ListenFastbootSocket()139 void ClientTcpTransport::ListenFastbootSocket() {
140     while (true) {
141         socket_ = service_->Accept();
142 
143         // Handshake
144         char buffer[kHandshakeLength + 1];
145         buffer[kHandshakeLength] = '\0';
146         if (socket_->ReceiveAll(buffer, kHandshakeLength, kHandshakeTimeoutMs) !=
147             kHandshakeLength) {
148             PLOG(ERROR) << "No Handshake message received";
149             socket_.reset(nullptr);
150             continue;
151         }
152 
153         if (memcmp(buffer, "FB", 2) != 0) {
154             PLOG(ERROR) << "Unrecognized initialization message";
155             socket_.reset(nullptr);
156             continue;
157         }
158 
159         int version = 0;
160         if (!android::base::ParseInt(buffer + 2, &version) || version < kProtocolVersion) {
161             LOG(ERROR) << "Unknown TCP protocol version " << buffer + 2
162                        << ", our version: " << kProtocolVersion;
163             socket_.reset(nullptr);
164             continue;
165         }
166 
167         std::string handshake_message(android::base::StringPrintf("FB%02d", kProtocolVersion));
168         if (!socket_->Send(handshake_message.c_str(), kHandshakeLength)) {
169             PLOG(ERROR) << "Failed to send initialization message";
170             socket_.reset(nullptr);
171             continue;
172         }
173 
174         break;
175     }
176 }
177