1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 import art.Redefinition;
18 import java.lang.reflect.Method;
19 import java.util.Arrays;
20 import java.util.Base64;
21 import java.util.concurrent.CountDownLatch;
22 import java.util.concurrent.Phaser;
23 import java.util.function.Consumer;
24 
25 public class Main {
26   public static final int NUM_THREADS = 10;
27   public static final boolean PRINT = false;
28 
29   // import java.util.function.Consumer;
30   //
31   // class Transform {
32   //   public native void nativeSayHi(Consumer<Consumer<String>> r, Consumer<String> rep);
33   //   public void sayHi(Consumer<Consumer<String>> r, Consumer<String> reporter) {
34   //     reporter.accept("goodbye - Start method sayHi");
35   //     r.accept(reporter);
36   //     reporter.accept("goodbye - End method sayHi");
37   //   }
38   // }
39   private static final byte[] CLASS_BYTES = Base64.getDecoder().decode(
40       "yv66vgAAADUAHAoABgASCAATCwAUABUIABYHABcHABgBAAY8aW5pdD4BAAMoKVYBAARDb2RlAQAP"
41       + "TGluZU51bWJlclRhYmxlAQALbmF0aXZlU2F5SGkBAD0oTGphdmEvdXRpbC9mdW5jdGlvbi9Db25z"
42       + "dW1lcjtMamF2YS91dGlsL2Z1bmN0aW9uL0NvbnN1bWVyOylWAQAJU2lnbmF0dXJlAQCEKExqYXZh"
43       + "L3V0aWwvZnVuY3Rpb24vQ29uc3VtZXI8TGphdmEvdXRpbC9mdW5jdGlvbi9Db25zdW1lcjxMamF2"
44       + "YS9sYW5nL1N0cmluZzs+Oz47TGphdmEvdXRpbC9mdW5jdGlvbi9Db25zdW1lcjxMamF2YS9sYW5n"
45       + "L1N0cmluZzs+OylWAQAFc2F5SGkBAApTb3VyY2VGaWxlAQAOVHJhbnNmb3JtLmphdmEMAAcACAEA"
46       + "HGdvb2RieWUgLSBTdGFydCBtZXRob2Qgc2F5SGkHABkMABoAGwEAGmdvb2RieWUgLSBFbmQgbWV0"
47       + "aG9kIHNheUhpAQAJVHJhbnNmb3JtAQAQamF2YS9sYW5nL09iamVjdAEAG2phdmEvdXRpbC9mdW5j"
48       + "dGlvbi9Db25zdW1lcgEABmFjY2VwdAEAFShMamF2YS9sYW5nL09iamVjdDspVgAgAAUABgAAAAAA"
49       + "AwAAAAcACAABAAkAAAAdAAEAAQAAAAUqtwABsQAAAAEACgAAAAYAAQAAAAcBAQALAAwAAQANAAAA"
50       + "AgAOAAEADwAMAAIACQAAADwAAgADAAAAGCwSArkAAwIAKyy5AAMCACwSBLkAAwIAsQAAAAEACgAA"
51       + "ABIABAAAABAACAARAA8AEgAXABMADQAAAAIADgABABAAAAACABE=");
52 
53   private static final byte[] DEX_BYTES = Base64.getDecoder().decode(
54       "ZGV4CjAzNQAztWgsKV3wmz41jXurCJpvXfxhxtK7W8NQBAAAcAAAAHhWNBIAAAAAAAAAAJgDAAAV"
55       + "AAAAcAAAAAUAAADEAAAAAwAAANgAAAAAAAAAAAAAAAUAAAD8AAAAAQAAACQBAAAMAwAARAEAAKgB"
56       + "AACrAQAAswEAALkBAAC/AQAAzAEAAOsBAAD/AQAAEwIAADICAABRAgAAYQIAAGQCAABoAgAAbQIA"
57       + "AHUCAACRAgAArwIAALwCAADDAgAAygIAAAQAAAAFAAAABgAAAAgAAAALAAAACwAAAAQAAAAAAAAA"
58       + "DAAAAAQAAACYAQAADQAAAAQAAACgAQAAAAAAAAEAAAAAAAIAEQAAAAAAAgASAAAAAgAAAAEAAAAD"
59       + "AAEADgAAAAAAAAAAAAAAAgAAAAAAAAAKAAAAeAMAAFgDAAAAAAAAAQABAAEAAACIAQAABAAAAHAQ"
60       + "AwAAAA4ABAADAAIAAACMAQAADgAAABoAEAByIAQAAwByIAQAMgAaAg8AciAEACMADgAHAA4AEAIA"
61       + "AA5aPFoAAAAAAQAAAAIAAAACAAAAAwADAAEoAAY8aW5pdD4ABD47KVYABD47PjsAC0xUcmFuc2Zv"
62       + "cm07AB1MZGFsdmlrL2Fubm90YXRpb24vU2lnbmF0dXJlOwASTGphdmEvbGFuZy9PYmplY3Q7ABJM"
63       + "amF2YS9sYW5nL1N0cmluZzsAHUxqYXZhL3V0aWwvZnVuY3Rpb24vQ29uc3VtZXI7AB1MamF2YS91"
64       + "dGlsL2Z1bmN0aW9uL0NvbnN1bWVyPAAOVHJhbnNmb3JtLmphdmEAAVYAAlZMAANWTEwABmFjY2Vw"
65       + "dAAaZ29vZGJ5ZSAtIEVuZCBtZXRob2Qgc2F5SGkAHGdvb2RieWUgLSBTdGFydCBtZXRob2Qgc2F5"
66       + "SGkAC25hdGl2ZVNheUhpAAVzYXlIaQAFdmFsdWUAdn5+RDh7ImNvbXBpbGF0aW9uLW1vZGUiOiJk"
67       + "ZWJ1ZyIsIm1pbi1hcGkiOjEsInNoYS0xIjoiNzExMWEzNWJhZTZkNTE4NWRjZmIzMzhkNjEwNzRh"
68       + "Y2E4NDI2YzAwNiIsInZlcnNpb24iOiIxLjUuMTQtZGV2In0AAgEBExwIFwAXCRcJFwcXAxcJFwcX"
69       + "AgAAAQIAgIAExAIBgQIAAQHcAgAAAAAAAAEAAABCAwAAbAMAAAAAAAACAAAAAAAAAAEAAABwAwAA"
70       + "AgAAAHADAAAPAAAAAAAAAAEAAAAAAAAAAQAAABUAAABwAAAAAgAAAAUAAADEAAAAAwAAAAMAAADY"
71       + "AAAABQAAAAUAAAD8AAAABgAAAAEAAAAkAQAAASAAAAIAAABEAQAAAyAAAAIAAACIAQAAARAAAAIA"
72       + "AACYAQAAAiAAABUAAACoAQAABCAAAAEAAABCAwAAACAAAAEAAABYAwAAAxAAAAIAAABsAwAABiAA"
73       + "AAEAAAB4AwAAABAAAAEAAACYAwAA");
74 
75   // A class that we can use to keep track of the output of this test.
76   private static class TestWatcher implements Consumer<String> {
77     private StringBuilder sb;
78     private String thread;
TestWatcher(String thread)79     public TestWatcher(String thread) {
80       sb = new StringBuilder();
81       this.thread = thread;
82     }
83 
84     @Override
accept(String s)85     public void accept(String s) {
86       String msg = thread + ": \t" + s;
87       maybePrint(msg);
88       sb.append(msg);
89       sb.append('\n');
90     }
91 
getOutput()92     public String getOutput() {
93       return sb.toString();
94     }
95 
clear()96     public void clear() {
97       sb = new StringBuilder();
98     }
99   }
100 
main(String[] args)101   public static void main(String[] args) throws Exception {
102     doTest(new Transform());
103   }
104 
105   private static boolean interpreting = true;
106 
doTest(Transform t)107   public static void doTest(Transform t) throws Exception {
108     TestWatcher[] watchers = new TestWatcher[NUM_THREADS];
109     for (int i = 0; i < NUM_THREADS; i++) {
110       watchers[i] = new TestWatcher("Thread " + i);
111     }
112 
113     // This just prints something out to show we are running the Runnable.
114     Consumer<Consumer<String>> say_nothing = (Consumer<String> w) -> {
115       w.accept("Not doing anything here");
116     };
117 
118     // Run ensureJitCompiled here since it might get GCd
119     ensureJitCompiled(Transform.class, "nativeSayHi");
120     final CountDownLatch arrive = new CountDownLatch(NUM_THREADS);
121     final CountDownLatch depart = new CountDownLatch(1);
122     Consumer<Consumer<String>> request_redefine = (Consumer<String> w) -> {
123       try {
124         arrive.countDown();
125         w.accept("Requesting redefinition");
126         depart.await();
127       } catch (Exception e) {
128         throw new RuntimeException("Failed to do something", e);
129       }
130     };
131     Thread redefinition_thread = new RedefinitionThread(arrive, depart);
132     redefinition_thread.start();
133     Thread[] threads = new Thread[NUM_THREADS];
134     for (int i = 0; i < NUM_THREADS; i++) {
135       threads[i] = new TestThread(t, watchers[i], say_nothing, request_redefine);
136       threads[i].start();
137     }
138     redefinition_thread.join();
139     Arrays.stream(threads).forEach((thr) -> {
140       try {
141         thr.join();
142       } catch (Exception e) {
143         throw new RuntimeException("Failed to join: ", e);
144       }
145     });
146     Arrays.stream(watchers).forEach((w) -> { System.out.println(w.getOutput()); });
147   }
148 
149   private static class RedefinitionThread extends Thread {
150     private CountDownLatch arrivalLatch;
151     private CountDownLatch departureLatch;
RedefinitionThread(CountDownLatch arrival, CountDownLatch departure)152     public RedefinitionThread(CountDownLatch arrival, CountDownLatch departure) {
153       super("Redefine thread!");
154       this.arrivalLatch = arrival;
155       this.departureLatch = departure;
156     }
157 
run()158     public void run() {
159       try {
160         this.arrivalLatch.await();
161         maybePrint("REDEFINITION THREAD: redefining something!");
162         Redefinition.doCommonClassRedefinition(Transform.class, CLASS_BYTES, DEX_BYTES);
163         maybePrint("REDEFINITION THREAD: redefined something!");
164         this.departureLatch.countDown();
165       } catch (Exception e) {
166         e.printStackTrace(System.out);
167         throw new RuntimeException("Failed to redefine", e);
168       }
169     }
170   }
171 
maybePrint(String s)172   private static synchronized void maybePrint(String s) {
173     if (PRINT) {
174       System.out.println(s);
175     }
176   }
177 
178   private static class TestThread extends Thread {
179     private Transform t;
180     private TestWatcher w;
181     private Consumer<Consumer<String>> do_nothing;
182     private Consumer<Consumer<String>> request_redefinition;
TestThread(Transform t, TestWatcher w, Consumer<Consumer<String>> do_nothing, Consumer<Consumer<String>> request_redefinition)183     public TestThread(Transform t,
184                       TestWatcher w,
185                       Consumer<Consumer<String>> do_nothing,
186                       Consumer<Consumer<String>> request_redefinition) {
187       super();
188       this.t = t;
189       this.w = w;
190       this.do_nothing = do_nothing;
191       this.request_redefinition = request_redefinition;
192     }
193 
run()194     public void run() {
195       w.clear();
196       t.nativeSayHi(do_nothing, w);
197       t.nativeSayHi(request_redefinition, w);
198       t.nativeSayHi(do_nothing, w);
199     }
200   }
201 
ensureJitCompiled(Class c, String name)202   private static native void ensureJitCompiled(Class c, String name);
203 }
204