1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "OperationResolver"
18 
19 #include "OperationResolver.h"
20 
21 #include "NeuralNetworks.h"
22 
23 namespace android {
24 namespace nn {
25 
26 using namespace hal;
27 
28 // TODO(b/119608412): Find a way to not reference every operation here.
29 const OperationRegistration* register_ABS();
30 const OperationRegistration* register_ADD();
31 const OperationRegistration* register_AVERAGE_POOL_2D();
32 const OperationRegistration* register_AXIS_ALIGNED_BBOX_TRANSFORM();
33 const OperationRegistration* register_BIDIRECTIONAL_SEQUENCE_RNN();
34 const OperationRegistration* register_BOX_WITH_NMS_LIMIT();
35 const OperationRegistration* register_CHANNEL_SHUFFLE();
36 const OperationRegistration* register_CONCATENATION();
37 const OperationRegistration* register_CONV_2D();
38 const OperationRegistration* register_DEPTHWISE_CONV_2D();
39 const OperationRegistration* register_DEQUANTIZE();
40 const OperationRegistration* register_DETECTION_POSTPROCESSING();
41 const OperationRegistration* register_DIV();
42 const OperationRegistration* register_ELU();
43 const OperationRegistration* register_EQUAL();
44 const OperationRegistration* register_EXP();
45 const OperationRegistration* register_FILL();
46 const OperationRegistration* register_FLOOR();
47 const OperationRegistration* register_FULLY_CONNECTED();
48 const OperationRegistration* register_GATHER();
49 const OperationRegistration* register_GENERATE_PROPOSALS();
50 const OperationRegistration* register_GREATER();
51 const OperationRegistration* register_GREATER_EQUAL();
52 const OperationRegistration* register_HARD_SWISH();
53 const OperationRegistration* register_HEATMAP_MAX_KEYPOINT();
54 const OperationRegistration* register_INSTANCE_NORMALIZATION();
55 const OperationRegistration* register_L2_NORMALIZATION();
56 const OperationRegistration* register_L2_POOL_2D();
57 const OperationRegistration* register_LESS();
58 const OperationRegistration* register_LESS_EQUAL();
59 const OperationRegistration* register_LOCAL_RESPONSE_NORMALIZATION();
60 const OperationRegistration* register_LOG();
61 const OperationRegistration* register_LOGICAL_AND();
62 const OperationRegistration* register_LOGICAL_NOT();
63 const OperationRegistration* register_LOGICAL_OR();
64 const OperationRegistration* register_LOGISTIC();
65 const OperationRegistration* register_LOG_SOFTMAX();
66 const OperationRegistration* register_MAX_POOL_2D();
67 const OperationRegistration* register_MUL();
68 const OperationRegistration* register_NEG();
69 const OperationRegistration* register_NOT_EQUAL();
70 const OperationRegistration* register_PRELU();
71 const OperationRegistration* register_QUANTIZE();
72 const OperationRegistration* register_QUANTIZED_LSTM();
73 const OperationRegistration* register_RANK();
74 const OperationRegistration* register_REDUCE_ALL();
75 const OperationRegistration* register_REDUCE_ANY();
76 const OperationRegistration* register_REDUCE_MAX();
77 const OperationRegistration* register_REDUCE_MIN();
78 const OperationRegistration* register_REDUCE_PROD();
79 const OperationRegistration* register_REDUCE_SUM();
80 const OperationRegistration* register_RELU();
81 const OperationRegistration* register_RELU1();
82 const OperationRegistration* register_RELU6();
83 const OperationRegistration* register_RESIZE_BILINEAR();
84 const OperationRegistration* register_RESIZE_NEAREST_NEIGHBOR();
85 const OperationRegistration* register_ROI_ALIGN();
86 const OperationRegistration* register_ROI_POOLING();
87 const OperationRegistration* register_RSQRT();
88 const OperationRegistration* register_SELECT();
89 const OperationRegistration* register_SIN();
90 const OperationRegistration* register_SLICE();
91 const OperationRegistration* register_SOFTMAX();
92 const OperationRegistration* register_SQRT();
93 const OperationRegistration* register_SQUEEZE();
94 const OperationRegistration* register_STRIDED_SLICE();
95 const OperationRegistration* register_SUB();
96 const OperationRegistration* register_TANH();
97 const OperationRegistration* register_TOPK_V2();
98 const OperationRegistration* register_TRANSPOSE();
99 const OperationRegistration* register_TRANSPOSE_CONV_2D();
100 const OperationRegistration* register_UNIDIRECTIONAL_SEQUENCE_LSTM();
101 const OperationRegistration* register_UNIDIRECTIONAL_SEQUENCE_RNN();
102 
BuiltinOperationResolver()103 BuiltinOperationResolver::BuiltinOperationResolver() {
104     registerOperation(register_ABS());
105     registerOperation(register_ADD());
106     registerOperation(register_AVERAGE_POOL_2D());
107     registerOperation(register_AXIS_ALIGNED_BBOX_TRANSFORM());
108     registerOperation(register_BIDIRECTIONAL_SEQUENCE_RNN());
109     registerOperation(register_BOX_WITH_NMS_LIMIT());
110     registerOperation(register_CHANNEL_SHUFFLE());
111     registerOperation(register_CONCATENATION());
112     registerOperation(register_CONV_2D());
113     registerOperation(register_DEPTHWISE_CONV_2D());
114     registerOperation(register_DEQUANTIZE());
115     registerOperation(register_DETECTION_POSTPROCESSING());
116     registerOperation(register_DIV());
117     registerOperation(register_ELU());
118     registerOperation(register_EQUAL());
119     registerOperation(register_EXP());
120     registerOperation(register_FILL());
121     registerOperation(register_FLOOR());
122     registerOperation(register_FULLY_CONNECTED());
123     registerOperation(register_GATHER());
124     registerOperation(register_GENERATE_PROPOSALS());
125     registerOperation(register_GREATER());
126     registerOperation(register_GREATER_EQUAL());
127     registerOperation(register_HARD_SWISH());
128     registerOperation(register_HEATMAP_MAX_KEYPOINT());
129     registerOperation(register_INSTANCE_NORMALIZATION());
130     registerOperation(register_L2_NORMALIZATION());
131     registerOperation(register_L2_POOL_2D());
132     registerOperation(register_LESS());
133     registerOperation(register_LESS_EQUAL());
134     registerOperation(register_LOCAL_RESPONSE_NORMALIZATION());
135     registerOperation(register_LOG());
136     registerOperation(register_LOGICAL_AND());
137     registerOperation(register_LOGICAL_NOT());
138     registerOperation(register_LOGICAL_OR());
139     registerOperation(register_LOGISTIC());
140     registerOperation(register_LOG_SOFTMAX());
141     registerOperation(register_MAX_POOL_2D());
142     registerOperation(register_MUL());
143     registerOperation(register_NEG());
144     registerOperation(register_NOT_EQUAL());
145     registerOperation(register_PRELU());
146     registerOperation(register_QUANTIZE());
147     registerOperation(register_QUANTIZED_LSTM());
148     registerOperation(register_RANK());
149     registerOperation(register_REDUCE_ALL());
150     registerOperation(register_REDUCE_ANY());
151     registerOperation(register_REDUCE_MAX());
152     registerOperation(register_REDUCE_MIN());
153     registerOperation(register_REDUCE_PROD());
154     registerOperation(register_REDUCE_SUM());
155     registerOperation(register_RELU());
156     registerOperation(register_RELU1());
157     registerOperation(register_RELU6());
158     registerOperation(register_RESIZE_BILINEAR());
159     registerOperation(register_RESIZE_NEAREST_NEIGHBOR());
160     registerOperation(register_ROI_ALIGN());
161     registerOperation(register_ROI_POOLING());
162     registerOperation(register_RSQRT());
163     registerOperation(register_SELECT());
164     registerOperation(register_SIN());
165     registerOperation(register_SLICE());
166     registerOperation(register_SOFTMAX());
167     registerOperation(register_SQRT());
168     registerOperation(register_SQUEEZE());
169     registerOperation(register_STRIDED_SLICE());
170     registerOperation(register_SUB());
171     registerOperation(register_TANH());
172     registerOperation(register_TOPK_V2());
173     registerOperation(register_TRANSPOSE());
174     registerOperation(register_TRANSPOSE_CONV_2D());
175     registerOperation(register_UNIDIRECTIONAL_SEQUENCE_LSTM());
176     registerOperation(register_UNIDIRECTIONAL_SEQUENCE_RNN());
177 }
178 
findOperation(OperationType operationType) const179 const OperationRegistration* BuiltinOperationResolver::findOperation(
180         OperationType operationType) const {
181     auto index = static_cast<int32_t>(operationType);
182     if (index < 0 || index >= kNumberOfOperationTypes) {
183         return nullptr;
184     }
185     return mRegistrations[index];
186 }
187 
registerOperation(const OperationRegistration * operationRegistration)188 void BuiltinOperationResolver::registerOperation(
189         const OperationRegistration* operationRegistration) {
190     CHECK(operationRegistration != nullptr);
191     auto index = static_cast<int32_t>(operationRegistration->type);
192     CHECK_LE(0, index);
193     CHECK_LT(index, kNumberOfOperationTypes);
194     CHECK(mRegistrations[index] == nullptr);
195     mRegistrations[index] = operationRegistration;
196 }
197 
198 }  // namespace nn
199 }  // namespace android
200