1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.lock_checker;
18 
19 import android.app.ActivityThread;
20 import android.os.Handler;
21 import android.os.HandlerThread;
22 import android.os.Looper;
23 import android.os.Message;
24 import android.os.Process;
25 import android.util.Log;
26 import android.util.LogWriter;
27 
28 import com.android.internal.os.RuntimeInit;
29 import com.android.internal.os.SomeArgs;
30 import com.android.internal.util.StatLogger;
31 
32 import dalvik.system.AnnotatedStackTraceElement;
33 
34 import libcore.util.HexEncoding;
35 
36 import java.io.PrintWriter;
37 import java.nio.charset.Charset;
38 import java.security.MessageDigest;
39 import java.security.NoSuchAlgorithmException;
40 import java.util.Map;
41 import java.util.concurrent.ConcurrentLinkedQueue;
42 import java.util.concurrent.atomic.AtomicInteger;
43 
44 /**
45  * Entry class for lock inversion infrastructure. The agent will inject calls to preLock
46  * and postLock, and the hook will call the checker, and store violations.
47  */
48 public class LockHook {
49     private static final String TAG = "LockHook";
50 
51     private static final Charset sFilenameCharset = Charset.forName("UTF-8");
52 
53     private static final HandlerThread sHandlerThread;
54     private static final WtfHandler sHandler;
55 
56     private static final AtomicInteger sTotalObtainCount = new AtomicInteger();
57     private static final AtomicInteger sTotalReleaseCount = new AtomicInteger();
58     private static final AtomicInteger sDeepestNest = new AtomicInteger();
59 
60     /**
61      * Whether to do the lock check on this thread.
62      */
63     private static final ThreadLocal<Boolean> sDoCheck = ThreadLocal.withInitial(() -> true);
64 
65     interface Stats {
66         int ON_THREAD = 0;
67     }
68 
69     static final StatLogger sStats = new StatLogger(new String[] { "on-thread", });
70 
71     private static final ConcurrentLinkedQueue<Violation> sViolations =
72             new ConcurrentLinkedQueue<>();
73     private static final int MAX_VIOLATIONS = 50;
74 
75     private static final LockChecker[] sCheckers;
76 
77     private static boolean sNativeHandling = false;
78     private static boolean sSimulateCrash = false;
79 
80     static {
81         sHandlerThread = new HandlerThread("LockHook:wtf", Process.THREAD_PRIORITY_BACKGROUND);
sHandlerThread.start()82         sHandlerThread.start();
83         sHandler = new WtfHandler(sHandlerThread.getLooper());
84 
85         sCheckers = new LockChecker[] { new OnThreadLockChecker() };
86 
87         sNativeHandling = getNativeHandlingConfig();
88         sSimulateCrash = getSimulateCrashConfig();
89     }
90 
getNativeHandlingConfig()91     private static native boolean getNativeHandlingConfig();
getSimulateCrashConfig()92     private static native boolean getSimulateCrashConfig();
93 
shouldDumpStacktrace(StacktraceHasher hasher, Map<String, T> dumpedSet, T val, AnnotatedStackTraceElement[] st, int from, int to)94     static <T> boolean shouldDumpStacktrace(StacktraceHasher hasher, Map<String, T> dumpedSet,
95             T val, AnnotatedStackTraceElement[] st, int from, int to) {
96         final String stacktraceHash = hasher.stacktraceHash(st, from, to);
97         if (dumpedSet.containsKey(stacktraceHash)) {
98             return false;
99         }
100         dumpedSet.put(stacktraceHash, val);
101         return true;
102     }
103 
updateDeepestNest(int nest)104     static void updateDeepestNest(int nest) {
105         for (;;) {
106             final int knownDeepest = sDeepestNest.get();
107             if (knownDeepest >= nest) {
108                 return;
109             }
110             if (sDeepestNest.compareAndSet(knownDeepest, nest)) {
111                 return;
112             }
113         }
114     }
115 
wtf(Violation v)116     static void wtf(Violation v) {
117         sHandler.wtf(v);
118     }
119 
doCheckOnThisThread(boolean check)120     static void doCheckOnThisThread(boolean check) {
121         sDoCheck.set(check);
122     }
123 
124     /**
125      * This method is called when a lock is about to be held. (Except if it's a
126      * synchronized, the lock is already held.)
127      */
preLock(Object lock)128     public static void preLock(Object lock) {
129         if (Thread.currentThread() != sHandlerThread && sDoCheck.get()) {
130             sDoCheck.set(false);
131             try {
132                 sTotalObtainCount.incrementAndGet();
133                 for (LockChecker checker : sCheckers) {
134                     checker.pre(lock);
135                 }
136             } finally {
137                 sDoCheck.set(true);
138             }
139         }
140     }
141 
142     /**
143      * This method is called when a lock is about to be released.
144      */
postLock(Object lock)145     public static void postLock(Object lock) {
146         if (Thread.currentThread() != sHandlerThread && sDoCheck.get()) {
147             sDoCheck.set(false);
148             try {
149                 sTotalReleaseCount.incrementAndGet();
150                 for (LockChecker checker : sCheckers) {
151                     checker.post(lock);
152                 }
153             } finally {
154                 sDoCheck.set(true);
155             }
156         }
157     }
158 
159     private static class WtfHandler extends Handler {
160         private static final int MSG_WTF = 1;
161 
WtfHandler(Looper looper)162         WtfHandler(Looper looper) {
163             super(looper);
164         }
165 
wtf(Violation v)166         public void wtf(Violation v) {
167             sDoCheck.set(false);
168             SomeArgs args = SomeArgs.obtain();
169             args.arg1 = v;
170             obtainMessage(MSG_WTF, args).sendToTarget();
171             sDoCheck.set(true);
172         }
173 
174         @Override
handleMessage(Message msg)175         public void handleMessage(Message msg) {
176             switch (msg.what) {
177                 case MSG_WTF:
178                     SomeArgs args = (SomeArgs) msg.obj;
179                     handleViolation((Violation) args.arg1);
180                     args.recycle();
181                     break;
182             }
183         }
184     }
185 
handleViolation(Violation v)186     private static void handleViolation(Violation v) {
187         String msg = v.toString();
188         Log.wtf(TAG, msg);
189         if (sNativeHandling) {
190             nWtf(msg);  // Also send to native.
191         }
192         if (sSimulateCrash) {
193             RuntimeInit.logUncaught("LockAgent",
194                     ActivityThread.isSystem() ? "system_server"
195                             : ActivityThread.currentProcessName(),
196                     Process.myPid(), v.getException());
197         }
198     }
199 
nWtf(String msg)200     private static native void nWtf(String msg);
201 
202     /**
203      * Generates a hash for a given stacktrace of a {@link Throwable}.
204      */
205     static class StacktraceHasher {
206         private byte[] mLineNumberBuffer = new byte[4];
207         private final MessageDigest mHash;
208 
StacktraceHasher()209         StacktraceHasher() {
210             try {
211                 mHash = MessageDigest.getInstance("MD5");
212             } catch (NoSuchAlgorithmException e) {
213                 throw new RuntimeException(e);
214             }
215         }
216 
stacktraceHash(Throwable t)217         public String stacktraceHash(Throwable t) {
218             mHash.reset();
219             for (StackTraceElement e : t.getStackTrace()) {
220                 hashStackTraceElement(e);
221             }
222             return HexEncoding.encodeToString(mHash.digest());
223         }
224 
stacktraceHash(AnnotatedStackTraceElement[] annotatedStack, int from, int to)225         public String stacktraceHash(AnnotatedStackTraceElement[] annotatedStack, int from,
226                 int to) {
227             mHash.reset();
228             for (int i = from; i <= to; i++) {
229                 hashStackTraceElement(annotatedStack[i].getStackTraceElement());
230             }
231             return HexEncoding.encodeToString(mHash.digest());
232         }
233 
hashStackTraceElement(StackTraceElement e)234         private void hashStackTraceElement(StackTraceElement e) {
235             if (e.getFileName() != null) {
236                 mHash.update(sFilenameCharset.encode(e.getFileName()).array());
237             } else {
238                 if (e.getClassName() != null) {
239                     mHash.update(sFilenameCharset.encode(e.getClassName()).array());
240                 }
241                 if (e.getMethodName() != null) {
242                     mHash.update(sFilenameCharset.encode(e.getMethodName()).array());
243                 }
244             }
245 
246             final int line = e.getLineNumber();
247             mLineNumberBuffer[0] = (byte) ((line >> 24) & 0xff);
248             mLineNumberBuffer[1] = (byte) ((line >> 16) & 0xff);
249             mLineNumberBuffer[2] = (byte) ((line >> 8) & 0xff);
250             mLineNumberBuffer[3] = (byte) ((line >> 0) & 0xff);
251             mHash.update(mLineNumberBuffer);
252         }
253     }
254 
addViolation(Violation v)255     static void addViolation(Violation v) {
256         wtf(v);
257 
258         sViolations.offer(v);
259         while (sViolations.size() > MAX_VIOLATIONS) {
260             sViolations.poll();
261         }
262     }
263 
264     /**
265      * Dump stats to the given PrintWriter.
266      */
dump(PrintWriter pw, String indent)267     public static void dump(PrintWriter pw, String indent) {
268         final int oc = LockHook.sTotalObtainCount.get();
269         final int rc = LockHook.sTotalReleaseCount.get();
270         final int dn = LockHook.sDeepestNest.get();
271         pw.print("Lock stats: oc=");
272         pw.print(oc);
273         pw.print(" rc=");
274         pw.print(rc);
275         pw.print(" dn=");
276         pw.print(dn);
277         pw.println();
278 
279         for (LockChecker checker : sCheckers) {
280             pw.print(indent);
281             pw.print("  ");
282             checker.dump(pw);
283             pw.println();
284         }
285 
286         sStats.dump(pw, indent);
287 
288         pw.print(indent);
289         pw.println("Violations:");
290         for (Object v : sViolations) {
291             pw.print(indent); // This won't really indent a multiline string,
292                               // though.
293             pw.println(v);
294         }
295     }
296 
297     /**
298      * Dump stats to logcat.
299      */
dump()300     public static void dump() {
301         // Dump to logcat.
302         PrintWriter out = new PrintWriter(new LogWriter(Log.WARN, TAG), true);
303         dump(out, "");
304         out.close();
305     }
306 
307     interface LockChecker {
pre(Object lock)308         void pre(Object lock);
309 
post(Object lock)310         void post(Object lock);
311 
getNumDetected()312         int getNumDetected();
313 
getNumDetectedUnique()314         int getNumDetectedUnique();
315 
getCheckerName()316         String getCheckerName();
317 
dump(PrintWriter pw)318         void dump(PrintWriter pw);
319     }
320 
321     interface Violation {
getException()322         Throwable getException();
323     }
324 }
325