1 /*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define LOG_TAG "Operations"
18
19 #include <tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h>
20 #include <tensorflow/lite/kernels/internal/optimized/optimized_ops.h>
21 #include <tensorflow/lite/kernels/internal/reference/integer_ops/logistic.h>
22 #include <tensorflow/lite/kernels/internal/reference/integer_ops/tanh.h>
23 #include <tensorflow/lite/kernels/internal/reference/reference_ops.h>
24
25 #include <algorithm>
26 #include <limits>
27 #include <vector>
28
29 #include "ActivationFunctor.h"
30 #include "CpuOperationUtils.h"
31 #include "HalInterfaces.h"
32 #include "OperationResolver.h"
33 #include "OperationsUtils.h"
34 #include "Tracing.h"
35
36 namespace android {
37 namespace nn {
38
39 using namespace hal;
40
41 namespace activation {
42
43 constexpr uint32_t kNumInputs = 1;
44 constexpr uint32_t kInputTensor = 0;
45
46 constexpr uint32_t kNumOutputs = 1;
47 constexpr uint32_t kOutputTensor = 0;
48
49 namespace {
50
51 template <typename T>
reluFloat(const T * inputData,const Shape & inputShape,T * outputData,const Shape & outputShape,float reluMin=0.f,float reluMax=std::numeric_limits<float>::max ())52 bool reluFloat(const T* inputData, const Shape& inputShape, T* outputData, const Shape& outputShape,
53 float reluMin = 0.f, float reluMax = std::numeric_limits<float>::max()) {
54 NNTRACE_COMP("reluX");
55 int numElements = getNumberOfElements(inputShape);
56 for (int i = 0; i < numElements; i++, inputData++, outputData++) {
57 *outputData = static_cast<T>(
58 std::min(std::max(reluMin, static_cast<float>(*inputData)), reluMax));
59 }
60 return true;
61 }
62 template bool reluFloat<float>(const float* inputData, const Shape& inputShape, float* outputData,
63 const Shape& outputShape, float reluMin, float reluMax);
64 template bool reluFloat<_Float16>(const _Float16* inputData, const Shape& inputShape,
65 _Float16* outputData, const Shape& outputShape, float reluMin,
66 float reluMax);
67
68 template <typename T>
relu1Float(const T * inputData,const Shape & inputShape,T * outputData,const Shape & outputShape)69 bool relu1Float(const T* inputData, const Shape& inputShape, T* outputData,
70 const Shape& outputShape) {
71 return reluFloat(inputData, inputShape, outputData, outputShape, -1.f, 1.f);
72 }
73 template bool relu1Float<float>(const float* inputData, const Shape& inputShape, float* outputData,
74 const Shape& outputShape);
75 template bool relu1Float<_Float16>(const _Float16* inputData, const Shape& inputShape,
76 _Float16* outputData, const Shape& outputShape);
77
78 template <typename T>
relu6Float(const T * inputData,const Shape & inputShape,T * outputData,const Shape & outputShape)79 bool relu6Float(const T* inputData, const Shape& inputShape, T* outputData,
80 const Shape& outputShape) {
81 return reluFloat(inputData, inputShape, outputData, outputShape, 0.f, 6.f);
82 }
83 template bool relu6Float<float>(const float* inputData, const Shape& inputShape, float* outputData,
84 const Shape& outputShape);
85 template bool relu6Float<_Float16>(const _Float16* inputData, const Shape& inputShape,
86 _Float16* outputData, const Shape& outputShape);
87
tanhFloat16(const _Float16 * inputData,const Shape & inputShape,_Float16 * outputData,const Shape & outputShape)88 bool tanhFloat16(const _Float16* inputData, const Shape& inputShape, _Float16* outputData,
89 const Shape& outputShape) {
90 NNTRACE_COMP("tanhFloat16");
91 int numElements = getNumberOfElements(inputShape);
92 for (int i = 0; i < numElements; i++, inputData++, outputData++) {
93 *outputData = static_cast<_Float16>(std::tanh(static_cast<float>(*inputData)));
94 }
95 return true;
96 }
97
tanhFloat32(const float * inputData,const Shape & inputShape,float * outputData,const Shape & outputShape)98 bool tanhFloat32(const float* inputData, const Shape& inputShape, float* outputData,
99 const Shape& outputShape) {
100 NNTRACE_COMP("tanhFloat32");
101 int numElements = getNumberOfElements(inputShape);
102 for (int i = 0; i < numElements; i++, inputData++, outputData++) {
103 *outputData = std::tanh(*inputData);
104 }
105 return true;
106 }
107
108 template <typename T>
logisticFloat(const T * inputData,const Shape & inputShape,T * outputData,const Shape & outputShape)109 bool logisticFloat(const T* inputData, const Shape& inputShape, T* outputData,
110 const Shape& outputShape) {
111 NNTRACE_COMP("logisticFloat");
112 int numElements = getNumberOfElements(inputShape);
113 for (int i = 0; i < numElements; i++, inputData++, outputData++) {
114 *outputData = static_cast<T>(1.f / (1.f + std::exp(static_cast<float>(-*inputData))));
115 }
116 return true;
117 }
118 template bool logisticFloat<float>(const float* inputData, const Shape& inputShape,
119 float* outputData, const Shape& outputShape);
120 template bool logisticFloat<_Float16>(const _Float16* inputData, const Shape& inputShape,
121 _Float16* outputData, const Shape& outputShape);
122
123 template <ActivationFn activation>
reluXQuant8(const uint8_t * inputData,const Shape & inputShape,uint8_t * outputData,const Shape & outputShape)124 inline bool reluXQuant8(const uint8_t* inputData, const Shape& inputShape, uint8_t* outputData,
125 const Shape& outputShape) {
126 int numElements = getNumberOfElements(inputShape);
127 int32_t output_activation_min = 0;
128 int32_t output_activation_max = 0;
129
130 CalculateActivationRangeUint8(activation, inputShape, &output_activation_min,
131 &output_activation_max);
132
133 for (int i = 0; i < numElements; i++, inputData++, outputData++) {
134 *outputData = std::min((uint8_t)output_activation_max,
135 std::max((uint8_t)output_activation_min, *inputData));
136 }
137 return true;
138 }
139
reluQuant8(const uint8_t * inputData,const Shape & inputShape,uint8_t * outputData,const Shape & outputShape)140 bool reluQuant8(const uint8_t* inputData, const Shape& inputShape, uint8_t* outputData,
141 const Shape& outputShape) {
142 NNTRACE_COMP("reluQuant8");
143 return reluXQuant8<kActivationRelu>(inputData, inputShape, outputData, outputShape);
144 }
145
relu1Quant8(const uint8_t * inputData,const Shape & inputShape,uint8_t * outputData,const Shape & outputShape)146 bool relu1Quant8(const uint8_t* inputData, const Shape& inputShape, uint8_t* outputData,
147 const Shape& outputShape) {
148 NNTRACE_COMP("relu1Quant8");
149 return reluXQuant8<kActivationRelu1>(inputData, inputShape, outputData, outputShape);
150 }
151
relu6Quant8(const uint8_t * inputData,const Shape & inputShape,uint8_t * outputData,const Shape & outputShape)152 bool relu6Quant8(const uint8_t* inputData, const Shape& inputShape, uint8_t* outputData,
153 const Shape& outputShape) {
154 NNTRACE_COMP("relu6Quant8");
155 return reluXQuant8<kActivationRelu6>(inputData, inputShape, outputData, outputShape);
156 }
157
tanhQuant8(const uint8_t * inputData,const Shape & inputShape,uint8_t * outputData,const Shape & outputShape)158 bool tanhQuant8(const uint8_t* inputData, const Shape& inputShape, uint8_t* outputData,
159 const Shape& outputShape) {
160 NNTRACE_TRANS("tanhQuant8");
161 if (outputShape.offset != 128 || outputShape.scale != 1.f / 128) {
162 LOG(ERROR) << "incorrect scale or offset for TANH output";
163 return false;
164 }
165
166 int numElements = getNumberOfElements(inputShape);
167 static constexpr int kInputIntegerBits = 4;
168
169 const double input_real_multiplier =
170 inputShape.scale * static_cast<double>(1 << (31 - kInputIntegerBits));
171
172 int32_t input_multiplier = 0;
173 int32_t input_left_shift = 0;
174 if (!QuantizeMultiplierGreaterThanOne(input_real_multiplier, &input_multiplier,
175 &input_left_shift)) {
176 return false;
177 }
178 int32_t input_range_radius = CalculateInputRadius(kInputIntegerBits, input_left_shift);
179
180 NNTRACE_COMP_SWITCH("optimized_ops::Tanh");
181 tflite::optimized_ops::Tanh(inputData, convertShapeToTflshape(inputShape), inputShape.offset,
182 input_range_radius, input_multiplier, input_left_shift, outputData,
183 convertShapeToTflshape(outputShape));
184
185 return true;
186 }
187
logisticQuant8(const uint8_t * inputData,const Shape & inputShape,uint8_t * outputData,const Shape & outputShape)188 bool logisticQuant8(const uint8_t* inputData, const Shape& inputShape, uint8_t* outputData,
189 const Shape& outputShape) {
190 NNTRACE_TRANS("logisticQuant8");
191 if (outputShape.offset != 0 || outputShape.scale != 1.f / 256) {
192 LOG(ERROR) << "incorrect scale / offset for output";
193 return false;
194 }
195
196 int numElements = getNumberOfElements(inputShape);
197 static constexpr int kInputIntegerBits = 4;
198
199 const double input_real_multiplier =
200 inputShape.scale * static_cast<double>(1 << (31 - kInputIntegerBits));
201
202 int32_t input_multiplier = 0;
203 int32_t input_left_shift = 0;
204 if (!QuantizeMultiplierGreaterThanOne(input_real_multiplier, &input_multiplier,
205 &input_left_shift)) {
206 return false;
207 }
208 int32_t input_range_radius = CalculateInputRadius(kInputIntegerBits, input_left_shift);
209
210 NNTRACE_COMP_SWITCH("optimized_ops::Logistic");
211 tflite::optimized_ops::Logistic(
212 inputData, convertShapeToTflshape(inputShape), inputShape.offset, input_range_radius,
213 input_multiplier, input_left_shift, outputData, convertShapeToTflshape(outputShape));
214
215 return true;
216 }
217
218 template <ActivationFn activation>
reluXQuant8Signed(const int8_t * inputData,const Shape & inputShape,int8_t * outputData,const Shape & outputShape)219 inline bool reluXQuant8Signed(const int8_t* inputData, const Shape& inputShape, int8_t* outputData,
220 const Shape& outputShape) {
221 int numElements = getNumberOfElements(inputShape);
222 int32_t output_activation_min = 0;
223 int32_t output_activation_max = 0;
224
225 CalculateActivationRangeInt8(activation, inputShape, &output_activation_min,
226 &output_activation_max);
227
228 for (int i = 0; i < numElements; i++, inputData++, outputData++) {
229 *outputData = std::min((int8_t)output_activation_max,
230 std::max((int8_t)output_activation_min, *inputData));
231 }
232 return true;
233 }
234
reluQuant8Signed(const int8_t * inputData,const Shape & inputShape,int8_t * outputData,const Shape & outputShape)235 bool reluQuant8Signed(const int8_t* inputData, const Shape& inputShape, int8_t* outputData,
236 const Shape& outputShape) {
237 NNTRACE_COMP("reluQuant8");
238 return reluXQuant8Signed<kActivationRelu>(inputData, inputShape, outputData, outputShape);
239 }
240
relu1Quant8Signed(const int8_t * inputData,const Shape & inputShape,int8_t * outputData,const Shape & outputShape)241 bool relu1Quant8Signed(const int8_t* inputData, const Shape& inputShape, int8_t* outputData,
242 const Shape& outputShape) {
243 NNTRACE_COMP("relu1Quant8");
244 return reluXQuant8Signed<kActivationRelu1>(inputData, inputShape, outputData, outputShape);
245 }
246
relu6Quant8Signed(const int8_t * inputData,const Shape & inputShape,int8_t * outputData,const Shape & outputShape)247 bool relu6Quant8Signed(const int8_t* inputData, const Shape& inputShape, int8_t* outputData,
248 const Shape& outputShape) {
249 NNTRACE_COMP("relu6Quant8");
250 return reluXQuant8Signed<kActivationRelu6>(inputData, inputShape, outputData, outputShape);
251 }
252
tanhQuant8Signed(const int8_t * inputData,const Shape & inputShape,int8_t * outputData,const Shape & outputShape)253 bool tanhQuant8Signed(const int8_t* inputData, const Shape& inputShape, int8_t* outputData,
254 const Shape& outputShape) {
255 NNTRACE_TRANS("tanhQuant8Signed");
256 if (outputShape.offset != 0 || outputShape.scale != 1.f / 128) {
257 LOG(ERROR) << "incorrect scale or offset for TANH output";
258 return false;
259 }
260
261 int numElements = getNumberOfElements(inputShape);
262 static constexpr int kInputIntegerBits = 4;
263
264 const double input_real_multiplier =
265 inputShape.scale * static_cast<double>(1 << (31 - kInputIntegerBits));
266
267 int32_t input_multiplier = 0;
268 int32_t input_left_shift = 0;
269 if (!QuantizeMultiplierGreaterThanOne(input_real_multiplier, &input_multiplier,
270 &input_left_shift)) {
271 return false;
272 }
273 int32_t input_range_radius = CalculateInputRadius(kInputIntegerBits, input_left_shift);
274
275 NNTRACE_COMP_SWITCH("reference_integer_ops::Tanh");
276 tflite::reference_integer_ops::Tanh(inputShape.offset, input_range_radius, input_multiplier,
277 input_left_shift, numElements, inputData, outputData);
278
279 return true;
280 }
281
logisticQuant8Signed(const int8_t * inputData,const Shape & inputShape,int8_t * outputData,const Shape & outputShape)282 bool logisticQuant8Signed(const int8_t* inputData, const Shape& inputShape, int8_t* outputData,
283 const Shape& outputShape) {
284 NNTRACE_TRANS("logisticQuant8Signed");
285 if (outputShape.offset != -128 || outputShape.scale != 1.f / 256) {
286 LOG(ERROR) << "incorrect scale / offset for output";
287 return false;
288 }
289
290 int numElements = getNumberOfElements(inputShape);
291 static constexpr int kInputIntegerBits = 4;
292
293 const double input_real_multiplier =
294 inputShape.scale * static_cast<double>(1 << (31 - kInputIntegerBits));
295
296 int32_t input_multiplier = 0;
297 int32_t input_left_shift = 0;
298 if (!QuantizeMultiplierGreaterThanOne(input_real_multiplier, &input_multiplier,
299 &input_left_shift)) {
300 return false;
301 }
302 int32_t input_range_radius = CalculateInputRadius(kInputIntegerBits, input_left_shift);
303
304 NNTRACE_COMP_SWITCH("reference_integer_ops::Logistic");
305 tflite::reference_integer_ops::Logistic(inputShape.offset, input_range_radius, input_multiplier,
306 input_left_shift, numElements, inputData, outputData);
307
308 return true;
309 }
310
DownScaleInt32ToInt16Multiplier(int32_t multiplier_int32,int16_t * multiplier_int16)311 void DownScaleInt32ToInt16Multiplier(int32_t multiplier_int32, int16_t* multiplier_int16) {
312 TFLITE_DCHECK_GE(multiplier_int32, 0);
313 static constexpr int32_t kRoundingOffset = 1 << 15;
314 if (multiplier_int32 >= std::numeric_limits<int32_t>::max() - kRoundingOffset) {
315 *multiplier_int16 = std::numeric_limits<int16_t>::max();
316 return;
317 }
318 const int32_t result = (multiplier_int32 + kRoundingOffset) >> 16;
319 TFLITE_DCHECK_LE(result << 16, multiplier_int32 + kRoundingOffset);
320 TFLITE_DCHECK_GT(result << 16, multiplier_int32 - kRoundingOffset);
321 *multiplier_int16 = result;
322 TFLITE_DCHECK_EQ(*multiplier_int16, result);
323 }
324
325 template <typename T>
hardSwishQuant(const T * inputData,const Shape & inputShape,T * outputData,const Shape & outputShape)326 bool hardSwishQuant(const T* inputData, const Shape& inputShape, T* outputData,
327 const Shape& outputShape) {
328 tflite::HardSwishParams params;
329 params.input_zero_point = inputShape.offset;
330 params.output_zero_point = outputShape.offset;
331 const float input_scale = inputShape.scale;
332 const float hires_input_scale = (1.0f / 128.0f) * input_scale;
333 const float reluish_scale = 3.0f / 32768.0f;
334 const float output_scale = outputShape.scale;
335
336 const float output_multiplier = hires_input_scale / output_scale;
337
338 int32_t output_multiplier_fixedpoint_int32;
339 NN_RET_CHECK(QuantizeMultiplier(output_multiplier, &output_multiplier_fixedpoint_int32,
340 ¶ms.output_multiplier_exponent));
341 DownScaleInt32ToInt16Multiplier(output_multiplier_fixedpoint_int32,
342 ¶ms.output_multiplier_fixedpoint_int16);
343 NN_RET_CHECK(params.output_multiplier_exponent <= 0);
344
345 const float reluish_multiplier = hires_input_scale / reluish_scale;
346 int32_t reluish_multiplier_fixedpoint_int32;
347 NN_RET_CHECK(QuantizeMultiplier(reluish_multiplier, &reluish_multiplier_fixedpoint_int32,
348 ¶ms.reluish_multiplier_exponent));
349 DownScaleInt32ToInt16Multiplier(reluish_multiplier_fixedpoint_int32,
350 ¶ms.reluish_multiplier_fixedpoint_int16);
351
352 tflite::reference_ops::HardSwish(params, convertShapeToTflshape(inputShape), inputData,
353 convertShapeToTflshape(outputShape), outputData);
354 return true;
355 }
356
357 } // namespace
358
validate(OperationType opType,const IOperationValidationContext * context)359 bool validate(OperationType opType, const IOperationValidationContext* context) {
360 NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
361 NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
362 auto inputType = context->getInputType(kInputTensor);
363 if (inputType == OperandType::TENSOR_FLOAT32) {
364 NN_RET_CHECK(validateHalVersion(context, HalVersion::V1_0));
365 } else if (inputType == OperandType::TENSOR_FLOAT16) {
366 NN_RET_CHECK(validateHalVersion(context, HalVersion::V1_2));
367 } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM) {
368 if (opType == OperationType::TANH) {
369 NN_RET_CHECK(validateHalVersion(context, HalVersion::V1_2));
370 } else {
371 NN_RET_CHECK(validateHalVersion(context, HalVersion::V1_0));
372 }
373 } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
374 NN_RET_CHECK(validateHalVersion(context, HalVersion::V1_3));
375 } else {
376 NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << getOperationName(opType);
377 }
378 const Shape& input = context->getInputShape(kInputTensor);
379 if (hasKnownRank(input)) {
380 NN_RET_CHECK_LE(getNumberOfDimensions(input), 4);
381 }
382 return validateInputTypes(context, {inputType}) && validateOutputTypes(context, {inputType});
383 }
384
validateHardSwish(const IOperationValidationContext * context)385 bool validateHardSwish(const IOperationValidationContext* context) {
386 NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
387 NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
388 auto inputType = context->getInputType(kInputTensor);
389 if (inputType == OperandType::TENSOR_FLOAT16 || inputType == OperandType::TENSOR_FLOAT32 ||
390 inputType == OperandType::TENSOR_QUANT8_ASYMM ||
391 inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
392 NN_RET_CHECK(validateHalVersion(context, HalVersion::V1_3));
393 } else {
394 NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation ELU";
395 }
396 return validateInputTypes(context, {inputType}) && validateOutputTypes(context, {inputType});
397 }
398
prepare(OperationType opType,IOperationExecutionContext * context)399 bool prepare(OperationType opType, IOperationExecutionContext* context) {
400 Shape input = context->getInputShape(kInputTensor);
401 if (opType != OperationType::HARD_SWISH) {
402 NN_RET_CHECK_LE(getNumberOfDimensions(input), 4);
403 }
404 Shape output = input;
405 if (input.type == OperandType::TENSOR_QUANT8_ASYMM ||
406 input.type == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
407 bool isSigned = input.type == OperandType::TENSOR_QUANT8_ASYMM_SIGNED;
408 switch (opType) {
409 case OperationType::HARD_SWISH: {
410 auto outputShape = context->getOutputShape(kOutputTensor);
411 output.scale = outputShape.scale;
412 output.offset = outputShape.offset;
413 } break;
414 case OperationType::RELU:
415 case OperationType::RELU1:
416 case OperationType::RELU6:
417 break;
418 case OperationType::LOGISTIC:
419 output.scale = 1.f / 256;
420 output.offset = isSigned ? -128 : 0;
421 break;
422 case OperationType::TANH:
423 output.scale = 1.f / 128;
424 output.offset = isSigned ? 0 : 128;
425 break;
426 default:
427 NN_RET_CHECK_FAIL() << "Unsupported operation type";
428 }
429 }
430 return context->setOutputShape(kOutputTensor, output);
431 }
432
executeRelu(IOperationExecutionContext * context)433 bool executeRelu(IOperationExecutionContext* context) {
434 // Bypass execution in the case of zero-sized input.
435 if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
436 switch (context->getInputType(kInputTensor)) {
437 case OperandType::TENSOR_FLOAT16:
438 return reluFloat(context->getInputBuffer<_Float16>(kInputTensor),
439 context->getInputShape(kInputTensor),
440 context->getOutputBuffer<_Float16>(kOutputTensor),
441 context->getOutputShape(kOutputTensor));
442 case OperandType::TENSOR_FLOAT32:
443 return reluFloat(context->getInputBuffer<float>(kInputTensor),
444 context->getInputShape(kInputTensor),
445 context->getOutputBuffer<float>(kOutputTensor),
446 context->getOutputShape(kOutputTensor));
447 case OperandType::TENSOR_QUANT8_ASYMM:
448 return reluQuant8(context->getInputBuffer<uint8_t>(kInputTensor),
449 context->getInputShape(kInputTensor),
450 context->getOutputBuffer<uint8_t>(kOutputTensor),
451 context->getOutputShape(kOutputTensor));
452 case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
453 return reluQuant8Signed(context->getInputBuffer<int8_t>(kInputTensor),
454 context->getInputShape(kInputTensor),
455 context->getOutputBuffer<int8_t>(kOutputTensor),
456 context->getOutputShape(kOutputTensor));
457 default:
458 NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation RELU";
459 }
460 }
461
executeRelu1(IOperationExecutionContext * context)462 bool executeRelu1(IOperationExecutionContext* context) {
463 // Bypass execution in the case of zero-sized input.
464 if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
465 switch (context->getInputType(kInputTensor)) {
466 case OperandType::TENSOR_FLOAT16:
467 return relu1Float(context->getInputBuffer<_Float16>(kInputTensor),
468 context->getInputShape(kInputTensor),
469 context->getOutputBuffer<_Float16>(kOutputTensor),
470 context->getOutputShape(kOutputTensor));
471 case OperandType::TENSOR_FLOAT32:
472 return relu1Float(context->getInputBuffer<float>(kInputTensor),
473 context->getInputShape(kInputTensor),
474 context->getOutputBuffer<float>(kOutputTensor),
475 context->getOutputShape(kOutputTensor));
476 case OperandType::TENSOR_QUANT8_ASYMM:
477 return relu1Quant8(context->getInputBuffer<uint8_t>(kInputTensor),
478 context->getInputShape(kInputTensor),
479 context->getOutputBuffer<uint8_t>(kOutputTensor),
480 context->getOutputShape(kOutputTensor));
481 case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
482 return relu1Quant8Signed(context->getInputBuffer<int8_t>(kInputTensor),
483 context->getInputShape(kInputTensor),
484 context->getOutputBuffer<int8_t>(kOutputTensor),
485 context->getOutputShape(kOutputTensor));
486 default:
487 NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation RELU1";
488 }
489 }
490
executeRelu6(IOperationExecutionContext * context)491 bool executeRelu6(IOperationExecutionContext* context) {
492 // Bypass execution in the case of zero-sized input.
493 if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
494 switch (context->getInputType(kInputTensor)) {
495 case OperandType::TENSOR_FLOAT16:
496 return relu6Float(context->getInputBuffer<_Float16>(kInputTensor),
497 context->getInputShape(kInputTensor),
498 context->getOutputBuffer<_Float16>(kOutputTensor),
499 context->getOutputShape(kOutputTensor));
500 case OperandType::TENSOR_FLOAT32:
501 return relu6Float(context->getInputBuffer<float>(kInputTensor),
502 context->getInputShape(kInputTensor),
503 context->getOutputBuffer<float>(kOutputTensor),
504 context->getOutputShape(kOutputTensor));
505 case OperandType::TENSOR_QUANT8_ASYMM:
506 return relu6Quant8(context->getInputBuffer<uint8_t>(kInputTensor),
507 context->getInputShape(kInputTensor),
508 context->getOutputBuffer<uint8_t>(kOutputTensor),
509 context->getOutputShape(kOutputTensor));
510 case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
511 return relu6Quant8Signed(context->getInputBuffer<int8_t>(kInputTensor),
512 context->getInputShape(kInputTensor),
513 context->getOutputBuffer<int8_t>(kOutputTensor),
514 context->getOutputShape(kOutputTensor));
515 default:
516 NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation RELU6";
517 }
518 }
519
executeLogistic(IOperationExecutionContext * context)520 bool executeLogistic(IOperationExecutionContext* context) {
521 // Bypass execution in the case of zero-sized input.
522 if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
523 switch (context->getInputType(kInputTensor)) {
524 case OperandType::TENSOR_FLOAT16:
525 return logisticFloat(context->getInputBuffer<_Float16>(kInputTensor),
526 context->getInputShape(kInputTensor),
527 context->getOutputBuffer<_Float16>(kOutputTensor),
528 context->getOutputShape(kOutputTensor));
529 case OperandType::TENSOR_FLOAT32:
530 return logisticFloat(context->getInputBuffer<float>(kInputTensor),
531 context->getInputShape(kInputTensor),
532 context->getOutputBuffer<float>(kOutputTensor),
533 context->getOutputShape(kOutputTensor));
534 case OperandType::TENSOR_QUANT8_ASYMM:
535 return logisticQuant8(context->getInputBuffer<uint8_t>(kInputTensor),
536 context->getInputShape(kInputTensor),
537 context->getOutputBuffer<uint8_t>(kOutputTensor),
538 context->getOutputShape(kOutputTensor));
539 case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
540 return logisticQuant8Signed(context->getInputBuffer<int8_t>(kInputTensor),
541 context->getInputShape(kInputTensor),
542 context->getOutputBuffer<int8_t>(kOutputTensor),
543 context->getOutputShape(kOutputTensor));
544 default:
545 NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation LOGISTIC";
546 }
547 }
548
executeTanh(IOperationExecutionContext * context)549 bool executeTanh(IOperationExecutionContext* context) {
550 // Bypass execution in the case of zero-sized input.
551 if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
552 switch (context->getInputType(kInputTensor)) {
553 case OperandType::TENSOR_FLOAT16:
554 return tanhFloat16(context->getInputBuffer<_Float16>(kInputTensor),
555 context->getInputShape(kInputTensor),
556 context->getOutputBuffer<_Float16>(kOutputTensor),
557 context->getOutputShape(kOutputTensor));
558 case OperandType::TENSOR_FLOAT32:
559 return tanhFloat32(context->getInputBuffer<float>(kInputTensor),
560 context->getInputShape(kInputTensor),
561 context->getOutputBuffer<float>(kOutputTensor),
562 context->getOutputShape(kOutputTensor));
563 case OperandType::TENSOR_QUANT8_ASYMM:
564 return tanhQuant8(context->getInputBuffer<uint8_t>(kInputTensor),
565 context->getInputShape(kInputTensor),
566 context->getOutputBuffer<uint8_t>(kOutputTensor),
567 context->getOutputShape(kOutputTensor));
568 case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
569 return tanhQuant8Signed(context->getInputBuffer<int8_t>(kInputTensor),
570 context->getInputShape(kInputTensor),
571 context->getOutputBuffer<int8_t>(kOutputTensor),
572 context->getOutputShape(kOutputTensor));
573 default:
574 NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation TANH";
575 }
576 }
577
executeHardSwish(IOperationExecutionContext * context)578 bool executeHardSwish(IOperationExecutionContext* context) {
579 // Bypass execution in the case of zero-sized input.
580 if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
581 switch (context->getInputType(kInputTensor)) {
582 case OperandType::TENSOR_FLOAT16: {
583 const Shape& inputShape = context->getInputShape(kInputTensor);
584 const Shape& outputShape = context->getOutputShape(kOutputTensor);
585 std::vector<float> inputFloat(getNumberOfElements(inputShape));
586 std::vector<float> outputFloat(getNumberOfElements(outputShape));
587 convertFloat16ToFloat32(context->getInputBuffer<_Float16>(kInputTensor), &inputFloat);
588 tflite::reference_ops::HardSwish(convertShapeToTflshape(inputShape), inputFloat.data(),
589 convertShapeToTflshape(outputShape),
590 outputFloat.data());
591 convertFloat32ToFloat16(outputFloat, context->getOutputBuffer<_Float16>(kOutputTensor));
592 return true;
593 }
594 case OperandType::TENSOR_FLOAT32: {
595 tflite::reference_ops::HardSwish(
596 convertShapeToTflshape(context->getInputShape(kInputTensor)),
597 context->getInputBuffer<float>(kInputTensor),
598 convertShapeToTflshape(context->getOutputShape(kOutputTensor)),
599 context->getOutputBuffer<float>(kOutputTensor));
600 return true;
601 }
602 case OperandType::TENSOR_QUANT8_ASYMM:
603 return hardSwishQuant(context->getInputBuffer<uint8_t>(kInputTensor),
604 context->getInputShape(kInputTensor),
605 context->getOutputBuffer<uint8_t>(kOutputTensor),
606 context->getOutputShape(kOutputTensor));
607 case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
608 return hardSwishQuant(context->getInputBuffer<int8_t>(kInputTensor),
609 context->getInputShape(kInputTensor),
610 context->getOutputBuffer<int8_t>(kOutputTensor),
611 context->getOutputShape(kOutputTensor));
612 default:
613 NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation TANH";
614 }
615 }
616
617 } // namespace activation
618
619 using std::placeholders::_1;
620 NN_REGISTER_OPERATION(RELU, "RELU", std::bind(activation::validate, OperationType::RELU, _1),
621 std::bind(activation::prepare, OperationType::RELU, _1),
622 activation::executeRelu, .allowZeroSizedInput = true);
623 NN_REGISTER_OPERATION(RELU1, "RELU1", std::bind(activation::validate, OperationType::RELU1, _1),
624 std::bind(activation::prepare, OperationType::RELU1, _1),
625 activation::executeRelu1, .allowZeroSizedInput = true);
626 NN_REGISTER_OPERATION(RELU6, "RELU6", std::bind(activation::validate, OperationType::RELU6, _1),
627 std::bind(activation::prepare, OperationType::RELU6, _1),
628 activation::executeRelu6, .allowZeroSizedInput = true);
629 NN_REGISTER_OPERATION(LOGISTIC, "LOGISTIC",
630 std::bind(activation::validate, OperationType::LOGISTIC, _1),
631 std::bind(activation::prepare, OperationType::LOGISTIC, _1),
632 activation::executeLogistic, .allowZeroSizedInput = true);
633 NN_REGISTER_OPERATION(TANH, "TANH", std::bind(activation::validate, OperationType::TANH, _1),
634 std::bind(activation::prepare, OperationType::TANH, _1),
635 activation::executeTanh, .allowZeroSizedInput = true);
636 NN_REGISTER_OPERATION(HARD_SWISH, "HARD_SWISH", activation::validateHardSwish,
637 std::bind(activation::prepare, OperationType::HARD_SWISH, _1),
638 activation::executeHardSwish, .allowZeroSizedInput = true);
639
640 } // namespace nn
641 } // namespace android
642