1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef ANDROID_FRAMEWORKS_ML_NN_COMMON_OPERATIONS_LSTM_H
18 #define ANDROID_FRAMEWORKS_ML_NN_COMMON_OPERATIONS_LSTM_H
19 
20 #include <tensorflow/lite/kernels/internal/tensor_utils.h>
21 
22 #include <algorithm>
23 #include <cmath>
24 #include <vector>
25 
26 #include "ActivationFunctor.h"
27 #include "HalInterfaces.h"
28 
29 namespace android {
30 namespace nn {
31 
32 struct LSTMParams {
33     TfLiteFusedActivation activation;
34     float cell_clip;
35     float proj_clip;
36     bool use_cifg;
37     bool use_peephole;
38     bool use_layer_norm;
39     bool use_projection_weight;
40     bool use_projection_bias;
41     bool merge_outputs;
42     bool time_major;
43     bool output_state;
44 };
45 
46 struct RunTimeOperandInfo;
47 struct Shape;
48 
49 class LSTMCell {
50    public:
51     LSTMCell(const hal::Operation& operation, RunTimeOperandInfo* operands);
52 
53     bool Prepare(const hal::Operation& operation, RunTimeOperandInfo* operands, Shape* scratchShape,
54                  Shape* outputStateShape, Shape* cellStateShape, Shape* outputShape);
55     bool Eval();
56 
57     // Input Tensors of size {n_batch, n_input}
58     static constexpr int kInputTensor = 0;
59 
60     // Input weight tensors of size: {n_cell, n_input}
61     static constexpr int kInputToInputWeightsTensor = 1;  // Optional
62     static constexpr int kInputToForgetWeightsTensor = 2;
63     static constexpr int kInputToCellWeightsTensor = 3;
64     static constexpr int kInputToOutputWeightsTensor = 4;
65 
66     // Recurrent weight tensors of size {n_cell, n_output}
67     static constexpr int kRecurrentToInputWeightsTensor = 5;  // Optional
68     static constexpr int kRecurrentToForgetWeightsTensor = 6;
69     static constexpr int kRecurrentToCellWeightsTensor = 7;
70     static constexpr int kRecurrentToOutputWeightsTensor = 8;
71 
72     // Peephole weights tensors of size {n_cell}, representing a diagonal matrix.
73     static constexpr int kCellToInputWeightsTensor = 9;    // Optional
74     static constexpr int kCellToForgetWeightsTensor = 10;  // Optional
75     static constexpr int kCellToOutputWeightsTensor = 11;  // Optional
76 
77     // Gates bias tensors of size {n_cell}
78     static constexpr int kInputGateBiasTensor = 12;  // Optional
79     static constexpr int kForgetGateBiasTensor = 13;
80     static constexpr int kCellGateBiasTensor = 14;
81     static constexpr int kOutputGateBiasTensor = 15;
82 
83     // Projection weight tensor of size {n_output, n_cell}
84     static constexpr int kProjectionWeightsTensor = 16;  // Optional
85     // Projection bias tensor of size {n_output}
86     static constexpr int kProjectionBiasTensor = 17;  // Optional
87 
88     static constexpr int kOutputStateInTensor = 18;
89     static constexpr int kCellStateInTensor = 19;
90 
91     static constexpr int kActivationParam = 20;
92     static constexpr int kCellClipParam = 21;
93     static constexpr int kProjClipParam = 22;
94 
95     // Layer norm weights tensors of size {n_cell}, representing a diagonal matrix.
96     static constexpr int kInputLayerNormWeightsTensor = 23;
97     static constexpr int kForgetLayerNormWeightsTensor = 24;
98     static constexpr int kCellLayerNormWeightsTensor = 25;
99     static constexpr int kOutputLayerNormWeightsTensor = 26;
100 
101     // Output tensors.
102     static constexpr int kScratchBufferTensor = 0;
103     static constexpr int kOutputStateOutTensor = 1;
104     static constexpr int kCellStateOutTensor = 2;
105     static constexpr int kOutputTensor = 3;
106 
107     static bool LSTMEvalFloat32(
108             const LSTMParams& params, const float* input_buffer, const Shape& input_shape,
109             const float* input_to_input_weights_buffer, const float* input_to_forget_weights_buffer,
110             const float* input_to_cell_weights_buffer, const float* input_to_output_weights_buffer,
111             const Shape& input_to_output_weights_shape,
112             const float* recurrent_to_input_weights_buffer,
113             const float* recurrent_to_forget_weights_buffer,
114             const float* recurrent_to_cell_weights_buffer,
115             const float* recurrent_to_output_weights_buffer,
116             const Shape& recurrent_to_output_weights_shape,
117             const float* cell_to_input_weights_buffer, const float* cell_to_forget_weights_buffer,
118             const float* cell_to_output_weights_buffer, const float* aux_input_buffer,
119             const float* aux_input_to_input_weights, const float* aux_input_to_forget_weights,
120             const float* aux_input_to_cell_weights, const float* aux_input_to_output_weights,
121             const float* input_gate_bias_buffer, const float* forget_gate_bias_buffer,
122             const float* cell_bias_buffer, const float* output_gate_bias_buffer,
123             const float* projection_weights_buffer, const float* projection_bias_buffer,
124             const float* output_state_in_buffer, const float* cell_state_in_buffer,
125             const float* input_layer_norm_weights_buffer,
126             const float* forget_layer_norm_weights_buffer,
127             const float* cell_layer_norm_weights_buffer,
128             const float* output_layer_norm_weights_buffer, float* output_state_out_buffer,
129             float* cell_state_out_buffer, float* output_buffer, float* scratch_buffer_buffer,
130             bool timeMajor = true, bool forwardSequence = true);
131 
132     static bool LSTMEvalFloat16(
133             const LSTMParams& params, const _Float16* input_buffer, const Shape& input_shape,
134             const _Float16* input_to_input_weights_buffer,
135             const _Float16* input_to_forget_weights_buffer,
136             const _Float16* input_to_cell_weights_buffer,
137             const _Float16* input_to_output_weights_buffer,
138             const Shape& input_to_output_weights_shape,
139             const _Float16* recurrent_to_input_weights_buffer,
140             const _Float16* recurrent_to_forget_weights_buffer,
141             const _Float16* recurrent_to_cell_weights_buffer,
142             const _Float16* recurrent_to_output_weights_buffer,
143             const Shape& recurrent_to_output_weights_shape,
144             const _Float16* cell_to_input_weights_buffer,
145             const _Float16* cell_to_forget_weights_buffer,
146             const _Float16* cell_to_output_weights_buffer, const _Float16* aux_input_buffer,
147             const _Float16* aux_input_to_input_weights, const _Float16* aux_input_to_forget_weights,
148             const _Float16* aux_input_to_cell_weights, const _Float16* aux_input_to_output_weights,
149             const _Float16* input_gate_bias_buffer, const _Float16* forget_gate_bias_buffer,
150             const _Float16* cell_bias_buffer, const _Float16* output_gate_bias_buffer,
151             const _Float16* projection_weights_buffer, const _Float16* projection_bias_buffer,
152             const _Float16* output_state_in_buffer, const _Float16* cell_state_in_buffer,
153             const _Float16* input_layer_norm_weights_buffer,
154             const _Float16* forget_layer_norm_weights_buffer,
155             const _Float16* cell_layer_norm_weights_buffer,
156             const _Float16* output_layer_norm_weights_buffer, _Float16* output_state_out_buffer,
157             _Float16* cell_state_out_buffer, _Float16* output_buffer,
158             _Float16* scratch_buffer_buffer, bool timeMajor = true, bool forwardSequence = true);
159 
160     static bool LSTMStep(
161             const LSTMParams& params, const float* input_buffer, const Shape& input_shape,
162             const float* input_to_input_weights_buffer, const float* input_to_forget_weights_buffer,
163             const float* input_to_cell_weights_buffer, const float* input_to_output_weights_buffer,
164             const Shape& input_to_output_weights_shape,
165             const float* recurrent_to_input_weights_buffer,
166             const float* recurrent_to_forget_weights_buffer,
167             const float* recurrent_to_cell_weights_buffer,
168             const float* recurrent_to_output_weights_buffer,
169             const Shape& recurrent_to_output_weights_shape,
170             const float* cell_to_input_weights_buffer, const float* cell_to_forget_weights_buffer,
171             const float* cell_to_output_weights_buffer, const float* aux_input_buffer,
172             const float* aux_input_to_input_weights, const float* aux_input_to_forget_weights,
173             const float* aux_input_to_cell_weights, const float* aux_input_to_output_weights,
174             const float* input_gate_bias_buffer, const float* forget_gate_bias_buffer,
175             const float* cell_bias_buffer, const float* output_gate_bias_buffer,
176             const float* projection_weights_buffer, const float* projection_bias_buffer,
177             const float* output_state_in_buffer, const float* cell_state_in_buffer,
178             const float* input_layer_norm_weights_buffer,
179             const float* forget_layer_norm_weights_buffer,
180             const float* cell_layer_norm_weights_buffer,
181             const float* output_layer_norm_weights_buffer, float* output_state_out_buffer,
182             float* cell_state_out_buffer, float* output_buffer, float* scratch_buffer_buffer);
183 
184     static bool CheckInputTensorDimensions(
185             const RunTimeOperandInfo* input_, const RunTimeOperandInfo* input_to_input_weights,
186             const RunTimeOperandInfo* input_to_forget_weights,
187             const RunTimeOperandInfo* input_to_cell_weights,
188             const RunTimeOperandInfo* input_to_output_weights,
189             const RunTimeOperandInfo* recurrent_to_input_weights,
190             const RunTimeOperandInfo* recurrent_to_forget_weights,
191             const RunTimeOperandInfo* recurrent_to_cell_weights,
192             const RunTimeOperandInfo* recurrent_to_output_weights,
193             const RunTimeOperandInfo* cell_to_input_weights,
194             const RunTimeOperandInfo* cell_to_forget_weights,
195             const RunTimeOperandInfo* cell_to_output_weights,
196             const RunTimeOperandInfo* input_gate_bias, const RunTimeOperandInfo* forget_gate_bias,
197             const RunTimeOperandInfo* cell_bias, const RunTimeOperandInfo* output_gate_bias,
198             const RunTimeOperandInfo* projection_weights, const RunTimeOperandInfo* projection_bias,
199             const RunTimeOperandInfo* input_layer_norm_weights,
200             const RunTimeOperandInfo* forget_layer_norm_weights,
201             const RunTimeOperandInfo* cell_layer_norm_weights,
202             const RunTimeOperandInfo* output_layer_norm_weights, uint32_t n_input,
203             uint32_t n_output, uint32_t n_cell, LSTMParams* params);
204 
205    private:
206     LSTMParams params_;
207     const RunTimeOperandInfo* input_;
208 
209     const RunTimeOperandInfo* input_to_input_weights_;
210     const RunTimeOperandInfo* input_to_forget_weights_;
211     const RunTimeOperandInfo* input_to_cell_weights_;
212     const RunTimeOperandInfo* input_to_output_weights_;
213 
214     const RunTimeOperandInfo* recurrent_to_input_weights_;
215     const RunTimeOperandInfo* recurrent_to_forget_weights_;
216     const RunTimeOperandInfo* recurrent_to_cell_weights_;
217     const RunTimeOperandInfo* recurrent_to_output_weights_;
218 
219     const RunTimeOperandInfo* cell_to_input_weights_;
220     const RunTimeOperandInfo* cell_to_forget_weights_;
221     const RunTimeOperandInfo* cell_to_output_weights_;
222 
223     const RunTimeOperandInfo* input_gate_bias_;
224     const RunTimeOperandInfo* forget_gate_bias_;
225     const RunTimeOperandInfo* cell_bias_;
226     const RunTimeOperandInfo* output_gate_bias_;
227 
228     const RunTimeOperandInfo* projection_weights_;
229     const RunTimeOperandInfo* projection_bias_;
230 
231     const RunTimeOperandInfo* output_state_in_;
232     const RunTimeOperandInfo* cell_state_in_;
233 
234     const RunTimeOperandInfo* input_layer_norm_weights_;
235     const RunTimeOperandInfo* forget_layer_norm_weights_;
236     const RunTimeOperandInfo* cell_layer_norm_weights_;
237     const RunTimeOperandInfo* output_layer_norm_weights_;
238 
239     RunTimeOperandInfo* output_state_out_;
240     RunTimeOperandInfo* cell_state_out_;
241     RunTimeOperandInfo* output_;
242 
243     RunTimeOperandInfo* scratch_buffer_;
244 };
245 
246 }  // namespace nn
247 }  // namespace android
248 
249 #endif  // ANDROID_FRAMEWORKS_ML_NN_COMMON_OPERATIONS_LSTM_H
250