1 /*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 // Contains the implementation of the operations.
18
19 #define LOG_TAG "Operations"
20
21 #include <tensorflow/lite/kernels/internal/optimized/integer_ops/add.h>
22 #include <tensorflow/lite/kernels/internal/optimized/integer_ops/mul.h>
23 #include <tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h>
24 #include <tensorflow/lite/kernels/internal/reference/integer_ops/add.h>
25 #include <tensorflow/lite/kernels/internal/reference/integer_ops/mul.h>
26 #include <tensorflow/lite/kernels/internal/types.h>
27
28 #include <algorithm>
29 #include <vector>
30
31 #include "CpuOperationUtils.h"
32 #include "HalInterfaces.h"
33 #include "IndexedShapeWrapper.h"
34 #include "OperationResolver.h"
35 #include "Tracing.h"
36
37 namespace android {
38 namespace nn {
39
40 using namespace hal;
41
42 namespace broadcast {
43
44 constexpr uint32_t kNumInputs = 3;
45 constexpr uint32_t kInputTensor1 = 0;
46 constexpr uint32_t kInputTensor2 = 1;
47 constexpr uint32_t kActivationScalar = 2;
48
49 constexpr uint32_t kNumOutputs = 1;
50 constexpr uint32_t kOutputTensor = 0;
51
52 namespace {
53
54 #define ANDROID_NN_MACRO_DISPATCH(macro) \
55 switch (activation) { \
56 case (int32_t)FusedActivationFunc::NONE: \
57 macro(kNone); \
58 break; \
59 case (int32_t)FusedActivationFunc::RELU: \
60 macro(kRelu); \
61 break; \
62 case (int32_t)FusedActivationFunc::RELU1: \
63 macro(kRelu1); \
64 break; \
65 case (int32_t)FusedActivationFunc::RELU6: \
66 macro(kRelu6); \
67 break; \
68 default: \
69 LOG(ERROR) << "Unsupported fused activation function type"; \
70 return false; \
71 }
72
73 using binaryFunctionFloat32 = std::function<bool(
74 const float* in1, const Shape& shape1, const float* in2, const Shape& shape2,
75 int32_t activation, float* out, const Shape& shapeOut)>;
76
binaryOperationFloat16(const _Float16 * in1,const Shape & shape1,const _Float16 * in2,const Shape & shape2,int32_t activation,_Float16 * out,const Shape & shapeOut,binaryFunctionFloat32 operationFloat32)77 bool binaryOperationFloat16(const _Float16* in1, const Shape& shape1, const _Float16* in2,
78 const Shape& shape2, int32_t activation, _Float16* out,
79 const Shape& shapeOut, binaryFunctionFloat32 operationFloat32) {
80 std::vector<float> in1_float32(getNumberOfElements(shape1));
81 convertFloat16ToFloat32(in1, &in1_float32);
82 std::vector<float> in2_float32(getNumberOfElements(shape2));
83 convertFloat16ToFloat32(in2, &in2_float32);
84 std::vector<float> out_float32(getNumberOfElements(shapeOut));
85
86 operationFloat32(in1_float32.data(), shape1, in2_float32.data(), shape2, activation,
87 out_float32.data(), shapeOut);
88 convertFloat32ToFloat16(out_float32, out);
89
90 return true;
91 }
92
addFloat32(const float * in1,const Shape & shape1,const float * in2,const Shape & shape2,int32_t activation,float * out,const Shape & shapeOut)93 bool addFloat32(const float* in1, const Shape& shape1, const float* in2, const Shape& shape2,
94 int32_t activation, float* out, const Shape& shapeOut) {
95 NNTRACE_TRANS("addFloat32");
96 bool needBroadcast = !SameShape(shape1, shape2);
97 if (needBroadcast) {
98 NNTRACE_COMP_SWITCH("optimized_ops::BroadcastAdd");
99 #define ANDROID_NN_BROADCAST_ADD(activation) \
100 tflite::optimized_ops::BroadcastAdd<tflite::FusedActivationFunctionType::activation>( \
101 in1, convertShapeToDims(shape1), in2, convertShapeToDims(shape2), out, \
102 convertShapeToDims(shapeOut))
103
104 ANDROID_NN_MACRO_DISPATCH(ANDROID_NN_BROADCAST_ADD)
105 #undef ANDROID_NN_BROADCAST_ADD
106 } else {
107 NNTRACE_COMP_SWITCH("optimized_ops::Add");
108 #define ANDROID_NN_ADD(activation) \
109 tflite::optimized_ops::Add<tflite::FusedActivationFunctionType::activation>( \
110 in1, convertShapeToDims(shape1), in2, convertShapeToDims(shape2), out, \
111 convertShapeToDims(shapeOut))
112
113 ANDROID_NN_MACRO_DISPATCH(ANDROID_NN_ADD)
114 #undef ANDROID_NN_ADD
115 }
116
117 return true;
118 }
119
addFloat16(const _Float16 * in1,const Shape & shape1,const _Float16 * in2,const Shape & shape2,int32_t activation,_Float16 * out,const Shape & shapeOut)120 bool addFloat16(const _Float16* in1, const Shape& shape1, const _Float16* in2, const Shape& shape2,
121 int32_t activation, _Float16* out, const Shape& shapeOut) {
122 NNTRACE_TRANS("addFloat16");
123 return binaryOperationFloat16(in1, shape1, in2, shape2, activation, out, shapeOut, &addFloat32);
124 }
125
126 template <typename T>
addQuant8(const T * in1,const Shape & shape1,const T * in2,const Shape & shape2,int32_t activation,T * out,const Shape & shapeOut)127 bool addQuant8(const T* in1, const Shape& shape1, const T* in2, const Shape& shape2,
128 int32_t activation, T* out, const Shape& shapeOut) {
129 NNTRACE_TRANS("addQuant8");
130 const bool needBroadcast = !SameShape(shape1, shape2);
131
132 const int32_t input1_offset = -shape1.offset;
133 const int32_t input2_offset = -shape2.offset;
134 const int32_t output_offset = shapeOut.offset;
135 const int left_shift = 20;
136 const double twice_max_input_scale = 2 * std::max(shape1.scale, shape2.scale);
137 const double real_input1_multiplier = shape1.scale / twice_max_input_scale;
138 const double real_input2_multiplier = shape2.scale / twice_max_input_scale;
139 const double real_output_multiplier =
140 twice_max_input_scale / ((1 << left_shift) * shapeOut.scale);
141
142 int32_t input1_multiplier;
143 int32_t input1_shift;
144 NN_RET_CHECK(QuantizeMultiplierSmallerThanOneExp(real_input1_multiplier, &input1_multiplier,
145 &input1_shift));
146 int32_t input2_multiplier;
147 int32_t input2_shift;
148 NN_RET_CHECK(QuantizeMultiplierSmallerThanOneExp(real_input2_multiplier, &input2_multiplier,
149 &input2_shift));
150 int32_t output_multiplier;
151 int32_t output_shift;
152 NN_RET_CHECK(QuantizeMultiplierSmallerThanOneExp(real_output_multiplier, &output_multiplier,
153 &output_shift));
154
155 int32_t output_activation_min;
156 int32_t output_activation_max;
157 constexpr bool isSignedOp = std::is_same<T, int8_t>::value;
158 if constexpr (isSignedOp) {
159 CalculateActivationRangeInt8(activation, shapeOut, &output_activation_min,
160 &output_activation_max);
161 } else {
162 CalculateActivationRangeUint8(activation, shapeOut, &output_activation_min,
163 &output_activation_max);
164 }
165
166 tflite::ArithmeticParams op_params;
167 op_params.left_shift = left_shift;
168 op_params.input1_offset = input1_offset;
169 op_params.input1_multiplier = input1_multiplier;
170 op_params.input1_shift = input1_shift;
171 op_params.input2_offset = input2_offset;
172 op_params.input2_multiplier = input2_multiplier;
173 op_params.input2_shift = input2_shift;
174 op_params.output_offset = output_offset;
175 op_params.output_multiplier = output_multiplier;
176 op_params.output_shift = output_shift;
177 tflite::SetActivationParams(output_activation_min, output_activation_max, &op_params);
178
179 if (needBroadcast) {
180 if constexpr (isSignedOp) {
181 NNTRACE_COMP_SWITCH("reference_integer_ops::BroadcastAdd4DSlow");
182 tflite::reference_integer_ops::BroadcastAdd4DSlow(
183 op_params, convertShapeToTflshape(shape1), in1, convertShapeToTflshape(shape2),
184 in2, convertShapeToTflshape(shapeOut), out);
185 } else {
186 NNTRACE_COMP_SWITCH("reference_ops::BroadcastAdd4DSlow");
187 tflite::reference_ops::BroadcastAdd4DSlow(op_params, convertShapeToTflshape(shape1),
188 in1, convertShapeToTflshape(shape2), in2,
189 convertShapeToTflshape(shapeOut), out);
190 }
191 } else {
192 if constexpr (isSignedOp) {
193 NNTRACE_COMP_SWITCH("optimized_integer_ops::Add");
194 tflite::optimized_integer_ops::Add(op_params, convertShapeToTflshape(shape1), in1,
195 convertShapeToTflshape(shape2), in2,
196 convertShapeToTflshape(shapeOut), out);
197 } else {
198 NNTRACE_COMP_SWITCH("optimized_ops::Add");
199 tflite::optimized_ops::Add(op_params, convertShapeToTflshape(shape1), in1,
200 convertShapeToTflshape(shape2), in2,
201 convertShapeToTflshape(shapeOut), out);
202 }
203 }
204
205 return true;
206 }
207
executeInt32(const int32_t * aData,const Shape & aShape,const int32_t * bData,const Shape & bShape,int32_t activation,int32_t * outputData,const Shape & outputShape,int32_t func (int32_t,int32_t))208 bool executeInt32(const int32_t* aData, const Shape& aShape, const int32_t* bData,
209 const Shape& bShape, int32_t activation, int32_t* outputData,
210 const Shape& outputShape, int32_t func(int32_t, int32_t)) {
211 NN_RET_CHECK_EQ(activation, ANEURALNETWORKS_FUSED_NONE);
212 IndexedShapeWrapper aShapeIndexed(aShape);
213 IndexedShapeWrapper bShapeIndexed(bShape);
214 IndexedShapeWrapper outputShapeIndexed(outputShape);
215 std::vector<uint32_t> curIndex(outputShape.dimensions.size(), 0);
216 bool lastIndex = false;
217 do {
218 uint32_t outputFlatIndex;
219 NN_RET_CHECK(outputShapeIndexed.indexToFlatIndex(curIndex, &outputFlatIndex));
220 uint32_t aFlatIndex;
221 NN_RET_CHECK(aShapeIndexed.broadcastedIndexToFlatIndex(curIndex, &aFlatIndex));
222 uint32_t bFlatIndex;
223 NN_RET_CHECK(bShapeIndexed.broadcastedIndexToFlatIndex(curIndex, &bFlatIndex));
224
225 outputData[outputFlatIndex] = func(aData[aFlatIndex], bData[bFlatIndex]);
226
227 NN_RET_CHECK(outputShapeIndexed.nextIndexInplace(&curIndex, &lastIndex));
228 } while (!lastIndex);
229 return true;
230 }
231
mulFloat32(const float * in1,const Shape & shape1,const float * in2,const Shape & shape2,int32_t activation,float * out,const Shape & shapeOut)232 bool mulFloat32(const float* in1, const Shape& shape1, const float* in2, const Shape& shape2,
233 int32_t activation, float* out, const Shape& shapeOut) {
234 NNTRACE_TRANS("mulFloat32");
235 bool needBroadcast = !SameShape(shape1, shape2);
236
237 if (needBroadcast) {
238 NNTRACE_COMP_SWITCH("optimized_ops::BroadcastMul");
239 #define ANDROID_NN_BROADCAST_MUL(activation) \
240 tflite::optimized_ops::BroadcastMul<tflite::FusedActivationFunctionType::activation>( \
241 in1, convertShapeToDims(shape1), in2, convertShapeToDims(shape2), out, \
242 convertShapeToDims(shapeOut))
243
244 ANDROID_NN_MACRO_DISPATCH(ANDROID_NN_BROADCAST_MUL)
245 #undef ANDROID_NN_BROADCAST_MUL
246 } else {
247 float output_activation_min, output_activation_max;
248 CalculateActivationRangeFloat(activation, &output_activation_min, &output_activation_max);
249
250 NNTRACE_COMP_SWITCH("optimized_ops::Mul");
251 tflite::optimized_ops::Mul(in1, convertShapeToDims(shape1), in2, convertShapeToDims(shape2),
252 output_activation_min, output_activation_max, out,
253 convertShapeToDims(shapeOut));
254 }
255
256 return true;
257 }
258
mulFloat16(const _Float16 * in1,const Shape & shape1,const _Float16 * in2,const Shape & shape2,int32_t activation,_Float16 * out,const Shape & shapeOut)259 bool mulFloat16(const _Float16* in1, const Shape& shape1, const _Float16* in2, const Shape& shape2,
260 int32_t activation, _Float16* out, const Shape& shapeOut) {
261 NNTRACE_TRANS("mulFloat16");
262 return binaryOperationFloat16(in1, shape1, in2, shape2, activation, out, shapeOut, &mulFloat32);
263 }
264
265 template <typename T>
mulQuant8(const T * in1,const Shape & shape1,const T * in2,const Shape & shape2,int32_t activation,T * out,const Shape & shapeOut)266 bool mulQuant8(const T* in1, const Shape& shape1, const T* in2, const Shape& shape2,
267 int32_t activation, T* out, const Shape& shapeOut) {
268 NNTRACE_TRANS("mulQuant8");
269 const int32_t input1_offset = -shape1.offset;
270 const int32_t input2_offset = -shape2.offset;
271 const int32_t output_offset = shapeOut.offset;
272 const double input_product_scale = shape1.scale * shape2.scale;
273 const double real_multiplier = input_product_scale / shapeOut.scale;
274 int32 output_multiplier;
275 int output_shift;
276 NN_RET_CHECK(QuantizeMultiplierSmallerThanOneExp(real_multiplier, &output_multiplier,
277 &output_shift));
278
279 constexpr bool isSignedOp = std::is_same<T, int8_t>::value;
280 int32_t output_activation_min;
281 int32_t output_activation_max;
282 if constexpr (isSignedOp) {
283 CalculateActivationRangeInt8(activation, shapeOut, &output_activation_min,
284 &output_activation_max);
285 } else {
286 CalculateActivationRangeUint8(activation, shapeOut, &output_activation_min,
287 &output_activation_max);
288 }
289
290 tflite::ArithmeticParams op_params;
291 op_params.input1_offset = input1_offset;
292 op_params.input2_offset = input2_offset;
293 op_params.output_offset = output_offset;
294 op_params.output_multiplier = output_multiplier;
295 op_params.output_shift = output_shift;
296 tflite::SetActivationParams(output_activation_min, output_activation_max, &op_params);
297
298 if constexpr (isSignedOp) {
299 NNTRACE_COMP_SWITCH("reference_integer_ops::BroadcastMul4DSlow");
300 tflite::reference_integer_ops::BroadcastMul4DSlow(op_params, convertShapeToTflshape(shape1),
301 in1, convertShapeToTflshape(shape2), in2,
302 convertShapeToTflshape(shapeOut), out);
303 } else {
304 NNTRACE_COMP_SWITCH("reference_ops::BroadcastMul4DSlow");
305 tflite::reference_ops::BroadcastMul4DSlow(op_params, convertShapeToTflshape(shape1), in1,
306 convertShapeToTflshape(shape2), in2,
307 convertShapeToTflshape(shapeOut), out);
308 }
309
310 return true;
311 }
312
subFloat32(const float * in1,const Shape & shape1,const float * in2,const Shape & shape2,int32_t activation,float * out,const Shape & shapeOut)313 bool subFloat32(const float* in1, const Shape& shape1, const float* in2, const Shape& shape2,
314 int32_t activation, float* out, const Shape& shapeOut) {
315 NNTRACE_TRANS("subFloat32");
316 NNTRACE_COMP_SWITCH("optimized_ops::Sub");
317 tflite::optimized_ops::Sub(in1, convertShapeToDims(shape1), in2, convertShapeToDims(shape2),
318 out, convertShapeToDims(shapeOut));
319
320 // TFLite does not apply activation to broadcast sub.
321 float output_activation_min, output_activation_max;
322 CalculateActivationRangeFloat(activation, &output_activation_min, &output_activation_max);
323 uint32_t numOutputElements = getNumberOfElements(shapeOut);
324 for (uint32_t i = 0; i < numOutputElements; i++) {
325 out[i] = std::min(std::max(out[i], output_activation_min), output_activation_max);
326 }
327 return true;
328 }
329
subFloat16(const _Float16 * in1,const Shape & shape1,const _Float16 * in2,const Shape & shape2,int32_t activation,_Float16 * out,const Shape & shapeOut)330 bool subFloat16(const _Float16* in1, const Shape& shape1, const _Float16* in2, const Shape& shape2,
331 int32_t activation, _Float16* out, const Shape& shapeOut) {
332 NNTRACE_TRANS("subFloat16");
333 return binaryOperationFloat16(in1, shape1, in2, shape2, activation, out, shapeOut, &subFloat32);
334 }
335
336 template <typename T>
subQuant8(const T * in1,const Shape & shape1,const T * in2,const Shape & shape2,int32_t activation,T * out,const Shape & shapeOut)337 bool subQuant8(const T* in1, const Shape& shape1, const T* in2, const Shape& shape2,
338 int32_t activation, T* out, const Shape& shapeOut) {
339 NNTRACE_TRANS("subQuant8");
340
341 const int32_t input1_offset = -shape1.offset;
342 const int32_t input2_offset = -shape2.offset;
343 const int32_t output_offset = shapeOut.offset;
344 const int left_shift = 20;
345 const double twice_max_input_scale = 2 * std::max(shape1.scale, shape2.scale);
346 const double real_input1_multiplier = shape1.scale / twice_max_input_scale;
347 const double real_input2_multiplier = shape2.scale / twice_max_input_scale;
348 const double real_output_multiplier =
349 twice_max_input_scale / ((1 << left_shift) * shapeOut.scale);
350
351 int32_t input1_multiplier;
352 int32_t input1_shift;
353 NN_RET_CHECK(QuantizeMultiplierSmallerThanOneExp(real_input1_multiplier, &input1_multiplier,
354 &input1_shift));
355 int32_t input2_multiplier;
356 int32_t input2_shift;
357 NN_RET_CHECK(QuantizeMultiplierSmallerThanOneExp(real_input2_multiplier, &input2_multiplier,
358 &input2_shift));
359 // Negate multiplier of the second input, so that we can use Add kernels.
360 input2_multiplier *= -1;
361
362 int32_t output_multiplier;
363 int32_t output_shift;
364 NN_RET_CHECK(QuantizeMultiplierSmallerThanOneExp(real_output_multiplier, &output_multiplier,
365 &output_shift));
366
367 constexpr bool isSignedOp = std::is_same<T, int8_t>::value;
368 int32_t output_activation_min;
369 int32_t output_activation_max;
370 if constexpr (isSignedOp) {
371 CalculateActivationRangeInt8(activation, shapeOut, &output_activation_min,
372 &output_activation_max);
373 } else {
374 CalculateActivationRangeUint8(activation, shapeOut, &output_activation_min,
375 &output_activation_max);
376 }
377
378 tflite::ArithmeticParams op_params;
379 op_params.left_shift = left_shift;
380 op_params.input1_offset = input1_offset;
381 op_params.input1_multiplier = input1_multiplier;
382 op_params.input1_shift = input1_shift;
383 op_params.input2_offset = input2_offset;
384 op_params.input2_multiplier = input2_multiplier;
385 op_params.input2_shift = input2_shift;
386 op_params.output_offset = output_offset;
387 op_params.output_multiplier = output_multiplier;
388 op_params.output_shift = output_shift;
389 tflite::SetActivationParams(output_activation_min, output_activation_max, &op_params);
390
391 // We are using tflite::optimized_ops::BroadcastAdd unconditionally here
392 // because tflite::optimized_ops::Add fails to pass some of the
393 // sub_quantized_different_scales tests.
394 if constexpr (isSignedOp) {
395 NNTRACE_COMP_SWITCH("reference_integer_ops::BroadcastAdd4DSlow");
396 tflite::reference_integer_ops::BroadcastAdd4DSlow(op_params, convertShapeToTflshape(shape1),
397 in1, convertShapeToTflshape(shape2), in2,
398 convertShapeToTflshape(shapeOut), out);
399 } else {
400 NNTRACE_COMP_SWITCH("reference_ops::BroadcastAdd4DSlow");
401 tflite::reference_ops::BroadcastAdd4DSlow(op_params, convertShapeToTflshape(shape1), in1,
402 convertShapeToTflshape(shape2), in2,
403 convertShapeToTflshape(shapeOut), out);
404 }
405
406 return true;
407 }
408
divFloat32(const float * in1,const Shape & shape1,const float * in2,const Shape & shape2,int32_t activation,float * out,const Shape & shapeOut)409 bool divFloat32(const float* in1, const Shape& shape1, const float* in2, const Shape& shape2,
410 int32_t activation, float* out, const Shape& shapeOut) {
411 NNTRACE_TRANS("divFloat32");
412 float output_activation_min, output_activation_max;
413 CalculateActivationRangeFloat(activation, &output_activation_min, &output_activation_max);
414
415 bool needBroadcast = !SameShape(shape1, shape2);
416 if (needBroadcast) {
417 NNTRACE_COMP_SWITCH("optimized_ops::BroadcastDiv");
418 tflite::optimized_ops::BroadcastDiv(
419 in1, convertShapeToDims(shape1), in2, convertShapeToDims(shape2),
420 output_activation_min, output_activation_max, out, convertShapeToDims(shapeOut));
421 } else {
422 NNTRACE_COMP_SWITCH("optimized_ops::Div");
423 tflite::optimized_ops::Div(in1, convertShapeToDims(shape1), in2, convertShapeToDims(shape2),
424 output_activation_min, output_activation_max, out,
425 convertShapeToDims(shapeOut));
426 }
427 return true;
428 }
429
divFloat16(const _Float16 * in1,const Shape & shape1,const _Float16 * in2,const Shape & shape2,int32_t activation,_Float16 * out,const Shape & shapeOut)430 bool divFloat16(const _Float16* in1, const Shape& shape1, const _Float16* in2, const Shape& shape2,
431 int32_t activation, _Float16* out, const Shape& shapeOut) {
432 NNTRACE_TRANS("divFloat16");
433 return binaryOperationFloat16(in1, shape1, in2, shape2, activation, out, shapeOut, &divFloat32);
434 }
435
436 } // namespace
437
validate(OperationType opType,const IOperationValidationContext * context)438 bool validate(OperationType opType, const IOperationValidationContext* context) {
439 const HalVersion opIntroducedAt = (opType == OperationType::DIV || opType == OperationType::SUB)
440 ? HalVersion::V1_1
441 : HalVersion::V1_0;
442 NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
443 NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
444 auto inputType = context->getInputType(kInputTensor1);
445 if (inputType == OperandType::TENSOR_FLOAT32) {
446 NN_RET_CHECK(validateHalVersion(context, std::max(HalVersion::V1_0, opIntroducedAt)));
447 } else if (inputType == OperandType::TENSOR_FLOAT16) {
448 NN_RET_CHECK(validateHalVersion(context, std::max(HalVersion::V1_2, opIntroducedAt)));
449 } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM) {
450 if (opType == OperationType::SUB) {
451 NN_RET_CHECK(validateHalVersion(context, std::max(HalVersion::V1_2, opIntroducedAt)));
452 } else if (opType == OperationType::DIV) {
453 NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation DIV";
454 } else if (opType == OperationType::MUL) {
455 Shape output = context->getOutputShape(kOutputTensor);
456 Shape input1 = context->getInputShape(kInputTensor1);
457 Shape input2 = context->getInputShape(kInputTensor2);
458 NN_RET_CHECK_GT(output.scale, input1.scale * input2.scale);
459 NN_RET_CHECK(validateHalVersion(context, std::max(HalVersion::V1_0, opIntroducedAt)));
460 } else {
461 NN_RET_CHECK(validateHalVersion(context, std::max(HalVersion::V1_0, opIntroducedAt)));
462 }
463 } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED ||
464 inputType == OperandType::TENSOR_INT32) {
465 NN_RET_CHECK(validateHalVersion(context, std::max(HalVersion::V1_3, opIntroducedAt)));
466 } else {
467 NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << getOperationName(opType);
468 }
469 const Shape& input1 = context->getInputShape(kInputTensor1);
470 const Shape& input2 = context->getInputShape(kInputTensor2);
471 if (hasKnownRank(input1) && hasKnownRank(input2)) {
472 NN_RET_CHECK_LE(getNumberOfDimensions(input1), 4);
473 NN_RET_CHECK_LE(getNumberOfDimensions(input2), 4);
474 }
475 return validateInputTypes(context, {inputType, inputType, OperandType::INT32}) &&
476 validateOutputTypes(context, {inputType});
477 }
478
prepare(IOperationExecutionContext * context)479 bool prepare(IOperationExecutionContext* context) {
480 Shape input1 = context->getInputShape(kInputTensor1);
481 Shape input2 = context->getInputShape(kInputTensor2);
482 Shape output = context->getOutputShape(kOutputTensor);
483 NN_RET_CHECK_LE(getNumberOfDimensions(input1), 4);
484 NN_RET_CHECK_LE(getNumberOfDimensions(input2), 4);
485 NN_RET_CHECK(calculateBroadcastedShape(input1, input2, &output));
486 return context->setOutputShape(kOutputTensor, output);
487 }
488
executeAdd(IOperationExecutionContext * context)489 bool executeAdd(IOperationExecutionContext* context) {
490 // Bypass execution in the case of zero-sized input.
491 if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
492 switch (context->getInputType(kInputTensor1)) {
493 case OperandType::TENSOR_FLOAT16:
494 return addFloat16(context->getInputBuffer<_Float16>(kInputTensor1),
495 context->getInputShape(kInputTensor1),
496 context->getInputBuffer<_Float16>(kInputTensor2),
497 context->getInputShape(kInputTensor2),
498 context->getInputValue<int32_t>(kActivationScalar),
499 context->getOutputBuffer<_Float16>(kOutputTensor),
500 context->getOutputShape(kOutputTensor));
501 case OperandType::TENSOR_FLOAT32:
502 return addFloat32(context->getInputBuffer<float>(kInputTensor1),
503 context->getInputShape(kInputTensor1),
504 context->getInputBuffer<float>(kInputTensor2),
505 context->getInputShape(kInputTensor2),
506 context->getInputValue<int32_t>(kActivationScalar),
507 context->getOutputBuffer<float>(kOutputTensor),
508 context->getOutputShape(kOutputTensor));
509 case OperandType::TENSOR_QUANT8_ASYMM:
510 return addQuant8(context->getInputBuffer<uint8_t>(kInputTensor1),
511 context->getInputShape(kInputTensor1),
512 context->getInputBuffer<uint8_t>(kInputTensor2),
513 context->getInputShape(kInputTensor2),
514 context->getInputValue<int32_t>(kActivationScalar),
515 context->getOutputBuffer<uint8_t>(kOutputTensor),
516 context->getOutputShape(kOutputTensor));
517 case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
518 return addQuant8(context->getInputBuffer<int8_t>(kInputTensor1),
519 context->getInputShape(kInputTensor1),
520 context->getInputBuffer<int8_t>(kInputTensor2),
521 context->getInputShape(kInputTensor2),
522 context->getInputValue<int32_t>(kActivationScalar),
523 context->getOutputBuffer<int8_t>(kOutputTensor),
524 context->getOutputShape(kOutputTensor));
525 case OperandType::TENSOR_INT32:
526 return executeInt32(context->getInputBuffer<int32_t>(kInputTensor1),
527 context->getInputShape(kInputTensor1),
528 context->getInputBuffer<int32_t>(kInputTensor2),
529 context->getInputShape(kInputTensor2),
530 context->getInputValue<int32_t>(kActivationScalar),
531 context->getOutputBuffer<int32_t>(kOutputTensor),
532 context->getOutputShape(kOutputTensor),
533 [](int32_t a, int32_t b) { return a + b; });
534 default:
535 NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation ADD";
536 }
537 }
538
executeMul(IOperationExecutionContext * context)539 bool executeMul(IOperationExecutionContext* context) {
540 // Bypass execution in the case of zero-sized input.
541 if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
542 switch (context->getInputType(kInputTensor1)) {
543 case OperandType::TENSOR_FLOAT16:
544 return mulFloat16(context->getInputBuffer<_Float16>(kInputTensor1),
545 context->getInputShape(kInputTensor1),
546 context->getInputBuffer<_Float16>(kInputTensor2),
547 context->getInputShape(kInputTensor2),
548 context->getInputValue<int32_t>(kActivationScalar),
549 context->getOutputBuffer<_Float16>(kOutputTensor),
550 context->getOutputShape(kOutputTensor));
551 case OperandType::TENSOR_FLOAT32:
552 return mulFloat32(context->getInputBuffer<float>(kInputTensor1),
553 context->getInputShape(kInputTensor1),
554 context->getInputBuffer<float>(kInputTensor2),
555 context->getInputShape(kInputTensor2),
556 context->getInputValue<int32_t>(kActivationScalar),
557 context->getOutputBuffer<float>(kOutputTensor),
558 context->getOutputShape(kOutputTensor));
559 case OperandType::TENSOR_QUANT8_ASYMM:
560 return mulQuant8(context->getInputBuffer<uint8_t>(kInputTensor1),
561 context->getInputShape(kInputTensor1),
562 context->getInputBuffer<uint8_t>(kInputTensor2),
563 context->getInputShape(kInputTensor2),
564 context->getInputValue<int32_t>(kActivationScalar),
565 context->getOutputBuffer<uint8_t>(kOutputTensor),
566 context->getOutputShape(kOutputTensor));
567 case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
568 return mulQuant8(context->getInputBuffer<int8_t>(kInputTensor1),
569 context->getInputShape(kInputTensor1),
570 context->getInputBuffer<int8_t>(kInputTensor2),
571 context->getInputShape(kInputTensor2),
572 context->getInputValue<int32_t>(kActivationScalar),
573 context->getOutputBuffer<int8_t>(kOutputTensor),
574 context->getOutputShape(kOutputTensor));
575 case OperandType::TENSOR_INT32:
576 return executeInt32(context->getInputBuffer<int32_t>(kInputTensor1),
577 context->getInputShape(kInputTensor1),
578 context->getInputBuffer<int32_t>(kInputTensor2),
579 context->getInputShape(kInputTensor2),
580 context->getInputValue<int32_t>(kActivationScalar),
581 context->getOutputBuffer<int32_t>(kOutputTensor),
582 context->getOutputShape(kOutputTensor),
583 [](int32_t a, int32_t b) { return a * b; });
584 default:
585 NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation MUL";
586 }
587 }
588
executeSub(IOperationExecutionContext * context)589 bool executeSub(IOperationExecutionContext* context) {
590 // Bypass execution in the case of zero-sized input.
591 if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
592 switch (context->getInputType(kInputTensor1)) {
593 case OperandType::TENSOR_FLOAT16:
594 return subFloat16(context->getInputBuffer<_Float16>(kInputTensor1),
595 context->getInputShape(kInputTensor1),
596 context->getInputBuffer<_Float16>(kInputTensor2),
597 context->getInputShape(kInputTensor2),
598 context->getInputValue<int32_t>(kActivationScalar),
599 context->getOutputBuffer<_Float16>(kOutputTensor),
600 context->getOutputShape(kOutputTensor));
601 case OperandType::TENSOR_FLOAT32:
602 return subFloat32(context->getInputBuffer<float>(kInputTensor1),
603 context->getInputShape(kInputTensor1),
604 context->getInputBuffer<float>(kInputTensor2),
605 context->getInputShape(kInputTensor2),
606 context->getInputValue<int32_t>(kActivationScalar),
607 context->getOutputBuffer<float>(kOutputTensor),
608 context->getOutputShape(kOutputTensor));
609 case OperandType::TENSOR_QUANT8_ASYMM:
610 return subQuant8(context->getInputBuffer<uint8_t>(kInputTensor1),
611 context->getInputShape(kInputTensor1),
612 context->getInputBuffer<uint8_t>(kInputTensor2),
613 context->getInputShape(kInputTensor2),
614 context->getInputValue<int32_t>(kActivationScalar),
615 context->getOutputBuffer<uint8_t>(kOutputTensor),
616 context->getOutputShape(kOutputTensor));
617 case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
618 return subQuant8(context->getInputBuffer<int8_t>(kInputTensor1),
619 context->getInputShape(kInputTensor1),
620 context->getInputBuffer<int8_t>(kInputTensor2),
621 context->getInputShape(kInputTensor2),
622 context->getInputValue<int32_t>(kActivationScalar),
623 context->getOutputBuffer<int8_t>(kOutputTensor),
624 context->getOutputShape(kOutputTensor));
625 case OperandType::TENSOR_INT32:
626 return executeInt32(context->getInputBuffer<int32_t>(kInputTensor1),
627 context->getInputShape(kInputTensor1),
628 context->getInputBuffer<int32_t>(kInputTensor2),
629 context->getInputShape(kInputTensor2),
630 context->getInputValue<int32_t>(kActivationScalar),
631 context->getOutputBuffer<int32_t>(kOutputTensor),
632 context->getOutputShape(kOutputTensor),
633 [](int32_t a, int32_t b) { return a - b; });
634 default:
635 NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation SUB";
636 }
637 }
638
executeDiv(IOperationExecutionContext * context)639 bool executeDiv(IOperationExecutionContext* context) {
640 // Bypass execution in the case of zero-sized input.
641 if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
642 switch (context->getInputType(kInputTensor1)) {
643 case OperandType::TENSOR_FLOAT16:
644 return divFloat16(context->getInputBuffer<_Float16>(kInputTensor1),
645 context->getInputShape(kInputTensor1),
646 context->getInputBuffer<_Float16>(kInputTensor2),
647 context->getInputShape(kInputTensor2),
648 context->getInputValue<int32_t>(kActivationScalar),
649 context->getOutputBuffer<_Float16>(kOutputTensor),
650 context->getOutputShape(kOutputTensor));
651 case OperandType::TENSOR_FLOAT32:
652 return divFloat32(context->getInputBuffer<float>(kInputTensor1),
653 context->getInputShape(kInputTensor1),
654 context->getInputBuffer<float>(kInputTensor2),
655 context->getInputShape(kInputTensor2),
656 context->getInputValue<int32_t>(kActivationScalar),
657 context->getOutputBuffer<float>(kOutputTensor),
658 context->getOutputShape(kOutputTensor));
659 case OperandType::TENSOR_INT32:
660 return executeInt32(context->getInputBuffer<int32_t>(kInputTensor1),
661 context->getInputShape(kInputTensor1),
662 context->getInputBuffer<int32_t>(kInputTensor2),
663 context->getInputShape(kInputTensor2),
664 context->getInputValue<int32_t>(kActivationScalar),
665 context->getOutputBuffer<int32_t>(kOutputTensor),
666 context->getOutputShape(kOutputTensor), [](int32_t a, int32_t b) {
667 // In NNAPI, DIV by zero is undefined, but should not crash.
668 if (b == 0) return 0;
669 int32_t result = a / b;
670 if (a % b != 0 && ((a < 0) != (b < 0))) {
671 // Implement "floor division".
672 --result;
673 }
674 return result;
675 });
676 default:
677 NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation DIV";
678 }
679 }
680
681 } // namespace broadcast
682
683 using std::placeholders::_1;
684 NN_REGISTER_OPERATION(ADD, "ADD", std::bind(broadcast::validate, OperationType::ADD, _1),
685 broadcast::prepare, broadcast::executeAdd, .allowZeroSizedInput = true);
686 NN_REGISTER_OPERATION(MUL, "MUL", std::bind(broadcast::validate, OperationType::MUL, _1),
687 broadcast::prepare, broadcast::executeMul, .allowZeroSizedInput = true);
688 NN_REGISTER_OPERATION(SUB, "SUB", std::bind(broadcast::validate, OperationType::SUB, _1),
689 broadcast::prepare, broadcast::executeSub, .allowZeroSizedInput = true);
690 NN_REGISTER_OPERATION(DIV, "DIV", std::bind(broadcast::validate, OperationType::DIV, _1),
691 broadcast::prepare, broadcast::executeDiv, .allowZeroSizedInput = true);
692
693 } // namespace nn
694 } // namespace android
695