1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <https/ClientSocket.h>
18 
19 #include <https/HTTPServer.h>
20 #include <https/RunLoop.h>
21 #include <https/SafeCallbackable.h>
22 #include <https/ServerSocket.h>
23 
24 #include <android-base/logging.h>
25 
26 #include <cstdlib>
27 
ClientSocket(std::shared_ptr<RunLoop> rl,HTTPServer * server,ServerSocket * parent,const sockaddr_in & addr,int sock)28 ClientSocket::ClientSocket(
29         std::shared_ptr<RunLoop> rl,
30         HTTPServer *server,
31         ServerSocket *parent,
32         const sockaddr_in &addr,
33         int sock)
34     : mRunLoop(rl),
35       mServer(server),
36       mParent(parent),
37       mRemoteAddr(addr),
38       mInBufferLen(0),
39       mSendPending(false),
40       mDisconnecting(false) {
41     if (parent->transportType() == ServerSocket::TransportType::TLS) {
42         mImplSSL = std::make_shared<SSLSocket>(
43                 mRunLoop,
44                 sock,
45                 *server->certificate_pem_path(),
46                 *server->private_key_pem_path());
47     } else {
48         mImplPlain = std::make_shared<PlainSocket>(mRunLoop, sock);
49     }
50 }
51 
run()52 void ClientSocket::run() {
53     getImpl()->postRecv(makeSafeCallback(this, &ClientSocket::handleIncomingData));
54 }
55 
fd() const56 int ClientSocket::fd() const {
57     return getImpl()->fd();
58 }
59 
setWebSocketHandler(std::shared_ptr<WebSocketHandler> handler)60 void ClientSocket::setWebSocketHandler(
61         std::shared_ptr<WebSocketHandler> handler) {
62     mWebSocketHandler = handler;
63     mWebSocketHandler->setClientSocket(shared_from_this());
64 }
65 
handleIncomingData()66 void ClientSocket::handleIncomingData() {
67     mInBuffer.resize(mInBufferLen + 1024);
68 
69     ssize_t n;
70     do {
71         n = getImpl()->recv(mInBuffer.data() + mInBufferLen, 1024);
72     } while (n < 0 && errno == EINTR);
73 
74     if (n == 0) {
75         if (errno == 0) {
76             // Don't process any data if there was an actual failure.
77             // This could be an authentication failure for example...
78             // We shouldn't trust anything the client says.
79             (void)handleRequest(true /* sawEOS */);
80         }
81 
82         disconnect();
83         return;
84     } else if (n < 0) {
85         LOG(ERROR)
86             << "recv returned error "
87             << errno
88             << " ("
89             << strerror(errno)
90             << ")";
91 
92         mParent->onClientSocketClosed(fd());
93         return;
94     }
95 
96     mInBufferLen += static_cast<size_t>(n);
97     const bool closeConnection = handleRequest(false /* sawEOS */);
98 
99     if (closeConnection) {
100         disconnect();
101     } else {
102         getImpl()->postRecv(
103                 makeSafeCallback(this, &ClientSocket::handleIncomingData));
104     }
105 }
106 
disconnect()107 void ClientSocket::disconnect() {
108     if (mDisconnecting) {
109         return;
110     }
111 
112     mDisconnecting = true;
113 
114     finishDisconnect();
115 }
116 
finishDisconnect()117 void ClientSocket::finishDisconnect() {
118     if (!mSendPending) {
119         // Our output queue may now be empty, but the underlying socket
120         // implementation may still buffer something that we need to flush
121         // first.
122         getImpl()->postFlush(
123                 makeSafeCallback<ClientSocket>(this, [](ClientSocket *me) {
124                     me->mParent->onClientSocketClosed(me->fd());
125                 }));
126     }
127 }
128 
handleRequest(bool isEOS)129 bool ClientSocket::handleRequest(bool isEOS) {
130     if (mWebSocketHandler) {
131         ssize_t n = mWebSocketHandler->handleRequest(
132                 mInBuffer.data(), mInBufferLen, isEOS);
133 
134         LOG(VERBOSE)
135             << "handleRequest returned "
136             << n
137             << " when called with "
138             << mInBufferLen
139             << ", eos="
140             << isEOS;
141 
142         if (n > 0) {
143             mInBuffer.erase(mInBuffer.begin(), mInBuffer.begin() + n);
144             mInBufferLen -= n;
145         }
146 
147         // NOTE: Do not return true, i.e. disconnect, if the json handler
148         // returns 0 bytes read, it simply means we need more data to continue.
149         return n < 0;
150     }
151 
152     size_t len = mInBufferLen;
153 
154     if (!isEOS) {
155         static const char kPattern[] = "\r\n\r\n";
156 
157         // Don't count the trailing NUL.
158         static constexpr size_t kPatternLength = sizeof(kPattern) - 1;
159 
160         size_t i = 0;
161         while (i + kPatternLength <= mInBufferLen
162                 && memcmp(mInBuffer.data() + i, kPattern, kPatternLength)) {
163             ++i;
164         }
165 
166         if (i + kPatternLength > mInBufferLen) {
167             return false;
168         }
169 
170         // Found a match.
171         len = i + kPatternLength;
172     }
173 
174     const bool closeConnection =
175         mServer->handleSingleRequest(this, mInBuffer.data(), len, isEOS);
176 
177     mInBuffer.clear();
178     mInBufferLen = 0;
179 
180     return closeConnection;
181 }
182 
queueOutputData(const uint8_t * data,size_t size)183 void ClientSocket::queueOutputData(const uint8_t *data, size_t size) {
184     std::copy(data, data + size, std::back_inserter(mOutBuffer));
185 
186     if (!mSendPending) {
187         mSendPending = true;
188         getImpl()->postSend(makeSafeCallback(this, &ClientSocket::sendOutputData));
189     }
190 }
191 
remoteAddr() const192 sockaddr_in ClientSocket::remoteAddr() const {
193     return mRemoteAddr;
194 }
195 
queueResponse(const std::string & response,const std::string & body)196 void ClientSocket::queueResponse(
197         const std::string &response, const std::string &body) {
198     std::copy(response.begin(), response.end(), std::back_inserter(mOutBuffer));
199     std::copy(body.begin(), body.end(), std::back_inserter(mOutBuffer));
200 
201     if (!mSendPending) {
202         mSendPending = true;
203         getImpl()->postSend(makeSafeCallback(this, &ClientSocket::sendOutputData));
204     }
205 }
206 
sendOutputData()207 void ClientSocket::sendOutputData() {
208     mSendPending = false;
209 
210     const size_t size = mOutBuffer.size();
211     size_t offset = 0;
212 
213     while (offset < size) {
214         ssize_t n = getImpl()->send(mOutBuffer.data() + offset, size - offset);
215 
216         if (n < 0) {
217             if (errno == EINTR) {
218                 continue;
219             }
220 
221             assert(!"Should not be here");
222         } else if (n == 0) {
223             // The remote seems gone, clear the output buffer and disconnect.
224             offset = size;
225             mDisconnecting = true;
226             break;
227         }
228 
229         offset += static_cast<size_t>(n);
230     }
231 
232     mOutBuffer.erase(mOutBuffer.begin(), mOutBuffer.begin() + offset);
233 
234     if (!mOutBuffer.empty()) {
235         mSendPending = true;
236         getImpl()->postSend(makeSafeCallback(this, &ClientSocket::sendOutputData));
237         return;
238     }
239 
240     if (mDisconnecting) {
241         finishDisconnect();
242     }
243 }
244 
245