1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "Operations"
18 
19 #include <tensorflow/lite/kernels/internal/optimized/optimized_ops.h>
20 #include <tensorflow/lite/kernels/internal/reference/integer_ops/pooling.h>
21 
22 #include <vector>
23 
24 #include "CpuOperationUtils.h"
25 #include "HalInterfaces.h"
26 #include "OperationResolver.h"
27 #include "Tracing.h"
28 
29 namespace android {
30 namespace nn {
31 
32 using namespace hal;
33 
34 namespace pooling {
35 
36 constexpr uint32_t kInputTensor = 0;
37 
38 constexpr uint32_t kNumOutputs = 1;
39 constexpr uint32_t kOutputTensor = 0;
40 
41 namespace {
42 
43 struct PoolingParam {
44     int32_t padding_left, padding_right;
45     int32_t padding_top, padding_bottom;
46     int32_t stride_width, stride_height;
47     int32_t filter_width, filter_height;
48     int32_t activation;
49     bool useNchw = false;
50 
initializeandroid::nn::pooling::__anon9be044b60111::PoolingParam51     bool initialize(const IOperationExecutionContext* context) {
52         uint32_t inCount = context->getNumInputs();
53         int32_t padding_implicit = 0;
54         if (inCount >= 10) {
55             padding_left = context->getInputValue<int32_t>(1);
56             padding_right = context->getInputValue<int32_t>(2);
57             padding_top = context->getInputValue<int32_t>(3);
58             padding_bottom = context->getInputValue<int32_t>(4);
59             stride_width = context->getInputValue<int32_t>(5);
60             stride_height = context->getInputValue<int32_t>(6);
61             filter_width = context->getInputValue<int32_t>(7);
62             filter_height = context->getInputValue<int32_t>(8);
63             activation = context->getInputValue<int32_t>(9);
64             if (inCount == 11) {
65                 useNchw = context->getInputValue<bool>(10);
66             }
67         } else {
68             padding_implicit = context->getInputValue<int32_t>(1);
69             stride_width = context->getInputValue<int32_t>(2);
70             stride_height = context->getInputValue<int32_t>(3);
71             filter_width = context->getInputValue<int32_t>(4);
72             filter_height = context->getInputValue<int32_t>(5);
73             activation = context->getInputValue<int32_t>(6);
74             if (inCount == 8) {
75                 useNchw = context->getInputValue<bool>(7);
76             }
77         }
78         if (inCount <= 8) {
79             Shape inputShape = context->getInputShape(kInputTensor);
80             int32_t input_height = getSizeOfDimension(inputShape, useNchw ? 2 : 1);
81             int32_t input_width = getSizeOfDimension(inputShape, useNchw ? 3 : 2);
82             calculateExplicitPadding(input_width, stride_width, filter_width, padding_implicit,
83                                      &padding_left, &padding_right);
84             calculateExplicitPadding(input_height, stride_height, filter_height, padding_implicit,
85                                      &padding_top, &padding_bottom);
86         }
87         NN_RET_CHECK_GE(padding_left, 0);
88         NN_RET_CHECK_GE(padding_right, 0);
89         NN_RET_CHECK_GE(padding_top, 0);
90         NN_RET_CHECK_GE(padding_bottom, 0);
91         NN_RET_CHECK_GT(stride_width, 0);
92         NN_RET_CHECK_GT(stride_height, 0);
93         NN_RET_CHECK_GT(filter_width, 0);
94         NN_RET_CHECK_GT(filter_height, 0);
95         NN_RET_CHECK_GE(activation, 0);
96         NN_RET_CHECK_GT(filter_width, padding_left);
97         NN_RET_CHECK_GT(filter_width, padding_right);
98         NN_RET_CHECK_GT(filter_height, padding_top);
99         NN_RET_CHECK_GT(filter_height, padding_bottom);
100         return true;
101     }
102 
toTfliteParamandroid::nn::pooling::__anon9be044b60111::PoolingParam103     tflite::PoolParams toTfliteParam(const Shape& output) const {
104         tflite::PoolParams params = {
105                 .padding_values = {.width = static_cast<int16_t>(padding_left),
106                                    .height = static_cast<int16_t>(padding_top),
107                                    .width_offset = 0,
108                                    .height_offset = 0},
109                 .stride_height = stride_height,
110                 .stride_width = stride_width,
111                 .filter_height = filter_height,
112                 .filter_width = filter_width,
113         };
114         if (output.type == OperandType::TENSOR_QUANT8_ASYMM) {
115             int32_t output_activation_min = 0;
116             int32_t output_activation_max = 0;
117             CalculateActivationRangeUint8(activation, output, &output_activation_min,
118                                           &output_activation_max);
119             params.quantized_activation_min = output_activation_min;
120             params.quantized_activation_max = output_activation_max;
121         } else if (output.type == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
122             int32_t output_activation_min = 0;
123             int32_t output_activation_max = 0;
124             CalculateActivationRangeInt8(activation, output, &output_activation_min,
125                                          &output_activation_max);
126             params.quantized_activation_min = output_activation_min;
127             params.quantized_activation_max = output_activation_max;
128         } else {
129             float output_activation_min, output_activation_max;
130             CalculateActivationRangeFloat(activation, &output_activation_min,
131                                           &output_activation_max);
132             params.float_activation_min = output_activation_min;
133             params.float_activation_max = output_activation_max;
134         }
135         return params;
136     }
137 };
138 
averagePoolNhwc(const float * inputData,const Shape & inputShape,const PoolingParam & param,float * outputData,const Shape & outputShape)139 bool averagePoolNhwc(const float* inputData, const Shape& inputShape, const PoolingParam& param,
140                      float* outputData, const Shape& outputShape) {
141     NNTRACE_TRANS("averagePoolFloat32");
142     auto op_params = param.toTfliteParam(outputShape);
143     NNTRACE_COMP_SWITCH("optimized_ops::AveragePool");
144     tflite::optimized_ops::AveragePool(op_params, convertShapeToTflshape(inputShape), inputData,
145                                        convertShapeToTflshape(outputShape), outputData);
146     return true;
147 }
148 
averagePoolNhwc(const _Float16 * inputData,const Shape & inputShape,const PoolingParam & param,_Float16 * outputData,const Shape & outputShape)149 bool averagePoolNhwc(const _Float16* inputData, const Shape& inputShape, const PoolingParam& param,
150                      _Float16* outputData, const Shape& outputShape) {
151     NNTRACE_TRANS("averagePoolFloat16");
152     std::vector<float> inputDataFloat32(getNumberOfElements(inputShape));
153     std::vector<float> outputDataFloat32(getNumberOfElements(outputShape));
154 
155     convertFloat16ToFloat32(inputData, &inputDataFloat32);
156     averagePoolNhwc(inputDataFloat32.data(), inputShape, param, outputDataFloat32.data(),
157                     outputShape);
158     convertFloat32ToFloat16(outputDataFloat32, outputData);
159     return true;
160 }
161 
averagePoolNhwc(const uint8_t * inputData,const Shape & inputShape,const PoolingParam & param,uint8_t * outputData,const Shape & outputShape)162 bool averagePoolNhwc(const uint8_t* inputData, const Shape& inputShape, const PoolingParam& param,
163                      uint8_t* outputData, const Shape& outputShape) {
164     NNTRACE_TRANS("averagePoolQuant8");
165     auto op_params = param.toTfliteParam(outputShape);
166     NNTRACE_COMP_SWITCH("optimized_ops::AveragePool");
167     tflite::optimized_ops::AveragePool(op_params, convertShapeToTflshape(inputShape), inputData,
168                                        convertShapeToTflshape(outputShape), outputData);
169     return true;
170 }
171 
averagePoolNhwc(const int8_t * inputData,const Shape & inputShape,const PoolingParam & param,int8_t * outputData,const Shape & outputShape)172 bool averagePoolNhwc(const int8_t* inputData, const Shape& inputShape, const PoolingParam& param,
173                      int8_t* outputData, const Shape& outputShape) {
174     NNTRACE_TRANS("averagePoolQuant8Signed");
175     auto op_params = param.toTfliteParam(outputShape);
176     NNTRACE_COMP_SWITCH("optimized_integer_ops::AveragePool");
177     // We are using reference implementation of the AveragePool op because the
178     // optimized version fails to pass some of the quantization coupling tests.
179     tflite::reference_integer_ops::AveragePool(op_params, convertShapeToTflshape(inputShape),
180                                                inputData, convertShapeToTflshape(outputShape),
181                                                outputData);
182     return true;
183 }
184 
l2PoolNhwc(const float * inputData,const Shape & inputShape,const PoolingParam & param,float * outputData,const Shape & outputShape)185 bool l2PoolNhwc(const float* inputData, const Shape& inputShape, const PoolingParam& param,
186                 float* outputData, const Shape& outputShape) {
187     NNTRACE_TRANS("l2PoolFloat32");
188     auto op_params = param.toTfliteParam(outputShape);
189     NNTRACE_COMP_SWITCH("optimized_ops::L2Pool");
190     tflite::optimized_ops::L2Pool(op_params, convertShapeToTflshape(inputShape), inputData,
191                                   convertShapeToTflshape(outputShape), outputData);
192     return true;
193 }
194 
l2PoolNhwc(const _Float16 * inputData,const Shape & inputShape,const PoolingParam & param,_Float16 * outputData,const Shape & outputShape)195 bool l2PoolNhwc(const _Float16* inputData, const Shape& inputShape, const PoolingParam& param,
196                 _Float16* outputData, const Shape& outputShape) {
197     NNTRACE_TRANS("l2PoolFloat16");
198     std::vector<float> inputDataFloat32(getNumberOfElements(inputShape));
199     std::vector<float> outputDataFloat32(getNumberOfElements(outputShape));
200 
201     convertFloat16ToFloat32(inputData, &inputDataFloat32);
202     l2PoolNhwc(inputDataFloat32.data(), inputShape, param, outputDataFloat32.data(), outputShape);
203     convertFloat32ToFloat16(outputDataFloat32, outputData);
204     return true;
205 }
206 
maxPoolNhwc(const float * inputData,const Shape & inputShape,const PoolingParam & param,float * outputData,const Shape & outputShape)207 bool maxPoolNhwc(const float* inputData, const Shape& inputShape, const PoolingParam& param,
208                  float* outputData, const Shape& outputShape) {
209     NNTRACE_TRANS("maxPoolFloat32");
210     auto op_params = param.toTfliteParam(outputShape);
211     NNTRACE_COMP_SWITCH("optimized_ops::MaxPool");
212     tflite::optimized_ops::MaxPool(op_params, convertShapeToTflshape(inputShape), inputData,
213                                    convertShapeToTflshape(outputShape), outputData);
214     return true;
215 }
216 
maxPoolNhwc(const uint8_t * inputData,const Shape & inputShape,const PoolingParam & param,uint8_t * outputData,const Shape & outputShape)217 bool maxPoolNhwc(const uint8_t* inputData, const Shape& inputShape, const PoolingParam& param,
218                  uint8_t* outputData, const Shape& outputShape) {
219     NNTRACE_TRANS("maxPoolQuant8");
220     auto op_params = param.toTfliteParam(outputShape);
221     NNTRACE_COMP_SWITCH("optimized_ops::MaxPool");
222     tflite::optimized_ops::MaxPool(op_params, convertShapeToTflshape(inputShape), inputData,
223                                    convertShapeToTflshape(outputShape), outputData);
224     return true;
225 }
226 
maxPoolNhwc(const int8_t * inputData,const Shape & inputShape,const PoolingParam & param,int8_t * outputData,const Shape & outputShape)227 bool maxPoolNhwc(const int8_t* inputData, const Shape& inputShape, const PoolingParam& param,
228                  int8_t* outputData, const Shape& outputShape) {
229     NNTRACE_TRANS("maxPoolQuant8Signed");
230     auto op_params = param.toTfliteParam(outputShape);
231     NNTRACE_COMP_SWITCH("optimized_integer_ops::MaxPool");
232     // We are using reference implementation of the MaxPool op because the
233     // optimized version fails to pass some of the quantization coupling tests.
234     tflite::reference_integer_ops::MaxPool(op_params, convertShapeToTflshape(inputShape), inputData,
235                                            convertShapeToTflshape(outputShape), outputData);
236     return true;
237 }
238 
maxPoolNhwc(const _Float16 * inputData,const Shape & inputShape,const PoolingParam & param,_Float16 * outputData,const Shape & outputShape)239 bool maxPoolNhwc(const _Float16* inputData, const Shape& inputShape, const PoolingParam& param,
240                  _Float16* outputData, const Shape& outputShape) {
241     NNTRACE_TRANS("maxPoolFloat16");
242     std::vector<float> inputData_float32(getNumberOfElements(inputShape));
243     std::vector<float> outputData_float32(getNumberOfElements(outputShape));
244 
245     convertFloat16ToFloat32(inputData, &inputData_float32);
246     maxPoolNhwc(inputData_float32.data(), inputShape, param, outputData_float32.data(),
247                 outputShape);
248     convertFloat32ToFloat16(outputData_float32, outputData);
249     return true;
250 }
251 
252 template <typename T>
averagePool(const T * inputData,const Shape & inputShape,const PoolingParam & param,T * outputData,const Shape & outputShape)253 bool averagePool(const T* inputData, const Shape& inputShape, const PoolingParam& param,
254                  T* outputData, const Shape& outputShape) {
255     InputWithLayout<T> input(param.useNchw);
256     OutputWithLayout<T> output(param.useNchw);
257     NN_RET_CHECK(input.initialize(inputData, inputShape));
258     NN_RET_CHECK(output.initialize(outputData, outputShape));
259     NN_RET_CHECK(averagePoolNhwc(input.getNhwcBuffer(), input.getNhwcShape(), param,
260                                  output.getNhwcBuffer(), output.getNhwcShape()));
261     NN_RET_CHECK(output.commit());
262     return true;
263 }
264 
265 template <typename T>
l2Pool(const T * inputData,const Shape & inputShape,const PoolingParam & param,T * outputData,const Shape & outputShape)266 bool l2Pool(const T* inputData, const Shape& inputShape, const PoolingParam& param, T* outputData,
267             const Shape& outputShape) {
268     InputWithLayout<T> input(param.useNchw);
269     OutputWithLayout<T> output(param.useNchw);
270     NN_RET_CHECK(input.initialize(inputData, inputShape));
271     NN_RET_CHECK(output.initialize(outputData, outputShape));
272     NN_RET_CHECK(l2PoolNhwc(input.getNhwcBuffer(), input.getNhwcShape(), param,
273                             output.getNhwcBuffer(), output.getNhwcShape()));
274     NN_RET_CHECK(output.commit());
275     return true;
276 }
277 
278 template <typename T>
maxPool(const T * inputData,const Shape & inputShape,const PoolingParam & param,T * outputData,const Shape & outputShape)279 bool maxPool(const T* inputData, const Shape& inputShape, const PoolingParam& param, T* outputData,
280              const Shape& outputShape) {
281     InputWithLayout<T> input(param.useNchw);
282     OutputWithLayout<T> output(param.useNchw);
283     NN_RET_CHECK(input.initialize(inputData, inputShape));
284     NN_RET_CHECK(output.initialize(outputData, outputShape));
285     NN_RET_CHECK(maxPoolNhwc(input.getNhwcBuffer(), input.getNhwcShape(), param,
286                              output.getNhwcBuffer(), output.getNhwcShape()));
287     NN_RET_CHECK(output.commit());
288     return true;
289 }
290 
291 }  // namespace
292 
validate(OperationType opType,const IOperationValidationContext * context)293 bool validate(OperationType opType, const IOperationValidationContext* context) {
294     NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
295     auto inputCount = context->getNumInputs();
296     NN_RET_CHECK(inputCount == 11 || inputCount == 10 || inputCount == 8 || inputCount == 7);
297     auto inputType = context->getInputType(kInputTensor);
298     std::vector<OperandType> inExpectedTypes;
299     if (inputType == OperandType::TENSOR_FLOAT32) {
300         NN_RET_CHECK(validateHalVersion(context, HalVersion::V1_0));
301         inExpectedTypes = {
302                 inputType,          OperandType::INT32, OperandType::INT32, OperandType::INT32,
303                 OperandType::INT32, OperandType::INT32, OperandType::INT32,
304         };
305     } else if (inputType == OperandType::TENSOR_FLOAT16) {
306         NN_RET_CHECK(validateHalVersion(context, HalVersion::V1_2));
307         inExpectedTypes = {
308                 OperandType::TENSOR_FLOAT16, OperandType::INT32, OperandType::INT32,
309                 OperandType::INT32,          OperandType::INT32, OperandType::INT32,
310                 OperandType::INT32,
311         };
312     } else if (opType != OperationType::L2_POOL_2D &&
313                inputType == OperandType::TENSOR_QUANT8_ASYMM) {
314         NN_RET_CHECK(validateHalVersion(context, HalVersion::V1_0));
315         inExpectedTypes = {
316                 OperandType::TENSOR_QUANT8_ASYMM,
317                 OperandType::INT32,
318                 OperandType::INT32,
319                 OperandType::INT32,
320                 OperandType::INT32,
321                 OperandType::INT32,
322                 OperandType::INT32,
323         };
324     } else if (opType != OperationType::L2_POOL_2D &&
325                inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
326         NN_RET_CHECK(validateHalVersion(context, HalVersion::V1_3));
327         inExpectedTypes = {
328                 OperandType::TENSOR_QUANT8_ASYMM_SIGNED,
329                 OperandType::INT32,
330                 OperandType::INT32,
331                 OperandType::INT32,
332                 OperandType::INT32,
333                 OperandType::INT32,
334                 OperandType::INT32,
335         };
336     } else {
337         NN_RET_CHECK_FAIL() << "Unsupported input tensor type for operation "
338                             << getOperationName(opType);
339     }
340 
341     if (inputCount >= 10) {
342         std::vector<OperandType> explicitScalarTypes(3, OperandType::INT32);
343         inExpectedTypes.insert(inExpectedTypes.end(), explicitScalarTypes.begin(),
344                                explicitScalarTypes.end());
345     }
346     if (inputCount == 11 || inputCount == 8) {
347         inExpectedTypes.push_back(OperandType::BOOL);
348         NN_RET_CHECK(validateHalVersion(context, HalVersion::V1_2));
349     } else {
350         NN_RET_CHECK(validateHalVersion(context, HalVersion::V1_0));
351     }
352     return validateInputTypes(context, inExpectedTypes) &&
353            validateOutputTypes(context, {inputType});
354 }
355 
prepare(IOperationExecutionContext * context)356 bool prepare(IOperationExecutionContext* context) {
357     Shape input = context->getInputShape(kInputTensor);
358     NN_RET_CHECK_EQ(getNumberOfDimensions(input), 4);
359 
360     PoolingParam param;
361     NN_RET_CHECK(param.initialize(context));
362 
363     // Only batches can be zero.
364     uint32_t batches = getSizeOfDimension(input, 0);
365     uint32_t height = getSizeOfDimension(input, param.useNchw ? 2 : 1);
366     uint32_t width = getSizeOfDimension(input, param.useNchw ? 3 : 2);
367     uint32_t channels = getSizeOfDimension(input, param.useNchw ? 1 : 3);
368     NN_RET_CHECK_GT(height, 0);
369     NN_RET_CHECK_GT(width, 0);
370     NN_RET_CHECK_GT(channels, 0);
371 
372     uint32_t outWidth = computeOutSize(width, param.filter_width, param.stride_width,
373                                        param.padding_left, param.padding_right);
374     uint32_t outHeight = computeOutSize(height, param.filter_height, param.stride_height,
375                                         param.padding_top, param.padding_bottom);
376 
377     Shape output = input;
378     if (param.useNchw) {
379         output.dimensions = {batches, channels, outHeight, outWidth};
380     } else {
381         output.dimensions = {batches, outHeight, outWidth, channels};
382     }
383     return context->setOutputShape(kOutputTensor, output);
384 }
385 
386 #define POOLING_DISPATCH_INPUT_TYPE(name, type, cppType)              \
387     case OperandType::type:                                           \
388         return name(context->getInputBuffer<cppType>(kInputTensor),   \
389                     context->getInputShape(kInputTensor), param,      \
390                     context->getOutputBuffer<cppType>(kOutputTensor), \
391                     context->getOutputShape(kOutputTensor))
392 
executeAveragePool(IOperationExecutionContext * context)393 bool executeAveragePool(IOperationExecutionContext* context) {
394     // Bypass execution in the case of zero-sized input.
395     if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
396     PoolingParam param;
397     NN_RET_CHECK(param.initialize(context));
398     switch (context->getInputType(kInputTensor)) {
399         POOLING_DISPATCH_INPUT_TYPE(averagePool, TENSOR_FLOAT32, float);
400         POOLING_DISPATCH_INPUT_TYPE(averagePool, TENSOR_FLOAT16, _Float16);
401         POOLING_DISPATCH_INPUT_TYPE(averagePool, TENSOR_QUANT8_ASYMM, uint8_t);
402         POOLING_DISPATCH_INPUT_TYPE(averagePool, TENSOR_QUANT8_ASYMM_SIGNED, int8_t);
403         default:
404             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation AVERAGE_POOL_2D";
405     }
406 }
407 
executeL2Pool(IOperationExecutionContext * context)408 bool executeL2Pool(IOperationExecutionContext* context) {
409     // Bypass execution in the case of zero-sized input.
410     if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
411     PoolingParam param;
412     NN_RET_CHECK(param.initialize(context));
413     switch (context->getInputType(kInputTensor)) {
414         POOLING_DISPATCH_INPUT_TYPE(l2Pool, TENSOR_FLOAT32, float);
415         POOLING_DISPATCH_INPUT_TYPE(l2Pool, TENSOR_FLOAT16, _Float16);
416         default:
417             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation L2_POOL_2D";
418     }
419 }
420 
executeMaxPool(IOperationExecutionContext * context)421 bool executeMaxPool(IOperationExecutionContext* context) {
422     // Bypass execution in the case of zero-sized input.
423     if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
424     PoolingParam param;
425     NN_RET_CHECK(param.initialize(context));
426     switch (context->getInputType(kInputTensor)) {
427         POOLING_DISPATCH_INPUT_TYPE(maxPool, TENSOR_FLOAT32, float);
428         POOLING_DISPATCH_INPUT_TYPE(maxPool, TENSOR_FLOAT16, _Float16);
429         POOLING_DISPATCH_INPUT_TYPE(maxPool, TENSOR_QUANT8_ASYMM, uint8_t);
430         POOLING_DISPATCH_INPUT_TYPE(maxPool, TENSOR_QUANT8_ASYMM_SIGNED, int8_t);
431         default:
432             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation MAX_POOL_2D";
433     }
434 }
435 
436 #undef POOLING_DISPATCH_INPUT_TYPE
437 
438 }  // namespace pooling
439 
440 using std::placeholders::_1;
441 NN_REGISTER_OPERATION(AVERAGE_POOL_2D, "AVERAGE_POOL_2D",
442                       std::bind(pooling::validate, OperationType::AVERAGE_POOL_2D, _1),
443                       pooling::prepare, pooling::executeAveragePool, .allowZeroSizedInput = true);
444 NN_REGISTER_OPERATION(L2_POOL_2D, "L2_POOL_2D",
445                       std::bind(pooling::validate, OperationType::L2_POOL_2D, _1), pooling::prepare,
446                       pooling::executeL2Pool, .allowZeroSizedInput = true);
447 NN_REGISTER_OPERATION(MAX_POOL_2D, "MAX_POOL_2D",
448                       std::bind(pooling::validate, OperationType::MAX_POOL_2D, _1),
449                       pooling::prepare, pooling::executeMaxPool, .allowZeroSizedInput = true);
450 
451 }  // namespace nn
452 }  // namespace android
453