1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 package com.android.tradefed.invoker.shard; 17 18 import com.android.annotations.VisibleForTesting; 19 import com.android.tradefed.config.IConfiguration; 20 import com.android.tradefed.invoker.IRescheduler; 21 import com.android.tradefed.invoker.TestInformation; 22 import com.android.tradefed.log.ITestLogger; 23 import com.android.tradefed.log.LogUtil.CLog; 24 import com.android.tradefed.result.ITestLoggerReceiver; 25 import com.android.tradefed.testtype.IBuildReceiver; 26 import com.android.tradefed.testtype.IDeviceTest; 27 import com.android.tradefed.testtype.IInvocationContextReceiver; 28 import com.android.tradefed.testtype.IRemoteTest; 29 import com.android.tradefed.testtype.IRuntimeHintProvider; 30 import com.android.tradefed.testtype.IShardableTest; 31 import com.android.tradefed.testtype.suite.ITestSuite; 32 import com.android.tradefed.testtype.suite.ModuleMerger; 33 import com.android.tradefed.util.TimeUtil; 34 35 import java.util.ArrayList; 36 import java.util.Collection; 37 import java.util.Collections; 38 import java.util.List; 39 40 /** Sharding strategy to create strict shards that do not report together, */ 41 public class StrictShardHelper extends ShardHelper { 42 43 /** {@inheritDoc} */ 44 @Override shardConfig( IConfiguration config, TestInformation testInfo, IRescheduler rescheduler, ITestLogger logger)45 public boolean shardConfig( 46 IConfiguration config, 47 TestInformation testInfo, 48 IRescheduler rescheduler, 49 ITestLogger logger) { 50 Integer shardCount = config.getCommandOptions().getShardCount(); 51 Integer shardIndex = config.getCommandOptions().getShardIndex(); 52 53 if (shardIndex == null) { 54 return super.shardConfig(config, testInfo, rescheduler, logger); 55 } 56 if (shardCount == null) { 57 throw new RuntimeException("shard-count is null while shard-index is " + shardIndex); 58 } 59 60 // Split tests in place, without actually sharding. 61 List<IRemoteTest> listAllTests = getAllTests(config, shardCount, testInfo, logger); 62 // We cannot shuffle to get better average results 63 normalizeDistribution(listAllTests, shardCount); 64 List<IRemoteTest> splitList; 65 if (shardCount == 1) { 66 // not sharded 67 splitList = listAllTests; 68 } else { 69 splitList = splitTests(listAllTests, shardCount).get(shardIndex); 70 } 71 aggregateSuiteModules(splitList); 72 config.setTests(splitList); 73 return false; 74 } 75 76 /** 77 * Helper to return the full list of {@link IRemoteTest} based on {@link IShardableTest} split. 78 * 79 * @param config the {@link IConfiguration} describing the invocation. 80 * @param shardCount the shard count hint to be provided to some tests. 81 * @param testInfo the {@link TestInformation} of the parent invocation. 82 * @return the list of all {@link IRemoteTest}. 83 */ getAllTests( IConfiguration config, Integer shardCount, TestInformation testInfo, ITestLogger logger)84 private List<IRemoteTest> getAllTests( 85 IConfiguration config, 86 Integer shardCount, 87 TestInformation testInfo, 88 ITestLogger logger) { 89 List<IRemoteTest> allTests = new ArrayList<>(); 90 for (IRemoteTest test : config.getTests()) { 91 if (test instanceof IShardableTest) { 92 // Inject current information to help with sharding 93 if (test instanceof IBuildReceiver) { 94 ((IBuildReceiver) test).setBuild(testInfo.getBuildInfo()); 95 } 96 if (test instanceof IDeviceTest) { 97 ((IDeviceTest) test).setDevice(testInfo.getDevice()); 98 } 99 if (test instanceof IInvocationContextReceiver) { 100 ((IInvocationContextReceiver) test).setInvocationContext(testInfo.getContext()); 101 } 102 if (test instanceof ITestLoggerReceiver) { 103 ((ITestLoggerReceiver) test).setTestLogger(logger); 104 } 105 106 // Handling of the ITestSuite is a special case, we do not allow pool of tests 107 // since each shard needs to be independent. 108 if (test instanceof ITestSuite) { 109 ((ITestSuite) test).setShouldMakeDynamicModule(false); 110 } 111 112 Collection<IRemoteTest> subTests = 113 ((IShardableTest) test).split(shardCount, testInfo); 114 if (subTests == null) { 115 // test did not shard so we add it as is. 116 allTests.add(test); 117 } else { 118 allTests.addAll(subTests); 119 } 120 } else { 121 // if test is not shardable we add it as is. 122 allTests.add(test); 123 } 124 } 125 return allTests; 126 } 127 128 /** 129 * Split the list of tests to run however the implementation see fit. Sharding needs to be 130 * consistent. It is acceptable to return an empty list if no tests can be run in the shard. 131 * 132 * <p>Implement this in order to provide a test suite specific sharding. The default 133 * implementation attempts to balance the number of IRemoteTest per shards as much as possible 134 * as a first step, then use a minor criteria or run-hint to adjust the lists a bit more. 135 * 136 * @param fullList the initial full list of {@link IRemoteTest} containing all the tests that 137 * need to run. 138 * @param shardCount the total number of shard that need to run. 139 * @return a list of list {@link IRemoteTest}s that have been assigned to each shard. The list 140 * size will be the shardCount. 141 */ 142 @VisibleForTesting splitTests(List<IRemoteTest> fullList, int shardCount)143 protected List<List<IRemoteTest>> splitTests(List<IRemoteTest> fullList, int shardCount) { 144 List<List<IRemoteTest>> shards = new ArrayList<>(); 145 // We are using Match.ceil to avoid the last shard having too much extra. 146 int numPerShard = (int) Math.ceil(fullList.size() / (float) shardCount); 147 148 boolean needsCorrection = false; 149 float correctionRatio = 0f; 150 if (fullList.size() > shardCount) { 151 // In some cases because of the Math.ceil, some combination might run out of tests 152 // before the last shard, in that case we populate a correction to rebalance the tests. 153 needsCorrection = (numPerShard * (shardCount - 1)) > fullList.size(); 154 correctionRatio = numPerShard - (fullList.size() / (float) shardCount); 155 } 156 // Recalculate the number of tests per shard with the correction taken into account. 157 numPerShard = (int) Math.floor(numPerShard - correctionRatio); 158 // Based of the parameters, distribute the tests accross shards. 159 shards = balancedDistrib(fullList, shardCount, numPerShard, needsCorrection); 160 // Do last minute rebalancing 161 topBottom(shards, shardCount); 162 return shards; 163 } 164 balancedDistrib( List<IRemoteTest> fullList, int shardCount, int numPerShard, boolean needsCorrection)165 private List<List<IRemoteTest>> balancedDistrib( 166 List<IRemoteTest> fullList, int shardCount, int numPerShard, boolean needsCorrection) { 167 List<List<IRemoteTest>> shards = new ArrayList<>(); 168 List<IRemoteTest> correctionList = new ArrayList<>(); 169 int correctionSize = 0; 170 171 // Generate all the shards 172 for (int i = 0; i < shardCount; i++) { 173 List<IRemoteTest> shardList; 174 if (i >= fullList.size()) { 175 // Return empty list when we don't have enough tests for all the shards. 176 shardList = new ArrayList<IRemoteTest>(); 177 shards.add(shardList); 178 continue; 179 } 180 181 if (i == shardCount - 1) { 182 // last shard take everything remaining except the correction: 183 if (needsCorrection) { 184 // We omit the size of the correction needed. 185 correctionSize = fullList.size() - (numPerShard + (i * numPerShard)); 186 correctionList = 187 fullList.subList(fullList.size() - correctionSize, fullList.size()); 188 } 189 shardList = fullList.subList(i * numPerShard, fullList.size() - correctionSize); 190 shards.add(new ArrayList<>(shardList)); 191 continue; 192 } 193 shardList = fullList.subList(i * numPerShard, numPerShard + (i * numPerShard)); 194 shards.add(new ArrayList<>(shardList)); 195 } 196 197 // If we have correction omitted tests, disperse them on each shard, at this point the 198 // number of tests in correction is ensured to be bellow the number of shards. 199 for (int i = 0; i < shardCount; i++) { 200 if (i < correctionList.size()) { 201 shards.get(i).add(correctionList.get(i)); 202 } else { 203 break; 204 } 205 } 206 return shards; 207 } 208 209 /** 210 * Move around predictably the tests in order to have a better uniformization of the tests in 211 * each shard. 212 */ normalizeDistribution(List<IRemoteTest> listAllTests, int shardCount)213 private void normalizeDistribution(List<IRemoteTest> listAllTests, int shardCount) { 214 final int numRound = shardCount; 215 final int distance = shardCount - 1; 216 for (int i = 0; i < numRound; i++) { 217 for (int j = 0; j < listAllTests.size(); j = j + distance) { 218 // Push the test at the end 219 IRemoteTest push = listAllTests.remove(j); 220 listAllTests.add(push); 221 } 222 } 223 } 224 225 /** 226 * Special handling for suite from {@link ITestSuite}. We aggregate the tests in the same shard 227 * in order to optimize target_preparation step. 228 * 229 * @param tests the {@link List} of {@link IRemoteTest} for that shard. 230 */ aggregateSuiteModules(List<IRemoteTest> tests)231 private void aggregateSuiteModules(List<IRemoteTest> tests) { 232 List<IRemoteTest> dupList = new ArrayList<>(tests); 233 for (int i = 0; i < dupList.size(); i++) { 234 if (dupList.get(i) instanceof ITestSuite) { 235 // We iterate the other tests to see if we can find another from the same module. 236 for (int j = i + 1; j < dupList.size(); j++) { 237 // If the test was not already merged 238 if (tests.contains(dupList.get(j))) { 239 if (dupList.get(j) instanceof ITestSuite) { 240 if (ModuleMerger.arePartOfSameSuite( 241 (ITestSuite) dupList.get(i), (ITestSuite) dupList.get(j))) { 242 ModuleMerger.mergeSplittedITestSuite( 243 (ITestSuite) dupList.get(i), (ITestSuite) dupList.get(j)); 244 tests.remove(dupList.get(j)); 245 } 246 } 247 } 248 } 249 } 250 } 251 } 252 topBottom(List<List<IRemoteTest>> allShards, int shardCount)253 private void topBottom(List<List<IRemoteTest>> allShards, int shardCount) { 254 // We only attempt this when the number of shard is pretty high 255 if (shardCount < 4) { 256 return; 257 } 258 // Generate approximate RuntimeHint for each shard 259 int index = 0; 260 List<SortShardObj> shardTimes = new ArrayList<>(); 261 for (List<IRemoteTest> shard : allShards) { 262 long aggTime = 0l; 263 CLog.d("++++++++++++++++++ SHARD %s +++++++++++++++", index); 264 for (IRemoteTest test : shard) { 265 if (test instanceof IRuntimeHintProvider) { 266 aggTime += ((IRuntimeHintProvider) test).getRuntimeHint(); 267 } 268 } 269 CLog.d("Shard %s approximate time: %s", index, TimeUtil.formatElapsedTime(aggTime)); 270 shardTimes.add(new SortShardObj(index, aggTime)); 271 index++; 272 CLog.d("+++++++++++++++++++++++++++++++++++++++++++"); 273 } 274 275 Collections.sort(shardTimes); 276 if ((shardTimes.get(0).mAggTime - shardTimes.get(shardTimes.size() - 1).mAggTime) 277 < 60 * 60 * 1000l) { 278 return; 279 } 280 281 // take 30% top shard (10 shard = top 3 shards) 282 for (int i = 0; i < (shardCount * 0.3); i++) { 283 CLog.d( 284 "Top shard %s is index %s with %s", 285 i, 286 shardTimes.get(i).mIndex, 287 TimeUtil.formatElapsedTime(shardTimes.get(i).mAggTime)); 288 int give = shardTimes.get(i).mIndex; 289 int receive = shardTimes.get(shardTimes.size() - 1 - i).mIndex; 290 CLog.d("Giving from shard %s to shard %s", give, receive); 291 for (int j = 0; j < (allShards.get(give).size() * (0.2f / (i + 1))); j++) { 292 IRemoteTest givetest = allShards.get(give).remove(0); 293 allShards.get(receive).add(givetest); 294 } 295 } 296 } 297 298 /** Object holder for shard, their index and their aggregated execution time. */ 299 private class SortShardObj implements Comparable<SortShardObj> { 300 public final int mIndex; 301 public final Long mAggTime; 302 SortShardObj(int index, long aggTime)303 public SortShardObj(int index, long aggTime) { 304 mIndex = index; 305 mAggTime = aggTime; 306 } 307 308 @Override compareTo(SortShardObj obj)309 public int compareTo(SortShardObj obj) { 310 return obj.mAggTime.compareTo(mAggTime); 311 } 312 } 313 } 314