1 /*
2 * Copyright (C) 2019 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <https/ClientSocket.h>
18
19 #include <https/HTTPServer.h>
20 #include <https/RunLoop.h>
21 #include <https/SafeCallbackable.h>
22 #include <https/ServerSocket.h>
23
24 #include <android-base/logging.h>
25
26 #include <cstdlib>
27
ClientSocket(std::shared_ptr<RunLoop> rl,HTTPServer * server,ServerSocket * parent,const sockaddr_in & addr,int sock)28 ClientSocket::ClientSocket(
29 std::shared_ptr<RunLoop> rl,
30 HTTPServer *server,
31 ServerSocket *parent,
32 const sockaddr_in &addr,
33 int sock)
34 : mRunLoop(rl),
35 mServer(server),
36 mParent(parent),
37 mRemoteAddr(addr),
38 mInBufferLen(0),
39 mSendPending(false),
40 mDisconnecting(false) {
41 if (parent->transportType() == ServerSocket::TransportType::TLS) {
42 mImplSSL = std::make_shared<SSLSocket>(
43 mRunLoop,
44 sock,
45 *server->certificate_pem_path(),
46 *server->private_key_pem_path());
47 } else {
48 mImplPlain = std::make_shared<PlainSocket>(mRunLoop, sock);
49 }
50 }
51
run()52 void ClientSocket::run() {
53 getImpl()->postRecv(makeSafeCallback(this, &ClientSocket::handleIncomingData));
54 }
55
fd() const56 int ClientSocket::fd() const {
57 return getImpl()->fd();
58 }
59
setWebSocketHandler(std::shared_ptr<WebSocketHandler> handler)60 void ClientSocket::setWebSocketHandler(
61 std::shared_ptr<WebSocketHandler> handler) {
62 mWebSocketHandler = handler;
63 mWebSocketHandler->setClientSocket(shared_from_this());
64 }
65
handleIncomingData()66 void ClientSocket::handleIncomingData() {
67 mInBuffer.resize(mInBufferLen + 1024);
68
69 ssize_t n;
70 do {
71 n = getImpl()->recv(mInBuffer.data() + mInBufferLen, 1024);
72 } while (n < 0 && errno == EINTR);
73
74 if (n == 0) {
75 if (errno == 0) {
76 // Don't process any data if there was an actual failure.
77 // This could be an authentication failure for example...
78 // We shouldn't trust anything the client says.
79 (void)handleRequest(true /* sawEOS */);
80 }
81
82 disconnect();
83 return;
84 } else if (n < 0) {
85 LOG(ERROR)
86 << "recv returned error "
87 << errno
88 << " ("
89 << strerror(errno)
90 << ")";
91
92 mParent->onClientSocketClosed(fd());
93 return;
94 }
95
96 mInBufferLen += static_cast<size_t>(n);
97 const bool closeConnection = handleRequest(false /* sawEOS */);
98
99 if (closeConnection) {
100 disconnect();
101 } else {
102 getImpl()->postRecv(
103 makeSafeCallback(this, &ClientSocket::handleIncomingData));
104 }
105 }
106
disconnect()107 void ClientSocket::disconnect() {
108 if (mDisconnecting) {
109 return;
110 }
111
112 mDisconnecting = true;
113
114 finishDisconnect();
115 }
116
finishDisconnect()117 void ClientSocket::finishDisconnect() {
118 if (!mSendPending) {
119 // Our output queue may now be empty, but the underlying socket
120 // implementation may still buffer something that we need to flush
121 // first.
122 getImpl()->postFlush(
123 makeSafeCallback<ClientSocket>(this, [](ClientSocket *me) {
124 me->mParent->onClientSocketClosed(me->fd());
125 }));
126 }
127 }
128
handleRequest(bool isEOS)129 bool ClientSocket::handleRequest(bool isEOS) {
130 if (mWebSocketHandler) {
131 ssize_t n = mWebSocketHandler->handleRequest(
132 mInBuffer.data(), mInBufferLen, isEOS);
133
134 LOG(VERBOSE)
135 << "handleRequest returned "
136 << n
137 << " when called with "
138 << mInBufferLen
139 << ", eos="
140 << isEOS;
141
142 if (n > 0) {
143 mInBuffer.erase(mInBuffer.begin(), mInBuffer.begin() + n);
144 mInBufferLen -= n;
145 }
146
147 // NOTE: Do not return true, i.e. disconnect, if the json handler
148 // returns 0 bytes read, it simply means we need more data to continue.
149 return n < 0;
150 }
151
152 size_t len = mInBufferLen;
153
154 if (!isEOS) {
155 static const char kPattern[] = "\r\n\r\n";
156
157 // Don't count the trailing NUL.
158 static constexpr size_t kPatternLength = sizeof(kPattern) - 1;
159
160 size_t i = 0;
161 while (i + kPatternLength <= mInBufferLen
162 && memcmp(mInBuffer.data() + i, kPattern, kPatternLength)) {
163 ++i;
164 }
165
166 if (i + kPatternLength > mInBufferLen) {
167 return false;
168 }
169
170 // Found a match.
171 len = i + kPatternLength;
172 }
173
174 const bool closeConnection =
175 mServer->handleSingleRequest(this, mInBuffer.data(), len, isEOS);
176
177 mInBuffer.clear();
178 mInBufferLen = 0;
179
180 return closeConnection;
181 }
182
queueOutputData(const uint8_t * data,size_t size)183 void ClientSocket::queueOutputData(const uint8_t *data, size_t size) {
184 std::copy(data, data + size, std::back_inserter(mOutBuffer));
185
186 if (!mSendPending) {
187 mSendPending = true;
188 getImpl()->postSend(makeSafeCallback(this, &ClientSocket::sendOutputData));
189 }
190 }
191
remoteAddr() const192 sockaddr_in ClientSocket::remoteAddr() const {
193 return mRemoteAddr;
194 }
195
queueResponse(const std::string & response,const std::string & body)196 void ClientSocket::queueResponse(
197 const std::string &response, const std::string &body) {
198 std::copy(response.begin(), response.end(), std::back_inserter(mOutBuffer));
199 std::copy(body.begin(), body.end(), std::back_inserter(mOutBuffer));
200
201 if (!mSendPending) {
202 mSendPending = true;
203 getImpl()->postSend(makeSafeCallback(this, &ClientSocket::sendOutputData));
204 }
205 }
206
sendOutputData()207 void ClientSocket::sendOutputData() {
208 mSendPending = false;
209
210 const size_t size = mOutBuffer.size();
211 size_t offset = 0;
212
213 while (offset < size) {
214 ssize_t n = getImpl()->send(mOutBuffer.data() + offset, size - offset);
215
216 if (n < 0) {
217 if (errno == EINTR) {
218 continue;
219 }
220
221 assert(!"Should not be here");
222 } else if (n == 0) {
223 // The remote seems gone, clear the output buffer and disconnect.
224 offset = size;
225 mDisconnecting = true;
226 break;
227 }
228
229 offset += static_cast<size_t>(n);
230 }
231
232 mOutBuffer.erase(mOutBuffer.begin(), mOutBuffer.begin() + offset);
233
234 if (!mOutBuffer.empty()) {
235 mSendPending = true;
236 getImpl()->postSend(makeSafeCallback(this, &ClientSocket::sendOutputData));
237 return;
238 }
239
240 if (mDisconnecting) {
241 finishDisconnect();
242 }
243 }
244
245