1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef ANDROID_FRAMEWORKS_ML_NN_COMMON_OPERATION_RESOLVER_H 18 #define ANDROID_FRAMEWORKS_ML_NN_COMMON_OPERATION_RESOLVER_H 19 20 #include "HalInterfaces.h" 21 #include "OperationsUtils.h" 22 23 namespace android { 24 namespace nn { 25 26 // Encapsulates an operation implementation. 27 struct OperationRegistration { 28 hal::OperationType type; 29 const char* name; 30 31 // Validates operand types, shapes, and any values known during graph creation. 32 std::function<bool(const IOperationValidationContext*)> validate; 33 34 // prepare is called when the inputs this operation depends on have been 35 // computed. Typically, prepare does any remaining validation and sets 36 // output shapes via context->setOutputShape(...). 37 std::function<bool(IOperationExecutionContext*)> prepare; 38 39 // Executes the operation, reading from context->getInputBuffer(...) 40 // and writing to context->getOutputBuffer(...). 41 std::function<bool(IOperationExecutionContext*)> execute; 42 43 struct Flag { 44 // Whether the operation allows at least one operand to be omitted. 45 bool allowOmittedOperand = false; 46 // Whether the operation allows at least one input operand to be a zero-sized tensor. 47 bool allowZeroSizedInput = false; 48 } flags; 49 OperationRegistrationOperationRegistration50 OperationRegistration(hal::OperationType type, const char* name, 51 std::function<bool(const IOperationValidationContext*)> validate, 52 std::function<bool(IOperationExecutionContext*)> prepare, 53 std::function<bool(IOperationExecutionContext*)> execute, Flag flags) 54 : type(type), 55 name(name), 56 validate(validate), 57 prepare(prepare), 58 execute(execute), 59 flags(flags) {} 60 }; 61 62 // A registry of operation implementations. 63 class IOperationResolver { 64 public: 65 virtual const OperationRegistration* findOperation(hal::OperationType operationType) const = 0; ~IOperationResolver()66 virtual ~IOperationResolver() {} 67 }; 68 69 // A registry of builtin operation implementations. 70 // 71 // Note that some operations bypass BuiltinOperationResolver (b/124041202). 72 // 73 // Usage: 74 // const OperationRegistration* operationRegistration = 75 // BuiltinOperationResolver::get()->findOperation(operationType); 76 // NN_RET_CHECK(operationRegistration != nullptr); 77 // NN_RET_CHECK(operationRegistration->validate != nullptr); 78 // NN_RET_CHECK(operationRegistration->validate(&context)); 79 // 80 class BuiltinOperationResolver : public IOperationResolver { 81 DISALLOW_COPY_AND_ASSIGN(BuiltinOperationResolver); 82 83 public: get()84 static const BuiltinOperationResolver* get() { 85 static BuiltinOperationResolver instance; 86 return &instance; 87 } 88 89 const OperationRegistration* findOperation(hal::OperationType operationType) const override; 90 91 private: 92 BuiltinOperationResolver(); 93 94 void registerOperation(const OperationRegistration* operationRegistration); 95 96 const OperationRegistration* mRegistrations[kNumberOfOperationTypes] = {}; 97 }; 98 99 // NN_REGISTER_OPERATION creates OperationRegistration for consumption by 100 // OperationResolver. 101 // 102 // Usage: 103 // (check OperationRegistration::Flag for available fields and default values.) 104 // 105 // - With default flags. 106 // NN_REGISTER_OPERATION(FOO_OP, foo_op::kOperationName, foo_op::validate, 107 // foo_op::prepare, foo_op::execute); 108 // 109 // - With a customized flag. 110 // NN_REGISTER_OPERATION(FOO_OP, foo_op::kOperationName, foo_op::validate, 111 // foo_op::prepare, foo_op::execute, .allowZeroSizedInput = true); 112 // 113 // - With multiple customized flags. 114 // NN_REGISTER_OPERATION(FOO_OP, foo_op::kOperationName, foo_op::validate, 115 // foo_op::prepare, foo_op::execute, .allowOmittedOperand = true, 116 // .allowZeroSizedInput = true); 117 // 118 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION 119 #define NN_REGISTER_OPERATION(identifier, operationName, validate, prepare, execute, ...) \ 120 const OperationRegistration* register_##identifier() { \ 121 static OperationRegistration registration(hal::OperationType::identifier, operationName, \ 122 validate, prepare, execute, {__VA_ARGS__}); \ 123 return ®istration; \ 124 } 125 #else 126 // This version ignores CPU execution logic (prepare and execute). 127 // The compiler is supposed to omit that code so that only validation logic 128 // makes it into libneuralnetworks_utils. 129 #define NN_REGISTER_OPERATION(identifier, operationName, validate, unused_prepare, unused_execute, \ 130 ...) \ 131 const OperationRegistration* register_##identifier() { \ 132 static OperationRegistration registration(hal::OperationType::identifier, operationName, \ 133 validate, nullptr, nullptr, {__VA_ARGS__}); \ 134 return ®istration; \ 135 } 136 #endif 137 138 } // namespace nn 139 } // namespace android 140 141 #endif // ANDROID_FRAMEWORKS_ML_NN_COMMON_OPERATION_RESOLVER_H 142