Lines Matching refs:context

63 bool validate(const IOperationValidationContext* context) {  in validate()  argument
64 NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs); in validate()
65 NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs); in validate()
66 auto inputType = context->getInputType(kInputTensor); in validate()
72 const Shape& inputShape = context->getInputShape(kInputTensor); in validate()
76 NN_RET_CHECK(validateInputTypes(context, {inputType, OperandType::INT32, OperandType::INT32})); in validate()
77 NN_RET_CHECK(validateOutputTypes(context, {inputType})); in validate()
79 return validateHalVersion(context, HalVersion::V1_3); in validate()
81 return validateHalVersion(context, HalVersion::V1_2); in validate()
85 bool prepare(IOperationExecutionContext* context) { in prepare() argument
86 Shape input = context->getInputShape(kInputTensor); in prepare()
87 int32_t numGroups = context->getInputValue<int32_t>(kNumGroups); in prepare()
88 int32_t axis = context->getInputValue<int32_t>(kInputAxis); in prepare()
92 return context->setOutputShape(kOutputTensor, input); in prepare()
95 bool execute(IOperationExecutionContext* context) { in execute() argument
96 int32_t numGroups = context->getInputValue<int32_t>(kNumGroups); in execute()
97 int32_t axis = context->getInputValue<int32_t>(kInputAxis); in execute()
98 NN_RET_CHECK(handleNegativeAxis(context->getInputShape(kInputTensor), &axis)); in execute()
99 switch (context->getInputType(kInputTensor)) { in execute()
101 return eval(context->getInputBuffer<_Float16>(kInputTensor), in execute()
102 context->getInputShape(kInputTensor), numGroups, axis, in execute()
103 context->getOutputBuffer<_Float16>(kOutputTensor)); in execute()
105 return eval(context->getInputBuffer<float>(kInputTensor), in execute()
106 context->getInputShape(kInputTensor), numGroups, axis, in execute()
107 context->getOutputBuffer<float>(kOutputTensor)); in execute()
109 return eval(context->getInputBuffer<uint8_t>(kInputTensor), in execute()
110 context->getInputShape(kInputTensor), numGroups, axis, in execute()
111 context->getOutputBuffer<uint8_t>(kOutputTensor)); in execute()
113 return eval(context->getInputBuffer<int8_t>(kInputTensor), in execute()
114 context->getInputShape(kInputTensor), numGroups, axis, in execute()
115 context->getOutputBuffer<int8_t>(kOutputTensor)); in execute()