1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <https/HTTPClientConnection.h>
18 
19 #include <https/HTTPRequestResponse.h>
20 #include <https/PlainSocket.h>
21 #include <https/RunLoop.h>
22 #include <https/SafeCallbackable.h>
23 #include <https/SSLSocket.h>
24 
25 #include <https/Support.h>
26 
27 #include <android-base/logging.h>
28 
29 #include <arpa/inet.h>
30 #include <cerrno>
31 #include <netinet/in.h>
32 #include <unistd.h>
33 
34 using namespace android;
35 
HTTPClientConnection(std::shared_ptr<RunLoop> rl,std::shared_ptr<WebSocketHandler> webSocketHandler,std::string_view path,ServerSocket::TransportType transportType,const std::optional<std::string> & trusted_pem_path)36 HTTPClientConnection::HTTPClientConnection(
37         std::shared_ptr<RunLoop> rl,
38         std::shared_ptr<WebSocketHandler> webSocketHandler,
39         std::string_view path,
40         ServerSocket::TransportType transportType,
41         const std::optional<std::string> &trusted_pem_path)
42     : mInitCheck(-ENODEV),
43       mRunLoop(rl),
44       mWebSocketHandler(webSocketHandler),
45       mPath(path),
46       mTransportType(transportType),
47       mSendPending(false),
48       mInBufferLen(0),
49       mWebSocketMode(false) {
50     int sock;
51 
52     sock = socket(PF_INET, SOCK_STREAM, 0);
53     if (sock < 0) {
54         mInitCheck = -errno;
55         goto bail;
56     }
57 
58     makeFdNonblocking(sock);
59 
60     if (mTransportType == ServerSocket::TransportType::TLS) {
61         CHECK(trusted_pem_path.has_value());
62 
63         mImpl = std::make_shared<SSLSocket>(
64                 mRunLoop, sock, 0 /* flags */, *trusted_pem_path);
65     } else {
66         mImpl = std::make_shared<PlainSocket>(mRunLoop, sock);
67     }
68 
69     mInitCheck = 0;
70     return;
71 
72 bail:
73     ;
74 }
75 
initCheck() const76 int HTTPClientConnection::initCheck() const {
77     return mInitCheck;
78 }
79 
connect(const char * host,uint16_t port)80 int HTTPClientConnection::connect(const char *host, uint16_t port) {
81     if (mInitCheck < 0) {
82         return mInitCheck;
83     }
84 
85     sockaddr_in addr;
86     memset(addr.sin_zero, 0, sizeof(addr.sin_zero));
87     addr.sin_family = AF_INET;
88     addr.sin_port = htons(port);
89     addr.sin_addr.s_addr = inet_addr(host);
90 
91     mRemoteAddr = addr;
92 
93     int res = ::connect(
94             mImpl->fd(), reinterpret_cast<sockaddr *>(&addr), sizeof(addr));
95 
96     if (res < 0 && errno != EINPROGRESS) {
97         return -errno;
98     }
99 
100     mImpl->postSend(makeSafeCallback(this, &HTTPClientConnection::sendRequest));
101 
102     return 0;
103 }
104 
sendRequest()105 void HTTPClientConnection::sendRequest() {
106     std::string request;
107     request =
108         "GET " + mPath + " HTTP/1.1\r\n"
109         "Connection: Upgrade\r\n"
110         "Upgrade: websocket\r\n"
111         "Sec-WebSocket-Version: 13\r\n"
112         "Sec-WebSocket-Key: foobar\r\n"
113         "\r\n";
114 
115     CHECK(mRunLoop->isCurrentThread());
116     std::copy(request.begin(), request.end(), std::back_inserter(mOutBuffer));
117 
118     if (!mSendPending) {
119         mSendPending = true;
120         mImpl->postSend(
121                 makeSafeCallback(this, &HTTPClientConnection::sendOutputData));
122     }
123 
124     mImpl->postRecv(
125             makeSafeCallback(this, &HTTPClientConnection::receiveResponse));
126 }
127 
receiveResponse()128 void HTTPClientConnection::receiveResponse() {
129     mInBuffer.resize(mInBufferLen + 1024);
130 
131     ssize_t n;
132     do {
133         n = mImpl->recv(mInBuffer.data() + mInBufferLen, 1024);
134     } while (n < 0 && errno == EINTR);
135 
136     if (n == 0) {
137         (void)handleResponse(true /* isEOS */);
138         return;
139     } else if (n < 0) {
140         LOG(ERROR) << "recv returned error '" << strerror(errno) << "'.";
141         return;
142     }
143 
144     mInBufferLen += static_cast<size_t>(n);
145 
146     if (!handleResponse(false /* isEOS */)) {
147         mImpl->postRecv(
148                 makeSafeCallback(this, &HTTPClientConnection::receiveResponse));
149     }
150 }
151 
handleResponse(bool isEOS)152 bool HTTPClientConnection::handleResponse(bool isEOS) {
153     if (mWebSocketMode) {
154         ssize_t n = mWebSocketHandler->handleRequest(
155                 mInBuffer.data(), mInBufferLen, isEOS);
156 
157         if (n > 0) {
158             mInBuffer.erase(mInBuffer.begin(), mInBuffer.begin() + n);
159             mInBufferLen -= n;
160         }
161 
162         return n <= 0;
163     }
164 
165     size_t len = mInBufferLen;
166 
167     if (!isEOS) {
168         static const char kPattern[] = "\r\n\r\n";
169 
170         // Don't count the trailing NUL.
171         static constexpr size_t kPatternLength = sizeof(kPattern) - 1;
172 
173         size_t i = 0;
174         while (i + kPatternLength <= mInBufferLen
175                 && memcmp(mInBuffer.data() + i, kPattern, kPatternLength)) {
176             ++i;
177         }
178 
179         if (i + kPatternLength > mInBufferLen) {
180             return false;
181         }
182 
183         // Found a match.
184         len = i + kPatternLength;
185     }
186 
187     HTTPResponse response;
188     if (response.setTo(mInBuffer.data(), len) < 0) {
189         LOG(ERROR) << "failed to get valid server response.";
190 
191         mInBuffer.clear();
192         mInBufferLen = 0;
193 
194         return true;
195     } else {
196         LOG(INFO)
197             << "got response: "
198             << response.getVersion()
199             << ", "
200             << response.getStatusCode()
201             << ", "
202             << response.getStatusMessage();
203 
204         LOG(INFO) << hexdump(mInBuffer.data(), len);
205 
206         mInBuffer.erase(mInBuffer.begin(), mInBuffer.begin() + len);
207         mInBufferLen -= len;
208 
209         size_t contentLength = response.getContentLength();
210         LOG(VERBOSE) << "contentLength = " << contentLength;
211         assert(mInBufferLen >= contentLength);
212 
213         LOG(INFO) << hexdump(mInBuffer.data(), contentLength);
214         mInBuffer.clear();
215 
216         if (response.getStatusCode() == 101) {
217             mWebSocketMode = true;
218 
219             mWebSocketHandler->setOutputCallback(
220                     mRemoteAddr,
221                     [this](const uint8_t *data, size_t size) {
222                         queueOutputData(data, size);
223                     });
224 
225             const std::string msg = "\"message\":\"Hellow, world!\"";
226             mWebSocketHandler->sendMessage(msg.c_str(), msg.size());
227 
228             return false;
229         }
230     }
231 
232     return true;
233 }
234 
queueOutputData(const uint8_t * data,size_t size)235 void HTTPClientConnection::queueOutputData(const uint8_t *data, size_t size) {
236     CHECK(mRunLoop->isCurrentThread());
237     std::copy(data, &data[size], std::back_inserter(mOutBuffer));
238 
239     if (!mSendPending) {
240         mSendPending = true;
241         mImpl->postSend(
242                 makeSafeCallback(this, &HTTPClientConnection::sendOutputData));
243     }
244 }
245 
sendOutputData()246 void HTTPClientConnection::sendOutputData() {
247     mSendPending = false;
248 
249     const size_t size = mOutBuffer.size();
250     size_t offset = 0;
251 
252     while (offset < size) {
253         ssize_t n = mImpl->send(mOutBuffer.data() + offset, size - offset);
254 
255         if (n < 0) {
256             if (errno == EINTR) {
257                 continue;
258             }
259 
260             if (errno == EAGAIN) {
261                 break;
262             }
263 
264             // The remote is gone (due to error), clear the output buffer and disconnect.
265             offset = size;
266             break;
267         } else if (n == 0) {
268             // The remote seems gone, clear the output buffer and disconnect.
269             offset = size;
270             break;
271         }
272 
273         offset += static_cast<size_t>(n);
274     }
275 
276     mOutBuffer.erase(mOutBuffer.begin(), mOutBuffer.begin() + offset);
277 
278     if (!mOutBuffer.empty()) {
279         mSendPending = true;
280         mImpl->postSend(
281                 makeSafeCallback(this, &HTTPClientConnection::sendOutputData));
282 
283         return;
284     }
285 }
286 
287