Lines Matching refs:context

93 inline bool hasTensor(IOperationExecutionContext* context, const uint32_t tensor) {  in hasTensor()  argument
94 return context->getInputBuffer(tensor) != nullptr; in hasTensor()
97 inline bool isTimeMajor(IOperationExecutionContext* context) { in isTimeMajor() argument
98 return context->getInputValue<bool>(kTimeMajorParam); in isTimeMajor()
102 inline LSTMParams getLSTMParams(IOperationExecutionContext* context) { in getLSTMParams() argument
105 static_cast<TfLiteFusedActivation>(context->getInputValue<int32_t>(kActivationParam)); in getLSTMParams()
106 params.cell_clip = static_cast<float>(context->getInputValue<T>(kCellClipParam)); in getLSTMParams()
107 params.proj_clip = static_cast<float>(context->getInputValue<T>(kProjClipParam)); in getLSTMParams()
108 params.use_cifg = !hasTensor(context, kInputToInputWeightsTensor); in getLSTMParams()
109 params.use_peephole = hasTensor(context, kCellToOutputWeightsTensor); in getLSTMParams()
110 params.use_layer_norm = hasTensor(context, kOutputLayerNormWeightsTensor); in getLSTMParams()
111 params.use_projection_weight = hasTensor(context, kProjectionWeightsTensor); in getLSTMParams()
112 params.use_projection_bias = hasTensor(context, kProjectionBiasTensor); in getLSTMParams()
118 bool validate(const IOperationValidationContext* context) { in validate() argument
119 NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs); in validate()
120 const uint32_t numOutputs = context->getNumOutputs(); in validate()
122 const OperandType inputType = context->getInputType(kInputTensor); in validate()
163 if (context->getNumOutputs() == kNumOutputsWithState) { in validate()
167 NN_RET_CHECK(validateInputTypes(context, inExpectedTypes)); in validate()
168 NN_RET_CHECK(validateOutputTypes(context, outExpectedTypes)); in validate()
169 return validateHalVersion(context, minHalVersionSupported); in validate()
172 bool prepare(IOperationExecutionContext* context) { in prepare() argument
193 NN_RET_CHECK(!context->isOmittedInput(requiredInput)) in prepare()
197 const Shape inputShape = context->getInputShape(kInputTensor); in prepare()
201 const uint32_t maxTime = getSizeOfDimension(inputShape, isTimeMajor(context) ? 0 : 1); in prepare()
202 const uint32_t batchSize = getSizeOfDimension(inputShape, isTimeMajor(context) ? 1 : 0); in prepare()
205 const Shape inputToOutputShape = context->getInputShape(kInputToOutputWeightsTensor); in prepare()
210 const Shape recurrentToOutputShape = context->getInputShape(kRecurrentToOutputWeightsTensor); in prepare()
215 if (hasTensor(context, kInputToInputWeightsTensor)) { in prepare()
216 const Shape inputToInputShape = context->getInputShape(kInputToInputWeightsTensor); in prepare()
222 const Shape inputToForgetShape = context->getInputShape(kInputToForgetWeightsTensor); in prepare()
226 const Shape inputToCellShape = context->getInputShape(kInputToCellWeightsTensor); in prepare()
231 if (hasTensor(context, kRecurrentToInputWeightsTensor)) { in prepare()
232 const Shape recurrentToInputShape = context->getInputShape(kRecurrentToInputWeightsTensor); in prepare()
238 const Shape recurrentToForgetShape = context->getInputShape(kRecurrentToForgetWeightsTensor); in prepare()
242 const Shape recurrentToCellShape = context->getInputShape(kRecurrentToCellWeightsTensor); in prepare()
249 const bool cifgWeightsAllOrNone = (hasTensor(context, kInputToInputWeightsTensor) && in prepare()
250 hasTensor(context, kRecurrentToInputWeightsTensor)) || in prepare()
251 (!hasTensor(context, kInputToInputWeightsTensor) && in prepare()
252 !hasTensor(context, kRecurrentToInputWeightsTensor)); in prepare()
255 if (hasTensor(context, kCellToInputWeightsTensor)) { in prepare()
256 const Shape cellToInputShape = context->getInputShape(kCellToInputWeightsTensor); in prepare()
261 if (hasTensor(context, kCellToForgetWeightsTensor)) { in prepare()
262 const Shape cellToForgetShape = context->getInputShape(kCellToForgetWeightsTensor); in prepare()
267 if (hasTensor(context, kCellToOutputWeightsTensor)) { in prepare()
268 const Shape cellToOutputShape = context->getInputShape(kCellToOutputWeightsTensor); in prepare()
274 const bool cifgUsed = !hasTensor(context, kInputToInputWeightsTensor); in prepare()
276 ((hasTensor(context, kCellToInputWeightsTensor) || cifgUsed) && in prepare()
277 hasTensor(context, kCellToForgetWeightsTensor) && in prepare()
278 hasTensor(context, kCellToOutputWeightsTensor)) || in prepare()
279 (!hasTensor(context, kCellToInputWeightsTensor) && in prepare()
280 !hasTensor(context, kCellToForgetWeightsTensor) && in prepare()
281 !hasTensor(context, kCellToOutputWeightsTensor)); in prepare()
285 NN_RET_CHECK(hasTensor(context, kInputGateBiasTensor)); in prepare()
286 const Shape inputGateBiasShape = context->getInputShape(kInputGateBiasTensor); in prepare()
290 NN_RET_CHECK(!hasTensor(context, kInputGateBiasTensor)) in prepare()
294 const Shape forgetGateBiasShape = context->getInputShape(kForgetGateBiasTensor); in prepare()
297 const Shape cellGateBiasShape = context->getInputShape(kCellGateBiasTensor); in prepare()
300 const Shape outputGateBiasShape = context->getInputShape(kOutputGateBiasTensor); in prepare()
304 if (hasTensor(context, kProjectionWeightsTensor)) { in prepare()
305 const Shape projectionShape = context->getInputShape(kProjectionWeightsTensor); in prepare()
311 if (hasTensor(context, kProjectionBiasTensor)) { in prepare()
312 const Shape projectionBiasShape = context->getInputShape(kProjectionBiasTensor); in prepare()
317 const Shape outputStateShape = context->getInputShape(kOutputStateInTensor); in prepare()
321 const Shape cellStateShape = context->getInputShape(kCellStateInTensor); in prepare()
326 if (hasTensor(context, kInputLayerNormWeightsTensor)) { in prepare()
327 const Shape inputLayerNormShape = context->getInputShape(kInputLayerNormWeightsTensor); in prepare()
332 if (hasTensor(context, kForgetLayerNormWeightsTensor)) { in prepare()
333 const Shape forgetLayerNormShape = context->getInputShape(kForgetLayerNormWeightsTensor); in prepare()
338 if (hasTensor(context, kCellLayerNormWeightsTensor)) { in prepare()
339 const Shape cellLayerNormShape = context->getInputShape(kCellLayerNormWeightsTensor); in prepare()
344 if (hasTensor(context, kOutputLayerNormWeightsTensor)) { in prepare()
345 const Shape outputLayerNormShape = context->getInputShape(kOutputLayerNormWeightsTensor); in prepare()
351 NN_RET_CHECK(!hasTensor(context, kInputLayerNormWeightsTensor)) in prepare()
354 (hasTensor(context, kForgetLayerNormWeightsTensor) && in prepare()
355 hasTensor(context, kCellLayerNormWeightsTensor) && in prepare()
356 hasTensor(context, kOutputLayerNormWeightsTensor)) || in prepare()
357 (!hasTensor(context, kForgetLayerNormWeightsTensor) && in prepare()
358 !hasTensor(context, kCellLayerNormWeightsTensor) && in prepare()
359 !hasTensor(context, kOutputLayerNormWeightsTensor)); in prepare()
363 (hasTensor(context, kInputLayerNormWeightsTensor) && in prepare()
364 hasTensor(context, kForgetLayerNormWeightsTensor) && in prepare()
365 hasTensor(context, kCellLayerNormWeightsTensor) && in prepare()
366 hasTensor(context, kOutputLayerNormWeightsTensor)) || in prepare()
367 (!hasTensor(context, kInputLayerNormWeightsTensor) && in prepare()
368 !hasTensor(context, kForgetLayerNormWeightsTensor) && in prepare()
369 !hasTensor(context, kCellLayerNormWeightsTensor) && in prepare()
370 !hasTensor(context, kOutputLayerNormWeightsTensor)); in prepare()
374 Shape outputShape = context->getInputShape(kInputTensor); in prepare()
377 if (context->getNumOutputs() == kNumOutputsWithState) { in prepare()
378 NN_RET_CHECK(!context->isOmittedOutput(kOutputStateOutTensor)); in prepare()
379 NN_RET_CHECK(!context->isOmittedOutput(kCellStateOutTensor)); in prepare()
381 Shape outputStateOutTensor = context->getInputShape(kOutputStateInTensor); in prepare()
385 NN_RET_CHECK(context->setOutputShape(kOutputStateOutTensor, outputStateOutTensor)); in prepare()
387 Shape cellStateOutTensor = context->getInputShape(kCellStateInTensor); in prepare()
391 NN_RET_CHECK(context->setOutputShape(kCellStateOutTensor, cellStateOutTensor)); in prepare()
394 return context->setOutputShape(kOutputTensor, outputShape); in prepare()
397 bool execute(IOperationExecutionContext* context) { in execute() argument
398 const auto outputStateSize = getNumberOfElements(context->getInputShape(kOutputStateInTensor)); in execute()
399 const auto cellStateSize = getNumberOfElements(context->getInputShape(kCellStateInTensor)); in execute()
400 const bool use_cifg = !hasTensor(context, kInputToInputWeightsTensor); in execute()
402 const bool useStateOutTensors = (context->getNumOutputs() == kNumOutputsWithState); in execute()
404 const OperandType inputType = context->getInputType(kInputTensor); in execute()
413 outputStateOut = context->getOutputBuffer<float>(kOutputStateOutTensor); in execute()
414 cellStateOut = context->getOutputBuffer<float>(kCellStateOutTensor); in execute()
423 getLSTMParams<float>(context), context->getInputBuffer<float>(kInputTensor), in execute()
424 context->getInputShape(kInputTensor), in execute()
425 context->getInputBuffer<float>(kInputToInputWeightsTensor), in execute()
426 context->getInputBuffer<float>(kInputToForgetWeightsTensor), in execute()
427 context->getInputBuffer<float>(kInputToCellWeightsTensor), in execute()
428 context->getInputBuffer<float>(kInputToOutputWeightsTensor), in execute()
429 context->getInputShape(kInputToOutputWeightsTensor), in execute()
430 context->getInputBuffer<float>(kRecurrentToInputWeightsTensor), in execute()
431 context->getInputBuffer<float>(kRecurrentToForgetWeightsTensor), in execute()
432 context->getInputBuffer<float>(kRecurrentToCellWeightsTensor), in execute()
433 context->getInputBuffer<float>(kRecurrentToOutputWeightsTensor), in execute()
434 context->getInputShape(kRecurrentToOutputWeightsTensor), in execute()
435 context->getInputBuffer<float>(kCellToInputWeightsTensor), in execute()
436 context->getInputBuffer<float>(kCellToForgetWeightsTensor), in execute()
437 context->getInputBuffer<float>(kCellToOutputWeightsTensor), in execute()
443 context->getInputBuffer<float>(kInputGateBiasTensor), in execute()
444 context->getInputBuffer<float>(kForgetGateBiasTensor), in execute()
445 context->getInputBuffer<float>(kCellGateBiasTensor), in execute()
446 context->getInputBuffer<float>(kOutputGateBiasTensor), in execute()
447 context->getInputBuffer<float>(kProjectionWeightsTensor), in execute()
448 context->getInputBuffer<float>(kProjectionBiasTensor), in execute()
449 context->getInputBuffer<float>(kOutputStateInTensor), in execute()
450 context->getInputBuffer<float>(kCellStateInTensor), in execute()
451 context->getInputBuffer<float>(kInputLayerNormWeightsTensor), in execute()
452 context->getInputBuffer<float>(kForgetLayerNormWeightsTensor), in execute()
453 context->getInputBuffer<float>(kCellLayerNormWeightsTensor), in execute()
454 context->getInputBuffer<float>(kOutputLayerNormWeightsTensor), outputStateOut, in execute()
455 cellStateOut, context->getOutputBuffer<float>(kOutputTensor), in execute()
456 scratchBuffer.data(), isTimeMajor(context)); in execute()
465 outputStateOut = context->getOutputBuffer<_Float16>(kOutputStateOutTensor); in execute()
466 cellStateOut = context->getOutputBuffer<_Float16>(kCellStateOutTensor); in execute()
475 getLSTMParams<_Float16>(context), in execute()
476 context->getInputBuffer<_Float16>(kInputTensor), in execute()
477 context->getInputShape(kInputTensor), in execute()
478 context->getInputBuffer<_Float16>(kInputToInputWeightsTensor), in execute()
479 context->getInputBuffer<_Float16>(kInputToForgetWeightsTensor), in execute()
480 context->getInputBuffer<_Float16>(kInputToCellWeightsTensor), in execute()
481 context->getInputBuffer<_Float16>(kInputToOutputWeightsTensor), in execute()
482 context->getInputShape(kInputToOutputWeightsTensor), in execute()
483 context->getInputBuffer<_Float16>(kRecurrentToInputWeightsTensor), in execute()
484 context->getInputBuffer<_Float16>(kRecurrentToForgetWeightsTensor), in execute()
485 context->getInputBuffer<_Float16>(kRecurrentToCellWeightsTensor), in execute()
486 context->getInputBuffer<_Float16>(kRecurrentToOutputWeightsTensor), in execute()
487 context->getInputShape(kRecurrentToOutputWeightsTensor), in execute()
488 context->getInputBuffer<_Float16>(kCellToInputWeightsTensor), in execute()
489 context->getInputBuffer<_Float16>(kCellToForgetWeightsTensor), in execute()
490 context->getInputBuffer<_Float16>(kCellToOutputWeightsTensor), in execute()
496 context->getInputBuffer<_Float16>(kInputGateBiasTensor), in execute()
497 context->getInputBuffer<_Float16>(kForgetGateBiasTensor), in execute()
498 context->getInputBuffer<_Float16>(kCellGateBiasTensor), in execute()
499 context->getInputBuffer<_Float16>(kOutputGateBiasTensor), in execute()
500 context->getInputBuffer<_Float16>(kProjectionWeightsTensor), in execute()
501 context->getInputBuffer<_Float16>(kProjectionBiasTensor), in execute()
502 context->getInputBuffer<_Float16>(kOutputStateInTensor), in execute()
503 context->getInputBuffer<_Float16>(kCellStateInTensor), in execute()
504 context->getInputBuffer<_Float16>(kInputLayerNormWeightsTensor), in execute()
505 context->getInputBuffer<_Float16>(kForgetLayerNormWeightsTensor), in execute()
506 context->getInputBuffer<_Float16>(kCellLayerNormWeightsTensor), in execute()
507 context->getInputBuffer<_Float16>(kOutputLayerNormWeightsTensor), in execute()
508 outputStateOut, cellStateOut, context->getOutputBuffer<_Float16>(kOutputTensor), in execute()
509 scratchBuffer.data(), isTimeMajor(context)); in execute()