1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef ANDROID_FRAMEWORKS_ML_NN_COMMON_OPERATIONS_H
18 #define ANDROID_FRAMEWORKS_ML_NN_COMMON_OPERATIONS_H
19 
20 #include "operations/BidirectionalSequenceLSTM.h"
21 #include "operations/Cast.h"
22 #include "operations/EmbeddingLookup.h"
23 #include "operations/ExpandDims.h"
24 #include "operations/HashtableLookup.h"
25 #include "operations/LSHProjection.h"
26 #include "operations/LSTM.h"
27 #include "operations/MaximumMinimum.h"
28 #include "operations/Multinomial.h"
29 #include "operations/Pow.h"
30 #include "operations/QuantizedLSTM.h"
31 #include "operations/RNN.h"
32 #include "operations/SVDF.h"
33 #include "operations/Tile.h"
34 
35 #include <stddef.h>
36 
37 #include <cstdint>
38 #include <vector>
39 
40 namespace android {
41 namespace nn {
42 
43 struct Shape;
44 
45 bool floorFloat16(const _Float16* inputData, _Float16* outputData, const Shape& shape);
46 bool floorFloat32(const float* inputData, float* outputData, const Shape& shape);
47 
48 bool depthwiseConvFloat16(const _Float16* inputData, const Shape& inputShape,
49                           const _Float16* filterData, const Shape& filterShape,
50                           const _Float16* biasData, const Shape& biasShape, int32_t paddingLeft,
51                           int32_t paddingRight, int32_t paddingTop, int32_t paddingBottom,
52                           int32_t strideWidth, int32_t strideHeight, int32_t dilationWidthFactor,
53                           int32_t dilationHeightFactor, int32_t depthMultiplier, int32_t activation,
54                           _Float16* outputData, const Shape& outputShape);
55 bool depthwiseConvFloat32(const float* inputData, const Shape& inputShape, const float* filterData,
56                           const Shape& filterShape, const float* biasData, const Shape& biasShape,
57                           int32_t paddingLeft, int32_t paddingRight, int32_t paddingTop,
58                           int32_t paddingBottom, int32_t strideWidth, int32_t strideHeight,
59                           int32_t dilationWidthFactor, int32_t dilationHeightFactor,
60                           int32_t depthMultiplier, int32_t activation, float* outputData,
61                           const Shape& outputShape);
62 bool depthwiseConvQuant8(const uint8_t* inputData, const Shape& inputShape,
63                          const uint8_t* filterData, const Shape& filterShape,
64                          const int32_t* biasData, const Shape& biasShape, int32_t paddingLeft,
65                          int32_t paddingRight, int32_t paddingTop, int32_t paddingBottom,
66                          int32_t strideWidth, int32_t strideHeight, int32_t dilationWidthFactor,
67                          int32_t dilationHeightFactor, int32_t depthMultiplier, int32_t activation,
68                          uint8_t* outputData, const Shape& outputShape);
69 bool depthwiseConvQuant8PerChannel(const uint8_t* inputData, const Shape& inputShape,
70                                    const int8_t* filterData, const Shape& filterShape,
71                                    const float* filterScales, const int32_t* biasData,
72                                    const Shape& biasShape, int32_t paddingLeft,
73                                    int32_t paddingRight, int32_t paddingTop, int32_t paddingBottom,
74                                    int32_t strideWidth, int32_t strideHeight,
75                                    int32_t dilationWidthFactor, int32_t dilationHeightFactor,
76                                    int32_t depthMultiplier, int32_t activation, uint8_t* outputData,
77                                    const Shape& outputShape);
78 
79 bool localResponseNormFloat16(const _Float16* inputData, const Shape& inputShape, int32_t radius,
80                               float bias, float alpha, float beta, int32_t axis,
81                               _Float16* outputData, const Shape& outputShape);
82 bool localResponseNormFloat32(const float* inputData, const Shape& inputShape, int32_t radius,
83                               float bias, float alpha, float beta, int32_t axis, float* outputData,
84                               const Shape& outputShape);
85 
86 bool copyData(const void* inputData, const Shape& inputShape, void* outputData,
87               const Shape& outputShape);
88 
89 template <typename T>
90 bool depthToSpaceGeneric(const T* inputData, const Shape& inputShape, int32_t blockSize,
91                          T* outputData, const Shape& outputShape);
92 template <typename T>
93 bool spaceToDepthGeneric(const T* inputData, const Shape& inputShape, int32_t blockSize,
94                          T* outputData, const Shape& outputShape);
95 
96 template <typename T>
97 bool padGeneric(const T* inputData, const Shape& inputShape, const int32_t* paddings, T pad_value,
98                 T* outputData, const Shape& outputShape);
99 
100 template <typename T>
101 bool batchToSpaceGeneric(const T* inputData, const Shape& inputShape, const int32_t* blockSize,
102                          T* outputData, const Shape& outputShape);
103 
104 template <typename T>
105 bool spaceToBatchGeneric(const T* inputData, const Shape& inputShape, const int32_t* blockSize,
106                          const int32_t* padding, const Shape& paddingShape, T* outputData,
107                          const Shape& outputShape);
108 
109 bool meanFloat16(_Float16* inputData, const Shape& inputShape, const int32_t* axis,
110                  const Shape& axisShape, bool keepDims, _Float16* outputData,
111                  const Shape& outputShape);
112 template <typename T, typename U>
113 bool meanGeneric(T* inputData, const Shape& inputShape, const int32_t* axis, const Shape& axisShape,
114                  bool keepDims, T* outputData, const Shape& outputShape);
115 
116 bool stridedSliceGeneric(const uint8_t* inputData, const Shape& inputShape,
117                          const int32_t* beginData, const int32_t* endData,
118                          const int32_t* stridesData, int32_t beginMask, int32_t endMask,
119                          int32_t shrinkAxisMask, uint8_t* outputData, const Shape& outputShape);
120 
121 bool argMinMaxGeneric(const uint8_t* inputData, const Shape& inputShape, int32_t axis,
122                       bool isArgMin, uint8_t* outputData, const Shape& outputShape);
123 
124 bool splitFloat16(const _Float16* inputData, const Shape& inputShape, int32_t axis,
125                   const std::vector<_Float16*>* outputDataPtrs,
126                   const std::vector<Shape>& outputShapes);
127 
128 bool splitFloat32(const float* inputData, const Shape& inputShape, const int32_t axis,
129                   const std::vector<float*>* outputDataPtrs,
130                   const std::vector<Shape>& outputShapes);
131 
132 bool splitInt32(const int32_t* inputData, const Shape& inputShape, const int32_t axis,
133                 const std::vector<int32_t*>* outputDataPtrs,
134                 const std::vector<Shape>& outputShapes);
135 
136 bool splitQuant8(const uint8_t* inputData, const Shape& inputShape, const int32_t axis,
137                  const std::vector<uint8_t*>* outputDataPtrs,
138                  const std::vector<Shape>& outputShapes);
139 
140 bool splitQuant8Signed(const int8_t* inputData, const Shape& inputShape, const int32_t axis,
141                        const std::vector<int8_t*>* outputDataPtrs,
142                        const std::vector<Shape>& outputShapes);
143 
144 bool groupedConvFloat16(const _Float16* inputData, const Shape& inputShape,
145                         const _Float16* filterData, const Shape& filterShape,
146                         const _Float16* biasData, const Shape& biasShape, int32_t numGroups,
147                         int32_t padding_left, int32_t padding_right, int32_t padding_top,
148                         int32_t padding_bottom, int32_t stride_width, int32_t stride_height,
149                         int32_t activation, _Float16* outputData, const Shape& outputShape);
150 
151 bool groupedConvFloat32(const float* inputData, const Shape& inputShape, const float* filterData,
152                         const Shape& filterShape, const float* biasData, const Shape& biasShape,
153                         int32_t numGroups, int32_t padding_left, int32_t padding_right,
154                         int32_t padding_top, int32_t padding_bottom, int32_t stride_width,
155                         int32_t stride_height, int32_t activation, float* outputData,
156                         const Shape& outputShape);
157 
158 template <typename T>
159 bool groupedConvQuant8(const T* inputData, const Shape& inputShape, const T* filterData,
160                        const Shape& filterShape, const int32_t* biasData, const Shape& biasShape,
161                        int32_t numGroups, int32_t padding_left, int32_t padding_right,
162                        int32_t padding_top, int32_t padding_bottom, int32_t stride_width,
163                        int32_t stride_height, int32_t activation, T* outputData,
164                        const Shape& outputShape);
165 
166 template <typename T>
167 bool groupedConvQuant8PerChannel(const T* inputData, const Shape& inputShape,
168                                  const int8_t* filterData, const Shape& filterShape,
169                                  const float* filterScales, const int32_t* biasData,
170                                  const Shape& biasShape, int32_t padding_left,
171                                  int32_t padding_right, int32_t padding_top, int32_t padding_bottom,
172                                  int32_t stride_width, int32_t stride_height, int32_t numGroups,
173                                  int32_t activation, T* outputData, const Shape& outputShape);
174 
175 bool channelShuffleGeneric(const uint8_t* inputData, const Shape& inputShape, int32_t numGroups,
176                            int32_t axis, uint8_t* outputData, const Shape& outputShape);
177 }  // namespace nn
178 }  // namespace android
179 #endif  // ANDROID_FRAMEWORKS_ML_NN_COMMON_OPERATIONS_H
180