1 /*
2 * Copyright (C) 2019 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <https/WebSocketHandler.h>
18
19 #include <https/ClientSocket.h>
20 #include <https/Support.h>
21
22 #include <iostream>
23 #include <sstream>
24
25 #include <string.h>
26
handleRequest(uint8_t * data,size_t size,bool isEOS)27 ssize_t WebSocketHandler::handleRequest(
28 uint8_t *data, size_t size, bool isEOS) {
29 (void)isEOS;
30
31 size_t offset = 0;
32 while (offset + 1 < size) {
33 uint8_t *packet = &data[offset];
34 const size_t avail = size - offset;
35
36 size_t packetOffset = 0;
37 const uint8_t headerByte = packet[packetOffset];
38
39 const bool hasMask = (packet[packetOffset + 1] & 0x80) != 0;
40 size_t payloadLen = packet[packetOffset + 1] & 0x7f;
41 packetOffset += 2;
42
43 if (payloadLen == 126) {
44 if (packetOffset + 1 >= avail) {
45 break;
46 }
47
48 payloadLen = U16_AT(&packet[packetOffset]);
49 packetOffset += 2;
50 } else if (payloadLen == 127) {
51 if (packetOffset + 7 >= avail) {
52 break;
53 }
54
55 payloadLen = U64_AT(&packet[packetOffset]);
56 packetOffset += 8;
57 }
58
59 uint32_t mask = 0;
60 if (hasMask) {
61 if (packetOffset + 3 >= avail) {
62 break;
63 }
64
65 mask = U32_AT(&packet[packetOffset]);
66 packetOffset += 4;
67 }
68
69 if (packetOffset + payloadLen > avail) {
70 break;
71 }
72
73 if (mask) {
74 for (size_t i = 0; i < payloadLen; ++i) {
75 packet[packetOffset + i] ^= ((mask >> (8 * (3 - (i % 4)))) & 0xff);
76 }
77 }
78
79 int err = 0;
80 bool is_control_frame = (headerByte & 0x08) != 0;
81 if (is_control_frame) {
82 uint8_t opcode = headerByte & 0x0f;
83 if (opcode == 0x9 /*ping*/) {
84 sendMessage(&packet[packetOffset], payloadLen, SendMode::pong);
85 } else if (opcode == 0x8 /*close*/) {
86 return -1;
87 }
88 } else {
89 err = handleMessage(headerByte, &packet[packetOffset], payloadLen);
90 }
91
92 offset += packetOffset + payloadLen;
93
94 if (err < 0) {
95 return err;
96 }
97 }
98
99 return offset;
100 }
101
isConnected()102 bool WebSocketHandler::isConnected() {
103 return mOutputCallback != nullptr || mClientSocket.lock() != nullptr;
104 }
105
setClientSocket(std::weak_ptr<ClientSocket> clientSocket)106 void WebSocketHandler::setClientSocket(std::weak_ptr<ClientSocket> clientSocket) {
107 mClientSocket = clientSocket;
108 }
109
setOutputCallback(const sockaddr_in & remoteAddr,OutputCallback fn)110 void WebSocketHandler::setOutputCallback(
111 const sockaddr_in &remoteAddr, OutputCallback fn) {
112 mOutputCallback = fn;
113 mRemoteAddr = remoteAddr;
114 }
115
handleMessage(uint8_t headerByte,const uint8_t * msg,size_t len)116 int WebSocketHandler::handleMessage(
117 uint8_t headerByte, const uint8_t *msg, size_t len) {
118 std::cerr
119 << "WebSocketHandler::handleMessage(0x"
120 << std::hex
121 << (unsigned)headerByte
122 << std::dec
123 << ")"
124 << std::endl;
125
126 std::cerr << hexdump(msg, len);
127
128 const uint8_t opcode = headerByte & 0x0f;
129 if (opcode == 8) {
130 // Connection close.
131 return -1;
132 }
133
134 return 0;
135 }
136
sendMessage(const void * data,size_t size,SendMode mode)137 int WebSocketHandler::sendMessage(
138 const void *data, size_t size, SendMode mode) {
139 static constexpr bool kUseMask = false;
140
141 size_t numHeaderBytes = 2 + (kUseMask ? 4 : 0);
142 if (size > 65535) {
143 numHeaderBytes += 8;
144 } else if (size > 125) {
145 numHeaderBytes += 2;
146 }
147
148 static constexpr uint8_t kOpCodeBySendMode[] = {
149 0x1, // text
150 0x2, // binary
151 0x8, // closeConnection
152 0xa, // pong
153 };
154
155 auto opcode = kOpCodeBySendMode[static_cast<uint8_t>(mode)];
156
157 std::unique_ptr<uint8_t[]> buffer(new uint8_t[numHeaderBytes + size]);
158 uint8_t *msg = buffer.get();
159 msg[0] = 0x80 | opcode; // FIN==1
160 msg[1] = kUseMask ? 0x80 : 0x00;
161
162 if (size > 65535) {
163 msg[1] |= 127;
164 msg[2] = 0x00;
165 msg[3] = 0x00;
166 msg[4] = 0x00;
167 msg[5] = 0x00;
168 msg[6] = (size >> 24) & 0xff;
169 msg[7] = (size >> 16) & 0xff;
170 msg[8] = (size >> 8) & 0xff;
171 msg[9] = size & 0xff;
172 } else if (size > 125) {
173 msg[1] |= 126;
174 msg[2] = (size >> 8) & 0xff;
175 msg[3] = size & 0xff;
176 } else {
177 msg[1] |= size;
178 }
179
180 if (kUseMask) {
181 uint32_t mask = rand();
182 msg[numHeaderBytes - 4] = (mask >> 24) & 0xff;
183 msg[numHeaderBytes - 3] = (mask >> 16) & 0xff;
184 msg[numHeaderBytes - 2] = (mask >> 8) & 0xff;
185 msg[numHeaderBytes - 1] = mask & 0xff;
186
187 for (size_t i = 0; i < size; ++i) {
188 msg[numHeaderBytes + i] =
189 ((const uint8_t *)data)[i]
190 ^ ((mask >> (8 * (3 - (i % 4)))) & 0xff);
191 }
192 } else {
193 memcpy(&msg[numHeaderBytes], data, size);
194 }
195
196 if (mOutputCallback) {
197 mOutputCallback(msg, numHeaderBytes + size);
198 } else {
199 auto clientSocket = mClientSocket.lock();
200 if (clientSocket) {
201 clientSocket->queueOutputData(msg, numHeaderBytes + size);
202 }
203 }
204
205 return 0;
206 }
207
remoteHost() const208 std::string WebSocketHandler::remoteHost() const {
209 sockaddr_in remoteAddr;
210
211 if (mOutputCallback) {
212 remoteAddr = mRemoteAddr;
213 } else {
214 auto clientSocket = mClientSocket.lock();
215 if (clientSocket) {
216 remoteAddr = clientSocket->remoteAddr();
217 } else {
218 return "0.0.0.0";
219 }
220 }
221
222 const uint32_t ipAddress = ntohl(remoteAddr.sin_addr.s_addr);
223
224 std::stringstream ss;
225 ss << (ipAddress >> 24)
226 << "."
227 << ((ipAddress >> 16) & 0xff)
228 << "."
229 << ((ipAddress >> 8) & 0xff)
230 << "."
231 << (ipAddress & 0xff);
232
233 return ss.str();
234 }
235
236