1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef ANDROID_FRAMEWORKS_ML_NN_COMMON_OPERATION_RESOLVER_H
18 #define ANDROID_FRAMEWORKS_ML_NN_COMMON_OPERATION_RESOLVER_H
19 
20 #include "HalInterfaces.h"
21 #include "OperationsUtils.h"
22 
23 namespace android {
24 namespace nn {
25 
26 // Encapsulates an operation implementation.
27 struct OperationRegistration {
28     hal::OperationType type;
29     const char* name;
30 
31     // Validates operand types, shapes, and any values known during graph creation.
32     std::function<bool(const IOperationValidationContext*)> validate;
33 
34     // prepare is called when the inputs this operation depends on have been
35     // computed. Typically, prepare does any remaining validation and sets
36     // output shapes via context->setOutputShape(...).
37     std::function<bool(IOperationExecutionContext*)> prepare;
38 
39     // Executes the operation, reading from context->getInputBuffer(...)
40     // and writing to context->getOutputBuffer(...).
41     std::function<bool(IOperationExecutionContext*)> execute;
42 
43     struct Flag {
44         // Whether the operation allows at least one operand to be omitted.
45         bool allowOmittedOperand = false;
46         // Whether the operation allows at least one input operand to be a zero-sized tensor.
47         bool allowZeroSizedInput = false;
48     } flags;
49 
OperationRegistrationOperationRegistration50     OperationRegistration(hal::OperationType type, const char* name,
51                           std::function<bool(const IOperationValidationContext*)> validate,
52                           std::function<bool(IOperationExecutionContext*)> prepare,
53                           std::function<bool(IOperationExecutionContext*)> execute, Flag flags)
54         : type(type),
55           name(name),
56           validate(validate),
57           prepare(prepare),
58           execute(execute),
59           flags(flags) {}
60 };
61 
62 // A registry of operation implementations.
63 class IOperationResolver {
64    public:
65     virtual const OperationRegistration* findOperation(hal::OperationType operationType) const = 0;
~IOperationResolver()66     virtual ~IOperationResolver() {}
67 };
68 
69 // A registry of builtin operation implementations.
70 //
71 // Note that some operations bypass BuiltinOperationResolver (b/124041202).
72 //
73 // Usage:
74 //   const OperationRegistration* operationRegistration =
75 //           BuiltinOperationResolver::get()->findOperation(operationType);
76 //   NN_RET_CHECK(operationRegistration != nullptr);
77 //   NN_RET_CHECK(operationRegistration->validate != nullptr);
78 //   NN_RET_CHECK(operationRegistration->validate(&context));
79 //
80 class BuiltinOperationResolver : public IOperationResolver {
81     DISALLOW_COPY_AND_ASSIGN(BuiltinOperationResolver);
82 
83    public:
get()84     static const BuiltinOperationResolver* get() {
85         static BuiltinOperationResolver instance;
86         return &instance;
87     }
88 
89     const OperationRegistration* findOperation(hal::OperationType operationType) const override;
90 
91    private:
92     BuiltinOperationResolver();
93 
94     void registerOperation(const OperationRegistration* operationRegistration);
95 
96     const OperationRegistration* mRegistrations[kNumberOfOperationTypes] = {};
97 };
98 
99 // NN_REGISTER_OPERATION creates OperationRegistration for consumption by
100 // OperationResolver.
101 //
102 // Usage:
103 // (check OperationRegistration::Flag for available fields and default values.)
104 //
105 // - With default flags.
106 //   NN_REGISTER_OPERATION(FOO_OP, foo_op::kOperationName, foo_op::validate,
107 //                         foo_op::prepare, foo_op::execute);
108 //
109 // - With a customized flag.
110 //   NN_REGISTER_OPERATION(FOO_OP, foo_op::kOperationName, foo_op::validate,
111 //                         foo_op::prepare, foo_op::execute, .allowZeroSizedInput = true);
112 //
113 // - With multiple customized flags.
114 //   NN_REGISTER_OPERATION(FOO_OP, foo_op::kOperationName, foo_op::validate,
115 //                         foo_op::prepare, foo_op::execute, .allowOmittedOperand = true,
116 //                         .allowZeroSizedInput = true);
117 //
118 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
119 #define NN_REGISTER_OPERATION(identifier, operationName, validate, prepare, execute, ...)        \
120     const OperationRegistration* register_##identifier() {                                       \
121         static OperationRegistration registration(hal::OperationType::identifier, operationName, \
122                                                   validate, prepare, execute, {__VA_ARGS__});    \
123         return &registration;                                                                    \
124     }
125 #else
126 // This version ignores CPU execution logic (prepare and execute).
127 // The compiler is supposed to omit that code so that only validation logic
128 // makes it into libneuralnetworks_utils.
129 #define NN_REGISTER_OPERATION(identifier, operationName, validate, unused_prepare, unused_execute, \
130                               ...)                                                                 \
131     const OperationRegistration* register_##identifier() {                                         \
132         static OperationRegistration registration(hal::OperationType::identifier, operationName,   \
133                                                   validate, nullptr, nullptr, {__VA_ARGS__});      \
134         return &registration;                                                                      \
135     }
136 #endif
137 
138 }  // namespace nn
139 }  // namespace android
140 
141 #endif  // ANDROID_FRAMEWORKS_ML_NN_COMMON_OPERATION_RESOLVER_H
142