1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <https/WebSocketHandler.h>
18 
19 #include <https/ClientSocket.h>
20 #include <https/Support.h>
21 
22 #include <iostream>
23 #include <sstream>
24 
25 #include <string.h>
26 
handleRequest(uint8_t * data,size_t size,bool isEOS)27 ssize_t WebSocketHandler::handleRequest(
28         uint8_t *data, size_t size, bool isEOS) {
29     (void)isEOS;
30 
31     size_t offset = 0;
32     while (offset + 1 < size) {
33         uint8_t *packet = &data[offset];
34         const size_t avail = size - offset;
35 
36         size_t packetOffset = 0;
37         const uint8_t headerByte = packet[packetOffset];
38 
39         const bool hasMask = (packet[packetOffset + 1] & 0x80) != 0;
40         size_t payloadLen = packet[packetOffset + 1] & 0x7f;
41         packetOffset += 2;
42 
43         if (payloadLen == 126) {
44             if (packetOffset + 1 >= avail) {
45                 break;
46             }
47 
48             payloadLen = U16_AT(&packet[packetOffset]);
49             packetOffset += 2;
50         } else if (payloadLen == 127) {
51             if (packetOffset + 7 >= avail) {
52                 break;
53             }
54 
55             payloadLen = U64_AT(&packet[packetOffset]);
56             packetOffset += 8;
57         }
58 
59         uint32_t mask = 0;
60         if (hasMask) {
61             if (packetOffset + 3 >= avail) {
62                 break;
63             }
64 
65             mask = U32_AT(&packet[packetOffset]);
66             packetOffset += 4;
67         }
68 
69         if (packetOffset + payloadLen > avail) {
70             break;
71         }
72 
73         if (mask) {
74             for (size_t i = 0; i < payloadLen; ++i) {
75                 packet[packetOffset + i] ^= ((mask >> (8 * (3 - (i % 4)))) & 0xff);
76             }
77         }
78 
79         int err = 0;
80         bool is_control_frame = (headerByte & 0x08) != 0;
81         if (is_control_frame) {
82           uint8_t opcode = headerByte & 0x0f;
83           if (opcode == 0x9 /*ping*/) {
84             sendMessage(&packet[packetOffset], payloadLen, SendMode::pong);
85           } else if (opcode == 0x8 /*close*/) {
86             return -1;
87           }
88         } else {
89           err = handleMessage(headerByte, &packet[packetOffset], payloadLen);
90         }
91 
92         offset += packetOffset + payloadLen;
93 
94         if (err < 0) {
95             return err;
96         }
97     }
98 
99     return offset;
100 }
101 
isConnected()102 bool WebSocketHandler::isConnected() {
103     return mOutputCallback != nullptr || mClientSocket.lock() != nullptr;
104 }
105 
setClientSocket(std::weak_ptr<ClientSocket> clientSocket)106 void WebSocketHandler::setClientSocket(std::weak_ptr<ClientSocket> clientSocket) {
107     mClientSocket = clientSocket;
108 }
109 
setOutputCallback(const sockaddr_in & remoteAddr,OutputCallback fn)110 void WebSocketHandler::setOutputCallback(
111         const sockaddr_in &remoteAddr, OutputCallback fn) {
112     mOutputCallback = fn;
113     mRemoteAddr = remoteAddr;
114 }
115 
handleMessage(uint8_t headerByte,const uint8_t * msg,size_t len)116 int WebSocketHandler::handleMessage(
117         uint8_t headerByte, const uint8_t *msg, size_t len) {
118     std::cerr
119         << "WebSocketHandler::handleMessage(0x"
120         << std::hex
121         << (unsigned)headerByte
122         << std::dec
123         << ")"
124         << std::endl;
125 
126     std::cerr << hexdump(msg, len);
127 
128     const uint8_t opcode = headerByte & 0x0f;
129     if (opcode == 8) {
130         // Connection close.
131         return -1;
132     }
133 
134     return 0;
135 }
136 
sendMessage(const void * data,size_t size,SendMode mode)137 int WebSocketHandler::sendMessage(
138         const void *data, size_t size, SendMode mode) {
139     static constexpr bool kUseMask = false;
140 
141     size_t numHeaderBytes = 2 + (kUseMask ? 4 : 0);
142     if (size > 65535) {
143         numHeaderBytes += 8;
144     } else if (size > 125) {
145         numHeaderBytes += 2;
146     }
147 
148     static constexpr uint8_t kOpCodeBySendMode[] = {
149         0x1,  // text
150         0x2,  // binary
151         0x8,  // closeConnection
152         0xa,  // pong
153     };
154 
155     auto opcode = kOpCodeBySendMode[static_cast<uint8_t>(mode)];
156 
157     std::unique_ptr<uint8_t[]> buffer(new uint8_t[numHeaderBytes + size]);
158     uint8_t *msg = buffer.get();
159     msg[0] = 0x80 | opcode;  // FIN==1
160     msg[1] = kUseMask ? 0x80 : 0x00;
161 
162     if (size > 65535) {
163         msg[1] |= 127;
164         msg[2] = 0x00;
165         msg[3] = 0x00;
166         msg[4] = 0x00;
167         msg[5] = 0x00;
168         msg[6] = (size >> 24) & 0xff;
169         msg[7] = (size >> 16) & 0xff;
170         msg[8] = (size >> 8) & 0xff;
171         msg[9] = size & 0xff;
172     } else if (size > 125) {
173         msg[1] |= 126;
174         msg[2] = (size >> 8) & 0xff;
175         msg[3] = size & 0xff;
176     } else {
177         msg[1] |= size;
178     }
179 
180     if (kUseMask) {
181         uint32_t mask = rand();
182         msg[numHeaderBytes - 4] = (mask >> 24) & 0xff;
183         msg[numHeaderBytes - 3] = (mask >> 16) & 0xff;
184         msg[numHeaderBytes - 2] = (mask >> 8) & 0xff;
185         msg[numHeaderBytes - 1] = mask & 0xff;
186 
187         for (size_t i = 0; i < size; ++i) {
188             msg[numHeaderBytes + i] =
189                 ((const uint8_t *)data)[i]
190                     ^ ((mask >> (8 * (3 - (i % 4)))) & 0xff);
191         }
192     } else {
193         memcpy(&msg[numHeaderBytes], data, size);
194     }
195 
196     if (mOutputCallback) {
197         mOutputCallback(msg, numHeaderBytes + size);
198     } else {
199         auto clientSocket = mClientSocket.lock();
200         if (clientSocket) {
201             clientSocket->queueOutputData(msg, numHeaderBytes + size);
202         }
203     }
204 
205     return 0;
206 }
207 
remoteHost() const208 std::string WebSocketHandler::remoteHost() const {
209     sockaddr_in remoteAddr;
210 
211     if (mOutputCallback) {
212         remoteAddr = mRemoteAddr;
213     } else {
214         auto clientSocket = mClientSocket.lock();
215         if (clientSocket) {
216             remoteAddr = clientSocket->remoteAddr();
217         } else {
218             return "0.0.0.0";
219         }
220     }
221 
222     const uint32_t ipAddress = ntohl(remoteAddr.sin_addr.s_addr);
223 
224     std::stringstream ss;
225     ss << (ipAddress >> 24)
226        << "."
227        << ((ipAddress >> 16) & 0xff)
228        << "."
229        << ((ipAddress >> 8) & 0xff)
230        << "."
231        << (ipAddress & 0xff);
232 
233     return ss.str();
234 }
235 
236