1 /**
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 
17 #define LOG_TAG "InputChannelTest"
18 
19 #include "../includes/common.h"
20 
21 #include <android-base/stringprintf.h>
22 #include <input/InputTransport.h>
23 
24 using namespace android;
25 using android::base::StringPrintf;
26 
memoryAsHexString(const void * const address,size_t numBytes)27 static std::string memoryAsHexString(const void* const address, size_t numBytes) {
28     std::string str;
29     for (size_t i = 0; i < numBytes; i++) {
30         str += StringPrintf("%02X ", static_cast<const uint8_t* const>(address)[i]);
31     }
32     return str;
33 }
34 
35 /**
36  * There could be non-zero bytes in-between InputMessage fields. Force-initialize the entire
37  * memory to zero, then only copy the valid bytes on a per-field basis.
38  * Input: message msg
39  * Output: cleaned message outMsg
40  */
sanitizeMessage(const InputMessage & msg,InputMessage * outMsg)41 static void sanitizeMessage(const InputMessage& msg, InputMessage* outMsg) {
42     memset(outMsg, 0, sizeof(*outMsg));
43 
44     // Write the header
45     outMsg->header.type = msg.header.type;
46 
47     // Write the body
48     switch(msg.header.type) {
49         case InputMessage::TYPE_KEY: {
50             // uint32_t seq
51             outMsg->body.key.seq = msg.body.key.seq;
52             // nsecs_t eventTime
53             outMsg->body.key.eventTime = msg.body.key.eventTime;
54             // int32_t deviceId
55             outMsg->body.key.deviceId = msg.body.key.deviceId;
56             // int32_t source
57             outMsg->body.key.source = msg.body.key.source;
58             // int32_t displayId
59             outMsg->body.key.displayId = msg.body.key.displayId;
60             // int32_t action
61             outMsg->body.key.action = msg.body.key.action;
62             // int32_t flags
63             outMsg->body.key.flags = msg.body.key.flags;
64             // int32_t keyCode
65             outMsg->body.key.keyCode = msg.body.key.keyCode;
66             // int32_t scanCode
67             outMsg->body.key.scanCode = msg.body.key.scanCode;
68             // int32_t metaState
69             outMsg->body.key.metaState = msg.body.key.metaState;
70             // int32_t repeatCount
71             outMsg->body.key.repeatCount = msg.body.key.repeatCount;
72             // nsecs_t downTime
73             outMsg->body.key.downTime = msg.body.key.downTime;
74             break;
75         }
76         case InputMessage::TYPE_MOTION: {
77             // uint32_t seq
78             outMsg->body.motion.seq = msg.body.motion.seq;
79             // nsecs_t eventTime
80             outMsg->body.motion.eventTime = msg.body.motion.eventTime;
81             // int32_t deviceId
82             outMsg->body.motion.deviceId = msg.body.motion.deviceId;
83             // int32_t source
84             outMsg->body.motion.source = msg.body.motion.source;
85             // int32_t displayId
86             outMsg->body.motion.displayId = msg.body.motion.displayId;
87             // int32_t action
88             outMsg->body.motion.action = msg.body.motion.action;
89             // int32_t actionButton
90             outMsg->body.motion.actionButton = msg.body.motion.actionButton;
91             // int32_t flags
92             outMsg->body.motion.flags = msg.body.motion.flags;
93             // int32_t metaState
94             outMsg->body.motion.metaState = msg.body.motion.metaState;
95             // int32_t buttonState
96             outMsg->body.motion.buttonState = msg.body.motion.buttonState;
97             // MotionClassification classification
98             outMsg->body.motion.classification = msg.body.motion.classification;
99             // int32_t edgeFlags
100             outMsg->body.motion.edgeFlags = msg.body.motion.edgeFlags;
101             // nsecs_t downTime
102             outMsg->body.motion.downTime = msg.body.motion.downTime;
103             // float xOffset
104             outMsg->body.motion.xOffset = msg.body.motion.xOffset;
105             // float yOffset
106             outMsg->body.motion.yOffset = msg.body.motion.yOffset;
107             // float xPrecision
108             outMsg->body.motion.xPrecision = msg.body.motion.xPrecision;
109             // float yPrecision
110             outMsg->body.motion.yPrecision = msg.body.motion.yPrecision;
111             // uint32_t pointerCount
112             outMsg->body.motion.pointerCount = msg.body.motion.pointerCount;
113             //struct Pointer pointers[MAX_POINTERS]
114             for (size_t i = 0; i < msg.body.motion.pointerCount; i++) {
115                 // PointerProperties properties
116                 outMsg->body.motion.pointers[i].properties.id =
117                         msg.body.motion.pointers[i].properties.id;
118                 outMsg->body.motion.pointers[i].properties.toolType =
119                         msg.body.motion.pointers[i].properties.toolType;
120                 // PointerCoords coords
121                 outMsg->body.motion.pointers[i].coords.bits =
122                         msg.body.motion.pointers[i].coords.bits;
123                 const uint32_t count = BitSet64::count(msg.body.motion.pointers[i].coords.bits);
124                 memcpy(&outMsg->body.motion.pointers[i].coords.values[0],
125                         &msg.body.motion.pointers[i].coords.values[0],
126                         count * sizeof(msg.body.motion.pointers[i].coords.values[0]));
127             }
128             break;
129         }
130         case InputMessage::TYPE_FINISHED: {
131             outMsg->body.finished.seq = msg.body.finished.seq;
132             outMsg->body.finished.handled = msg.body.finished.handled;
133             break;
134         }
135     }
136 }
137 
138 /**
139  * Return false if vulnerability is found for a given message type
140  */
checkMessage(sp<InputChannel> server,sp<InputChannel> client,int type)141 static bool checkMessage(sp<InputChannel> server, sp<InputChannel> client, int type) {
142     InputMessage serverMsg;
143     // Set all potentially uninitialized bytes to 1, for easier comparison
144 
145     memset(&serverMsg, 1, sizeof(serverMsg));
146     serverMsg.header.type = type;
147     if (type == InputMessage::TYPE_MOTION) {
148         serverMsg.body.motion.pointerCount = MAX_POINTERS;
149     }
150     status_t result = server->sendMessage(&serverMsg);
151     if (result != OK) {
152         ALOGE("Could not send message to the input channel");
153         return false;
154     }
155 
156     InputMessage clientMsg;
157     result = client->receiveMessage(&clientMsg);
158     if (result != OK) {
159         ALOGE("Could not receive message from the input channel");
160         return false;
161     }
162     if (serverMsg.header.type != clientMsg.header.type) {
163         ALOGE("Types do not match");
164         return false;
165     }
166 
167     if (clientMsg.header.padding != 0) {
168         ALOGE("Found padding to be uninitialized");
169         return false;
170     }
171 
172     InputMessage sanitizedClientMsg;
173     sanitizeMessage(clientMsg, &sanitizedClientMsg);
174     if (memcmp(&clientMsg, &sanitizedClientMsg, clientMsg.size()) != 0) {
175         ALOGE("Client received un-sanitized message");
176         ALOGE("Received message: %s", memoryAsHexString(&clientMsg, clientMsg.size()).c_str());
177         ALOGE("Expected message: %s",
178                 memoryAsHexString(&sanitizedClientMsg, clientMsg.size()).c_str());
179         return false;
180     }
181 
182     return true;
183 }
184 
185 /**
186  * Create an unsanitized message
187  * Send
188  * Receive
189  * Compare the received message to a sanitized expected message
190  * Do this for all message types
191  */
main()192 int main() {
193     sp<InputChannel> server, client;
194 
195     status_t result = InputChannel::openInputChannelPair("channel name", server, client);
196     if (result != OK) {
197         ALOGE("Could not open input channel pair");
198         return 0;
199     }
200 
201     int types[] = {InputMessage::TYPE_KEY, InputMessage::TYPE_MOTION, InputMessage::TYPE_FINISHED};
202     for (int type : types) {
203         bool success = checkMessage(server, client, type);
204         if (!success) {
205             ALOGE("Check message failed for type %i", type);
206             return EXIT_VULNERABLE;
207         }
208     }
209 
210     return 0;
211 }
212 
213