1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 package com.android.tradefed.invoker.shard;
17 
18 import com.android.annotations.VisibleForTesting;
19 import com.android.tradefed.config.IConfiguration;
20 import com.android.tradefed.invoker.IRescheduler;
21 import com.android.tradefed.invoker.TestInformation;
22 import com.android.tradefed.log.ITestLogger;
23 import com.android.tradefed.log.LogUtil.CLog;
24 import com.android.tradefed.result.ITestLoggerReceiver;
25 import com.android.tradefed.testtype.IBuildReceiver;
26 import com.android.tradefed.testtype.IDeviceTest;
27 import com.android.tradefed.testtype.IInvocationContextReceiver;
28 import com.android.tradefed.testtype.IRemoteTest;
29 import com.android.tradefed.testtype.IRuntimeHintProvider;
30 import com.android.tradefed.testtype.IShardableTest;
31 import com.android.tradefed.testtype.suite.ITestSuite;
32 import com.android.tradefed.testtype.suite.ModuleMerger;
33 import com.android.tradefed.util.TimeUtil;
34 
35 import java.util.ArrayList;
36 import java.util.Collection;
37 import java.util.Collections;
38 import java.util.List;
39 
40 /** Sharding strategy to create strict shards that do not report together, */
41 public class StrictShardHelper extends ShardHelper {
42 
43     /** {@inheritDoc} */
44     @Override
shardConfig( IConfiguration config, TestInformation testInfo, IRescheduler rescheduler, ITestLogger logger)45     public boolean shardConfig(
46             IConfiguration config,
47             TestInformation testInfo,
48             IRescheduler rescheduler,
49             ITestLogger logger) {
50         Integer shardCount = config.getCommandOptions().getShardCount();
51         Integer shardIndex = config.getCommandOptions().getShardIndex();
52 
53         if (shardIndex == null) {
54             return super.shardConfig(config, testInfo, rescheduler, logger);
55         }
56         if (shardCount == null) {
57             throw new RuntimeException("shard-count is null while shard-index is " + shardIndex);
58         }
59 
60         // Split tests in place, without actually sharding.
61         List<IRemoteTest> listAllTests = getAllTests(config, shardCount, testInfo, logger);
62         // We cannot shuffle to get better average results
63         normalizeDistribution(listAllTests, shardCount);
64         List<IRemoteTest> splitList;
65         if (shardCount == 1) {
66             // not sharded
67             splitList = listAllTests;
68         } else {
69             splitList = splitTests(listAllTests, shardCount).get(shardIndex);
70         }
71         aggregateSuiteModules(splitList);
72         config.setTests(splitList);
73         return false;
74     }
75 
76     /**
77      * Helper to return the full list of {@link IRemoteTest} based on {@link IShardableTest} split.
78      *
79      * @param config the {@link IConfiguration} describing the invocation.
80      * @param shardCount the shard count hint to be provided to some tests.
81      * @param testInfo the {@link TestInformation} of the parent invocation.
82      * @return the list of all {@link IRemoteTest}.
83      */
getAllTests( IConfiguration config, Integer shardCount, TestInformation testInfo, ITestLogger logger)84     private List<IRemoteTest> getAllTests(
85             IConfiguration config,
86             Integer shardCount,
87             TestInformation testInfo,
88             ITestLogger logger) {
89         List<IRemoteTest> allTests = new ArrayList<>();
90         for (IRemoteTest test : config.getTests()) {
91             if (test instanceof IShardableTest) {
92                 // Inject current information to help with sharding
93                 if (test instanceof IBuildReceiver) {
94                     ((IBuildReceiver) test).setBuild(testInfo.getBuildInfo());
95                 }
96                 if (test instanceof IDeviceTest) {
97                     ((IDeviceTest) test).setDevice(testInfo.getDevice());
98                 }
99                 if (test instanceof IInvocationContextReceiver) {
100                     ((IInvocationContextReceiver) test).setInvocationContext(testInfo.getContext());
101                 }
102                 if (test instanceof ITestLoggerReceiver) {
103                     ((ITestLoggerReceiver) test).setTestLogger(logger);
104                 }
105 
106                 // Handling of the ITestSuite is a special case, we do not allow pool of tests
107                 // since each shard needs to be independent.
108                 if (test instanceof ITestSuite) {
109                     ((ITestSuite) test).setShouldMakeDynamicModule(false);
110                 }
111 
112                 Collection<IRemoteTest> subTests =
113                         ((IShardableTest) test).split(shardCount, testInfo);
114                 if (subTests == null) {
115                     // test did not shard so we add it as is.
116                     allTests.add(test);
117                 } else {
118                     allTests.addAll(subTests);
119                 }
120             } else {
121                 // if test is not shardable we add it as is.
122                 allTests.add(test);
123             }
124         }
125         return allTests;
126     }
127 
128     /**
129      * Split the list of tests to run however the implementation see fit. Sharding needs to be
130      * consistent. It is acceptable to return an empty list if no tests can be run in the shard.
131      *
132      * <p>Implement this in order to provide a test suite specific sharding. The default
133      * implementation attempts to balance the number of IRemoteTest per shards as much as possible
134      * as a first step, then use a minor criteria or run-hint to adjust the lists a bit more.
135      *
136      * @param fullList the initial full list of {@link IRemoteTest} containing all the tests that
137      *     need to run.
138      * @param shardCount the total number of shard that need to run.
139      * @return a list of list {@link IRemoteTest}s that have been assigned to each shard. The list
140      *     size will be the shardCount.
141      */
142     @VisibleForTesting
splitTests(List<IRemoteTest> fullList, int shardCount)143     protected List<List<IRemoteTest>> splitTests(List<IRemoteTest> fullList, int shardCount) {
144         List<List<IRemoteTest>> shards = new ArrayList<>();
145         // We are using Match.ceil to avoid the last shard having too much extra.
146         int numPerShard = (int) Math.ceil(fullList.size() / (float) shardCount);
147 
148         boolean needsCorrection = false;
149         float correctionRatio = 0f;
150         if (fullList.size() > shardCount) {
151             // In some cases because of the Math.ceil, some combination might run out of tests
152             // before the last shard, in that case we populate a correction to rebalance the tests.
153             needsCorrection = (numPerShard * (shardCount - 1)) > fullList.size();
154             correctionRatio = numPerShard - (fullList.size() / (float) shardCount);
155         }
156         // Recalculate the number of tests per shard with the correction taken into account.
157         numPerShard = (int) Math.floor(numPerShard - correctionRatio);
158         // Based of the parameters, distribute the tests accross shards.
159         shards = balancedDistrib(fullList, shardCount, numPerShard, needsCorrection);
160         // Do last minute rebalancing
161         topBottom(shards, shardCount);
162         return shards;
163     }
164 
balancedDistrib( List<IRemoteTest> fullList, int shardCount, int numPerShard, boolean needsCorrection)165     private List<List<IRemoteTest>> balancedDistrib(
166             List<IRemoteTest> fullList, int shardCount, int numPerShard, boolean needsCorrection) {
167         List<List<IRemoteTest>> shards = new ArrayList<>();
168         List<IRemoteTest> correctionList = new ArrayList<>();
169         int correctionSize = 0;
170 
171         // Generate all the shards
172         for (int i = 0; i < shardCount; i++) {
173             List<IRemoteTest> shardList;
174             if (i >= fullList.size()) {
175                 // Return empty list when we don't have enough tests for all the shards.
176                 shardList = new ArrayList<IRemoteTest>();
177                 shards.add(shardList);
178                 continue;
179             }
180 
181             if (i == shardCount - 1) {
182                 // last shard take everything remaining except the correction:
183                 if (needsCorrection) {
184                     // We omit the size of the correction needed.
185                     correctionSize = fullList.size() - (numPerShard + (i * numPerShard));
186                     correctionList =
187                             fullList.subList(fullList.size() - correctionSize, fullList.size());
188                 }
189                 shardList = fullList.subList(i * numPerShard, fullList.size() - correctionSize);
190                 shards.add(new ArrayList<>(shardList));
191                 continue;
192             }
193             shardList = fullList.subList(i * numPerShard, numPerShard + (i * numPerShard));
194             shards.add(new ArrayList<>(shardList));
195         }
196 
197         // If we have correction omitted tests, disperse them on each shard, at this point the
198         // number of tests in correction is ensured to be bellow the number of shards.
199         for (int i = 0; i < shardCount; i++) {
200             if (i < correctionList.size()) {
201                 shards.get(i).add(correctionList.get(i));
202             } else {
203                 break;
204             }
205         }
206         return shards;
207     }
208 
209     /**
210      * Move around predictably the tests in order to have a better uniformization of the tests in
211      * each shard.
212      */
normalizeDistribution(List<IRemoteTest> listAllTests, int shardCount)213     private void normalizeDistribution(List<IRemoteTest> listAllTests, int shardCount) {
214         final int numRound = shardCount;
215         final int distance = shardCount - 1;
216         for (int i = 0; i < numRound; i++) {
217             for (int j = 0; j < listAllTests.size(); j = j + distance) {
218                 // Push the test at the end
219                 IRemoteTest push = listAllTests.remove(j);
220                 listAllTests.add(push);
221             }
222         }
223     }
224 
225     /**
226      * Special handling for suite from {@link ITestSuite}. We aggregate the tests in the same shard
227      * in order to optimize target_preparation step.
228      *
229      * @param tests the {@link List} of {@link IRemoteTest} for that shard.
230      */
aggregateSuiteModules(List<IRemoteTest> tests)231     private void aggregateSuiteModules(List<IRemoteTest> tests) {
232         List<IRemoteTest> dupList = new ArrayList<>(tests);
233         for (int i = 0; i < dupList.size(); i++) {
234             if (dupList.get(i) instanceof ITestSuite) {
235                 // We iterate the other tests to see if we can find another from the same module.
236                 for (int j = i + 1; j < dupList.size(); j++) {
237                     // If the test was not already merged
238                     if (tests.contains(dupList.get(j))) {
239                         if (dupList.get(j) instanceof ITestSuite) {
240                             if (ModuleMerger.arePartOfSameSuite(
241                                     (ITestSuite) dupList.get(i), (ITestSuite) dupList.get(j))) {
242                                 ModuleMerger.mergeSplittedITestSuite(
243                                         (ITestSuite) dupList.get(i), (ITestSuite) dupList.get(j));
244                                 tests.remove(dupList.get(j));
245                             }
246                         }
247                     }
248                 }
249             }
250         }
251     }
252 
topBottom(List<List<IRemoteTest>> allShards, int shardCount)253     private void topBottom(List<List<IRemoteTest>> allShards, int shardCount) {
254         // We only attempt this when the number of shard is pretty high
255         if (shardCount < 4) {
256             return;
257         }
258         // Generate approximate RuntimeHint for each shard
259         int index = 0;
260         List<SortShardObj> shardTimes = new ArrayList<>();
261         for (List<IRemoteTest> shard : allShards) {
262             long aggTime = 0l;
263             CLog.d("++++++++++++++++++ SHARD %s +++++++++++++++", index);
264             for (IRemoteTest test : shard) {
265                 if (test instanceof IRuntimeHintProvider) {
266                     aggTime += ((IRuntimeHintProvider) test).getRuntimeHint();
267                 }
268             }
269             CLog.d("Shard %s approximate time: %s", index, TimeUtil.formatElapsedTime(aggTime));
270             shardTimes.add(new SortShardObj(index, aggTime));
271             index++;
272             CLog.d("+++++++++++++++++++++++++++++++++++++++++++");
273         }
274 
275         Collections.sort(shardTimes);
276         if ((shardTimes.get(0).mAggTime - shardTimes.get(shardTimes.size() - 1).mAggTime)
277                 < 60 * 60 * 1000l) {
278             return;
279         }
280 
281         // take 30% top shard (10 shard = top 3 shards)
282         for (int i = 0; i < (shardCount * 0.3); i++) {
283             CLog.d(
284                     "Top shard %s is index %s with %s",
285                     i,
286                     shardTimes.get(i).mIndex,
287                     TimeUtil.formatElapsedTime(shardTimes.get(i).mAggTime));
288             int give = shardTimes.get(i).mIndex;
289             int receive = shardTimes.get(shardTimes.size() - 1 - i).mIndex;
290             CLog.d("Giving from shard %s to shard %s", give, receive);
291             for (int j = 0; j < (allShards.get(give).size() * (0.2f / (i + 1))); j++) {
292                 IRemoteTest givetest = allShards.get(give).remove(0);
293                 allShards.get(receive).add(givetest);
294             }
295         }
296     }
297 
298     /** Object holder for shard, their index and their aggregated execution time. */
299     private class SortShardObj implements Comparable<SortShardObj> {
300         public final int mIndex;
301         public final Long mAggTime;
302 
SortShardObj(int index, long aggTime)303         public SortShardObj(int index, long aggTime) {
304             mIndex = index;
305             mAggTime = aggTime;
306         }
307 
308         @Override
compareTo(SortShardObj obj)309         public int compareTo(SortShardObj obj) {
310             return obj.mAggTime.compareTo(mAggTime);
311         }
312     }
313 }
314