1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "Operations"
18 
19 #include <tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h>
20 #include <tensorflow/lite/kernels/internal/optimized/optimized_ops.h>
21 #include <tensorflow/lite/kernels/internal/reference/integer_ops/logistic.h>
22 #include <tensorflow/lite/kernels/internal/reference/integer_ops/tanh.h>
23 #include <tensorflow/lite/kernels/internal/reference/reference_ops.h>
24 
25 #include <algorithm>
26 #include <limits>
27 #include <vector>
28 
29 #include "ActivationFunctor.h"
30 #include "CpuOperationUtils.h"
31 #include "HalInterfaces.h"
32 #include "OperationResolver.h"
33 #include "OperationsUtils.h"
34 #include "Tracing.h"
35 
36 namespace android {
37 namespace nn {
38 
39 using namespace hal;
40 
41 namespace activation {
42 
43 constexpr uint32_t kNumInputs = 1;
44 constexpr uint32_t kInputTensor = 0;
45 
46 constexpr uint32_t kNumOutputs = 1;
47 constexpr uint32_t kOutputTensor = 0;
48 
49 namespace {
50 
51 template <typename T>
reluFloat(const T * inputData,const Shape & inputShape,T * outputData,const Shape & outputShape,float reluMin=0.f,float reluMax=std::numeric_limits<float>::max ())52 bool reluFloat(const T* inputData, const Shape& inputShape, T* outputData, const Shape& outputShape,
53                float reluMin = 0.f, float reluMax = std::numeric_limits<float>::max()) {
54     NNTRACE_COMP("reluX");
55     int numElements = getNumberOfElements(inputShape);
56     for (int i = 0; i < numElements; i++, inputData++, outputData++) {
57         *outputData = static_cast<T>(
58                 std::min(std::max(reluMin, static_cast<float>(*inputData)), reluMax));
59     }
60     return true;
61 }
62 template bool reluFloat<float>(const float* inputData, const Shape& inputShape, float* outputData,
63                                const Shape& outputShape, float reluMin, float reluMax);
64 template bool reluFloat<_Float16>(const _Float16* inputData, const Shape& inputShape,
65                                   _Float16* outputData, const Shape& outputShape, float reluMin,
66                                   float reluMax);
67 
68 template <typename T>
relu1Float(const T * inputData,const Shape & inputShape,T * outputData,const Shape & outputShape)69 bool relu1Float(const T* inputData, const Shape& inputShape, T* outputData,
70                 const Shape& outputShape) {
71     return reluFloat(inputData, inputShape, outputData, outputShape, -1.f, 1.f);
72 }
73 template bool relu1Float<float>(const float* inputData, const Shape& inputShape, float* outputData,
74                                 const Shape& outputShape);
75 template bool relu1Float<_Float16>(const _Float16* inputData, const Shape& inputShape,
76                                    _Float16* outputData, const Shape& outputShape);
77 
78 template <typename T>
relu6Float(const T * inputData,const Shape & inputShape,T * outputData,const Shape & outputShape)79 bool relu6Float(const T* inputData, const Shape& inputShape, T* outputData,
80                 const Shape& outputShape) {
81     return reluFloat(inputData, inputShape, outputData, outputShape, 0.f, 6.f);
82 }
83 template bool relu6Float<float>(const float* inputData, const Shape& inputShape, float* outputData,
84                                 const Shape& outputShape);
85 template bool relu6Float<_Float16>(const _Float16* inputData, const Shape& inputShape,
86                                    _Float16* outputData, const Shape& outputShape);
87 
tanhFloat16(const _Float16 * inputData,const Shape & inputShape,_Float16 * outputData,const Shape & outputShape)88 bool tanhFloat16(const _Float16* inputData, const Shape& inputShape, _Float16* outputData,
89                  const Shape& outputShape) {
90     NNTRACE_COMP("tanhFloat16");
91     int numElements = getNumberOfElements(inputShape);
92     for (int i = 0; i < numElements; i++, inputData++, outputData++) {
93         *outputData = static_cast<_Float16>(std::tanh(static_cast<float>(*inputData)));
94     }
95     return true;
96 }
97 
tanhFloat32(const float * inputData,const Shape & inputShape,float * outputData,const Shape & outputShape)98 bool tanhFloat32(const float* inputData, const Shape& inputShape, float* outputData,
99                  const Shape& outputShape) {
100     NNTRACE_COMP("tanhFloat32");
101     int numElements = getNumberOfElements(inputShape);
102     for (int i = 0; i < numElements; i++, inputData++, outputData++) {
103         *outputData = std::tanh(*inputData);
104     }
105     return true;
106 }
107 
108 template <typename T>
logisticFloat(const T * inputData,const Shape & inputShape,T * outputData,const Shape & outputShape)109 bool logisticFloat(const T* inputData, const Shape& inputShape, T* outputData,
110                    const Shape& outputShape) {
111     NNTRACE_COMP("logisticFloat");
112     int numElements = getNumberOfElements(inputShape);
113     for (int i = 0; i < numElements; i++, inputData++, outputData++) {
114         *outputData = static_cast<T>(1.f / (1.f + std::exp(static_cast<float>(-*inputData))));
115     }
116     return true;
117 }
118 template bool logisticFloat<float>(const float* inputData, const Shape& inputShape,
119                                    float* outputData, const Shape& outputShape);
120 template bool logisticFloat<_Float16>(const _Float16* inputData, const Shape& inputShape,
121                                       _Float16* outputData, const Shape& outputShape);
122 
123 template <ActivationFn activation>
reluXQuant8(const uint8_t * inputData,const Shape & inputShape,uint8_t * outputData,const Shape & outputShape)124 inline bool reluXQuant8(const uint8_t* inputData, const Shape& inputShape, uint8_t* outputData,
125                         const Shape& outputShape) {
126     int numElements = getNumberOfElements(inputShape);
127     int32_t output_activation_min = 0;
128     int32_t output_activation_max = 0;
129 
130     CalculateActivationRangeUint8(activation, inputShape, &output_activation_min,
131                                   &output_activation_max);
132 
133     for (int i = 0; i < numElements; i++, inputData++, outputData++) {
134         *outputData = std::min((uint8_t)output_activation_max,
135                                std::max((uint8_t)output_activation_min, *inputData));
136     }
137     return true;
138 }
139 
reluQuant8(const uint8_t * inputData,const Shape & inputShape,uint8_t * outputData,const Shape & outputShape)140 bool reluQuant8(const uint8_t* inputData, const Shape& inputShape, uint8_t* outputData,
141                 const Shape& outputShape) {
142     NNTRACE_COMP("reluQuant8");
143     return reluXQuant8<kActivationRelu>(inputData, inputShape, outputData, outputShape);
144 }
145 
relu1Quant8(const uint8_t * inputData,const Shape & inputShape,uint8_t * outputData,const Shape & outputShape)146 bool relu1Quant8(const uint8_t* inputData, const Shape& inputShape, uint8_t* outputData,
147                  const Shape& outputShape) {
148     NNTRACE_COMP("relu1Quant8");
149     return reluXQuant8<kActivationRelu1>(inputData, inputShape, outputData, outputShape);
150 }
151 
relu6Quant8(const uint8_t * inputData,const Shape & inputShape,uint8_t * outputData,const Shape & outputShape)152 bool relu6Quant8(const uint8_t* inputData, const Shape& inputShape, uint8_t* outputData,
153                  const Shape& outputShape) {
154     NNTRACE_COMP("relu6Quant8");
155     return reluXQuant8<kActivationRelu6>(inputData, inputShape, outputData, outputShape);
156 }
157 
tanhQuant8(const uint8_t * inputData,const Shape & inputShape,uint8_t * outputData,const Shape & outputShape)158 bool tanhQuant8(const uint8_t* inputData, const Shape& inputShape, uint8_t* outputData,
159                 const Shape& outputShape) {
160     NNTRACE_TRANS("tanhQuant8");
161     if (outputShape.offset != 128 || outputShape.scale != 1.f / 128) {
162         LOG(ERROR) << "incorrect scale or offset for TANH output";
163         return false;
164     }
165 
166     int numElements = getNumberOfElements(inputShape);
167     static constexpr int kInputIntegerBits = 4;
168 
169     const double input_real_multiplier =
170             inputShape.scale * static_cast<double>(1 << (31 - kInputIntegerBits));
171 
172     int32_t input_multiplier = 0;
173     int32_t input_left_shift = 0;
174     if (!QuantizeMultiplierGreaterThanOne(input_real_multiplier, &input_multiplier,
175                                           &input_left_shift)) {
176         return false;
177     }
178     int32_t input_range_radius = CalculateInputRadius(kInputIntegerBits, input_left_shift);
179 
180     NNTRACE_COMP_SWITCH("optimized_ops::Tanh");
181     tflite::optimized_ops::Tanh(inputData, convertShapeToTflshape(inputShape), inputShape.offset,
182                                 input_range_radius, input_multiplier, input_left_shift, outputData,
183                                 convertShapeToTflshape(outputShape));
184 
185     return true;
186 }
187 
logisticQuant8(const uint8_t * inputData,const Shape & inputShape,uint8_t * outputData,const Shape & outputShape)188 bool logisticQuant8(const uint8_t* inputData, const Shape& inputShape, uint8_t* outputData,
189                     const Shape& outputShape) {
190     NNTRACE_TRANS("logisticQuant8");
191     if (outputShape.offset != 0 || outputShape.scale != 1.f / 256) {
192         LOG(ERROR) << "incorrect scale / offset for output";
193         return false;
194     }
195 
196     int numElements = getNumberOfElements(inputShape);
197     static constexpr int kInputIntegerBits = 4;
198 
199     const double input_real_multiplier =
200             inputShape.scale * static_cast<double>(1 << (31 - kInputIntegerBits));
201 
202     int32_t input_multiplier = 0;
203     int32_t input_left_shift = 0;
204     if (!QuantizeMultiplierGreaterThanOne(input_real_multiplier, &input_multiplier,
205                                           &input_left_shift)) {
206         return false;
207     }
208     int32_t input_range_radius = CalculateInputRadius(kInputIntegerBits, input_left_shift);
209 
210     NNTRACE_COMP_SWITCH("optimized_ops::Logistic");
211     tflite::optimized_ops::Logistic(
212             inputData, convertShapeToTflshape(inputShape), inputShape.offset, input_range_radius,
213             input_multiplier, input_left_shift, outputData, convertShapeToTflshape(outputShape));
214 
215     return true;
216 }
217 
218 template <ActivationFn activation>
reluXQuant8Signed(const int8_t * inputData,const Shape & inputShape,int8_t * outputData,const Shape & outputShape)219 inline bool reluXQuant8Signed(const int8_t* inputData, const Shape& inputShape, int8_t* outputData,
220                               const Shape& outputShape) {
221     int numElements = getNumberOfElements(inputShape);
222     int32_t output_activation_min = 0;
223     int32_t output_activation_max = 0;
224 
225     CalculateActivationRangeInt8(activation, inputShape, &output_activation_min,
226                                  &output_activation_max);
227 
228     for (int i = 0; i < numElements; i++, inputData++, outputData++) {
229         *outputData = std::min((int8_t)output_activation_max,
230                                std::max((int8_t)output_activation_min, *inputData));
231     }
232     return true;
233 }
234 
reluQuant8Signed(const int8_t * inputData,const Shape & inputShape,int8_t * outputData,const Shape & outputShape)235 bool reluQuant8Signed(const int8_t* inputData, const Shape& inputShape, int8_t* outputData,
236                       const Shape& outputShape) {
237     NNTRACE_COMP("reluQuant8");
238     return reluXQuant8Signed<kActivationRelu>(inputData, inputShape, outputData, outputShape);
239 }
240 
relu1Quant8Signed(const int8_t * inputData,const Shape & inputShape,int8_t * outputData,const Shape & outputShape)241 bool relu1Quant8Signed(const int8_t* inputData, const Shape& inputShape, int8_t* outputData,
242                        const Shape& outputShape) {
243     NNTRACE_COMP("relu1Quant8");
244     return reluXQuant8Signed<kActivationRelu1>(inputData, inputShape, outputData, outputShape);
245 }
246 
relu6Quant8Signed(const int8_t * inputData,const Shape & inputShape,int8_t * outputData,const Shape & outputShape)247 bool relu6Quant8Signed(const int8_t* inputData, const Shape& inputShape, int8_t* outputData,
248                        const Shape& outputShape) {
249     NNTRACE_COMP("relu6Quant8");
250     return reluXQuant8Signed<kActivationRelu6>(inputData, inputShape, outputData, outputShape);
251 }
252 
tanhQuant8Signed(const int8_t * inputData,const Shape & inputShape,int8_t * outputData,const Shape & outputShape)253 bool tanhQuant8Signed(const int8_t* inputData, const Shape& inputShape, int8_t* outputData,
254                       const Shape& outputShape) {
255     NNTRACE_TRANS("tanhQuant8Signed");
256     if (outputShape.offset != 0 || outputShape.scale != 1.f / 128) {
257         LOG(ERROR) << "incorrect scale or offset for TANH output";
258         return false;
259     }
260 
261     int numElements = getNumberOfElements(inputShape);
262     static constexpr int kInputIntegerBits = 4;
263 
264     const double input_real_multiplier =
265             inputShape.scale * static_cast<double>(1 << (31 - kInputIntegerBits));
266 
267     int32_t input_multiplier = 0;
268     int32_t input_left_shift = 0;
269     if (!QuantizeMultiplierGreaterThanOne(input_real_multiplier, &input_multiplier,
270                                           &input_left_shift)) {
271         return false;
272     }
273     int32_t input_range_radius = CalculateInputRadius(kInputIntegerBits, input_left_shift);
274 
275     NNTRACE_COMP_SWITCH("reference_integer_ops::Tanh");
276     tflite::reference_integer_ops::Tanh(inputShape.offset, input_range_radius, input_multiplier,
277                                         input_left_shift, numElements, inputData, outputData);
278 
279     return true;
280 }
281 
logisticQuant8Signed(const int8_t * inputData,const Shape & inputShape,int8_t * outputData,const Shape & outputShape)282 bool logisticQuant8Signed(const int8_t* inputData, const Shape& inputShape, int8_t* outputData,
283                           const Shape& outputShape) {
284     NNTRACE_TRANS("logisticQuant8Signed");
285     if (outputShape.offset != -128 || outputShape.scale != 1.f / 256) {
286         LOG(ERROR) << "incorrect scale / offset for output";
287         return false;
288     }
289 
290     int numElements = getNumberOfElements(inputShape);
291     static constexpr int kInputIntegerBits = 4;
292 
293     const double input_real_multiplier =
294             inputShape.scale * static_cast<double>(1 << (31 - kInputIntegerBits));
295 
296     int32_t input_multiplier = 0;
297     int32_t input_left_shift = 0;
298     if (!QuantizeMultiplierGreaterThanOne(input_real_multiplier, &input_multiplier,
299                                           &input_left_shift)) {
300         return false;
301     }
302     int32_t input_range_radius = CalculateInputRadius(kInputIntegerBits, input_left_shift);
303 
304     NNTRACE_COMP_SWITCH("reference_integer_ops::Logistic");
305     tflite::reference_integer_ops::Logistic(inputShape.offset, input_range_radius, input_multiplier,
306                                             input_left_shift, numElements, inputData, outputData);
307 
308     return true;
309 }
310 
DownScaleInt32ToInt16Multiplier(int32_t multiplier_int32,int16_t * multiplier_int16)311 void DownScaleInt32ToInt16Multiplier(int32_t multiplier_int32, int16_t* multiplier_int16) {
312     TFLITE_DCHECK_GE(multiplier_int32, 0);
313     static constexpr int32_t kRoundingOffset = 1 << 15;
314     if (multiplier_int32 >= std::numeric_limits<int32_t>::max() - kRoundingOffset) {
315         *multiplier_int16 = std::numeric_limits<int16_t>::max();
316         return;
317     }
318     const int32_t result = (multiplier_int32 + kRoundingOffset) >> 16;
319     TFLITE_DCHECK_LE(result << 16, multiplier_int32 + kRoundingOffset);
320     TFLITE_DCHECK_GT(result << 16, multiplier_int32 - kRoundingOffset);
321     *multiplier_int16 = result;
322     TFLITE_DCHECK_EQ(*multiplier_int16, result);
323 }
324 
325 template <typename T>
hardSwishQuant(const T * inputData,const Shape & inputShape,T * outputData,const Shape & outputShape)326 bool hardSwishQuant(const T* inputData, const Shape& inputShape, T* outputData,
327                     const Shape& outputShape) {
328     tflite::HardSwishParams params;
329     params.input_zero_point = inputShape.offset;
330     params.output_zero_point = outputShape.offset;
331     const float input_scale = inputShape.scale;
332     const float hires_input_scale = (1.0f / 128.0f) * input_scale;
333     const float reluish_scale = 3.0f / 32768.0f;
334     const float output_scale = outputShape.scale;
335 
336     const float output_multiplier = hires_input_scale / output_scale;
337 
338     int32_t output_multiplier_fixedpoint_int32;
339     NN_RET_CHECK(QuantizeMultiplier(output_multiplier, &output_multiplier_fixedpoint_int32,
340                                     &params.output_multiplier_exponent));
341     DownScaleInt32ToInt16Multiplier(output_multiplier_fixedpoint_int32,
342                                     &params.output_multiplier_fixedpoint_int16);
343     NN_RET_CHECK(params.output_multiplier_exponent <= 0);
344 
345     const float reluish_multiplier = hires_input_scale / reluish_scale;
346     int32_t reluish_multiplier_fixedpoint_int32;
347     NN_RET_CHECK(QuantizeMultiplier(reluish_multiplier, &reluish_multiplier_fixedpoint_int32,
348                                     &params.reluish_multiplier_exponent));
349     DownScaleInt32ToInt16Multiplier(reluish_multiplier_fixedpoint_int32,
350                                     &params.reluish_multiplier_fixedpoint_int16);
351 
352     tflite::reference_ops::HardSwish(params, convertShapeToTflshape(inputShape), inputData,
353                                      convertShapeToTflshape(outputShape), outputData);
354     return true;
355 }
356 
357 }  // namespace
358 
validate(OperationType opType,const IOperationValidationContext * context)359 bool validate(OperationType opType, const IOperationValidationContext* context) {
360     NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
361     NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
362     auto inputType = context->getInputType(kInputTensor);
363     if (inputType == OperandType::TENSOR_FLOAT32) {
364         NN_RET_CHECK(validateHalVersion(context, HalVersion::V1_0));
365     } else if (inputType == OperandType::TENSOR_FLOAT16) {
366         NN_RET_CHECK(validateHalVersion(context, HalVersion::V1_2));
367     } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM) {
368         if (opType == OperationType::TANH) {
369             NN_RET_CHECK(validateHalVersion(context, HalVersion::V1_2));
370         } else {
371             NN_RET_CHECK(validateHalVersion(context, HalVersion::V1_0));
372         }
373     } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
374         NN_RET_CHECK(validateHalVersion(context, HalVersion::V1_3));
375     } else {
376         NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << getOperationName(opType);
377     }
378     const Shape& input = context->getInputShape(kInputTensor);
379     if (hasKnownRank(input)) {
380         NN_RET_CHECK_LE(getNumberOfDimensions(input), 4);
381     }
382     return validateInputTypes(context, {inputType}) && validateOutputTypes(context, {inputType});
383 }
384 
validateHardSwish(const IOperationValidationContext * context)385 bool validateHardSwish(const IOperationValidationContext* context) {
386     NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
387     NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
388     auto inputType = context->getInputType(kInputTensor);
389     if (inputType == OperandType::TENSOR_FLOAT16 || inputType == OperandType::TENSOR_FLOAT32 ||
390         inputType == OperandType::TENSOR_QUANT8_ASYMM ||
391         inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
392         NN_RET_CHECK(validateHalVersion(context, HalVersion::V1_3));
393     } else {
394         NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation ELU";
395     }
396     return validateInputTypes(context, {inputType}) && validateOutputTypes(context, {inputType});
397 }
398 
prepare(OperationType opType,IOperationExecutionContext * context)399 bool prepare(OperationType opType, IOperationExecutionContext* context) {
400     Shape input = context->getInputShape(kInputTensor);
401     if (opType != OperationType::HARD_SWISH) {
402         NN_RET_CHECK_LE(getNumberOfDimensions(input), 4);
403     }
404     Shape output = input;
405     if (input.type == OperandType::TENSOR_QUANT8_ASYMM ||
406         input.type == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
407         bool isSigned = input.type == OperandType::TENSOR_QUANT8_ASYMM_SIGNED;
408         switch (opType) {
409             case OperationType::HARD_SWISH: {
410                 auto outputShape = context->getOutputShape(kOutputTensor);
411                 output.scale = outputShape.scale;
412                 output.offset = outputShape.offset;
413             } break;
414             case OperationType::RELU:
415             case OperationType::RELU1:
416             case OperationType::RELU6:
417                 break;
418             case OperationType::LOGISTIC:
419                 output.scale = 1.f / 256;
420                 output.offset = isSigned ? -128 : 0;
421                 break;
422             case OperationType::TANH:
423                 output.scale = 1.f / 128;
424                 output.offset = isSigned ? 0 : 128;
425                 break;
426             default:
427                 NN_RET_CHECK_FAIL() << "Unsupported operation type";
428         }
429     }
430     return context->setOutputShape(kOutputTensor, output);
431 }
432 
executeRelu(IOperationExecutionContext * context)433 bool executeRelu(IOperationExecutionContext* context) {
434     // Bypass execution in the case of zero-sized input.
435     if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
436     switch (context->getInputType(kInputTensor)) {
437         case OperandType::TENSOR_FLOAT16:
438             return reluFloat(context->getInputBuffer<_Float16>(kInputTensor),
439                              context->getInputShape(kInputTensor),
440                              context->getOutputBuffer<_Float16>(kOutputTensor),
441                              context->getOutputShape(kOutputTensor));
442         case OperandType::TENSOR_FLOAT32:
443             return reluFloat(context->getInputBuffer<float>(kInputTensor),
444                              context->getInputShape(kInputTensor),
445                              context->getOutputBuffer<float>(kOutputTensor),
446                              context->getOutputShape(kOutputTensor));
447         case OperandType::TENSOR_QUANT8_ASYMM:
448             return reluQuant8(context->getInputBuffer<uint8_t>(kInputTensor),
449                               context->getInputShape(kInputTensor),
450                               context->getOutputBuffer<uint8_t>(kOutputTensor),
451                               context->getOutputShape(kOutputTensor));
452         case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
453             return reluQuant8Signed(context->getInputBuffer<int8_t>(kInputTensor),
454                                     context->getInputShape(kInputTensor),
455                                     context->getOutputBuffer<int8_t>(kOutputTensor),
456                                     context->getOutputShape(kOutputTensor));
457         default:
458             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation RELU";
459     }
460 }
461 
executeRelu1(IOperationExecutionContext * context)462 bool executeRelu1(IOperationExecutionContext* context) {
463     // Bypass execution in the case of zero-sized input.
464     if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
465     switch (context->getInputType(kInputTensor)) {
466         case OperandType::TENSOR_FLOAT16:
467             return relu1Float(context->getInputBuffer<_Float16>(kInputTensor),
468                               context->getInputShape(kInputTensor),
469                               context->getOutputBuffer<_Float16>(kOutputTensor),
470                               context->getOutputShape(kOutputTensor));
471         case OperandType::TENSOR_FLOAT32:
472             return relu1Float(context->getInputBuffer<float>(kInputTensor),
473                               context->getInputShape(kInputTensor),
474                               context->getOutputBuffer<float>(kOutputTensor),
475                               context->getOutputShape(kOutputTensor));
476         case OperandType::TENSOR_QUANT8_ASYMM:
477             return relu1Quant8(context->getInputBuffer<uint8_t>(kInputTensor),
478                                context->getInputShape(kInputTensor),
479                                context->getOutputBuffer<uint8_t>(kOutputTensor),
480                                context->getOutputShape(kOutputTensor));
481         case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
482             return relu1Quant8Signed(context->getInputBuffer<int8_t>(kInputTensor),
483                                      context->getInputShape(kInputTensor),
484                                      context->getOutputBuffer<int8_t>(kOutputTensor),
485                                      context->getOutputShape(kOutputTensor));
486         default:
487             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation RELU1";
488     }
489 }
490 
executeRelu6(IOperationExecutionContext * context)491 bool executeRelu6(IOperationExecutionContext* context) {
492     // Bypass execution in the case of zero-sized input.
493     if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
494     switch (context->getInputType(kInputTensor)) {
495         case OperandType::TENSOR_FLOAT16:
496             return relu6Float(context->getInputBuffer<_Float16>(kInputTensor),
497                               context->getInputShape(kInputTensor),
498                               context->getOutputBuffer<_Float16>(kOutputTensor),
499                               context->getOutputShape(kOutputTensor));
500         case OperandType::TENSOR_FLOAT32:
501             return relu6Float(context->getInputBuffer<float>(kInputTensor),
502                               context->getInputShape(kInputTensor),
503                               context->getOutputBuffer<float>(kOutputTensor),
504                               context->getOutputShape(kOutputTensor));
505         case OperandType::TENSOR_QUANT8_ASYMM:
506             return relu6Quant8(context->getInputBuffer<uint8_t>(kInputTensor),
507                                context->getInputShape(kInputTensor),
508                                context->getOutputBuffer<uint8_t>(kOutputTensor),
509                                context->getOutputShape(kOutputTensor));
510         case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
511             return relu6Quant8Signed(context->getInputBuffer<int8_t>(kInputTensor),
512                                      context->getInputShape(kInputTensor),
513                                      context->getOutputBuffer<int8_t>(kOutputTensor),
514                                      context->getOutputShape(kOutputTensor));
515         default:
516             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation RELU6";
517     }
518 }
519 
executeLogistic(IOperationExecutionContext * context)520 bool executeLogistic(IOperationExecutionContext* context) {
521     // Bypass execution in the case of zero-sized input.
522     if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
523     switch (context->getInputType(kInputTensor)) {
524         case OperandType::TENSOR_FLOAT16:
525             return logisticFloat(context->getInputBuffer<_Float16>(kInputTensor),
526                                  context->getInputShape(kInputTensor),
527                                  context->getOutputBuffer<_Float16>(kOutputTensor),
528                                  context->getOutputShape(kOutputTensor));
529         case OperandType::TENSOR_FLOAT32:
530             return logisticFloat(context->getInputBuffer<float>(kInputTensor),
531                                  context->getInputShape(kInputTensor),
532                                  context->getOutputBuffer<float>(kOutputTensor),
533                                  context->getOutputShape(kOutputTensor));
534         case OperandType::TENSOR_QUANT8_ASYMM:
535             return logisticQuant8(context->getInputBuffer<uint8_t>(kInputTensor),
536                                   context->getInputShape(kInputTensor),
537                                   context->getOutputBuffer<uint8_t>(kOutputTensor),
538                                   context->getOutputShape(kOutputTensor));
539         case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
540             return logisticQuant8Signed(context->getInputBuffer<int8_t>(kInputTensor),
541                                         context->getInputShape(kInputTensor),
542                                         context->getOutputBuffer<int8_t>(kOutputTensor),
543                                         context->getOutputShape(kOutputTensor));
544         default:
545             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation LOGISTIC";
546     }
547 }
548 
executeTanh(IOperationExecutionContext * context)549 bool executeTanh(IOperationExecutionContext* context) {
550     // Bypass execution in the case of zero-sized input.
551     if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
552     switch (context->getInputType(kInputTensor)) {
553         case OperandType::TENSOR_FLOAT16:
554             return tanhFloat16(context->getInputBuffer<_Float16>(kInputTensor),
555                                context->getInputShape(kInputTensor),
556                                context->getOutputBuffer<_Float16>(kOutputTensor),
557                                context->getOutputShape(kOutputTensor));
558         case OperandType::TENSOR_FLOAT32:
559             return tanhFloat32(context->getInputBuffer<float>(kInputTensor),
560                                context->getInputShape(kInputTensor),
561                                context->getOutputBuffer<float>(kOutputTensor),
562                                context->getOutputShape(kOutputTensor));
563         case OperandType::TENSOR_QUANT8_ASYMM:
564             return tanhQuant8(context->getInputBuffer<uint8_t>(kInputTensor),
565                               context->getInputShape(kInputTensor),
566                               context->getOutputBuffer<uint8_t>(kOutputTensor),
567                               context->getOutputShape(kOutputTensor));
568         case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
569             return tanhQuant8Signed(context->getInputBuffer<int8_t>(kInputTensor),
570                                     context->getInputShape(kInputTensor),
571                                     context->getOutputBuffer<int8_t>(kOutputTensor),
572                                     context->getOutputShape(kOutputTensor));
573         default:
574             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation TANH";
575     }
576 }
577 
executeHardSwish(IOperationExecutionContext * context)578 bool executeHardSwish(IOperationExecutionContext* context) {
579     // Bypass execution in the case of zero-sized input.
580     if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
581     switch (context->getInputType(kInputTensor)) {
582         case OperandType::TENSOR_FLOAT16: {
583             const Shape& inputShape = context->getInputShape(kInputTensor);
584             const Shape& outputShape = context->getOutputShape(kOutputTensor);
585             std::vector<float> inputFloat(getNumberOfElements(inputShape));
586             std::vector<float> outputFloat(getNumberOfElements(outputShape));
587             convertFloat16ToFloat32(context->getInputBuffer<_Float16>(kInputTensor), &inputFloat);
588             tflite::reference_ops::HardSwish(convertShapeToTflshape(inputShape), inputFloat.data(),
589                                              convertShapeToTflshape(outputShape),
590                                              outputFloat.data());
591             convertFloat32ToFloat16(outputFloat, context->getOutputBuffer<_Float16>(kOutputTensor));
592             return true;
593         }
594         case OperandType::TENSOR_FLOAT32: {
595             tflite::reference_ops::HardSwish(
596                     convertShapeToTflshape(context->getInputShape(kInputTensor)),
597                     context->getInputBuffer<float>(kInputTensor),
598                     convertShapeToTflshape(context->getOutputShape(kOutputTensor)),
599                     context->getOutputBuffer<float>(kOutputTensor));
600             return true;
601         }
602         case OperandType::TENSOR_QUANT8_ASYMM:
603             return hardSwishQuant(context->getInputBuffer<uint8_t>(kInputTensor),
604                                   context->getInputShape(kInputTensor),
605                                   context->getOutputBuffer<uint8_t>(kOutputTensor),
606                                   context->getOutputShape(kOutputTensor));
607         case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
608             return hardSwishQuant(context->getInputBuffer<int8_t>(kInputTensor),
609                                   context->getInputShape(kInputTensor),
610                                   context->getOutputBuffer<int8_t>(kOutputTensor),
611                                   context->getOutputShape(kOutputTensor));
612         default:
613             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation TANH";
614     }
615 }
616 
617 }  // namespace activation
618 
619 using std::placeholders::_1;
620 NN_REGISTER_OPERATION(RELU, "RELU", std::bind(activation::validate, OperationType::RELU, _1),
621                       std::bind(activation::prepare, OperationType::RELU, _1),
622                       activation::executeRelu, .allowZeroSizedInput = true);
623 NN_REGISTER_OPERATION(RELU1, "RELU1", std::bind(activation::validate, OperationType::RELU1, _1),
624                       std::bind(activation::prepare, OperationType::RELU1, _1),
625                       activation::executeRelu1, .allowZeroSizedInput = true);
626 NN_REGISTER_OPERATION(RELU6, "RELU6", std::bind(activation::validate, OperationType::RELU6, _1),
627                       std::bind(activation::prepare, OperationType::RELU6, _1),
628                       activation::executeRelu6, .allowZeroSizedInput = true);
629 NN_REGISTER_OPERATION(LOGISTIC, "LOGISTIC",
630                       std::bind(activation::validate, OperationType::LOGISTIC, _1),
631                       std::bind(activation::prepare, OperationType::LOGISTIC, _1),
632                       activation::executeLogistic, .allowZeroSizedInput = true);
633 NN_REGISTER_OPERATION(TANH, "TANH", std::bind(activation::validate, OperationType::TANH, _1),
634                       std::bind(activation::prepare, OperationType::TANH, _1),
635                       activation::executeTanh, .allowZeroSizedInput = true);
636 NN_REGISTER_OPERATION(HARD_SWISH, "HARD_SWISH", activation::validateHardSwish,
637                       std::bind(activation::prepare, OperationType::HARD_SWISH, _1),
638                       activation::executeHardSwish, .allowZeroSizedInput = true);
639 
640 }  // namespace nn
641 }  // namespace android
642