1 /*
2  * Copyright (C) 2016 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 import annotations.BootstrapMethod;
18 import annotations.CalledByIndy;
19 import java.lang.invoke.CallSite;
20 import java.lang.invoke.ConstantCallSite;
21 import java.lang.invoke.MethodHandle;
22 import java.lang.invoke.MethodHandles;
23 import java.lang.invoke.MethodType;
24 import java.util.concurrent.CyclicBarrier;
25 import java.util.concurrent.atomic.AtomicInteger;
26 
27 public class TestInvokeCustomWithConcurrentThreads extends TestBase implements Runnable {
28     private static final int NUMBER_OF_THREADS = 16;
29 
30     private static final AtomicInteger nextIndex = new AtomicInteger(0);
31 
32     private static final ThreadLocal<Integer> threadIndex =
33             new ThreadLocal<Integer>() {
34                 @Override
35                 protected Integer initialValue() {
36                     return nextIndex.getAndIncrement();
37                 }
38             };
39 
40     // Array of call sites instantiated, one per thread
41     private static final CallSite[] instantiated = new CallSite[NUMBER_OF_THREADS];
42 
43     // Array of counters for how many times each instantiated call site is called
44     private static final AtomicInteger[] called = new AtomicInteger[NUMBER_OF_THREADS];
45 
46     // Array of call site indices of which call site a thread invoked
47     private static final AtomicInteger[] targetted = new AtomicInteger[NUMBER_OF_THREADS];
48 
49     // Synchronization barrier all threads will wait on in the bootstrap method.
50     private static final CyclicBarrier barrier = new CyclicBarrier(NUMBER_OF_THREADS);
51 
TestInvokeCustomWithConcurrentThreads()52     private TestInvokeCustomWithConcurrentThreads() {}
53 
getThreadIndex()54     private static int getThreadIndex() {
55         return threadIndex.get().intValue();
56     }
57 
notUsed(int x)58     public static int notUsed(int x) {
59         return x;
60     }
61 
run()62     public void run() {
63         int x = setCalled(-1 /* argument dropped */);
64         notUsed(x);
65     }
66 
67     @CalledByIndy(
68         bootstrapMethod =
69                 @BootstrapMethod(
70                     enclosingType = TestInvokeCustomWithConcurrentThreads.class,
71                     name = "linkerMethod",
72                     parameterTypes = {MethodHandles.Lookup.class, String.class, MethodType.class}
73                 ),
74         fieldOrMethodName = "setCalled",
75         returnType = int.class,
76         parameterTypes = {int.class}
77     )
setCalled(int index)78     private static int setCalled(int index) {
79         called[index].getAndIncrement();
80         targetted[getThreadIndex()].set(index);
81         return 0;
82     }
83 
84     @SuppressWarnings("unused")
linkerMethod( MethodHandles.Lookup caller, String name, MethodType methodType)85     private static CallSite linkerMethod(
86             MethodHandles.Lookup caller, String name, MethodType methodType) throws Throwable {
87         MethodHandle mh =
88                 caller.findStatic(TestInvokeCustomWithConcurrentThreads.class, name, methodType);
89         assertEquals(methodType, mh.type());
90         assertEquals(mh.type().parameterCount(), 1);
91         mh = MethodHandles.insertArguments(mh, 0, getThreadIndex());
92         mh = MethodHandles.dropArguments(mh, 0, int.class);
93         assertEquals(mh.type().parameterCount(), 1);
94         assertEquals(methodType, mh.type());
95 
96         // Wait for all threads to be in this method.
97         // Multiple call sites should be created, but only one
98         // invoked.
99         barrier.await();
100 
101         instantiated[getThreadIndex()] = new ConstantCallSite(mh);
102         return instantiated[getThreadIndex()];
103     }
104 
test()105     public static void test() throws Throwable {
106         // Initialize counters for which call site gets invoked
107         for (int i = 0; i < NUMBER_OF_THREADS; ++i) {
108             called[i] = new AtomicInteger(0);
109             targetted[i] = new AtomicInteger(0);
110         }
111 
112         // Run threads that each invoke-custom the call site
113         Thread[] threads = new Thread[NUMBER_OF_THREADS];
114         for (int i = 0; i < NUMBER_OF_THREADS; ++i) {
115             threads[i] = new Thread(new TestInvokeCustomWithConcurrentThreads());
116             threads[i].start();
117         }
118 
119         // Wait for all threads to complete
120         for (int i = 0; i < NUMBER_OF_THREADS; ++i) {
121             threads[i].join();
122         }
123 
124         // Check one call site instance won
125         int winners = 0;
126         int votes = 0;
127         for (int i = 0; i < NUMBER_OF_THREADS; ++i) {
128             assertNotEquals(instantiated[i], null);
129             if (called[i].get() != 0) {
130                 winners++;
131                 votes += called[i].get();
132             }
133         }
134 
135         System.out.println("Winners " + winners + " Votes " + votes);
136 
137         // We assert this below but output details when there's an error as
138         // it's non-deterministic.
139         if (winners != 1) {
140             System.out.println("Threads did not the same call-sites:");
141             for (int i = 0; i < NUMBER_OF_THREADS; ++i) {
142                 System.out.format(
143                         " Thread % 2d invoked call site instance #%02d\n", i, targetted[i].get());
144             }
145         }
146 
147         // We assert this below but output details when there's an error as
148         // it's non-deterministic.
149         if (votes != NUMBER_OF_THREADS) {
150             System.out.println("Call-sites invocations :");
151             for (int i = 0; i < NUMBER_OF_THREADS; ++i) {
152                 System.out.format(
153                         " Call site instance #%02d was invoked % 2d times\n", i, called[i].get());
154             }
155         }
156 
157         assertEquals(winners, 1);
158         assertEquals(votes, NUMBER_OF_THREADS);
159     }
160 }
161