1 /* 2 * Copyright (C) 2019 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef ANDROID_FRAMEWORK_ML_NN_RUNTIME_TEST_FUZZING_RANDOM_VARIABLE_H 18 #define ANDROID_FRAMEWORK_ML_NN_RUNTIME_TEST_FUZZING_RANDOM_VARIABLE_H 19 20 #include <algorithm> 21 #include <iostream> 22 #include <map> 23 #include <memory> 24 #include <numeric> 25 #include <set> 26 #include <string> 27 #include <unordered_map> 28 #include <vector> 29 30 namespace android { 31 namespace nn { 32 namespace fuzzing_test { 33 34 static const int kMaxValue = 10000; 35 static const int kInvalidValue = INT_MIN; 36 37 // Describe the search range for the value of a random variable. 38 class RandomVariableRange { 39 public: 40 RandomVariableRange() = default; RandomVariableRange(int value)41 explicit RandomVariableRange(int value) : mChoices({value}) {} RandomVariableRange(int lower,int upper)42 RandomVariableRange(int lower, int upper) : mChoices(upper - lower + 1) { 43 std::iota(mChoices.begin(), mChoices.end(), lower); 44 } RandomVariableRange(const std::vector<int> & vec)45 explicit RandomVariableRange(const std::vector<int>& vec) : mChoices(vec) {} RandomVariableRange(const std::set<int> & st)46 explicit RandomVariableRange(const std::set<int>& st) : mChoices(st.begin(), st.end()) {} 47 RandomVariableRange(const RandomVariableRange&) = default; 48 RandomVariableRange& operator=(const RandomVariableRange&) = default; 49 empty()50 bool empty() const { return mChoices.empty(); } has(int value)51 bool has(int value) const { 52 return std::binary_search(mChoices.begin(), mChoices.end(), value); 53 } size()54 size_t size() const { return mChoices.size(); } min()55 int min() const { return *mChoices.begin(); } max()56 int max() const { return *mChoices.rbegin(); } getChoices()57 const std::vector<int>& getChoices() const { return mChoices; } 58 59 // Narrow down the range to fit [lower, upper]. Use kInvalidValue to indicate unlimited bound. 60 void setRange(int lower, int upper); 61 // Narrow down the range to a random selected choice. Return the chosen value. 62 int toConst(); 63 64 // Calculate the intersection of two ranges. 65 friend RandomVariableRange operator&(const RandomVariableRange& lhs, 66 const RandomVariableRange& rhs); 67 68 private: 69 // Always in ascending order. 70 std::vector<int> mChoices; 71 }; 72 73 // Defines the interface for an operation applying to RandomVariables. 74 class IRandomVariableOp { 75 public: ~IRandomVariableOp()76 virtual ~IRandomVariableOp() {} 77 // Forward evaluation of two values. 78 virtual int eval(int lhs, int rhs) const = 0; 79 // Gets the range of the operation outcomes. The returned range must include all possible 80 // outcomes of this operation, but may contain invalid results. 81 virtual RandomVariableRange getInitRange(const RandomVariableRange& lhs, 82 const RandomVariableRange& rhs) const; 83 // Provides faster range evaluation for evalSubnetSingleOpHelper if possible. 84 virtual void eval(const std::set<int>* parent1In, const std::set<int>* parent2In, 85 const std::set<int>* childIn, std::set<int>* parent1Out, 86 std::set<int>* parent2Out, std::set<int>* childOut) const; 87 // For debugging purpose. 88 virtual const char* getName() const = 0; 89 }; 90 91 enum class RandomVariableType { FREE = 0, CONST = 1, OP = 2 }; 92 93 struct RandomVariableBase { 94 // Each RandomVariableBase is assigned an unique index for debugging purpose. 95 static unsigned int globalIndex; 96 int index; 97 98 RandomVariableType type; 99 RandomVariableRange range; 100 int value = 0; 101 std::shared_ptr<const IRandomVariableOp> op = nullptr; 102 103 // Network structural information. 104 std::shared_ptr<RandomVariableBase> parent1 = nullptr; 105 std::shared_ptr<RandomVariableBase> parent2 = nullptr; 106 std::vector<std::weak_ptr<RandomVariableBase>> children; 107 108 // The last time that this RandomVariableBase is modified. 109 int timestamp; 110 111 explicit RandomVariableBase(int value); 112 RandomVariableBase(int lower, int upper); 113 explicit RandomVariableBase(const std::vector<int>& choices); 114 RandomVariableBase(const std::shared_ptr<RandomVariableBase>& lhs, 115 const std::shared_ptr<RandomVariableBase>& rhs, 116 const std::shared_ptr<const IRandomVariableOp>& op); 117 RandomVariableBase(const RandomVariableBase&) = delete; 118 RandomVariableBase& operator=(const RandomVariableBase&) = delete; 119 120 // Freeze FREE RandomVariable to one valid choice. 121 // Should only invoke on FREE RandomVariable. 122 void freeze(); 123 124 // Get CONST value or calculate from parents. 125 // Should not invoke on FREE RandomVariable. 126 int getValue() const; 127 128 // Update the timestamp to the latest global time. 129 void updateTimestamp(); 130 }; 131 132 using RandomVariableNode = std::shared_ptr<RandomVariableBase>; 133 134 // A wrapper class of RandomVariableBase that manages RandomVariableBase with shared_ptr and 135 // provides useful methods and operator overloading to build the random variable network. 136 class RandomVariable { 137 public: 138 // Construct a placeholder RandomVariable with nullptr. RandomVariable()139 RandomVariable() : mVar(nullptr) {} 140 141 // Construct a CONST RandomVariable with specified value. 142 /* implicit */ RandomVariable(int value); 143 144 // Construct a FREE RandomVariable with range [lower, upper]. 145 RandomVariable(int lower, int upper); 146 147 // Construct a FREE RandomVariable with specified value choices. 148 explicit RandomVariable(const std::vector<int>& choices); 149 150 // This is for RandomVariableType::FREE only. 151 // Construct a FREE RandomVariable with default range [1, defaultValue]. 152 /* implicit */ RandomVariable(RandomVariableType type); 153 154 // RandomVariables share the same RandomVariableBase if copied or copy-assigned. 155 RandomVariable(const RandomVariable& other) = default; 156 RandomVariable& operator=(const RandomVariable& other) = default; 157 158 // Get the value of the RandomVariable, the value must be deterministic. getValue()159 int getValue() const { return mVar->getValue(); } 160 161 // Get the underlying managed RandomVariableNode. get()162 RandomVariableNode get() const { return mVar; }; 163 164 bool operator==(nullptr_t) const { return mVar == nullptr; } 165 bool operator!=(nullptr_t) const { return mVar != nullptr; } 166 167 // Arithmetic operators and methods on RandomVariables. 168 friend RandomVariable operator+(const RandomVariable& lhs, const RandomVariable& rhs); 169 friend RandomVariable operator-(const RandomVariable& lhs, const RandomVariable& rhs); 170 friend RandomVariable operator*(const RandomVariable& lhs, const RandomVariable& rhs); 171 friend RandomVariable operator*(const RandomVariable& lhs, const float& rhs); 172 friend RandomVariable operator/(const RandomVariable& lhs, const RandomVariable& rhs); 173 friend RandomVariable operator%(const RandomVariable& lhs, const RandomVariable& rhs); 174 friend RandomVariable max(const RandomVariable& lhs, const RandomVariable& rhs); 175 friend RandomVariable min(const RandomVariable& lhs, const RandomVariable& rhs); 176 RandomVariable exactDiv(const RandomVariable& other); 177 178 // Set constraints on the RandomVariable. Use kInvalidValue to indicate unlimited bound. 179 void setRange(int lower, int upper); 180 RandomVariable setEqual(const RandomVariable& other) const; 181 RandomVariable setGreaterThan(const RandomVariable& other) const; 182 RandomVariable setGreaterEqual(const RandomVariable& other) const; 183 184 // A FREE RandomVariable is constructed with default range [1, defaultValue]. 185 static int defaultValue; 186 187 private: 188 // Construct a RandomVariable as the result of an OP between two other RandomVariables. 189 RandomVariable(const RandomVariable& lhs, const RandomVariable& rhs, 190 const std::shared_ptr<const IRandomVariableOp>& op); 191 RandomVariableNode mVar; 192 }; 193 194 using EvaluationOrder = std::vector<RandomVariableNode>; 195 196 // The base class of a network consisting of disjoint subnets. 197 class DisjointNetwork { 198 public: 199 // Add a node to the network, join the parent subnets if needed. 200 void add(const RandomVariableNode& var); 201 202 // Similar to join(int, int), but accept RandomVariableNodes. join(const RandomVariableNode & var1,const RandomVariableNode & var2)203 int join(const RandomVariableNode& var1, const RandomVariableNode& var2) { 204 return DisjointNetwork::join(mIndexMap[var1], mIndexMap[var2]); 205 } 206 207 protected: 208 DisjointNetwork() = default; 209 DisjointNetwork(const DisjointNetwork&) = default; 210 DisjointNetwork& operator=(const DisjointNetwork&) = default; 211 212 // Join two subnets by appending every node in ind2 after ind1, return the resulting subnet 213 // index. Use -1 for invalid subnet index. 214 int join(int ind1, int ind2); 215 216 // A map from the network node to the corresponding subnet index. 217 std::unordered_map<RandomVariableNode, int> mIndexMap; 218 219 // A map from the subnet index to the set of nodes within the subnet. The nodes are maintained 220 // in a valid evaluation order, that is, a valid topological sort. 221 std::map<int, EvaluationOrder> mEvalOrderMap; 222 223 // The next index for a new disjoint subnet component. 224 int mNextIndex = 0; 225 }; 226 227 // Manages the active RandomVariable network. Only one instance of this class will exist. 228 class RandomVariableNetwork : public DisjointNetwork { 229 public: 230 // Returns the singleton network instance. 231 static RandomVariableNetwork* get(); 232 233 // Re-initialization. Should be called every time a new random graph is being generated. 234 void initialize(int defaultValue); 235 236 // Set the elementwise equality of the two vectors of RandomVariables iff it results in a 237 // soluble network. 238 bool setEqualIfCompatible(const std::vector<RandomVariable>& lhs, 239 const std::vector<RandomVariable>& rhs); 240 241 // Freeze all FREE RandomVariables in the network to a random valid combination. 242 bool freeze(); 243 244 // Check if node2 is FREE and can be evaluated after node1. 245 bool isSubordinate(const RandomVariableNode& node1, const RandomVariableNode& node2); 246 247 // Get and then advance the current global timestamp. getGlobalTime()248 int getGlobalTime() { return mGlobalTime++; } 249 250 // Add a special constraint on dimension product. 251 void addDimensionProd(const std::vector<RandomVariable>& dims); 252 253 private: 254 RandomVariableNetwork() = default; 255 RandomVariableNetwork(const RandomVariableNetwork&) = default; 256 RandomVariableNetwork& operator=(const RandomVariableNetwork&) = default; 257 258 // A class to revert all the changes made to RandomVariableNetwork since the Reverter object is 259 // constructed. Only used when setEqualIfCompatible results in incompatible. 260 class Reverter; 261 262 // Find valid choices for all RandomVariables in the network. Update the RandomVariableRange 263 // if the network is soluble, otherwise, return false and leave the ranges unchanged. 264 bool evalRange(); 265 266 int mGlobalTime = 0; 267 int mTimestamp = -1; 268 269 std::vector<EvaluationOrder> mDimProd; 270 }; 271 272 } // namespace fuzzing_test 273 } // namespace nn 274 } // namespace android 275 276 #endif // ANDROID_FRAMEWORK_ML_NN_RUNTIME_TEST_FUZZING_RANDOM_VARIABLE_H 277