/* * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #define LOG_TAG "OperationResolver" #include "OperationResolver.h" #include "NeuralNetworks.h" namespace android { namespace nn { using namespace hal; // TODO(b/119608412): Find a way to not reference every operation here. const OperationRegistration* register_ABS(); const OperationRegistration* register_ADD(); const OperationRegistration* register_AVERAGE_POOL_2D(); const OperationRegistration* register_AXIS_ALIGNED_BBOX_TRANSFORM(); const OperationRegistration* register_BIDIRECTIONAL_SEQUENCE_RNN(); const OperationRegistration* register_BOX_WITH_NMS_LIMIT(); const OperationRegistration* register_CHANNEL_SHUFFLE(); const OperationRegistration* register_CONCATENATION(); const OperationRegistration* register_CONV_2D(); const OperationRegistration* register_DEPTHWISE_CONV_2D(); const OperationRegistration* register_DEQUANTIZE(); const OperationRegistration* register_DETECTION_POSTPROCESSING(); const OperationRegistration* register_DIV(); const OperationRegistration* register_ELU(); const OperationRegistration* register_EQUAL(); const OperationRegistration* register_EXP(); const OperationRegistration* register_FILL(); const OperationRegistration* register_FLOOR(); const OperationRegistration* register_FULLY_CONNECTED(); const OperationRegistration* register_GATHER(); const OperationRegistration* register_GENERATE_PROPOSALS(); const OperationRegistration* register_GREATER(); const OperationRegistration* register_GREATER_EQUAL(); const OperationRegistration* register_HARD_SWISH(); const OperationRegistration* register_HEATMAP_MAX_KEYPOINT(); const OperationRegistration* register_INSTANCE_NORMALIZATION(); const OperationRegistration* register_L2_NORMALIZATION(); const OperationRegistration* register_L2_POOL_2D(); const OperationRegistration* register_LESS(); const OperationRegistration* register_LESS_EQUAL(); const OperationRegistration* register_LOCAL_RESPONSE_NORMALIZATION(); const OperationRegistration* register_LOG(); const OperationRegistration* register_LOGICAL_AND(); const OperationRegistration* register_LOGICAL_NOT(); const OperationRegistration* register_LOGICAL_OR(); const OperationRegistration* register_LOGISTIC(); const OperationRegistration* register_LOG_SOFTMAX(); const OperationRegistration* register_MAX_POOL_2D(); const OperationRegistration* register_MUL(); const OperationRegistration* register_NEG(); const OperationRegistration* register_NOT_EQUAL(); const OperationRegistration* register_PRELU(); const OperationRegistration* register_QUANTIZE(); const OperationRegistration* register_QUANTIZED_LSTM(); const OperationRegistration* register_RANK(); const OperationRegistration* register_REDUCE_ALL(); const OperationRegistration* register_REDUCE_ANY(); const OperationRegistration* register_REDUCE_MAX(); const OperationRegistration* register_REDUCE_MIN(); const OperationRegistration* register_REDUCE_PROD(); const OperationRegistration* register_REDUCE_SUM(); const OperationRegistration* register_RELU(); const OperationRegistration* register_RELU1(); const OperationRegistration* register_RELU6(); const OperationRegistration* register_RESIZE_BILINEAR(); const OperationRegistration* register_RESIZE_NEAREST_NEIGHBOR(); const OperationRegistration* register_ROI_ALIGN(); const OperationRegistration* register_ROI_POOLING(); const OperationRegistration* register_RSQRT(); const OperationRegistration* register_SELECT(); const OperationRegistration* register_SIN(); const OperationRegistration* register_SLICE(); const OperationRegistration* register_SOFTMAX(); const OperationRegistration* register_SQRT(); const OperationRegistration* register_SQUEEZE(); const OperationRegistration* register_STRIDED_SLICE(); const OperationRegistration* register_SUB(); const OperationRegistration* register_TANH(); const OperationRegistration* register_TOPK_V2(); const OperationRegistration* register_TRANSPOSE(); const OperationRegistration* register_TRANSPOSE_CONV_2D(); const OperationRegistration* register_UNIDIRECTIONAL_SEQUENCE_LSTM(); const OperationRegistration* register_UNIDIRECTIONAL_SEQUENCE_RNN(); BuiltinOperationResolver::BuiltinOperationResolver() { registerOperation(register_ABS()); registerOperation(register_ADD()); registerOperation(register_AVERAGE_POOL_2D()); registerOperation(register_AXIS_ALIGNED_BBOX_TRANSFORM()); registerOperation(register_BIDIRECTIONAL_SEQUENCE_RNN()); registerOperation(register_BOX_WITH_NMS_LIMIT()); registerOperation(register_CHANNEL_SHUFFLE()); registerOperation(register_CONCATENATION()); registerOperation(register_CONV_2D()); registerOperation(register_DEPTHWISE_CONV_2D()); registerOperation(register_DEQUANTIZE()); registerOperation(register_DETECTION_POSTPROCESSING()); registerOperation(register_DIV()); registerOperation(register_ELU()); registerOperation(register_EQUAL()); registerOperation(register_EXP()); registerOperation(register_FILL()); registerOperation(register_FLOOR()); registerOperation(register_FULLY_CONNECTED()); registerOperation(register_GATHER()); registerOperation(register_GENERATE_PROPOSALS()); registerOperation(register_GREATER()); registerOperation(register_GREATER_EQUAL()); registerOperation(register_HARD_SWISH()); registerOperation(register_HEATMAP_MAX_KEYPOINT()); registerOperation(register_INSTANCE_NORMALIZATION()); registerOperation(register_L2_NORMALIZATION()); registerOperation(register_L2_POOL_2D()); registerOperation(register_LESS()); registerOperation(register_LESS_EQUAL()); registerOperation(register_LOCAL_RESPONSE_NORMALIZATION()); registerOperation(register_LOG()); registerOperation(register_LOGICAL_AND()); registerOperation(register_LOGICAL_NOT()); registerOperation(register_LOGICAL_OR()); registerOperation(register_LOGISTIC()); registerOperation(register_LOG_SOFTMAX()); registerOperation(register_MAX_POOL_2D()); registerOperation(register_MUL()); registerOperation(register_NEG()); registerOperation(register_NOT_EQUAL()); registerOperation(register_PRELU()); registerOperation(register_QUANTIZE()); registerOperation(register_QUANTIZED_LSTM()); registerOperation(register_RANK()); registerOperation(register_REDUCE_ALL()); registerOperation(register_REDUCE_ANY()); registerOperation(register_REDUCE_MAX()); registerOperation(register_REDUCE_MIN()); registerOperation(register_REDUCE_PROD()); registerOperation(register_REDUCE_SUM()); registerOperation(register_RELU()); registerOperation(register_RELU1()); registerOperation(register_RELU6()); registerOperation(register_RESIZE_BILINEAR()); registerOperation(register_RESIZE_NEAREST_NEIGHBOR()); registerOperation(register_ROI_ALIGN()); registerOperation(register_ROI_POOLING()); registerOperation(register_RSQRT()); registerOperation(register_SELECT()); registerOperation(register_SIN()); registerOperation(register_SLICE()); registerOperation(register_SOFTMAX()); registerOperation(register_SQRT()); registerOperation(register_SQUEEZE()); registerOperation(register_STRIDED_SLICE()); registerOperation(register_SUB()); registerOperation(register_TANH()); registerOperation(register_TOPK_V2()); registerOperation(register_TRANSPOSE()); registerOperation(register_TRANSPOSE_CONV_2D()); registerOperation(register_UNIDIRECTIONAL_SEQUENCE_LSTM()); registerOperation(register_UNIDIRECTIONAL_SEQUENCE_RNN()); } const OperationRegistration* BuiltinOperationResolver::findOperation( OperationType operationType) const { auto index = static_cast(operationType); if (index < 0 || index >= kNumberOfOperationTypes) { return nullptr; } return mRegistrations[index]; } void BuiltinOperationResolver::registerOperation( const OperationRegistration* operationRegistration) { CHECK(operationRegistration != nullptr); auto index = static_cast(operationRegistration->type); CHECK_LE(0, index); CHECK_LT(index, kNumberOfOperationTypes); CHECK(mRegistrations[index] == nullptr); mRegistrations[index] = operationRegistration; } } // namespace nn } // namespace android