/* * Copyright (C) 2019 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #define LOG_TAG "MetaModel" #include "MetaModel.h" #include #include #include #include #include #include #include "GraphDump.h" #include "HalInterfaces.h" #include "Utils.h" namespace android::nn { using namespace hal; namespace { // Add an element to the end of the vector and return a pair consisting of the // index of the new element and a pointer to the new element. template std::pair extend(hidl_vec* vec) { size_t nextIndex = vec->size(); vec->resize(nextIndex + 1); return {nextIndex, &(*vec)[nextIndex]}; } // Add an element to the end of the vector, set it to the specified value, and // return a pair consisting of the index of the new element and a pointer to the // new element. template std::pair extend(hidl_vec* vec, const T& val) { auto extended = extend(vec); *extended.second = val; return extended; } template bool operator<(const hidl_vec& a, const hidl_vec& b) { return std::lexicographical_compare(a.begin(), a.end(), b.begin(), b.end()); } // Compile-time mapping from a particular Model type to a name for that type. template struct ModelVersion; template <> struct ModelVersion { static constexpr char name[] = "V1_0"; }; template <> struct ModelVersion { static constexpr char name[] = "V1_1"; }; template <> struct ModelVersion { static constexpr char name[] = "V1_2"; }; template <> struct ModelVersion { static constexpr char name[] = "V1_3"; }; // Dispatcher mechanism for calling an appropriate uncheckedConvertToV1_* // given the desired return type. template T_ReturnType uncheckedConvertTo(OperationType type); template <> hal::V1_0::OperationType uncheckedConvertTo(OperationType type) { return uncheckedConvertToV1_0(type); } template <> hal::V1_1::OperationType uncheckedConvertTo(OperationType type) { return uncheckedConvertToV1_1(type); } template <> hal::V1_2::OperationType uncheckedConvertTo(OperationType type) { return uncheckedConvertToV1_2(type); } // Dispatcher mechanism for calling an appropriate convertToV1_* given the // desired return type. Note that there is no V1_1::Operand type. template T_ReturnType convertTo(Operand operand); template <> hal::V1_0::Operand convertTo(Operand operand) { return convertToV1_0(operand); } template <> hal::V1_2::Operand convertTo(Operand operand) { return convertToV1_2(operand); } // Dispatcher mechanism for calling an appropriate convertToV1_* given the // desired return type. Note that there are no V1_[12]::OperandLifeTime types. template T_ReturnType convertTo(OperandLifeTime lifetime); template <> hal::V1_0::OperandLifeTime convertTo(OperandLifeTime lifetime) { return convertToV1_0(lifetime); } template <> hal::V1_3::OperandLifeTime convertTo(OperandLifeTime lifetime) { return lifetime; } // Dispatcher mechanism for calling an appropriate compliantWithV1_* given the // desired target model type. template void getNoncompliantOperations(const hal::V1_3::Model& model, std::set* noncompliantOperations); template <> void getNoncompliantOperations(const hal::V1_3::Model& model, std::set* noncompliantOperations) { compliantWithV1_0(model, noncompliantOperations); } template <> void getNoncompliantOperations(const hal::V1_3::Model& model, std::set* noncompliantOperations) { compliantWithV1_1(model, noncompliantOperations); } template <> void getNoncompliantOperations(const hal::V1_3::Model& model, std::set* noncompliantOperations) { compliantWithV1_2(model, noncompliantOperations); } template bool invalid(const T_SlicedModel& model, bool strictSlicing) { // A model must have at least one operation. However, it's possible that a // slice has no operations (because no operations from the original model // are compliant with the sliced model type). In this case, the sliced // model would be invalid. const bool looksEmpty = (model.operations.size() == 0); if (strictSlicing) { CHECK_EQ(looksEmpty, (model.operands.size() == 0)); } if (looksEmpty) return true; // A model must have at least one output. However, it's possible for a // model to contain dead operations (i.e., outputs on which no model outputs // are data dependent). A slice might contain only dead operations, and // hence have no model outputs. In this case, the sliced model would be // invalid. if (model.outputIndexes.size() == 0) return true; // We shouldn't have to check whether the model is valid. // However, it could be invalid if: // - there is an error in the slicing algorithm; or // - there is an error in compliantWith (see http://b/131845106) if (!validateModel(model)) { LOG(WARNING) << "Sliced model fails validateModel()"; CHECK(!strictSlicing); return true; } return false; } } // anonymous namespace template MetaModel::ReturnedSlice MetaModel::getSlice(Slice* slice) const { CHECK(slice != nullptr); if (slice->mState == SliceState::UNINITIALIZED) { *slice = makeSlice(); } if (slice->mState == SliceState::INVALID) { return {}; } return MetaModel::ReturnedSlice(std::make_pair( slice->mHidlModel, Mapper([slice](uint32_t slicedOperationIndex) { return slice->mSlicedOperationIndexToOrigIndex.at(slicedOperationIndex); }))); } template MetaModel::ReturnedSlice MetaModel::getSlice( Slice* slice) const; template MetaModel::ReturnedSlice MetaModel::getSlice( Slice* slice) const; template MetaModel::ReturnedSlice MetaModel::getSlice( Slice* slice) const; // When adding HAL version 1.4, make sure to handle control flow and referenced // subgraphs here properly. A V1_3 sliced model should contain an IF/WHILE and // its referenced subgraphs only if there are no V1_4+ operations in those // subgraphs. // template MetaModel::ReturnedSlice MetaModel::getSlice( // Slice* slice) const; // Utility class for makeSlice(). // // For each output operand of a noncompliant operation that is the input // operand of at least one compliant operation, we will ensure that there is // a sliced model input whose "type" is that of the output operand. This is // a map from operand "type" (in the original model) to model input // operand index (in the sliced model). Unfortunately, there is no // representation of operand "type" defined in the HAL that we can use // naively here -- we want (OperandType, dimensions, scale, zeroPoint, // extraParams), but these fields exist in Operand along with other fields // that need to be excluded from the map key (numberOfConsumers, lifetime, // location). There are several choices: // - Don't have a map -- each output identified above gets its own sliced // model input (no sharing of sliced model inputs). // - Create an operand "type" representation solely for use as a map key. // - Write a tailored comparison function that ignores the excluded fields. // We choose to write a tailored comparison function. If Treble were to // generate a comparison function for us (http://b/130567619) then it might // be better to instead reset the excluded fields to canonical values -- // then we could use the Treble provided comparison function, and the // solution would be robust (in a correctness sense, not a sharing sense) if // more fields are added and we neglect to canonicalize them. // // We also use this map for model input operands of the original model that // become input operands of the sliced model. This means that an original // model input operand might be commoned with other original model input // operands and/or with original model temporary operands. template class MetaModel::OrigOperandToSlicedInputOperandIndex { public: OrigOperandToSlicedInputOperandIndex(hidl_vec* slicedOperands, hidl_vec* slicedInputIndexes) : mSlicedOperands(*slicedOperands), mSlicedInputIndexes(*slicedInputIndexes) {} // Given an operand from the original model, return the index of the // corresponding model input operand from the sliced model. Creates a // new operand in the sliced model if necessary. uint32_t getIndex(Operand operand) { // Lookup auto it = mMap.find(operand); if (it != mMap.end()) { VLOG(COMPILATION) << "OrigOperandToSlicedInputOperandIndex::getIndex looked for " << toString(operand) << " and found " << it->second << ": " << toString(it->first); return it->second; } // Create operand.numberOfConsumers = 0; operand.lifetime = convertTo(OperandLifeTime::SUBGRAPH_INPUT); operand.location = {}; uint32_t slicedOperandIndex = extend(&mSlicedOperands, convertTo(operand)).first; mMap[operand] = slicedOperandIndex; extend(&mSlicedInputIndexes, slicedOperandIndex); VLOG(COMPILATION) << "OrigOperandToSlicedInputOperandIndex::getIndex created " << slicedOperandIndex << ": " << toString(operand); return slicedOperandIndex; } private: class Compare { public: bool operator()(const Operand& a, const Operand& b) const { if (a.type != b.type) { return a.type < b.type; } if (a.dimensions != b.dimensions) { return a.dimensions < b.dimensions; } if (a.scale != b.scale) { return a.scale < b.scale; } if (a.zeroPoint != b.zeroPoint) { return a.zeroPoint < b.zeroPoint; } return compare(a.extraParams, b.extraParams); } private: static bool compare(const SymmPerChannelQuantParams& a, const SymmPerChannelQuantParams& b) { if (a.scales != b.scales) { return a.scales < b.scales; } return a.channelDim < b.channelDim; } static bool compare(const OperandExtraParams& a, const OperandExtraParams& b) { if (a.getDiscriminator() != b.getDiscriminator()) { return a.getDiscriminator() < b.getDiscriminator(); } switch (a.getDiscriminator()) { case OperandExtraParams::hidl_discriminator::channelQuant: return compare(a.channelQuant(), b.channelQuant()); case OperandExtraParams::hidl_discriminator::extension: return a.extension() < b.extension(); case OperandExtraParams::hidl_discriminator::none: return false; default: CHECK(false) << "Unexpected"; return false; } } }; std::map mMap; hidl_vec& mSlicedOperands; hidl_vec& mSlicedInputIndexes; }; template void MetaModel::processOperations( Slice* slice, std::map* origOperandIndexToSlicedIndex, OrigOperandToSlicedInputOperandIndex::Operand>* origOperandToSlicedInputOperandIndex, const std::set& noncompliantOperations, const std::set& inputOperandIndexesOfCompliantOperations) const { using SlicedOperand = typename Slice::Operand; using SlicedOperation = typename Slice::Operation; using SlicedOperationType = typename Slice::OperationType; const auto& origOperands = mHidlModel.main.operands; const auto& origOperations = mHidlModel.main.operations; auto& slicedOperands = slice->mHidlModel.operands; auto& slicedOperations = slice->mHidlModel.operations; for (uint32_t origOperationIndex = 0; origOperationIndex < origOperations.size(); ++origOperationIndex) { const Operation& origOperation = origOperations[origOperationIndex]; if (noncompliantOperations.count(origOperationIndex)) { for (uint32_t output : origOperation.outputs) { if (!inputOperandIndexesOfCompliantOperations.count(output)) { continue; } const uint32_t slicedIndex = origOperandToSlicedInputOperandIndex->getIndex(origOperands[output]); (*origOperandIndexToSlicedIndex)[output] = slicedIndex; VLOG(COMPILATION) << "origOperandIndexToSlicedIndex noncompliant output processing created " << output << " -> " << slicedIndex << ": " << toString(slicedOperands[slicedIndex]); } } else { slice->mSlicedOperationIndexToOrigIndex.push_back(origOperationIndex); SlicedOperation& slicedOperation = *extend(&slicedOperations).second; CHECK_EQ(slice->mSlicedOperationIndexToOrigIndex.size(), slicedOperations.size()); slicedOperation.type = uncheckedConvertTo(origOperation.type); // Model is topologically sorted, so all operation inputs must be // present in origOperandIndexToSlicedIndex, and no operation // outputs may be. // Operation inputs // - Fill in slicedOperation.inputs // - Update number of consumers for each input operand slicedOperation.inputs.resize(origOperation.inputs.size()); std::transform( origOperation.inputs.begin(), origOperation.inputs.end(), slicedOperation.inputs.begin(), [&origOperandIndexToSlicedIndex, &slicedOperands](uint32_t origOperandIndex) { uint32_t slicedOperandIndex = origOperandIndexToSlicedIndex->at(origOperandIndex); slicedOperands[slicedOperandIndex].numberOfConsumers++; VLOG(COMPILATION) << "origOperandIndexToSlicedIndex compliant input " "processing created " << origOperandIndex << " -> " << slicedOperandIndex << ": " << toString(slicedOperands[slicedOperandIndex]); return slicedOperandIndex; }); // Operation outputs // - Add new operands to slicedOperands // - Update origOperandIndexToSlicedIndex // - Fill in slicedOperation.outputs // - Record as a model output, if necessary const uint32_t firstOutputSlicedOperandIndex = slicedOperands.size(); slicedOperands.resize(firstOutputSlicedOperandIndex + origOperation.outputs.size()); slicedOperation.outputs.resize(origOperation.outputs.size()); for (uint32_t outputNum = 0; outputNum < slicedOperation.outputs.size(); ++outputNum) { uint32_t origOperandIndex = origOperation.outputs[outputNum]; uint32_t slicedOperandIndex = firstOutputSlicedOperandIndex + outputNum; auto& slicedOperand = slicedOperands[slicedOperandIndex]; const auto& origOperand = origOperands[origOperandIndex]; slicedOperand = convertTo(origOperand); slicedOperand.numberOfConsumers = 0; CHECK_EQ(origOperandIndexToSlicedIndex->count(origOperandIndex), size_t(0)); (*origOperandIndexToSlicedIndex)[origOperandIndex] = slicedOperandIndex; slicedOperation.outputs[outputNum] = slicedOperandIndex; const auto subgraphOutputLifetime = convertTo( OperandLifeTime::SUBGRAPH_OUTPUT); if (!inputOperandIndexesOfCompliantOperations.count(origOperandIndex) && origOperand.numberOfConsumers) { // Was consumed only by noncompliant operations; convert to // an output of the sliced model. slicedOperand.lifetime = subgraphOutputLifetime; } VLOG(COMPILATION) << "origOperandIndexToSlicedIndex compliant output created " << origOperandIndex << " -> " << slicedOperandIndex << ": " << toString(slicedOperand); if (slicedOperand.lifetime == subgraphOutputLifetime) { extend(&slice->mHidlModel.outputIndexes, slicedOperandIndex); } } } } } template MetaModel::Slice MetaModel::makeSlice() const { using SlicedOperand = typename Slice::Operand; Slice slice; const auto& origOperands = mHidlModel.main.operands; const auto& origOperations = mHidlModel.main.operations; auto& slicedOperands = slice.mHidlModel.operands; // Indexes of elements of noncompliant origOperations std::set noncompliantOperations; getNoncompliantOperations(mHidlModel, &noncompliantOperations); // Map from an operand index in origOperands to the corresponding operand index in // slicedOperands std::map origOperandIndexToSlicedIndex; // Collect the operand indexes of every operand that is an input to a // compliant operation. If the operand is a CONSTANT_* or a NO_VALUE, copy // it to the sliced model and update origOperandIndexToSlicedIndex // accordingly. Otherwise, we'll deal with the operand in the subsequent // "Main loop", where we process operation outputs (intermediates and model // outputs). std::set inputOperandIndexesOfCompliantOperations; for (uint32_t origOperationIndex = 0; origOperationIndex < origOperations.size(); ++origOperationIndex) { if (noncompliantOperations.count(origOperationIndex)) { continue; } for (uint32_t input : origOperations[origOperationIndex].inputs) { if (inputOperandIndexesOfCompliantOperations.insert(input).second) { const Operand& origOperand = origOperands[input]; switch (origOperand.lifetime) { case OperandLifeTime::CONSTANT_COPY: case OperandLifeTime::CONSTANT_REFERENCE: case OperandLifeTime::NO_VALUE: { const uint32_t slicedOperandIndex = extend(&slicedOperands, convertTo(origOperand)) .first; slicedOperands[slicedOperandIndex].numberOfConsumers = 0; origOperandIndexToSlicedIndex[input] = slicedOperandIndex; VLOG(COMPILATION) << "origOperandIndexToSlicedIndex initialization created " << input << " -> " << slicedOperandIndex << ": " << toString(slicedOperands[slicedOperandIndex]); break; } default: break; } } } } OrigOperandToSlicedInputOperandIndex origOperandToSlicedInputOperandIndex( &slicedOperands, &slice.mHidlModel.inputIndexes); // An input of the original model is an input of the sliced model if and // only if it is consumed by at least one compliant operation. Note that in // the sliced model we share all model inputs of the same "type"; and that // we may later add model inputs to the sliced model. for (uint32_t origInputIndex : mHidlModel.main.inputIndexes) { if (inputOperandIndexesOfCompliantOperations.count(origInputIndex)) { const uint32_t slicedIndex = origOperandToSlicedInputOperandIndex.getIndex(origOperands[origInputIndex]); origOperandIndexToSlicedIndex[origInputIndex] = slicedIndex; VLOG(COMPILATION) << "origOperandIndexToSlicedIndex inputIndexes processing created " << origInputIndex << " -> " << slicedIndex << ": " << toString(slicedOperands[slicedIndex]); } } // Main loop: Process each operation of the original model. processOperations(&slice, &origOperandIndexToSlicedIndex, &origOperandToSlicedInputOperandIndex, noncompliantOperations, inputOperandIndexesOfCompliantOperations); // To keep things simple, we copy over these fields as-is. We could instead // opt to regenerate them based on the operands present in the sliced model: // This would be more complex and probably take more computation time, but // it would reduce the size of the sliced model, and hence the time spent // copying it around and passing it across the HAL interface. slice.mHidlModel.operandValues = mHidlModel.operandValues; slice.mHidlModel.pools = mHidlModel.pools; if (VLOG_IS_ON(COMPILATION)) { { std::ostringstream fromName; fromName << "Slice: From " << ModelVersion::name; graphDump(fromName.str().c_str(), mHidlModel); } { std::ostringstream toName; toName << "Slice: To " << ModelVersion::name; graphDump(toName.str().c_str(), convertToV1_3(slice.mHidlModel)); } } slice.mState = invalid(slice.mHidlModel, mStrictSlicing) ? SliceState::INVALID : SliceState::NORMAL; return slice; } template MetaModel::Slice MetaModel::makeSlice() const; template MetaModel::Slice MetaModel::makeSlice() const; } // namespace android::nn