Lines Matching refs:context

98 inline bool hasTensor(IOperationExecutionContext* context, const uint32_t tensor) {  in hasTensor()  argument
99 return context->getInputBuffer(tensor) != nullptr; in hasTensor()
106 bool validate(const IOperationValidationContext* context) { in validate() argument
107 NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs); in validate()
108 NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs); in validate()
143 NN_RET_CHECK(validateInputTypes(context, inExpectedTypes)); in validate()
152 NN_RET_CHECK(validateOutputTypes(context, outExpectedTypes)); in validate()
154 return validateHalVersion(context, HalVersion::V1_3); in validate()
157 bool prepare(IOperationExecutionContext* context) { in prepare() argument
174 NN_RET_CHECK(!context->isOmittedInput(tensor)) in prepare()
178 const Shape inputShape = context->getInputShape(kInputTensor); in prepare()
185 const Shape inputToOutputShape = context->getInputShape(kInputToOutputWeightsTensor); in prepare()
190 const Shape recurrentToOutputShape = context->getInputShape(kRecurrentToOutputWeightsTensor); in prepare()
195 if (hasTensor(context, kInputToInputWeightsTensor)) { in prepare()
196 const Shape inputToInputShape = context->getInputShape(kInputToInputWeightsTensor); in prepare()
202 const Shape inputToForgetShape = context->getInputShape(kInputToForgetWeightsTensor); in prepare()
206 const Shape inputToCellShape = context->getInputShape(kInputToCellWeightsTensor); in prepare()
211 if (hasTensor(context, kRecurrentToInputWeightsTensor)) { in prepare()
212 const Shape recurrentToInputShape = context->getInputShape(kRecurrentToInputWeightsTensor); in prepare()
218 const Shape recurrentToForgetShape = context->getInputShape(kRecurrentToForgetWeightsTensor); in prepare()
222 const Shape recurrentToCellShape = context->getInputShape(kRecurrentToCellWeightsTensor); in prepare()
229 const bool cifgWeightsAllOrNone = (hasTensor(context, kInputToInputWeightsTensor) && in prepare()
230 hasTensor(context, kRecurrentToInputWeightsTensor)) || in prepare()
231 (!hasTensor(context, kInputToInputWeightsTensor) && in prepare()
232 !hasTensor(context, kRecurrentToInputWeightsTensor)); in prepare()
235 if (hasTensor(context, kCellToInputWeightsTensor)) { in prepare()
236 const Shape cellToInputShape = context->getInputShape(kCellToInputWeightsTensor); in prepare()
241 if (hasTensor(context, kCellToForgetWeightsTensor)) { in prepare()
242 const Shape cellToForgetShape = context->getInputShape(kCellToForgetWeightsTensor); in prepare()
247 if (hasTensor(context, kCellToOutputWeightsTensor)) { in prepare()
248 const Shape cellToOutputShape = context->getInputShape(kCellToOutputWeightsTensor); in prepare()
254 const bool cifgUsed = !hasTensor(context, kInputToInputWeightsTensor); in prepare()
256 ((hasTensor(context, kCellToInputWeightsTensor) || cifgUsed) && in prepare()
257 hasTensor(context, kCellToForgetWeightsTensor) && in prepare()
258 hasTensor(context, kCellToOutputWeightsTensor)) || in prepare()
259 (!hasTensor(context, kCellToInputWeightsTensor) && in prepare()
260 !hasTensor(context, kCellToForgetWeightsTensor) && in prepare()
261 !hasTensor(context, kCellToOutputWeightsTensor)); in prepare()
265 NN_RET_CHECK(hasTensor(context, kInputGateBiasTensor)); in prepare()
266 const Shape inputGateBiasShape = context->getInputShape(kInputGateBiasTensor); in prepare()
270 NN_RET_CHECK(!hasTensor(context, kInputGateBiasTensor)) in prepare()
274 const Shape forgetGateBiasShape = context->getInputShape(kForgetGateBiasTensor); in prepare()
277 const Shape cellGateBiasShape = context->getInputShape(kCellGateBiasTensor); in prepare()
280 const Shape outputGateBiasShape = context->getInputShape(kOutputGateBiasTensor); in prepare()
284 if (hasTensor(context, kProjectionWeightsTensor)) { in prepare()
285 const Shape projectionShape = context->getInputShape(kProjectionWeightsTensor); in prepare()
291 if (hasTensor(context, kProjectionBiasTensor)) { in prepare()
292 const Shape projectionBiasShape = context->getInputShape(kProjectionBiasTensor); in prepare()
297 const Shape outputStateShape = context->getInputShape(kPrevOutputTensor); in prepare()
301 const Shape cellStateShape = context->getInputShape(kPrevCellStateTensor); in prepare()
306 if (hasTensor(context, kInputLayerNormTensor)) { in prepare()
307 const Shape inputLayerNormShape = context->getInputShape(kInputLayerNormTensor); in prepare()
312 if (hasTensor(context, kForgetLayerNormTensor)) { in prepare()
313 const Shape forgetLayerNormShape = context->getInputShape(kForgetLayerNormTensor); in prepare()
318 if (hasTensor(context, kCellLayerNormTensor)) { in prepare()
319 const Shape cellLayerNormShape = context->getInputShape(kCellLayerNormTensor); in prepare()
324 if (hasTensor(context, kOutputLayerNormTensor)) { in prepare()
325 const Shape outputLayerNormShape = context->getInputShape(kOutputLayerNormTensor); in prepare()
331 NN_RET_CHECK(!hasTensor(context, kInputLayerNormTensor)) in prepare()
333 const bool layerNormWeightsAllOrNoneCifg = (hasTensor(context, kForgetLayerNormTensor) && in prepare()
334 hasTensor(context, kCellLayerNormTensor) && in prepare()
335 hasTensor(context, kOutputLayerNormTensor)) || in prepare()
336 (!hasTensor(context, kForgetLayerNormTensor) && in prepare()
337 !hasTensor(context, kCellLayerNormTensor) && in prepare()
338 !hasTensor(context, kOutputLayerNormTensor)); in prepare()
341 const bool layerNormWeightsAllOrNone = (hasTensor(context, kInputLayerNormTensor) && in prepare()
342 hasTensor(context, kForgetLayerNormTensor) && in prepare()
343 hasTensor(context, kCellLayerNormTensor) && in prepare()
344 hasTensor(context, kOutputLayerNormTensor)) || in prepare()
345 (!hasTensor(context, kInputLayerNormTensor) && in prepare()
346 !hasTensor(context, kForgetLayerNormTensor) && in prepare()
347 !hasTensor(context, kCellLayerNormTensor) && in prepare()
348 !hasTensor(context, kOutputLayerNormTensor)); in prepare()
352 const Shape prevOutputShape = context->getInputShape(kPrevOutputTensor); in prepare()
353 Shape outputShape = context->getOutputShape(kOutputTensor); in prepare()
356 const Shape prevCellStateShape = context->getInputShape(kPrevCellStateTensor); in prepare()
357 Shape cellStateOutShape = context->getOutputShape(kCellStateOutTensor); in prepare()
360 return context->setOutputShape(kOutputStateOutTensor, outputShape) && in prepare()
361 context->setOutputShape(kCellStateOutTensor, cellStateOutShape) && in prepare()
362 context->setOutputShape(kOutputTensor, outputShape); in prepare()
365 bool execute(IOperationExecutionContext* context) { in execute() argument
367 const Shape inputShape = context->getInputShape(kInputTensor); in execute()
368 const Shape inputToInputWeightsShape = context->getInputShape(kInputToInputWeightsTensor); in execute()
370 context->getInputShape(kRecurrentToInputWeightsTensor); in execute()
371 const Shape cellToInputShape = context->getInputShape(kCellToInputWeightsTensor); in execute()
372 const Shape inputLayerNormShape = context->getInputShape(kInputLayerNormTensor); in execute()
373 const Shape inputToForgetWeightsShape = context->getInputShape(kInputToForgetWeightsTensor); in execute()
375 context->getInputShape(kRecurrentToForgetWeightsTensor); in execute()
376 const Shape cellToForgetShape = context->getInputShape(kCellToForgetWeightsTensor); in execute()
377 const Shape forgetLayerNormShape = context->getInputShape(kForgetLayerNormTensor); in execute()
378 const Shape inputToCellWeightsShape = context->getInputShape(kInputToCellWeightsTensor); in execute()
379 const Shape recurrentToCellWeightsShape = context->getInputShape(kRecurrentToCellWeightsTensor); in execute()
380 const Shape cellLayerNormShape = context->getInputShape(kCellLayerNormTensor); in execute()
381 const Shape inputToOutputWeightsShape = context->getInputShape(kInputToOutputWeightsTensor); in execute()
383 context->getInputShape(kRecurrentToOutputWeightsTensor); in execute()
384 const Shape cellToOutputShape = context->getInputShape(kCellToOutputWeightsTensor); in execute()
385 const Shape outputLayerNormShape = context->getInputShape(kOutputLayerNormTensor); in execute()
386 const Shape projectionWeightsShape = context->getInputShape(kProjectionWeightsTensor); in execute()
387 const Shape prevOutputShape = context->getInputShape(kPrevOutputTensor); in execute()
388 const Shape prevCellStateShape = context->getInputShape(kPrevCellStateTensor); in execute()
395 const float cellClip = context->getInputValue<float>(kCellClip); in execute()
396 const float projectionClip = context->getInputValue<float>(kProjectionClip); in execute()
397 const float inputIntermediateScale = context->getInputValue<float>(kInputIntermediateScale); in execute()
398 const float forgetIntermediateScale = context->getInputValue<float>(kForgetIntermediateScale); in execute()
399 const float cellIntermediateScale = context->getInputValue<float>(kCellIntermediateScale); in execute()
400 const float outputIntermediateScale = context->getInputValue<float>(kOutputIntermediateScale); in execute()
401 const int8_t hiddenStateZeroPoint = context->getInputValue<int8_t>(kHiddenStateZeroPoint); in execute()
402 const float hiddenStateScale = context->getInputValue<float>(kHiddenStateScale); in execute()
405 reinterpret_cast<const int8_t*>(context->getInputBuffer(kInputTensor)); in execute()
408 reinterpret_cast<const int8_t*>(context->getInputBuffer(kInputToInputWeightsTensor)); in execute()
411 context->getInputBuffer(kRecurrentToInputWeightsTensor)); in execute()
413 reinterpret_cast<const int16_t*>(context->getInputBuffer(kCellToInputWeightsTensor)); in execute()
415 reinterpret_cast<const int16_t*>(context->getInputBuffer(kInputLayerNormTensor)); in execute()
417 reinterpret_cast<const int32_t*>(context->getInputBuffer(kInputGateBiasTensor)); in execute()
420 reinterpret_cast<const int8_t*>(context->getInputBuffer(kInputToForgetWeightsTensor)); in execute()
422 context->getInputBuffer(kRecurrentToForgetWeightsTensor)); in execute()
424 reinterpret_cast<const int16_t*>(context->getInputBuffer(kCellToForgetWeightsTensor)); in execute()
426 reinterpret_cast<const int16_t*>(context->getInputBuffer(kForgetLayerNormTensor)); in execute()
428 reinterpret_cast<const int32_t*>(context->getInputBuffer(kForgetGateBiasTensor)); in execute()
431 reinterpret_cast<const int8_t*>(context->getInputBuffer(kInputToCellWeightsTensor)); in execute()
433 reinterpret_cast<const int8_t*>(context->getInputBuffer(kRecurrentToCellWeightsTensor)); in execute()
435 reinterpret_cast<const int16_t*>(context->getInputBuffer(kCellLayerNormTensor)); in execute()
437 reinterpret_cast<const int32_t*>(context->getInputBuffer(kCellGateBiasTensor)); in execute()
440 reinterpret_cast<const int8_t*>(context->getInputBuffer(kInputToOutputWeightsTensor)); in execute()
442 context->getInputBuffer(kRecurrentToOutputWeightsTensor)); in execute()
444 reinterpret_cast<const int16_t*>(context->getInputBuffer(kCellToOutputWeightsTensor)); in execute()
446 reinterpret_cast<const int16_t*>(context->getInputBuffer(kOutputLayerNormTensor)); in execute()
448 reinterpret_cast<const int32_t*>(context->getInputBuffer(kOutputGateBiasTensor)); in execute()
451 reinterpret_cast<const int8_t*>(context->getInputBuffer(kProjectionWeightsTensor)); in execute()
453 reinterpret_cast<const int32_t*>(context->getInputBuffer(kProjectionBiasTensor)); in execute()
456 reinterpret_cast<const int8_t*>(context->getInputBuffer(kPrevOutputTensor)); in execute()
458 reinterpret_cast<const int16_t*>(context->getInputBuffer(kPrevCellStateTensor)); in execute()
461 reinterpret_cast<uint8_t*>(context->getOutputBuffer(kOutputStateOutTensor)); in execute()
463 reinterpret_cast<int16_t*>(context->getOutputBuffer(kCellStateOutTensor)); in execute()
464 int8_t* outputBuffer = reinterpret_cast<int8_t*>(context->getOutputBuffer(kOutputTensor)); in execute()