Lines Matching refs:context

359 bool validate(OperationType opType, const IOperationValidationContext* context) {  in validate()  argument
360 NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs); in validate()
361 NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs); in validate()
362 auto inputType = context->getInputType(kInputTensor); in validate()
364 NN_RET_CHECK(validateHalVersion(context, HalVersion::V1_0)); in validate()
366 NN_RET_CHECK(validateHalVersion(context, HalVersion::V1_2)); in validate()
369 NN_RET_CHECK(validateHalVersion(context, HalVersion::V1_2)); in validate()
371 NN_RET_CHECK(validateHalVersion(context, HalVersion::V1_0)); in validate()
374 NN_RET_CHECK(validateHalVersion(context, HalVersion::V1_3)); in validate()
378 const Shape& input = context->getInputShape(kInputTensor); in validate()
382 return validateInputTypes(context, {inputType}) && validateOutputTypes(context, {inputType}); in validate()
385 bool validateHardSwish(const IOperationValidationContext* context) { in validateHardSwish() argument
386 NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs); in validateHardSwish()
387 NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs); in validateHardSwish()
388 auto inputType = context->getInputType(kInputTensor); in validateHardSwish()
392 NN_RET_CHECK(validateHalVersion(context, HalVersion::V1_3)); in validateHardSwish()
396 return validateInputTypes(context, {inputType}) && validateOutputTypes(context, {inputType}); in validateHardSwish()
399 bool prepare(OperationType opType, IOperationExecutionContext* context) { in prepare() argument
400 Shape input = context->getInputShape(kInputTensor); in prepare()
410 auto outputShape = context->getOutputShape(kOutputTensor); in prepare()
430 return context->setOutputShape(kOutputTensor, output); in prepare()
433 bool executeRelu(IOperationExecutionContext* context) { in executeRelu() argument
435 if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true; in executeRelu()
436 switch (context->getInputType(kInputTensor)) { in executeRelu()
438 return reluFloat(context->getInputBuffer<_Float16>(kInputTensor), in executeRelu()
439 context->getInputShape(kInputTensor), in executeRelu()
440 context->getOutputBuffer<_Float16>(kOutputTensor), in executeRelu()
441 context->getOutputShape(kOutputTensor)); in executeRelu()
443 return reluFloat(context->getInputBuffer<float>(kInputTensor), in executeRelu()
444 context->getInputShape(kInputTensor), in executeRelu()
445 context->getOutputBuffer<float>(kOutputTensor), in executeRelu()
446 context->getOutputShape(kOutputTensor)); in executeRelu()
448 return reluQuant8(context->getInputBuffer<uint8_t>(kInputTensor), in executeRelu()
449 context->getInputShape(kInputTensor), in executeRelu()
450 context->getOutputBuffer<uint8_t>(kOutputTensor), in executeRelu()
451 context->getOutputShape(kOutputTensor)); in executeRelu()
453 return reluQuant8Signed(context->getInputBuffer<int8_t>(kInputTensor), in executeRelu()
454 context->getInputShape(kInputTensor), in executeRelu()
455 context->getOutputBuffer<int8_t>(kOutputTensor), in executeRelu()
456 context->getOutputShape(kOutputTensor)); in executeRelu()
462 bool executeRelu1(IOperationExecutionContext* context) { in executeRelu1() argument
464 if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true; in executeRelu1()
465 switch (context->getInputType(kInputTensor)) { in executeRelu1()
467 return relu1Float(context->getInputBuffer<_Float16>(kInputTensor), in executeRelu1()
468 context->getInputShape(kInputTensor), in executeRelu1()
469 context->getOutputBuffer<_Float16>(kOutputTensor), in executeRelu1()
470 context->getOutputShape(kOutputTensor)); in executeRelu1()
472 return relu1Float(context->getInputBuffer<float>(kInputTensor), in executeRelu1()
473 context->getInputShape(kInputTensor), in executeRelu1()
474 context->getOutputBuffer<float>(kOutputTensor), in executeRelu1()
475 context->getOutputShape(kOutputTensor)); in executeRelu1()
477 return relu1Quant8(context->getInputBuffer<uint8_t>(kInputTensor), in executeRelu1()
478 context->getInputShape(kInputTensor), in executeRelu1()
479 context->getOutputBuffer<uint8_t>(kOutputTensor), in executeRelu1()
480 context->getOutputShape(kOutputTensor)); in executeRelu1()
482 return relu1Quant8Signed(context->getInputBuffer<int8_t>(kInputTensor), in executeRelu1()
483 context->getInputShape(kInputTensor), in executeRelu1()
484 context->getOutputBuffer<int8_t>(kOutputTensor), in executeRelu1()
485 context->getOutputShape(kOutputTensor)); in executeRelu1()
491 bool executeRelu6(IOperationExecutionContext* context) { in executeRelu6() argument
493 if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true; in executeRelu6()
494 switch (context->getInputType(kInputTensor)) { in executeRelu6()
496 return relu6Float(context->getInputBuffer<_Float16>(kInputTensor), in executeRelu6()
497 context->getInputShape(kInputTensor), in executeRelu6()
498 context->getOutputBuffer<_Float16>(kOutputTensor), in executeRelu6()
499 context->getOutputShape(kOutputTensor)); in executeRelu6()
501 return relu6Float(context->getInputBuffer<float>(kInputTensor), in executeRelu6()
502 context->getInputShape(kInputTensor), in executeRelu6()
503 context->getOutputBuffer<float>(kOutputTensor), in executeRelu6()
504 context->getOutputShape(kOutputTensor)); in executeRelu6()
506 return relu6Quant8(context->getInputBuffer<uint8_t>(kInputTensor), in executeRelu6()
507 context->getInputShape(kInputTensor), in executeRelu6()
508 context->getOutputBuffer<uint8_t>(kOutputTensor), in executeRelu6()
509 context->getOutputShape(kOutputTensor)); in executeRelu6()
511 return relu6Quant8Signed(context->getInputBuffer<int8_t>(kInputTensor), in executeRelu6()
512 context->getInputShape(kInputTensor), in executeRelu6()
513 context->getOutputBuffer<int8_t>(kOutputTensor), in executeRelu6()
514 context->getOutputShape(kOutputTensor)); in executeRelu6()
520 bool executeLogistic(IOperationExecutionContext* context) { in executeLogistic() argument
522 if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true; in executeLogistic()
523 switch (context->getInputType(kInputTensor)) { in executeLogistic()
525 return logisticFloat(context->getInputBuffer<_Float16>(kInputTensor), in executeLogistic()
526 context->getInputShape(kInputTensor), in executeLogistic()
527 context->getOutputBuffer<_Float16>(kOutputTensor), in executeLogistic()
528 context->getOutputShape(kOutputTensor)); in executeLogistic()
530 return logisticFloat(context->getInputBuffer<float>(kInputTensor), in executeLogistic()
531 context->getInputShape(kInputTensor), in executeLogistic()
532 context->getOutputBuffer<float>(kOutputTensor), in executeLogistic()
533 context->getOutputShape(kOutputTensor)); in executeLogistic()
535 return logisticQuant8(context->getInputBuffer<uint8_t>(kInputTensor), in executeLogistic()
536 context->getInputShape(kInputTensor), in executeLogistic()
537 context->getOutputBuffer<uint8_t>(kOutputTensor), in executeLogistic()
538 context->getOutputShape(kOutputTensor)); in executeLogistic()
540 return logisticQuant8Signed(context->getInputBuffer<int8_t>(kInputTensor), in executeLogistic()
541 context->getInputShape(kInputTensor), in executeLogistic()
542 context->getOutputBuffer<int8_t>(kOutputTensor), in executeLogistic()
543 context->getOutputShape(kOutputTensor)); in executeLogistic()
549 bool executeTanh(IOperationExecutionContext* context) { in executeTanh() argument
551 if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true; in executeTanh()
552 switch (context->getInputType(kInputTensor)) { in executeTanh()
554 return tanhFloat16(context->getInputBuffer<_Float16>(kInputTensor), in executeTanh()
555 context->getInputShape(kInputTensor), in executeTanh()
556 context->getOutputBuffer<_Float16>(kOutputTensor), in executeTanh()
557 context->getOutputShape(kOutputTensor)); in executeTanh()
559 return tanhFloat32(context->getInputBuffer<float>(kInputTensor), in executeTanh()
560 context->getInputShape(kInputTensor), in executeTanh()
561 context->getOutputBuffer<float>(kOutputTensor), in executeTanh()
562 context->getOutputShape(kOutputTensor)); in executeTanh()
564 return tanhQuant8(context->getInputBuffer<uint8_t>(kInputTensor), in executeTanh()
565 context->getInputShape(kInputTensor), in executeTanh()
566 context->getOutputBuffer<uint8_t>(kOutputTensor), in executeTanh()
567 context->getOutputShape(kOutputTensor)); in executeTanh()
569 return tanhQuant8Signed(context->getInputBuffer<int8_t>(kInputTensor), in executeTanh()
570 context->getInputShape(kInputTensor), in executeTanh()
571 context->getOutputBuffer<int8_t>(kOutputTensor), in executeTanh()
572 context->getOutputShape(kOutputTensor)); in executeTanh()
578 bool executeHardSwish(IOperationExecutionContext* context) { in executeHardSwish() argument
580 if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true; in executeHardSwish()
581 switch (context->getInputType(kInputTensor)) { in executeHardSwish()
583 const Shape& inputShape = context->getInputShape(kInputTensor); in executeHardSwish()
584 const Shape& outputShape = context->getOutputShape(kOutputTensor); in executeHardSwish()
587 convertFloat16ToFloat32(context->getInputBuffer<_Float16>(kInputTensor), &inputFloat); in executeHardSwish()
591 convertFloat32ToFloat16(outputFloat, context->getOutputBuffer<_Float16>(kOutputTensor)); in executeHardSwish()
596 convertShapeToTflshape(context->getInputShape(kInputTensor)), in executeHardSwish()
597 context->getInputBuffer<float>(kInputTensor), in executeHardSwish()
598 convertShapeToTflshape(context->getOutputShape(kOutputTensor)), in executeHardSwish()
599 context->getOutputBuffer<float>(kOutputTensor)); in executeHardSwish()
603 return hardSwishQuant(context->getInputBuffer<uint8_t>(kInputTensor), in executeHardSwish()
604 context->getInputShape(kInputTensor), in executeHardSwish()
605 context->getOutputBuffer<uint8_t>(kOutputTensor), in executeHardSwish()
606 context->getOutputShape(kOutputTensor)); in executeHardSwish()
608 return hardSwishQuant(context->getInputBuffer<int8_t>(kInputTensor), in executeHardSwish()
609 context->getInputShape(kInputTensor), in executeHardSwish()
610 context->getOutputBuffer<int8_t>(kOutputTensor), in executeHardSwish()
611 context->getOutputShape(kOutputTensor)); in executeHardSwish()