Lines Matching refs:context

65 bool validate(const IOperationValidationContext* context) {  in validate()  argument
66 NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs); in validate()
67 NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs); in validate()
68 OperandType inputType = context->getInputType(kInputTensor); in validate()
75 NN_RET_CHECK(validateInputTypes(context, in validate()
77 NN_RET_CHECK(validateOutputTypes(context, {inputType})); in validate()
79 return validateHalVersion(context, HalVersion::V1_3); in validate()
81 return validateHalVersion(context, HalVersion::V1_2); in validate()
85 bool prepare(IOperationExecutionContext* context) { in prepare() argument
86 Shape input = context->getInputShape(kInputTensor); in prepare()
87 int32_t axis = context->getInputValue<int32_t>(kInputAxis); in prepare()
89 Shape indices = context->getInputShape(kInputIndices); in prepare()
90 Shape output = context->getOutputShape(kOutputTensor); in prepare()
101 return context->setOutputShape(kOutputTensor, output); in prepare()
104 bool execute(IOperationExecutionContext* context) { in execute() argument
105 int32_t axis = context->getInputValue<int32_t>(kInputAxis); in execute()
106 NN_RET_CHECK(handleNegativeAxis(context->getInputShape(kInputTensor), &axis)); in execute()
107 switch (context->getInputType(kInputTensor)) { in execute()
109 return eval(context->getInputBuffer<_Float16>(kInputTensor), in execute()
110 context->getInputShape(kInputTensor), axis, in execute()
111 context->getInputBuffer<int32_t>(kInputIndices), in execute()
112 context->getInputShape(kInputIndices), in execute()
113 context->getOutputBuffer<_Float16>(kOutputTensor)); in execute()
115 return eval(context->getInputBuffer<float>(kInputTensor), in execute()
116 context->getInputShape(kInputTensor), axis, in execute()
117 context->getInputBuffer<int32_t>(kInputIndices), in execute()
118 context->getInputShape(kInputIndices), in execute()
119 context->getOutputBuffer<float>(kOutputTensor)); in execute()
121 return eval(context->getInputBuffer<int32_t>(kInputTensor), in execute()
122 context->getInputShape(kInputTensor), axis, in execute()
123 context->getInputBuffer<int32_t>(kInputIndices), in execute()
124 context->getInputShape(kInputIndices), in execute()
125 context->getOutputBuffer<int32_t>(kOutputTensor)); in execute()
127 return eval(context->getInputBuffer<uint8_t>(kInputTensor), in execute()
128 context->getInputShape(kInputTensor), axis, in execute()
129 context->getInputBuffer<int32_t>(kInputIndices), in execute()
130 context->getInputShape(kInputIndices), in execute()
131 context->getOutputBuffer<uint8_t>(kOutputTensor)); in execute()
133 return eval(context->getInputBuffer<int8_t>(kInputTensor), in execute()
134 context->getInputShape(kInputTensor), axis, in execute()
135 context->getInputBuffer<int32_t>(kInputIndices), in execute()
136 context->getInputShape(kInputIndices), in execute()
137 context->getOutputBuffer<int8_t>(kOutputTensor)); in execute()