1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "Operations"
18 
19 #include <tensorflow/lite/kernels/internal/reference/reference_ops.h>
20 
21 #include <algorithm>
22 #include <limits>
23 #include <vector>
24 
25 #include "HalInterfaces.h"
26 #include "OperationResolver.h"
27 #include "OperationsUtils.h"
28 #include "Tracing.h"
29 
30 namespace android {
31 namespace nn {
32 namespace reduce {
33 
34 constexpr uint32_t kNumInputs = 3;
35 constexpr uint32_t kInputTensor = 0;
36 constexpr uint32_t kInputAxes = 1;
37 constexpr uint32_t kInputKeepDims = 2;
38 
39 constexpr uint32_t kNumOutputs = 1;
40 constexpr uint32_t kOutputTensor = 0;
41 
42 // Values from
43 // https://en.wikipedia.org/wiki/Half-precision_floating-point_format#IEEE_754_half-precision_binary_floating-point_format:_binary16
44 constexpr _Float16 kFloat16Max = 65504;
45 constexpr _Float16 kFloat16Lowest = -kFloat16Max;
46 
47 namespace {
48 
49 using namespace hal;
50 
51 template <typename T>
compute(IOperationExecutionContext * context,T init,T func (T,T))52 inline bool compute(IOperationExecutionContext* context, T init, T func(T, T)) {
53     const Shape inputShape = context->getInputShape(kInputTensor);
54     const Shape axesShape = context->getInputShape(kInputAxes);
55     const Shape outputShape = context->getOutputShape(kOutputTensor);
56     const uint32_t inputRank = getNumberOfDimensions(inputShape);
57     const uint32_t numAxes = getNumberOfElements(axesShape);
58     std::vector<int> tempIndex(inputShape.dimensions.size());
59     std::vector<int> tempAxes(numAxes);
60     return tflite::reference_ops::ReduceGeneric<T>(
61             context->getInputBuffer<T>(kInputTensor),
62             reinterpret_cast<const int32_t*>(inputShape.dimensions.data()), inputRank,
63             context->getOutputBuffer<T>(kOutputTensor),
64             reinterpret_cast<const int32_t*>(outputShape.dimensions.data()),
65             outputShape.dimensions.size(), context->getInputBuffer<int32_t>(kInputAxes), numAxes,
66             context->getInputValue<bool8>(kInputKeepDims), tempIndex.data(), tempAxes.data(), init,
67             func);
68 }
69 
70 }  // namespace
71 
validateProdSum(const IOperationValidationContext * context)72 bool validateProdSum(const IOperationValidationContext* context) {
73     NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
74     NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
75     OperandType inputType = context->getInputType(kInputTensor);
76     NN_RET_CHECK(inputType == OperandType::TENSOR_FLOAT16 ||
77                  inputType == OperandType::TENSOR_FLOAT32)
78             << "Unsupported tensor type for REDUCE_PROD or REDUCE_SUM";
79     NN_RET_CHECK(
80             validateInputTypes(context, {inputType, OperandType::TENSOR_INT32, OperandType::BOOL}));
81     NN_RET_CHECK(validateOutputTypes(context, {inputType}));
82     const Shape& input = context->getInputShape(kInputTensor);
83     if (hasKnownRank(input)) {
84         NN_RET_CHECK_LE(getNumberOfDimensions(input), 4);
85     }
86     return validateHalVersion(context, HalVersion::V1_2);
87 }
88 
validateMaxMin(const IOperationValidationContext * context)89 bool validateMaxMin(const IOperationValidationContext* context) {
90     NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
91     NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
92     OperandType inputType = context->getInputType(kInputTensor);
93     NN_RET_CHECK(inputType == OperandType::TENSOR_FLOAT16 ||
94                  inputType == OperandType::TENSOR_FLOAT32 ||
95                  inputType == OperandType::TENSOR_QUANT8_ASYMM ||
96                  inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED)
97             << "Unsupported tensor type for REDUCE_MAX or REDUCE_MIN";
98     NN_RET_CHECK(
99             validateInputTypes(context, {inputType, OperandType::TENSOR_INT32, OperandType::BOOL}));
100     NN_RET_CHECK(validateOutputTypes(context, {inputType}));
101     auto minHalVersion = HalVersion::V1_2;
102     if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
103         minHalVersion = HalVersion::V1_3;
104     }
105     const Shape& input = context->getInputShape(kInputTensor);
106     if (hasKnownRank(input)) {
107         NN_RET_CHECK_LE(getNumberOfDimensions(input), 4);
108     }
109     return validateHalVersion(context, minHalVersion);
110 }
111 
validateLogical(const IOperationValidationContext * context)112 bool validateLogical(const IOperationValidationContext* context) {
113     NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
114     NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
115     OperandType inputType = context->getInputType(kInputTensor);
116     NN_RET_CHECK(inputType == OperandType::TENSOR_BOOL8)
117             << "Unsupported tensor type for REDUCE_ANY or REDUCE_ALL";
118     NN_RET_CHECK(
119             validateInputTypes(context, {inputType, OperandType::TENSOR_INT32, OperandType::BOOL}));
120     NN_RET_CHECK(validateOutputTypes(context, {inputType}));
121     const Shape& input = context->getInputShape(kInputTensor);
122     if (hasKnownRank(input)) {
123         NN_RET_CHECK_LE(getNumberOfDimensions(input), 4);
124     }
125     return validateHalVersion(context, HalVersion::V1_2);
126 }
127 
prepare(IOperationExecutionContext * context)128 bool prepare(IOperationExecutionContext* context) {
129     Shape inputShape = context->getInputShape(kInputTensor);
130     const uint32_t inputRank = getNumberOfDimensions(inputShape);
131     NN_RET_CHECK_LE(inputRank, 4);
132 
133     std::vector<bool> shouldReduce(inputRank);
134     const int32_t* axes = context->getInputBuffer<int32_t>(kInputAxes);
135     Shape axesShape = context->getInputShape(kInputAxes);
136     NN_RET_CHECK_EQ(getNumberOfDimensions(axesShape), 1u);
137     const uint32_t numAxes = getNumberOfElements(axesShape);
138     for (uint32_t i = 0; i < numAxes; ++i) {
139         int32_t axis = axes[i];
140         NN_RET_CHECK(handleNegativeAxis(inputRank, &axis));
141         shouldReduce[axis] = true;
142     }
143 
144     // Input and output must have the same quantization parameters, etc.
145     Shape outputShape = inputShape;
146     outputShape.dimensions.clear();
147     bool keepDims = context->getInputValue<bool8>(kInputKeepDims);
148     for (uint32_t axis = 0; axis < inputRank; ++axis) {
149         if (shouldReduce[axis]) {
150             if (keepDims) {
151                 outputShape.dimensions.push_back(1);
152             }
153         } else {
154             outputShape.dimensions.push_back(getSizeOfDimension(inputShape, axis));
155         }
156     }
157 
158     // Handle the case when all dimensions are removed
159     if (outputShape.dimensions.empty()) {
160         outputShape.dimensions.push_back(1);
161     }
162 
163     return context->setOutputShape(kOutputTensor, outputShape);
164 }
165 
executeProd(IOperationExecutionContext * context)166 bool executeProd(IOperationExecutionContext* context) {
167     switch (context->getInputType(kInputTensor)) {
168         case OperandType::TENSOR_FLOAT16:
169             return compute<_Float16>(context, 1, [](_Float16 a, _Float16 b) -> _Float16 {
170                 // Handle the zero case because 0 * inf evaluates to nan.
171                 if (a == 0 || b == 0) return 0;
172                 return a * b;
173             });
174         case OperandType::TENSOR_FLOAT32:
175             return compute<float>(context, 1, [](float a, float b) -> float {
176                 // Handle the zero case because 0 * inf evaluates to nan.
177                 if (a == 0 || b == 0) return 0;
178                 return a * b;
179             });
180         default:
181             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation REDUCE_PROD";
182     }
183 }
184 
executeSum(IOperationExecutionContext * context)185 bool executeSum(IOperationExecutionContext* context) {
186     switch (context->getInputType(kInputTensor)) {
187         case OperandType::TENSOR_FLOAT16:
188             return compute<_Float16>(context, 0, [](_Float16 a, _Float16 b) { return a + b; });
189         case OperandType::TENSOR_FLOAT32:
190             return compute<float>(context, 0, [](float a, float b) { return a + b; });
191         default:
192             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation REDUCE_SUM";
193     }
194 }
195 
executeMax(IOperationExecutionContext * context)196 bool executeMax(IOperationExecutionContext* context) {
197     switch (context->getInputType(kInputTensor)) {
198         case OperandType::TENSOR_FLOAT16:
199             return compute<_Float16>(context, kFloat16Lowest,
200                                      [](_Float16 a, _Float16 b) { return std::max(a, b); });
201         case OperandType::TENSOR_FLOAT32:
202             return compute<float>(context, std::numeric_limits<float>::lowest(),
203                                   [](float a, float b) { return std::max(a, b); });
204         case OperandType::TENSOR_QUANT8_ASYMM:
205             return compute<uint8_t>(context, std::numeric_limits<uint8_t>::lowest(),
206                                     [](uint8_t a, uint8_t b) { return std::max(a, b); });
207         case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
208             return compute<int8_t>(context, std::numeric_limits<int8_t>::lowest(),
209                                    [](int8_t a, int8_t b) { return std::max(a, b); });
210         default:
211             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation REDUCE_MAX";
212     }
213 }
214 
executeMin(IOperationExecutionContext * context)215 bool executeMin(IOperationExecutionContext* context) {
216     switch (context->getInputType(kInputTensor)) {
217         case OperandType::TENSOR_FLOAT16:
218             return compute<_Float16>(context, kFloat16Max,
219                                      [](_Float16 a, _Float16 b) { return std::min(a, b); });
220         case OperandType::TENSOR_FLOAT32:
221             return compute<float>(context, std::numeric_limits<float>::max(),
222                                   [](float a, float b) { return std::min(a, b); });
223         case OperandType::TENSOR_QUANT8_ASYMM:
224             return compute<uint8_t>(context, std::numeric_limits<uint8_t>::max(),
225                                     [](uint8_t a, uint8_t b) { return std::min(a, b); });
226         case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
227             return compute<int8_t>(context, std::numeric_limits<int8_t>::max(),
228                                    [](int8_t a, int8_t b) { return std::min(a, b); });
229         default:
230             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation REDUCE_MIN";
231     }
232 }
233 
executeAny(IOperationExecutionContext * context)234 bool executeAny(IOperationExecutionContext* context) {
235     switch (context->getInputType(kInputTensor)) {
236         case OperandType::TENSOR_BOOL8:
237             return compute<bool8>(context, false,
238                                   [](bool8 a, bool8 b) { return static_cast<bool8>(a || b); });
239         default:
240             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation REDUCE_ANY";
241     }
242 }
243 
executeAll(IOperationExecutionContext * context)244 bool executeAll(IOperationExecutionContext* context) {
245     switch (context->getInputType(kInputTensor)) {
246         case OperandType::TENSOR_BOOL8:
247             return compute<bool8>(context, true,
248                                   [](bool8 a, bool8 b) { return static_cast<bool8>(a && b); });
249         default:
250             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation REDUCE_ALL";
251     }
252 }
253 
254 }  // namespace reduce
255 
256 NN_REGISTER_OPERATION(REDUCE_PROD, "REDUCE_PROD", reduce::validateProdSum, reduce::prepare,
257                       reduce::executeProd);
258 NN_REGISTER_OPERATION(REDUCE_SUM, "REDUCE_SUM", reduce::validateProdSum, reduce::prepare,
259                       reduce::executeSum);
260 NN_REGISTER_OPERATION(REDUCE_MAX, "REDUCE_MAX", reduce::validateMaxMin, reduce::prepare,
261                       reduce::executeMax);
262 NN_REGISTER_OPERATION(REDUCE_MIN, "REDUCE_MIN", reduce::validateMaxMin, reduce::prepare,
263                       reduce::executeMin);
264 NN_REGISTER_OPERATION(REDUCE_ANY, "REDUCE_ANY", reduce::validateLogical, reduce::prepare,
265                       reduce::executeAny);
266 NN_REGISTER_OPERATION(REDUCE_ALL, "REDUCE_ALL", reduce::validateLogical, reduce::prepare,
267                       reduce::executeAll);
268 
269 }  // namespace nn
270 }  // namespace android
271