Lines Matching refs:context
52 inline bool compute(IOperationExecutionContext* context, T init, T func(T, T)) { in compute() argument
53 const Shape inputShape = context->getInputShape(kInputTensor); in compute()
54 const Shape axesShape = context->getInputShape(kInputAxes); in compute()
55 const Shape outputShape = context->getOutputShape(kOutputTensor); in compute()
61 context->getInputBuffer<T>(kInputTensor), in compute()
63 context->getOutputBuffer<T>(kOutputTensor), in compute()
65 outputShape.dimensions.size(), context->getInputBuffer<int32_t>(kInputAxes), numAxes, in compute()
66 context->getInputValue<bool8>(kInputKeepDims), tempIndex.data(), tempAxes.data(), init, in compute()
72 bool validateProdSum(const IOperationValidationContext* context) { in validateProdSum() argument
73 NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs); in validateProdSum()
74 NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs); in validateProdSum()
75 OperandType inputType = context->getInputType(kInputTensor); in validateProdSum()
80 validateInputTypes(context, {inputType, OperandType::TENSOR_INT32, OperandType::BOOL})); in validateProdSum()
81 NN_RET_CHECK(validateOutputTypes(context, {inputType})); in validateProdSum()
82 const Shape& input = context->getInputShape(kInputTensor); in validateProdSum()
86 return validateHalVersion(context, HalVersion::V1_2); in validateProdSum()
89 bool validateMaxMin(const IOperationValidationContext* context) { in validateMaxMin() argument
90 NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs); in validateMaxMin()
91 NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs); in validateMaxMin()
92 OperandType inputType = context->getInputType(kInputTensor); in validateMaxMin()
99 validateInputTypes(context, {inputType, OperandType::TENSOR_INT32, OperandType::BOOL})); in validateMaxMin()
100 NN_RET_CHECK(validateOutputTypes(context, {inputType})); in validateMaxMin()
105 const Shape& input = context->getInputShape(kInputTensor); in validateMaxMin()
109 return validateHalVersion(context, minHalVersion); in validateMaxMin()
112 bool validateLogical(const IOperationValidationContext* context) { in validateLogical() argument
113 NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs); in validateLogical()
114 NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs); in validateLogical()
115 OperandType inputType = context->getInputType(kInputTensor); in validateLogical()
119 validateInputTypes(context, {inputType, OperandType::TENSOR_INT32, OperandType::BOOL})); in validateLogical()
120 NN_RET_CHECK(validateOutputTypes(context, {inputType})); in validateLogical()
121 const Shape& input = context->getInputShape(kInputTensor); in validateLogical()
125 return validateHalVersion(context, HalVersion::V1_2); in validateLogical()
128 bool prepare(IOperationExecutionContext* context) { in prepare() argument
129 Shape inputShape = context->getInputShape(kInputTensor); in prepare()
134 const int32_t* axes = context->getInputBuffer<int32_t>(kInputAxes); in prepare()
135 Shape axesShape = context->getInputShape(kInputAxes); in prepare()
147 bool keepDims = context->getInputValue<bool8>(kInputKeepDims); in prepare()
163 return context->setOutputShape(kOutputTensor, outputShape); in prepare()
166 bool executeProd(IOperationExecutionContext* context) { in executeProd() argument
167 switch (context->getInputType(kInputTensor)) { in executeProd()
169 return compute<_Float16>(context, 1, [](_Float16 a, _Float16 b) -> _Float16 { in executeProd()
175 return compute<float>(context, 1, [](float a, float b) -> float { in executeProd()
185 bool executeSum(IOperationExecutionContext* context) { in executeSum() argument
186 switch (context->getInputType(kInputTensor)) { in executeSum()
188 return compute<_Float16>(context, 0, [](_Float16 a, _Float16 b) { return a + b; }); in executeSum()
190 return compute<float>(context, 0, [](float a, float b) { return a + b; }); in executeSum()
196 bool executeMax(IOperationExecutionContext* context) { in executeMax() argument
197 switch (context->getInputType(kInputTensor)) { in executeMax()
199 return compute<_Float16>(context, kFloat16Lowest, in executeMax()
202 return compute<float>(context, std::numeric_limits<float>::lowest(), in executeMax()
205 return compute<uint8_t>(context, std::numeric_limits<uint8_t>::lowest(), in executeMax()
208 return compute<int8_t>(context, std::numeric_limits<int8_t>::lowest(), in executeMax()
215 bool executeMin(IOperationExecutionContext* context) { in executeMin() argument
216 switch (context->getInputType(kInputTensor)) { in executeMin()
218 return compute<_Float16>(context, kFloat16Max, in executeMin()
221 return compute<float>(context, std::numeric_limits<float>::max(), in executeMin()
224 return compute<uint8_t>(context, std::numeric_limits<uint8_t>::max(), in executeMin()
227 return compute<int8_t>(context, std::numeric_limits<int8_t>::max(), in executeMin()
234 bool executeAny(IOperationExecutionContext* context) { in executeAny() argument
235 switch (context->getInputType(kInputTensor)) { in executeAny()
237 return compute<bool8>(context, false, in executeAny()
244 bool executeAll(IOperationExecutionContext* context) { in executeAll() argument
245 switch (context->getInputType(kInputTensor)) { in executeAll()
247 return compute<bool8>(context, true, in executeAll()