1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef ANDROID_FRAMEWORKS_ML_NN_COMMON_OPERATIONS_LSTM_H 18 #define ANDROID_FRAMEWORKS_ML_NN_COMMON_OPERATIONS_LSTM_H 19 20 #include <tensorflow/lite/kernels/internal/tensor_utils.h> 21 22 #include <algorithm> 23 #include <cmath> 24 #include <vector> 25 26 #include "ActivationFunctor.h" 27 #include "HalInterfaces.h" 28 29 namespace android { 30 namespace nn { 31 32 struct LSTMParams { 33 TfLiteFusedActivation activation; 34 float cell_clip; 35 float proj_clip; 36 bool use_cifg; 37 bool use_peephole; 38 bool use_layer_norm; 39 bool use_projection_weight; 40 bool use_projection_bias; 41 bool merge_outputs; 42 bool time_major; 43 bool output_state; 44 }; 45 46 struct RunTimeOperandInfo; 47 struct Shape; 48 49 class LSTMCell { 50 public: 51 LSTMCell(const hal::Operation& operation, RunTimeOperandInfo* operands); 52 53 bool Prepare(const hal::Operation& operation, RunTimeOperandInfo* operands, Shape* scratchShape, 54 Shape* outputStateShape, Shape* cellStateShape, Shape* outputShape); 55 bool Eval(); 56 57 // Input Tensors of size {n_batch, n_input} 58 static constexpr int kInputTensor = 0; 59 60 // Input weight tensors of size: {n_cell, n_input} 61 static constexpr int kInputToInputWeightsTensor = 1; // Optional 62 static constexpr int kInputToForgetWeightsTensor = 2; 63 static constexpr int kInputToCellWeightsTensor = 3; 64 static constexpr int kInputToOutputWeightsTensor = 4; 65 66 // Recurrent weight tensors of size {n_cell, n_output} 67 static constexpr int kRecurrentToInputWeightsTensor = 5; // Optional 68 static constexpr int kRecurrentToForgetWeightsTensor = 6; 69 static constexpr int kRecurrentToCellWeightsTensor = 7; 70 static constexpr int kRecurrentToOutputWeightsTensor = 8; 71 72 // Peephole weights tensors of size {n_cell}, representing a diagonal matrix. 73 static constexpr int kCellToInputWeightsTensor = 9; // Optional 74 static constexpr int kCellToForgetWeightsTensor = 10; // Optional 75 static constexpr int kCellToOutputWeightsTensor = 11; // Optional 76 77 // Gates bias tensors of size {n_cell} 78 static constexpr int kInputGateBiasTensor = 12; // Optional 79 static constexpr int kForgetGateBiasTensor = 13; 80 static constexpr int kCellGateBiasTensor = 14; 81 static constexpr int kOutputGateBiasTensor = 15; 82 83 // Projection weight tensor of size {n_output, n_cell} 84 static constexpr int kProjectionWeightsTensor = 16; // Optional 85 // Projection bias tensor of size {n_output} 86 static constexpr int kProjectionBiasTensor = 17; // Optional 87 88 static constexpr int kOutputStateInTensor = 18; 89 static constexpr int kCellStateInTensor = 19; 90 91 static constexpr int kActivationParam = 20; 92 static constexpr int kCellClipParam = 21; 93 static constexpr int kProjClipParam = 22; 94 95 // Layer norm weights tensors of size {n_cell}, representing a diagonal matrix. 96 static constexpr int kInputLayerNormWeightsTensor = 23; 97 static constexpr int kForgetLayerNormWeightsTensor = 24; 98 static constexpr int kCellLayerNormWeightsTensor = 25; 99 static constexpr int kOutputLayerNormWeightsTensor = 26; 100 101 // Output tensors. 102 static constexpr int kScratchBufferTensor = 0; 103 static constexpr int kOutputStateOutTensor = 1; 104 static constexpr int kCellStateOutTensor = 2; 105 static constexpr int kOutputTensor = 3; 106 107 static bool LSTMEvalFloat32( 108 const LSTMParams& params, const float* input_buffer, const Shape& input_shape, 109 const float* input_to_input_weights_buffer, const float* input_to_forget_weights_buffer, 110 const float* input_to_cell_weights_buffer, const float* input_to_output_weights_buffer, 111 const Shape& input_to_output_weights_shape, 112 const float* recurrent_to_input_weights_buffer, 113 const float* recurrent_to_forget_weights_buffer, 114 const float* recurrent_to_cell_weights_buffer, 115 const float* recurrent_to_output_weights_buffer, 116 const Shape& recurrent_to_output_weights_shape, 117 const float* cell_to_input_weights_buffer, const float* cell_to_forget_weights_buffer, 118 const float* cell_to_output_weights_buffer, const float* aux_input_buffer, 119 const float* aux_input_to_input_weights, const float* aux_input_to_forget_weights, 120 const float* aux_input_to_cell_weights, const float* aux_input_to_output_weights, 121 const float* input_gate_bias_buffer, const float* forget_gate_bias_buffer, 122 const float* cell_bias_buffer, const float* output_gate_bias_buffer, 123 const float* projection_weights_buffer, const float* projection_bias_buffer, 124 const float* output_state_in_buffer, const float* cell_state_in_buffer, 125 const float* input_layer_norm_weights_buffer, 126 const float* forget_layer_norm_weights_buffer, 127 const float* cell_layer_norm_weights_buffer, 128 const float* output_layer_norm_weights_buffer, float* output_state_out_buffer, 129 float* cell_state_out_buffer, float* output_buffer, float* scratch_buffer_buffer, 130 bool timeMajor = true, bool forwardSequence = true); 131 132 static bool LSTMEvalFloat16( 133 const LSTMParams& params, const _Float16* input_buffer, const Shape& input_shape, 134 const _Float16* input_to_input_weights_buffer, 135 const _Float16* input_to_forget_weights_buffer, 136 const _Float16* input_to_cell_weights_buffer, 137 const _Float16* input_to_output_weights_buffer, 138 const Shape& input_to_output_weights_shape, 139 const _Float16* recurrent_to_input_weights_buffer, 140 const _Float16* recurrent_to_forget_weights_buffer, 141 const _Float16* recurrent_to_cell_weights_buffer, 142 const _Float16* recurrent_to_output_weights_buffer, 143 const Shape& recurrent_to_output_weights_shape, 144 const _Float16* cell_to_input_weights_buffer, 145 const _Float16* cell_to_forget_weights_buffer, 146 const _Float16* cell_to_output_weights_buffer, const _Float16* aux_input_buffer, 147 const _Float16* aux_input_to_input_weights, const _Float16* aux_input_to_forget_weights, 148 const _Float16* aux_input_to_cell_weights, const _Float16* aux_input_to_output_weights, 149 const _Float16* input_gate_bias_buffer, const _Float16* forget_gate_bias_buffer, 150 const _Float16* cell_bias_buffer, const _Float16* output_gate_bias_buffer, 151 const _Float16* projection_weights_buffer, const _Float16* projection_bias_buffer, 152 const _Float16* output_state_in_buffer, const _Float16* cell_state_in_buffer, 153 const _Float16* input_layer_norm_weights_buffer, 154 const _Float16* forget_layer_norm_weights_buffer, 155 const _Float16* cell_layer_norm_weights_buffer, 156 const _Float16* output_layer_norm_weights_buffer, _Float16* output_state_out_buffer, 157 _Float16* cell_state_out_buffer, _Float16* output_buffer, 158 _Float16* scratch_buffer_buffer, bool timeMajor = true, bool forwardSequence = true); 159 160 static bool LSTMStep( 161 const LSTMParams& params, const float* input_buffer, const Shape& input_shape, 162 const float* input_to_input_weights_buffer, const float* input_to_forget_weights_buffer, 163 const float* input_to_cell_weights_buffer, const float* input_to_output_weights_buffer, 164 const Shape& input_to_output_weights_shape, 165 const float* recurrent_to_input_weights_buffer, 166 const float* recurrent_to_forget_weights_buffer, 167 const float* recurrent_to_cell_weights_buffer, 168 const float* recurrent_to_output_weights_buffer, 169 const Shape& recurrent_to_output_weights_shape, 170 const float* cell_to_input_weights_buffer, const float* cell_to_forget_weights_buffer, 171 const float* cell_to_output_weights_buffer, const float* aux_input_buffer, 172 const float* aux_input_to_input_weights, const float* aux_input_to_forget_weights, 173 const float* aux_input_to_cell_weights, const float* aux_input_to_output_weights, 174 const float* input_gate_bias_buffer, const float* forget_gate_bias_buffer, 175 const float* cell_bias_buffer, const float* output_gate_bias_buffer, 176 const float* projection_weights_buffer, const float* projection_bias_buffer, 177 const float* output_state_in_buffer, const float* cell_state_in_buffer, 178 const float* input_layer_norm_weights_buffer, 179 const float* forget_layer_norm_weights_buffer, 180 const float* cell_layer_norm_weights_buffer, 181 const float* output_layer_norm_weights_buffer, float* output_state_out_buffer, 182 float* cell_state_out_buffer, float* output_buffer, float* scratch_buffer_buffer); 183 184 static bool CheckInputTensorDimensions( 185 const RunTimeOperandInfo* input_, const RunTimeOperandInfo* input_to_input_weights, 186 const RunTimeOperandInfo* input_to_forget_weights, 187 const RunTimeOperandInfo* input_to_cell_weights, 188 const RunTimeOperandInfo* input_to_output_weights, 189 const RunTimeOperandInfo* recurrent_to_input_weights, 190 const RunTimeOperandInfo* recurrent_to_forget_weights, 191 const RunTimeOperandInfo* recurrent_to_cell_weights, 192 const RunTimeOperandInfo* recurrent_to_output_weights, 193 const RunTimeOperandInfo* cell_to_input_weights, 194 const RunTimeOperandInfo* cell_to_forget_weights, 195 const RunTimeOperandInfo* cell_to_output_weights, 196 const RunTimeOperandInfo* input_gate_bias, const RunTimeOperandInfo* forget_gate_bias, 197 const RunTimeOperandInfo* cell_bias, const RunTimeOperandInfo* output_gate_bias, 198 const RunTimeOperandInfo* projection_weights, const RunTimeOperandInfo* projection_bias, 199 const RunTimeOperandInfo* input_layer_norm_weights, 200 const RunTimeOperandInfo* forget_layer_norm_weights, 201 const RunTimeOperandInfo* cell_layer_norm_weights, 202 const RunTimeOperandInfo* output_layer_norm_weights, uint32_t n_input, 203 uint32_t n_output, uint32_t n_cell, LSTMParams* params); 204 205 private: 206 LSTMParams params_; 207 const RunTimeOperandInfo* input_; 208 209 const RunTimeOperandInfo* input_to_input_weights_; 210 const RunTimeOperandInfo* input_to_forget_weights_; 211 const RunTimeOperandInfo* input_to_cell_weights_; 212 const RunTimeOperandInfo* input_to_output_weights_; 213 214 const RunTimeOperandInfo* recurrent_to_input_weights_; 215 const RunTimeOperandInfo* recurrent_to_forget_weights_; 216 const RunTimeOperandInfo* recurrent_to_cell_weights_; 217 const RunTimeOperandInfo* recurrent_to_output_weights_; 218 219 const RunTimeOperandInfo* cell_to_input_weights_; 220 const RunTimeOperandInfo* cell_to_forget_weights_; 221 const RunTimeOperandInfo* cell_to_output_weights_; 222 223 const RunTimeOperandInfo* input_gate_bias_; 224 const RunTimeOperandInfo* forget_gate_bias_; 225 const RunTimeOperandInfo* cell_bias_; 226 const RunTimeOperandInfo* output_gate_bias_; 227 228 const RunTimeOperandInfo* projection_weights_; 229 const RunTimeOperandInfo* projection_bias_; 230 231 const RunTimeOperandInfo* output_state_in_; 232 const RunTimeOperandInfo* cell_state_in_; 233 234 const RunTimeOperandInfo* input_layer_norm_weights_; 235 const RunTimeOperandInfo* forget_layer_norm_weights_; 236 const RunTimeOperandInfo* cell_layer_norm_weights_; 237 const RunTimeOperandInfo* output_layer_norm_weights_; 238 239 RunTimeOperandInfo* output_state_out_; 240 RunTimeOperandInfo* cell_state_out_; 241 RunTimeOperandInfo* output_; 242 243 RunTimeOperandInfo* scratch_buffer_; 244 }; 245 246 } // namespace nn 247 } // namespace android 248 249 #endif // ANDROID_FRAMEWORKS_ML_NN_COMMON_OPERATIONS_LSTM_H 250