1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 // Contains the implementation of the operations.
18 
19 #define LOG_TAG "Operations"
20 
21 #include <tensorflow/lite/kernels/internal/optimized/integer_ops/add.h>
22 #include <tensorflow/lite/kernels/internal/optimized/integer_ops/mul.h>
23 #include <tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h>
24 #include <tensorflow/lite/kernels/internal/reference/integer_ops/add.h>
25 #include <tensorflow/lite/kernels/internal/reference/integer_ops/mul.h>
26 #include <tensorflow/lite/kernels/internal/types.h>
27 
28 #include <algorithm>
29 #include <vector>
30 
31 #include "CpuOperationUtils.h"
32 #include "HalInterfaces.h"
33 #include "IndexedShapeWrapper.h"
34 #include "OperationResolver.h"
35 #include "Tracing.h"
36 
37 namespace android {
38 namespace nn {
39 
40 using namespace hal;
41 
42 namespace broadcast {
43 
44 constexpr uint32_t kNumInputs = 3;
45 constexpr uint32_t kInputTensor1 = 0;
46 constexpr uint32_t kInputTensor2 = 1;
47 constexpr uint32_t kActivationScalar = 2;
48 
49 constexpr uint32_t kNumOutputs = 1;
50 constexpr uint32_t kOutputTensor = 0;
51 
52 namespace {
53 
54 #define ANDROID_NN_MACRO_DISPATCH(macro)                                \
55     switch (activation) {                                               \
56         case (int32_t)FusedActivationFunc::NONE:                        \
57             macro(kNone);                                               \
58             break;                                                      \
59         case (int32_t)FusedActivationFunc::RELU:                        \
60             macro(kRelu);                                               \
61             break;                                                      \
62         case (int32_t)FusedActivationFunc::RELU1:                       \
63             macro(kRelu1);                                              \
64             break;                                                      \
65         case (int32_t)FusedActivationFunc::RELU6:                       \
66             macro(kRelu6);                                              \
67             break;                                                      \
68         default:                                                        \
69             LOG(ERROR) << "Unsupported fused activation function type"; \
70             return false;                                               \
71     }
72 
73 using binaryFunctionFloat32 = std::function<bool(
74         const float* in1, const Shape& shape1, const float* in2, const Shape& shape2,
75         int32_t activation, float* out, const Shape& shapeOut)>;
76 
binaryOperationFloat16(const _Float16 * in1,const Shape & shape1,const _Float16 * in2,const Shape & shape2,int32_t activation,_Float16 * out,const Shape & shapeOut,binaryFunctionFloat32 operationFloat32)77 bool binaryOperationFloat16(const _Float16* in1, const Shape& shape1, const _Float16* in2,
78                             const Shape& shape2, int32_t activation, _Float16* out,
79                             const Shape& shapeOut, binaryFunctionFloat32 operationFloat32) {
80     std::vector<float> in1_float32(getNumberOfElements(shape1));
81     convertFloat16ToFloat32(in1, &in1_float32);
82     std::vector<float> in2_float32(getNumberOfElements(shape2));
83     convertFloat16ToFloat32(in2, &in2_float32);
84     std::vector<float> out_float32(getNumberOfElements(shapeOut));
85 
86     operationFloat32(in1_float32.data(), shape1, in2_float32.data(), shape2, activation,
87                      out_float32.data(), shapeOut);
88     convertFloat32ToFloat16(out_float32, out);
89 
90     return true;
91 }
92 
addFloat32(const float * in1,const Shape & shape1,const float * in2,const Shape & shape2,int32_t activation,float * out,const Shape & shapeOut)93 bool addFloat32(const float* in1, const Shape& shape1, const float* in2, const Shape& shape2,
94                 int32_t activation, float* out, const Shape& shapeOut) {
95     NNTRACE_TRANS("addFloat32");
96     bool needBroadcast = !SameShape(shape1, shape2);
97     if (needBroadcast) {
98         NNTRACE_COMP_SWITCH("optimized_ops::BroadcastAdd");
99 #define ANDROID_NN_BROADCAST_ADD(activation)                                              \
100     tflite::optimized_ops::BroadcastAdd<tflite::FusedActivationFunctionType::activation>( \
101             in1, convertShapeToDims(shape1), in2, convertShapeToDims(shape2), out,        \
102             convertShapeToDims(shapeOut))
103 
104         ANDROID_NN_MACRO_DISPATCH(ANDROID_NN_BROADCAST_ADD)
105 #undef ANDROID_NN_BROADCAST_ADD
106     } else {
107         NNTRACE_COMP_SWITCH("optimized_ops::Add");
108 #define ANDROID_NN_ADD(activation)                                                 \
109     tflite::optimized_ops::Add<tflite::FusedActivationFunctionType::activation>(   \
110             in1, convertShapeToDims(shape1), in2, convertShapeToDims(shape2), out, \
111             convertShapeToDims(shapeOut))
112 
113         ANDROID_NN_MACRO_DISPATCH(ANDROID_NN_ADD)
114 #undef ANDROID_NN_ADD
115     }
116 
117     return true;
118 }
119 
addFloat16(const _Float16 * in1,const Shape & shape1,const _Float16 * in2,const Shape & shape2,int32_t activation,_Float16 * out,const Shape & shapeOut)120 bool addFloat16(const _Float16* in1, const Shape& shape1, const _Float16* in2, const Shape& shape2,
121                 int32_t activation, _Float16* out, const Shape& shapeOut) {
122     NNTRACE_TRANS("addFloat16");
123     return binaryOperationFloat16(in1, shape1, in2, shape2, activation, out, shapeOut, &addFloat32);
124 }
125 
126 template <typename T>
addQuant8(const T * in1,const Shape & shape1,const T * in2,const Shape & shape2,int32_t activation,T * out,const Shape & shapeOut)127 bool addQuant8(const T* in1, const Shape& shape1, const T* in2, const Shape& shape2,
128                int32_t activation, T* out, const Shape& shapeOut) {
129     NNTRACE_TRANS("addQuant8");
130     const bool needBroadcast = !SameShape(shape1, shape2);
131 
132     const int32_t input1_offset = -shape1.offset;
133     const int32_t input2_offset = -shape2.offset;
134     const int32_t output_offset = shapeOut.offset;
135     const int left_shift = 20;
136     const double twice_max_input_scale = 2 * std::max(shape1.scale, shape2.scale);
137     const double real_input1_multiplier = shape1.scale / twice_max_input_scale;
138     const double real_input2_multiplier = shape2.scale / twice_max_input_scale;
139     const double real_output_multiplier =
140             twice_max_input_scale / ((1 << left_shift) * shapeOut.scale);
141 
142     int32_t input1_multiplier;
143     int32_t input1_shift;
144     NN_RET_CHECK(QuantizeMultiplierSmallerThanOneExp(real_input1_multiplier, &input1_multiplier,
145                                                      &input1_shift));
146     int32_t input2_multiplier;
147     int32_t input2_shift;
148     NN_RET_CHECK(QuantizeMultiplierSmallerThanOneExp(real_input2_multiplier, &input2_multiplier,
149                                                      &input2_shift));
150     int32_t output_multiplier;
151     int32_t output_shift;
152     NN_RET_CHECK(QuantizeMultiplierSmallerThanOneExp(real_output_multiplier, &output_multiplier,
153                                                      &output_shift));
154 
155     int32_t output_activation_min;
156     int32_t output_activation_max;
157     constexpr bool isSignedOp = std::is_same<T, int8_t>::value;
158     if constexpr (isSignedOp) {
159         CalculateActivationRangeInt8(activation, shapeOut, &output_activation_min,
160                                      &output_activation_max);
161     } else {
162         CalculateActivationRangeUint8(activation, shapeOut, &output_activation_min,
163                                       &output_activation_max);
164     }
165 
166     tflite::ArithmeticParams op_params;
167     op_params.left_shift = left_shift;
168     op_params.input1_offset = input1_offset;
169     op_params.input1_multiplier = input1_multiplier;
170     op_params.input1_shift = input1_shift;
171     op_params.input2_offset = input2_offset;
172     op_params.input2_multiplier = input2_multiplier;
173     op_params.input2_shift = input2_shift;
174     op_params.output_offset = output_offset;
175     op_params.output_multiplier = output_multiplier;
176     op_params.output_shift = output_shift;
177     tflite::SetActivationParams(output_activation_min, output_activation_max, &op_params);
178 
179     if (needBroadcast) {
180         if constexpr (isSignedOp) {
181             NNTRACE_COMP_SWITCH("reference_integer_ops::BroadcastAdd4DSlow");
182             tflite::reference_integer_ops::BroadcastAdd4DSlow(
183                     op_params, convertShapeToTflshape(shape1), in1, convertShapeToTflshape(shape2),
184                     in2, convertShapeToTflshape(shapeOut), out);
185         } else {
186             NNTRACE_COMP_SWITCH("reference_ops::BroadcastAdd4DSlow");
187             tflite::reference_ops::BroadcastAdd4DSlow(op_params, convertShapeToTflshape(shape1),
188                                                       in1, convertShapeToTflshape(shape2), in2,
189                                                       convertShapeToTflshape(shapeOut), out);
190         }
191     } else {
192         if constexpr (isSignedOp) {
193             NNTRACE_COMP_SWITCH("optimized_integer_ops::Add");
194             tflite::optimized_integer_ops::Add(op_params, convertShapeToTflshape(shape1), in1,
195                                                convertShapeToTflshape(shape2), in2,
196                                                convertShapeToTflshape(shapeOut), out);
197         } else {
198             NNTRACE_COMP_SWITCH("optimized_ops::Add");
199             tflite::optimized_ops::Add(op_params, convertShapeToTflshape(shape1), in1,
200                                        convertShapeToTflshape(shape2), in2,
201                                        convertShapeToTflshape(shapeOut), out);
202         }
203     }
204 
205     return true;
206 }
207 
executeInt32(const int32_t * aData,const Shape & aShape,const int32_t * bData,const Shape & bShape,int32_t activation,int32_t * outputData,const Shape & outputShape,int32_t func (int32_t,int32_t))208 bool executeInt32(const int32_t* aData, const Shape& aShape, const int32_t* bData,
209                   const Shape& bShape, int32_t activation, int32_t* outputData,
210                   const Shape& outputShape, int32_t func(int32_t, int32_t)) {
211     NN_RET_CHECK_EQ(activation, ANEURALNETWORKS_FUSED_NONE);
212     IndexedShapeWrapper aShapeIndexed(aShape);
213     IndexedShapeWrapper bShapeIndexed(bShape);
214     IndexedShapeWrapper outputShapeIndexed(outputShape);
215     std::vector<uint32_t> curIndex(outputShape.dimensions.size(), 0);
216     bool lastIndex = false;
217     do {
218         uint32_t outputFlatIndex;
219         NN_RET_CHECK(outputShapeIndexed.indexToFlatIndex(curIndex, &outputFlatIndex));
220         uint32_t aFlatIndex;
221         NN_RET_CHECK(aShapeIndexed.broadcastedIndexToFlatIndex(curIndex, &aFlatIndex));
222         uint32_t bFlatIndex;
223         NN_RET_CHECK(bShapeIndexed.broadcastedIndexToFlatIndex(curIndex, &bFlatIndex));
224 
225         outputData[outputFlatIndex] = func(aData[aFlatIndex], bData[bFlatIndex]);
226 
227         NN_RET_CHECK(outputShapeIndexed.nextIndexInplace(&curIndex, &lastIndex));
228     } while (!lastIndex);
229     return true;
230 }
231 
mulFloat32(const float * in1,const Shape & shape1,const float * in2,const Shape & shape2,int32_t activation,float * out,const Shape & shapeOut)232 bool mulFloat32(const float* in1, const Shape& shape1, const float* in2, const Shape& shape2,
233                 int32_t activation, float* out, const Shape& shapeOut) {
234     NNTRACE_TRANS("mulFloat32");
235     bool needBroadcast = !SameShape(shape1, shape2);
236 
237     if (needBroadcast) {
238         NNTRACE_COMP_SWITCH("optimized_ops::BroadcastMul");
239 #define ANDROID_NN_BROADCAST_MUL(activation)                                              \
240     tflite::optimized_ops::BroadcastMul<tflite::FusedActivationFunctionType::activation>( \
241             in1, convertShapeToDims(shape1), in2, convertShapeToDims(shape2), out,        \
242             convertShapeToDims(shapeOut))
243 
244         ANDROID_NN_MACRO_DISPATCH(ANDROID_NN_BROADCAST_MUL)
245 #undef ANDROID_NN_BROADCAST_MUL
246     } else {
247         float output_activation_min, output_activation_max;
248         CalculateActivationRangeFloat(activation, &output_activation_min, &output_activation_max);
249 
250         NNTRACE_COMP_SWITCH("optimized_ops::Mul");
251         tflite::optimized_ops::Mul(in1, convertShapeToDims(shape1), in2, convertShapeToDims(shape2),
252                                    output_activation_min, output_activation_max, out,
253                                    convertShapeToDims(shapeOut));
254     }
255 
256     return true;
257 }
258 
mulFloat16(const _Float16 * in1,const Shape & shape1,const _Float16 * in2,const Shape & shape2,int32_t activation,_Float16 * out,const Shape & shapeOut)259 bool mulFloat16(const _Float16* in1, const Shape& shape1, const _Float16* in2, const Shape& shape2,
260                 int32_t activation, _Float16* out, const Shape& shapeOut) {
261     NNTRACE_TRANS("mulFloat16");
262     return binaryOperationFloat16(in1, shape1, in2, shape2, activation, out, shapeOut, &mulFloat32);
263 }
264 
265 template <typename T>
mulQuant8(const T * in1,const Shape & shape1,const T * in2,const Shape & shape2,int32_t activation,T * out,const Shape & shapeOut)266 bool mulQuant8(const T* in1, const Shape& shape1, const T* in2, const Shape& shape2,
267                int32_t activation, T* out, const Shape& shapeOut) {
268     NNTRACE_TRANS("mulQuant8");
269     const int32_t input1_offset = -shape1.offset;
270     const int32_t input2_offset = -shape2.offset;
271     const int32_t output_offset = shapeOut.offset;
272     const double input_product_scale = shape1.scale * shape2.scale;
273     const double real_multiplier = input_product_scale / shapeOut.scale;
274     int32 output_multiplier;
275     int output_shift;
276     NN_RET_CHECK(QuantizeMultiplierSmallerThanOneExp(real_multiplier, &output_multiplier,
277                                                      &output_shift));
278 
279     constexpr bool isSignedOp = std::is_same<T, int8_t>::value;
280     int32_t output_activation_min;
281     int32_t output_activation_max;
282     if constexpr (isSignedOp) {
283         CalculateActivationRangeInt8(activation, shapeOut, &output_activation_min,
284                                      &output_activation_max);
285     } else {
286         CalculateActivationRangeUint8(activation, shapeOut, &output_activation_min,
287                                       &output_activation_max);
288     }
289 
290     tflite::ArithmeticParams op_params;
291     op_params.input1_offset = input1_offset;
292     op_params.input2_offset = input2_offset;
293     op_params.output_offset = output_offset;
294     op_params.output_multiplier = output_multiplier;
295     op_params.output_shift = output_shift;
296     tflite::SetActivationParams(output_activation_min, output_activation_max, &op_params);
297 
298     if constexpr (isSignedOp) {
299         NNTRACE_COMP_SWITCH("reference_integer_ops::BroadcastMul4DSlow");
300         tflite::reference_integer_ops::BroadcastMul4DSlow(op_params, convertShapeToTflshape(shape1),
301                                                           in1, convertShapeToTflshape(shape2), in2,
302                                                           convertShapeToTflshape(shapeOut), out);
303     } else {
304         NNTRACE_COMP_SWITCH("reference_ops::BroadcastMul4DSlow");
305         tflite::reference_ops::BroadcastMul4DSlow(op_params, convertShapeToTflshape(shape1), in1,
306                                                   convertShapeToTflshape(shape2), in2,
307                                                   convertShapeToTflshape(shapeOut), out);
308     }
309 
310     return true;
311 }
312 
subFloat32(const float * in1,const Shape & shape1,const float * in2,const Shape & shape2,int32_t activation,float * out,const Shape & shapeOut)313 bool subFloat32(const float* in1, const Shape& shape1, const float* in2, const Shape& shape2,
314                 int32_t activation, float* out, const Shape& shapeOut) {
315     NNTRACE_TRANS("subFloat32");
316     NNTRACE_COMP_SWITCH("optimized_ops::Sub");
317     tflite::optimized_ops::Sub(in1, convertShapeToDims(shape1), in2, convertShapeToDims(shape2),
318                                out, convertShapeToDims(shapeOut));
319 
320     // TFLite does not apply activation to broadcast sub.
321     float output_activation_min, output_activation_max;
322     CalculateActivationRangeFloat(activation, &output_activation_min, &output_activation_max);
323     uint32_t numOutputElements = getNumberOfElements(shapeOut);
324     for (uint32_t i = 0; i < numOutputElements; i++) {
325         out[i] = std::min(std::max(out[i], output_activation_min), output_activation_max);
326     }
327     return true;
328 }
329 
subFloat16(const _Float16 * in1,const Shape & shape1,const _Float16 * in2,const Shape & shape2,int32_t activation,_Float16 * out,const Shape & shapeOut)330 bool subFloat16(const _Float16* in1, const Shape& shape1, const _Float16* in2, const Shape& shape2,
331                 int32_t activation, _Float16* out, const Shape& shapeOut) {
332     NNTRACE_TRANS("subFloat16");
333     return binaryOperationFloat16(in1, shape1, in2, shape2, activation, out, shapeOut, &subFloat32);
334 }
335 
336 template <typename T>
subQuant8(const T * in1,const Shape & shape1,const T * in2,const Shape & shape2,int32_t activation,T * out,const Shape & shapeOut)337 bool subQuant8(const T* in1, const Shape& shape1, const T* in2, const Shape& shape2,
338                int32_t activation, T* out, const Shape& shapeOut) {
339     NNTRACE_TRANS("subQuant8");
340 
341     const int32_t input1_offset = -shape1.offset;
342     const int32_t input2_offset = -shape2.offset;
343     const int32_t output_offset = shapeOut.offset;
344     const int left_shift = 20;
345     const double twice_max_input_scale = 2 * std::max(shape1.scale, shape2.scale);
346     const double real_input1_multiplier = shape1.scale / twice_max_input_scale;
347     const double real_input2_multiplier = shape2.scale / twice_max_input_scale;
348     const double real_output_multiplier =
349             twice_max_input_scale / ((1 << left_shift) * shapeOut.scale);
350 
351     int32_t input1_multiplier;
352     int32_t input1_shift;
353     NN_RET_CHECK(QuantizeMultiplierSmallerThanOneExp(real_input1_multiplier, &input1_multiplier,
354                                                      &input1_shift));
355     int32_t input2_multiplier;
356     int32_t input2_shift;
357     NN_RET_CHECK(QuantizeMultiplierSmallerThanOneExp(real_input2_multiplier, &input2_multiplier,
358                                                      &input2_shift));
359     // Negate multiplier of the second input, so that we can use Add kernels.
360     input2_multiplier *= -1;
361 
362     int32_t output_multiplier;
363     int32_t output_shift;
364     NN_RET_CHECK(QuantizeMultiplierSmallerThanOneExp(real_output_multiplier, &output_multiplier,
365                                                      &output_shift));
366 
367     constexpr bool isSignedOp = std::is_same<T, int8_t>::value;
368     int32_t output_activation_min;
369     int32_t output_activation_max;
370     if constexpr (isSignedOp) {
371         CalculateActivationRangeInt8(activation, shapeOut, &output_activation_min,
372                                      &output_activation_max);
373     } else {
374         CalculateActivationRangeUint8(activation, shapeOut, &output_activation_min,
375                                       &output_activation_max);
376     }
377 
378     tflite::ArithmeticParams op_params;
379     op_params.left_shift = left_shift;
380     op_params.input1_offset = input1_offset;
381     op_params.input1_multiplier = input1_multiplier;
382     op_params.input1_shift = input1_shift;
383     op_params.input2_offset = input2_offset;
384     op_params.input2_multiplier = input2_multiplier;
385     op_params.input2_shift = input2_shift;
386     op_params.output_offset = output_offset;
387     op_params.output_multiplier = output_multiplier;
388     op_params.output_shift = output_shift;
389     tflite::SetActivationParams(output_activation_min, output_activation_max, &op_params);
390 
391     // We are using tflite::optimized_ops::BroadcastAdd unconditionally here
392     // because tflite::optimized_ops::Add fails to pass some of the
393     // sub_quantized_different_scales tests.
394     if constexpr (isSignedOp) {
395         NNTRACE_COMP_SWITCH("reference_integer_ops::BroadcastAdd4DSlow");
396         tflite::reference_integer_ops::BroadcastAdd4DSlow(op_params, convertShapeToTflshape(shape1),
397                                                           in1, convertShapeToTflshape(shape2), in2,
398                                                           convertShapeToTflshape(shapeOut), out);
399     } else {
400         NNTRACE_COMP_SWITCH("reference_ops::BroadcastAdd4DSlow");
401         tflite::reference_ops::BroadcastAdd4DSlow(op_params, convertShapeToTflshape(shape1), in1,
402                                                   convertShapeToTflshape(shape2), in2,
403                                                   convertShapeToTflshape(shapeOut), out);
404     }
405 
406     return true;
407 }
408 
divFloat32(const float * in1,const Shape & shape1,const float * in2,const Shape & shape2,int32_t activation,float * out,const Shape & shapeOut)409 bool divFloat32(const float* in1, const Shape& shape1, const float* in2, const Shape& shape2,
410                 int32_t activation, float* out, const Shape& shapeOut) {
411     NNTRACE_TRANS("divFloat32");
412     float output_activation_min, output_activation_max;
413     CalculateActivationRangeFloat(activation, &output_activation_min, &output_activation_max);
414 
415     bool needBroadcast = !SameShape(shape1, shape2);
416     if (needBroadcast) {
417         NNTRACE_COMP_SWITCH("optimized_ops::BroadcastDiv");
418         tflite::optimized_ops::BroadcastDiv(
419                 in1, convertShapeToDims(shape1), in2, convertShapeToDims(shape2),
420                 output_activation_min, output_activation_max, out, convertShapeToDims(shapeOut));
421     } else {
422         NNTRACE_COMP_SWITCH("optimized_ops::Div");
423         tflite::optimized_ops::Div(in1, convertShapeToDims(shape1), in2, convertShapeToDims(shape2),
424                                    output_activation_min, output_activation_max, out,
425                                    convertShapeToDims(shapeOut));
426     }
427     return true;
428 }
429 
divFloat16(const _Float16 * in1,const Shape & shape1,const _Float16 * in2,const Shape & shape2,int32_t activation,_Float16 * out,const Shape & shapeOut)430 bool divFloat16(const _Float16* in1, const Shape& shape1, const _Float16* in2, const Shape& shape2,
431                 int32_t activation, _Float16* out, const Shape& shapeOut) {
432     NNTRACE_TRANS("divFloat16");
433     return binaryOperationFloat16(in1, shape1, in2, shape2, activation, out, shapeOut, &divFloat32);
434 }
435 
436 }  // namespace
437 
validate(OperationType opType,const IOperationValidationContext * context)438 bool validate(OperationType opType, const IOperationValidationContext* context) {
439     const HalVersion opIntroducedAt = (opType == OperationType::DIV || opType == OperationType::SUB)
440                                               ? HalVersion::V1_1
441                                               : HalVersion::V1_0;
442     NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
443     NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
444     auto inputType = context->getInputType(kInputTensor1);
445     if (inputType == OperandType::TENSOR_FLOAT32) {
446         NN_RET_CHECK(validateHalVersion(context, std::max(HalVersion::V1_0, opIntroducedAt)));
447     } else if (inputType == OperandType::TENSOR_FLOAT16) {
448         NN_RET_CHECK(validateHalVersion(context, std::max(HalVersion::V1_2, opIntroducedAt)));
449     } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM) {
450         if (opType == OperationType::SUB) {
451             NN_RET_CHECK(validateHalVersion(context, std::max(HalVersion::V1_2, opIntroducedAt)));
452         } else if (opType == OperationType::DIV) {
453             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation DIV";
454         } else if (opType == OperationType::MUL) {
455             Shape output = context->getOutputShape(kOutputTensor);
456             Shape input1 = context->getInputShape(kInputTensor1);
457             Shape input2 = context->getInputShape(kInputTensor2);
458             NN_RET_CHECK_GT(output.scale, input1.scale * input2.scale);
459             NN_RET_CHECK(validateHalVersion(context, std::max(HalVersion::V1_0, opIntroducedAt)));
460         } else {
461             NN_RET_CHECK(validateHalVersion(context, std::max(HalVersion::V1_0, opIntroducedAt)));
462         }
463     } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED ||
464                inputType == OperandType::TENSOR_INT32) {
465         NN_RET_CHECK(validateHalVersion(context, std::max(HalVersion::V1_3, opIntroducedAt)));
466     } else {
467         NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << getOperationName(opType);
468     }
469     const Shape& input1 = context->getInputShape(kInputTensor1);
470     const Shape& input2 = context->getInputShape(kInputTensor2);
471     if (hasKnownRank(input1) && hasKnownRank(input2)) {
472         NN_RET_CHECK_LE(getNumberOfDimensions(input1), 4);
473         NN_RET_CHECK_LE(getNumberOfDimensions(input2), 4);
474     }
475     return validateInputTypes(context, {inputType, inputType, OperandType::INT32}) &&
476            validateOutputTypes(context, {inputType});
477 }
478 
prepare(IOperationExecutionContext * context)479 bool prepare(IOperationExecutionContext* context) {
480     Shape input1 = context->getInputShape(kInputTensor1);
481     Shape input2 = context->getInputShape(kInputTensor2);
482     Shape output = context->getOutputShape(kOutputTensor);
483     NN_RET_CHECK_LE(getNumberOfDimensions(input1), 4);
484     NN_RET_CHECK_LE(getNumberOfDimensions(input2), 4);
485     NN_RET_CHECK(calculateBroadcastedShape(input1, input2, &output));
486     return context->setOutputShape(kOutputTensor, output);
487 }
488 
executeAdd(IOperationExecutionContext * context)489 bool executeAdd(IOperationExecutionContext* context) {
490     // Bypass execution in the case of zero-sized input.
491     if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
492     switch (context->getInputType(kInputTensor1)) {
493         case OperandType::TENSOR_FLOAT16:
494             return addFloat16(context->getInputBuffer<_Float16>(kInputTensor1),
495                               context->getInputShape(kInputTensor1),
496                               context->getInputBuffer<_Float16>(kInputTensor2),
497                               context->getInputShape(kInputTensor2),
498                               context->getInputValue<int32_t>(kActivationScalar),
499                               context->getOutputBuffer<_Float16>(kOutputTensor),
500                               context->getOutputShape(kOutputTensor));
501         case OperandType::TENSOR_FLOAT32:
502             return addFloat32(context->getInputBuffer<float>(kInputTensor1),
503                               context->getInputShape(kInputTensor1),
504                               context->getInputBuffer<float>(kInputTensor2),
505                               context->getInputShape(kInputTensor2),
506                               context->getInputValue<int32_t>(kActivationScalar),
507                               context->getOutputBuffer<float>(kOutputTensor),
508                               context->getOutputShape(kOutputTensor));
509         case OperandType::TENSOR_QUANT8_ASYMM:
510             return addQuant8(context->getInputBuffer<uint8_t>(kInputTensor1),
511                              context->getInputShape(kInputTensor1),
512                              context->getInputBuffer<uint8_t>(kInputTensor2),
513                              context->getInputShape(kInputTensor2),
514                              context->getInputValue<int32_t>(kActivationScalar),
515                              context->getOutputBuffer<uint8_t>(kOutputTensor),
516                              context->getOutputShape(kOutputTensor));
517         case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
518             return addQuant8(context->getInputBuffer<int8_t>(kInputTensor1),
519                              context->getInputShape(kInputTensor1),
520                              context->getInputBuffer<int8_t>(kInputTensor2),
521                              context->getInputShape(kInputTensor2),
522                              context->getInputValue<int32_t>(kActivationScalar),
523                              context->getOutputBuffer<int8_t>(kOutputTensor),
524                              context->getOutputShape(kOutputTensor));
525         case OperandType::TENSOR_INT32:
526             return executeInt32(context->getInputBuffer<int32_t>(kInputTensor1),
527                                 context->getInputShape(kInputTensor1),
528                                 context->getInputBuffer<int32_t>(kInputTensor2),
529                                 context->getInputShape(kInputTensor2),
530                                 context->getInputValue<int32_t>(kActivationScalar),
531                                 context->getOutputBuffer<int32_t>(kOutputTensor),
532                                 context->getOutputShape(kOutputTensor),
533                                 [](int32_t a, int32_t b) { return a + b; });
534         default:
535             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation ADD";
536     }
537 }
538 
executeMul(IOperationExecutionContext * context)539 bool executeMul(IOperationExecutionContext* context) {
540     // Bypass execution in the case of zero-sized input.
541     if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
542     switch (context->getInputType(kInputTensor1)) {
543         case OperandType::TENSOR_FLOAT16:
544             return mulFloat16(context->getInputBuffer<_Float16>(kInputTensor1),
545                               context->getInputShape(kInputTensor1),
546                               context->getInputBuffer<_Float16>(kInputTensor2),
547                               context->getInputShape(kInputTensor2),
548                               context->getInputValue<int32_t>(kActivationScalar),
549                               context->getOutputBuffer<_Float16>(kOutputTensor),
550                               context->getOutputShape(kOutputTensor));
551         case OperandType::TENSOR_FLOAT32:
552             return mulFloat32(context->getInputBuffer<float>(kInputTensor1),
553                               context->getInputShape(kInputTensor1),
554                               context->getInputBuffer<float>(kInputTensor2),
555                               context->getInputShape(kInputTensor2),
556                               context->getInputValue<int32_t>(kActivationScalar),
557                               context->getOutputBuffer<float>(kOutputTensor),
558                               context->getOutputShape(kOutputTensor));
559         case OperandType::TENSOR_QUANT8_ASYMM:
560             return mulQuant8(context->getInputBuffer<uint8_t>(kInputTensor1),
561                              context->getInputShape(kInputTensor1),
562                              context->getInputBuffer<uint8_t>(kInputTensor2),
563                              context->getInputShape(kInputTensor2),
564                              context->getInputValue<int32_t>(kActivationScalar),
565                              context->getOutputBuffer<uint8_t>(kOutputTensor),
566                              context->getOutputShape(kOutputTensor));
567         case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
568             return mulQuant8(context->getInputBuffer<int8_t>(kInputTensor1),
569                              context->getInputShape(kInputTensor1),
570                              context->getInputBuffer<int8_t>(kInputTensor2),
571                              context->getInputShape(kInputTensor2),
572                              context->getInputValue<int32_t>(kActivationScalar),
573                              context->getOutputBuffer<int8_t>(kOutputTensor),
574                              context->getOutputShape(kOutputTensor));
575         case OperandType::TENSOR_INT32:
576             return executeInt32(context->getInputBuffer<int32_t>(kInputTensor1),
577                                 context->getInputShape(kInputTensor1),
578                                 context->getInputBuffer<int32_t>(kInputTensor2),
579                                 context->getInputShape(kInputTensor2),
580                                 context->getInputValue<int32_t>(kActivationScalar),
581                                 context->getOutputBuffer<int32_t>(kOutputTensor),
582                                 context->getOutputShape(kOutputTensor),
583                                 [](int32_t a, int32_t b) { return a * b; });
584         default:
585             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation MUL";
586     }
587 }
588 
executeSub(IOperationExecutionContext * context)589 bool executeSub(IOperationExecutionContext* context) {
590     // Bypass execution in the case of zero-sized input.
591     if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
592     switch (context->getInputType(kInputTensor1)) {
593         case OperandType::TENSOR_FLOAT16:
594             return subFloat16(context->getInputBuffer<_Float16>(kInputTensor1),
595                               context->getInputShape(kInputTensor1),
596                               context->getInputBuffer<_Float16>(kInputTensor2),
597                               context->getInputShape(kInputTensor2),
598                               context->getInputValue<int32_t>(kActivationScalar),
599                               context->getOutputBuffer<_Float16>(kOutputTensor),
600                               context->getOutputShape(kOutputTensor));
601         case OperandType::TENSOR_FLOAT32:
602             return subFloat32(context->getInputBuffer<float>(kInputTensor1),
603                               context->getInputShape(kInputTensor1),
604                               context->getInputBuffer<float>(kInputTensor2),
605                               context->getInputShape(kInputTensor2),
606                               context->getInputValue<int32_t>(kActivationScalar),
607                               context->getOutputBuffer<float>(kOutputTensor),
608                               context->getOutputShape(kOutputTensor));
609         case OperandType::TENSOR_QUANT8_ASYMM:
610             return subQuant8(context->getInputBuffer<uint8_t>(kInputTensor1),
611                              context->getInputShape(kInputTensor1),
612                              context->getInputBuffer<uint8_t>(kInputTensor2),
613                              context->getInputShape(kInputTensor2),
614                              context->getInputValue<int32_t>(kActivationScalar),
615                              context->getOutputBuffer<uint8_t>(kOutputTensor),
616                              context->getOutputShape(kOutputTensor));
617         case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
618             return subQuant8(context->getInputBuffer<int8_t>(kInputTensor1),
619                              context->getInputShape(kInputTensor1),
620                              context->getInputBuffer<int8_t>(kInputTensor2),
621                              context->getInputShape(kInputTensor2),
622                              context->getInputValue<int32_t>(kActivationScalar),
623                              context->getOutputBuffer<int8_t>(kOutputTensor),
624                              context->getOutputShape(kOutputTensor));
625         case OperandType::TENSOR_INT32:
626             return executeInt32(context->getInputBuffer<int32_t>(kInputTensor1),
627                                 context->getInputShape(kInputTensor1),
628                                 context->getInputBuffer<int32_t>(kInputTensor2),
629                                 context->getInputShape(kInputTensor2),
630                                 context->getInputValue<int32_t>(kActivationScalar),
631                                 context->getOutputBuffer<int32_t>(kOutputTensor),
632                                 context->getOutputShape(kOutputTensor),
633                                 [](int32_t a, int32_t b) { return a - b; });
634         default:
635             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation SUB";
636     }
637 }
638 
executeDiv(IOperationExecutionContext * context)639 bool executeDiv(IOperationExecutionContext* context) {
640     // Bypass execution in the case of zero-sized input.
641     if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
642     switch (context->getInputType(kInputTensor1)) {
643         case OperandType::TENSOR_FLOAT16:
644             return divFloat16(context->getInputBuffer<_Float16>(kInputTensor1),
645                               context->getInputShape(kInputTensor1),
646                               context->getInputBuffer<_Float16>(kInputTensor2),
647                               context->getInputShape(kInputTensor2),
648                               context->getInputValue<int32_t>(kActivationScalar),
649                               context->getOutputBuffer<_Float16>(kOutputTensor),
650                               context->getOutputShape(kOutputTensor));
651         case OperandType::TENSOR_FLOAT32:
652             return divFloat32(context->getInputBuffer<float>(kInputTensor1),
653                               context->getInputShape(kInputTensor1),
654                               context->getInputBuffer<float>(kInputTensor2),
655                               context->getInputShape(kInputTensor2),
656                               context->getInputValue<int32_t>(kActivationScalar),
657                               context->getOutputBuffer<float>(kOutputTensor),
658                               context->getOutputShape(kOutputTensor));
659         case OperandType::TENSOR_INT32:
660             return executeInt32(context->getInputBuffer<int32_t>(kInputTensor1),
661                                 context->getInputShape(kInputTensor1),
662                                 context->getInputBuffer<int32_t>(kInputTensor2),
663                                 context->getInputShape(kInputTensor2),
664                                 context->getInputValue<int32_t>(kActivationScalar),
665                                 context->getOutputBuffer<int32_t>(kOutputTensor),
666                                 context->getOutputShape(kOutputTensor), [](int32_t a, int32_t b) {
667                                     // In NNAPI, DIV by zero is undefined, but should not crash.
668                                     if (b == 0) return 0;
669                                     int32_t result = a / b;
670                                     if (a % b != 0 && ((a < 0) != (b < 0))) {
671                                         // Implement "floor division".
672                                         --result;
673                                     }
674                                     return result;
675                                 });
676         default:
677             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation DIV";
678     }
679 }
680 
681 }  // namespace broadcast
682 
683 using std::placeholders::_1;
684 NN_REGISTER_OPERATION(ADD, "ADD", std::bind(broadcast::validate, OperationType::ADD, _1),
685                       broadcast::prepare, broadcast::executeAdd, .allowZeroSizedInput = true);
686 NN_REGISTER_OPERATION(MUL, "MUL", std::bind(broadcast::validate, OperationType::MUL, _1),
687                       broadcast::prepare, broadcast::executeMul, .allowZeroSizedInput = true);
688 NN_REGISTER_OPERATION(SUB, "SUB", std::bind(broadcast::validate, OperationType::SUB, _1),
689                       broadcast::prepare, broadcast::executeSub, .allowZeroSizedInput = true);
690 NN_REGISTER_OPERATION(DIV, "DIV", std::bind(broadcast::validate, OperationType::DIV, _1),
691                       broadcast::prepare, broadcast::executeDiv, .allowZeroSizedInput = true);
692 
693 }  // namespace nn
694 }  // namespace android
695