1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef ANDROID_FRAMEWORKS_ML_NN_COMMON_OPERATIONS_H 18 #define ANDROID_FRAMEWORKS_ML_NN_COMMON_OPERATIONS_H 19 20 #include "operations/BidirectionalSequenceLSTM.h" 21 #include "operations/Cast.h" 22 #include "operations/EmbeddingLookup.h" 23 #include "operations/ExpandDims.h" 24 #include "operations/HashtableLookup.h" 25 #include "operations/LSHProjection.h" 26 #include "operations/LSTM.h" 27 #include "operations/MaximumMinimum.h" 28 #include "operations/Multinomial.h" 29 #include "operations/Pow.h" 30 #include "operations/QuantizedLSTM.h" 31 #include "operations/RNN.h" 32 #include "operations/SVDF.h" 33 #include "operations/Tile.h" 34 35 #include <stddef.h> 36 37 #include <cstdint> 38 #include <vector> 39 40 namespace android { 41 namespace nn { 42 43 struct Shape; 44 45 bool floorFloat16(const _Float16* inputData, _Float16* outputData, const Shape& shape); 46 bool floorFloat32(const float* inputData, float* outputData, const Shape& shape); 47 48 bool depthwiseConvFloat16(const _Float16* inputData, const Shape& inputShape, 49 const _Float16* filterData, const Shape& filterShape, 50 const _Float16* biasData, const Shape& biasShape, int32_t paddingLeft, 51 int32_t paddingRight, int32_t paddingTop, int32_t paddingBottom, 52 int32_t strideWidth, int32_t strideHeight, int32_t dilationWidthFactor, 53 int32_t dilationHeightFactor, int32_t depthMultiplier, int32_t activation, 54 _Float16* outputData, const Shape& outputShape); 55 bool depthwiseConvFloat32(const float* inputData, const Shape& inputShape, const float* filterData, 56 const Shape& filterShape, const float* biasData, const Shape& biasShape, 57 int32_t paddingLeft, int32_t paddingRight, int32_t paddingTop, 58 int32_t paddingBottom, int32_t strideWidth, int32_t strideHeight, 59 int32_t dilationWidthFactor, int32_t dilationHeightFactor, 60 int32_t depthMultiplier, int32_t activation, float* outputData, 61 const Shape& outputShape); 62 bool depthwiseConvQuant8(const uint8_t* inputData, const Shape& inputShape, 63 const uint8_t* filterData, const Shape& filterShape, 64 const int32_t* biasData, const Shape& biasShape, int32_t paddingLeft, 65 int32_t paddingRight, int32_t paddingTop, int32_t paddingBottom, 66 int32_t strideWidth, int32_t strideHeight, int32_t dilationWidthFactor, 67 int32_t dilationHeightFactor, int32_t depthMultiplier, int32_t activation, 68 uint8_t* outputData, const Shape& outputShape); 69 bool depthwiseConvQuant8PerChannel(const uint8_t* inputData, const Shape& inputShape, 70 const int8_t* filterData, const Shape& filterShape, 71 const float* filterScales, const int32_t* biasData, 72 const Shape& biasShape, int32_t paddingLeft, 73 int32_t paddingRight, int32_t paddingTop, int32_t paddingBottom, 74 int32_t strideWidth, int32_t strideHeight, 75 int32_t dilationWidthFactor, int32_t dilationHeightFactor, 76 int32_t depthMultiplier, int32_t activation, uint8_t* outputData, 77 const Shape& outputShape); 78 79 bool localResponseNormFloat16(const _Float16* inputData, const Shape& inputShape, int32_t radius, 80 float bias, float alpha, float beta, int32_t axis, 81 _Float16* outputData, const Shape& outputShape); 82 bool localResponseNormFloat32(const float* inputData, const Shape& inputShape, int32_t radius, 83 float bias, float alpha, float beta, int32_t axis, float* outputData, 84 const Shape& outputShape); 85 86 bool copyData(const void* inputData, const Shape& inputShape, void* outputData, 87 const Shape& outputShape); 88 89 template <typename T> 90 bool depthToSpaceGeneric(const T* inputData, const Shape& inputShape, int32_t blockSize, 91 T* outputData, const Shape& outputShape); 92 template <typename T> 93 bool spaceToDepthGeneric(const T* inputData, const Shape& inputShape, int32_t blockSize, 94 T* outputData, const Shape& outputShape); 95 96 template <typename T> 97 bool padGeneric(const T* inputData, const Shape& inputShape, const int32_t* paddings, T pad_value, 98 T* outputData, const Shape& outputShape); 99 100 template <typename T> 101 bool batchToSpaceGeneric(const T* inputData, const Shape& inputShape, const int32_t* blockSize, 102 T* outputData, const Shape& outputShape); 103 104 template <typename T> 105 bool spaceToBatchGeneric(const T* inputData, const Shape& inputShape, const int32_t* blockSize, 106 const int32_t* padding, const Shape& paddingShape, T* outputData, 107 const Shape& outputShape); 108 109 bool meanFloat16(_Float16* inputData, const Shape& inputShape, const int32_t* axis, 110 const Shape& axisShape, bool keepDims, _Float16* outputData, 111 const Shape& outputShape); 112 template <typename T, typename U> 113 bool meanGeneric(T* inputData, const Shape& inputShape, const int32_t* axis, const Shape& axisShape, 114 bool keepDims, T* outputData, const Shape& outputShape); 115 116 bool stridedSliceGeneric(const uint8_t* inputData, const Shape& inputShape, 117 const int32_t* beginData, const int32_t* endData, 118 const int32_t* stridesData, int32_t beginMask, int32_t endMask, 119 int32_t shrinkAxisMask, uint8_t* outputData, const Shape& outputShape); 120 121 bool argMinMaxGeneric(const uint8_t* inputData, const Shape& inputShape, int32_t axis, 122 bool isArgMin, uint8_t* outputData, const Shape& outputShape); 123 124 bool splitFloat16(const _Float16* inputData, const Shape& inputShape, int32_t axis, 125 const std::vector<_Float16*>* outputDataPtrs, 126 const std::vector<Shape>& outputShapes); 127 128 bool splitFloat32(const float* inputData, const Shape& inputShape, const int32_t axis, 129 const std::vector<float*>* outputDataPtrs, 130 const std::vector<Shape>& outputShapes); 131 132 bool splitInt32(const int32_t* inputData, const Shape& inputShape, const int32_t axis, 133 const std::vector<int32_t*>* outputDataPtrs, 134 const std::vector<Shape>& outputShapes); 135 136 bool splitQuant8(const uint8_t* inputData, const Shape& inputShape, const int32_t axis, 137 const std::vector<uint8_t*>* outputDataPtrs, 138 const std::vector<Shape>& outputShapes); 139 140 bool splitQuant8Signed(const int8_t* inputData, const Shape& inputShape, const int32_t axis, 141 const std::vector<int8_t*>* outputDataPtrs, 142 const std::vector<Shape>& outputShapes); 143 144 bool groupedConvFloat16(const _Float16* inputData, const Shape& inputShape, 145 const _Float16* filterData, const Shape& filterShape, 146 const _Float16* biasData, const Shape& biasShape, int32_t numGroups, 147 int32_t padding_left, int32_t padding_right, int32_t padding_top, 148 int32_t padding_bottom, int32_t stride_width, int32_t stride_height, 149 int32_t activation, _Float16* outputData, const Shape& outputShape); 150 151 bool groupedConvFloat32(const float* inputData, const Shape& inputShape, const float* filterData, 152 const Shape& filterShape, const float* biasData, const Shape& biasShape, 153 int32_t numGroups, int32_t padding_left, int32_t padding_right, 154 int32_t padding_top, int32_t padding_bottom, int32_t stride_width, 155 int32_t stride_height, int32_t activation, float* outputData, 156 const Shape& outputShape); 157 158 template <typename T> 159 bool groupedConvQuant8(const T* inputData, const Shape& inputShape, const T* filterData, 160 const Shape& filterShape, const int32_t* biasData, const Shape& biasShape, 161 int32_t numGroups, int32_t padding_left, int32_t padding_right, 162 int32_t padding_top, int32_t padding_bottom, int32_t stride_width, 163 int32_t stride_height, int32_t activation, T* outputData, 164 const Shape& outputShape); 165 166 template <typename T> 167 bool groupedConvQuant8PerChannel(const T* inputData, const Shape& inputShape, 168 const int8_t* filterData, const Shape& filterShape, 169 const float* filterScales, const int32_t* biasData, 170 const Shape& biasShape, int32_t padding_left, 171 int32_t padding_right, int32_t padding_top, int32_t padding_bottom, 172 int32_t stride_width, int32_t stride_height, int32_t numGroups, 173 int32_t activation, T* outputData, const Shape& outputShape); 174 175 bool channelShuffleGeneric(const uint8_t* inputData, const Shape& inputShape, int32_t numGroups, 176 int32_t axis, uint8_t* outputData, const Shape& outputShape); 177 } // namespace nn 178 } // namespace android 179 #endif // ANDROID_FRAMEWORKS_ML_NN_COMMON_OPERATIONS_H 180