1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef ANDROID_FRAMEWORKS_ML_NN_COMMON_OPERATIONS_BIDIRECTIONAL_SEQUENCE_LSTM_H
18 #define ANDROID_FRAMEWORKS_ML_NN_COMMON_OPERATIONS_BIDIRECTIONAL_SEQUENCE_LSTM_H
19 
20 #include <tensorflow/lite/kernels/internal/tensor_utils.h>
21 
22 #include <algorithm>
23 #include <cmath>
24 #include <vector>
25 
26 #include "ActivationFunctor.h"
27 #include "LSTM.h"
28 #include "OperationsUtils.h"
29 
30 namespace android {
31 namespace nn {
32 
33 struct RunTimeOperandInfo;
34 
35 class BidirectionalSequenceLSTM {
36    public:
37     BidirectionalSequenceLSTM(const hal::Operation& operation, RunTimeOperandInfo* operands);
38 
39     bool Prepare(const hal::Operation& operation, RunTimeOperandInfo* operands,
40                  Shape* fwOutputShape, Shape* bwOutputShape, Shape* fwOutputActivationState,
41                  Shape* fwOutputCellState, Shape* bwOutputActivationState,
42                  Shape* bwOutputCellState);
43     bool Eval();
44 
45     // Input Tensors of size {max_time, n_batch, n_input}
46     static constexpr int kInputTensor = 0;
47 
48     // Forward LSTM cell tensors.
49     // Input weight tensors of size: {n_cell, n_input}
50     static constexpr int kFwInputToInputWeightsTensor = 1;  // Optional
51     static constexpr int kFwInputToForgetWeightsTensor = 2;
52     static constexpr int kFwInputToCellWeightsTensor = 3;
53     static constexpr int kFwInputToOutputWeightsTensor = 4;
54 
55     // Recurrent weight tensors of size {n_cell, n_output}
56     static constexpr int kFwRecurrentToInputWeightsTensor = 5;  // Optional
57     static constexpr int kFwRecurrentToForgetWeightsTensor = 6;
58     static constexpr int kFwRecurrentToCellWeightsTensor = 7;
59     static constexpr int kFwRecurrentToOutputWeightsTensor = 8;
60 
61     // Peephole weights tensors of size {n_cell}, representing a diagonal matrix.
62     static constexpr int kFwCellToInputWeightsTensor = 9;    // Optional
63     static constexpr int kFwCellToForgetWeightsTensor = 10;  // Optional
64     static constexpr int kFwCellToOutputWeightsTensor = 11;  // Optional
65 
66     // Gates bias tensors of size {n_cell}
67     static constexpr int kFwInputGateBiasTensor = 12;  // Optional
68     static constexpr int kFwForgetGateBiasTensor = 13;
69     static constexpr int kFwCellGateBiasTensor = 14;
70     static constexpr int kFwOutputGateBiasTensor = 15;
71 
72     // Projection weight tensor of size {n_output, n_cell}
73     static constexpr int kFwProjectionWeightsTensor = 16;  // Optional
74     // Projection bias tensor of size {n_output}
75     static constexpr int kFwProjectionBiasTensor = 17;  // Optional
76 
77     // Backward LSTM cell tensors.
78     // Input weight tensors of size: {n_cell, n_input}
79     static constexpr int kBwInputToInputWeightsTensor = 18;  // Optional
80     static constexpr int kBwInputToForgetWeightsTensor = 19;
81     static constexpr int kBwInputToCellWeightsTensor = 20;
82     static constexpr int kBwInputToOutputWeightsTensor = 21;
83 
84     // Recurrent weight tensors of size {n_cell, n_output}
85     static constexpr int kBwRecurrentToInputWeightsTensor = 22;  // Optional
86     static constexpr int kBwRecurrentToForgetWeightsTensor = 23;
87     static constexpr int kBwRecurrentToCellWeightsTensor = 24;
88     static constexpr int kBwRecurrentToOutputWeightsTensor = 25;
89 
90     // Peephole weights tensors of size {n_cell}, representing a diagonal matrix.
91     static constexpr int kBwCellToInputWeightsTensor = 26;   // Optional
92     static constexpr int kBwCellToForgetWeightsTensor = 27;  // Optional
93     static constexpr int kBwCellToOutputWeightsTensor = 28;  // Optional
94 
95     // Gates bias tensors of size {n_cell}
96     static constexpr int kBwInputGateBiasTensor = 29;  // Optional
97     static constexpr int kBwForgetGateBiasTensor = 30;
98     static constexpr int kBwCellGateBiasTensor = 31;
99     static constexpr int kBwOutputGateBiasTensor = 32;
100 
101     // Projection weight tensor of size {n_output, n_cell}
102     static constexpr int kBwProjectionWeightsTensor = 33;  // Optional
103     // Projection bias tensor of size {n_output}
104     static constexpr int kBwProjectionBiasTensor = 34;  // Optional
105 
106     // Stateful input tensors that are variables and will be modified by the Op.
107     // Activation state tensors of size {n_batch, n_output}
108     static constexpr int kFwInputActivationStateTensor = 35;
109     // Cell state tensors of size {n_batch, n_cell}
110     static constexpr int kFwInputCellStateTensor = 36;
111     // Activation state tensors of size {n_batch, n_output}
112     static constexpr int kBwInputActivationStateTensor = 37;
113     // Cell state tensors of size {n_batch, n_cell}
114     static constexpr int kBwInputCellStateTensor = 38;
115 
116     // Used as auxiliary input and weights when stacking for
117     // tf.contrib.rnn.stack_bidirectional_rnn case (with cross links); Used as input
118     // to the backward cell when stacking for tf.nn.static_bidirectional_rnn case
119     // (without cross links).
120     static constexpr int kAuxInputTensor = 39;  // Optional
121     // Forward weights.
122     static constexpr int kFwAuxInputToInputWeightsTensor = 40;   // Optional
123     static constexpr int kFwAuxInputToForgetWeightsTensor = 41;  // Optional
124     static constexpr int kFwAuxInputToCellWeightsTensor = 42;    // Optional
125     static constexpr int kFwAuxInputToOutputWeightsTensor = 43;  // Optional
126     // Backward weights.
127     static constexpr int kBwAuxInputToInputWeightsTensor = 44;   // Optional
128     static constexpr int kBwAuxInputToForgetWeightsTensor = 45;  // Optional
129     static constexpr int kBwAuxInputToCellWeightsTensor = 46;    // Optional
130     static constexpr int kBwAuxInputToOutputWeightsTensor = 47;  // Optional
131 
132     static constexpr int kActivationParam = 48;
133     static constexpr int kCellClipParam = 49;
134     static constexpr int kProjClipParam = 50;
135     static constexpr int kMergeOutputsParam = 51;
136     static constexpr int kTimeMajorParam = 52;
137 
138     // Forward layer norm weights tensors of size {n_cell}, representing a diagonal matrix.
139     static constexpr int kFwInputLayerNormWeightsTensor = 53;   // Optional
140     static constexpr int kFwForgetLayerNormWeightsTensor = 54;  // Optional
141     static constexpr int kFwCellLayerNormWeightsTensor = 55;    // Optional
142     static constexpr int kFwOutputLayerNormWeightsTensor = 56;  // Optional
143     // Backward layer norm weights tensors of size {n_cell}, representing a diagonal matrix.
144     static constexpr int kBwInputLayerNormWeightsTensor = 57;   // Optional
145     static constexpr int kBwForgetLayerNormWeightsTensor = 58;  // Optional
146     static constexpr int kBwCellLayerNormWeightsTensor = 59;    // Optional
147     static constexpr int kBwOutputLayerNormWeightsTensor = 60;  // Optional
148 
149     // Output tensors.
150     static constexpr int kFwOutputTensor = 0;
151     static constexpr int kBwOutputTensor = 1;  // Ignored if merge_outputs is set.
152 
153     static constexpr int kFwOutputActivationStateTensor = 2;
154     static constexpr int kFwOutputCellStateTensor = 3;
155     static constexpr int kBwOutputActivationStateTensor = 4;
156     static constexpr int kBwOutputCellStateTensor = 5;
157 
158    private:
159     LSTMParams params_;
160     Shape fw_scratch_shape_;
161     Shape bw_scratch_shape_;
162 
163     const RunTimeOperandInfo* input_;
164 
165     const RunTimeOperandInfo* aux_input_;
166     const RunTimeOperandInfo* fw_aux_input_to_input_weights_;
167     const RunTimeOperandInfo* fw_aux_input_to_forget_weights_;
168     const RunTimeOperandInfo* fw_aux_input_to_cell_weights_;
169     const RunTimeOperandInfo* fw_aux_input_to_output_weights_;
170     const RunTimeOperandInfo* bw_aux_input_to_input_weights_;
171     const RunTimeOperandInfo* bw_aux_input_to_forget_weights_;
172     const RunTimeOperandInfo* bw_aux_input_to_cell_weights_;
173     const RunTimeOperandInfo* bw_aux_input_to_output_weights_;
174 
175     const RunTimeOperandInfo* fw_input_to_input_weights_;
176     const RunTimeOperandInfo* fw_input_to_forget_weights_;
177     const RunTimeOperandInfo* fw_input_to_cell_weights_;
178     const RunTimeOperandInfo* fw_input_to_output_weights_;
179 
180     const RunTimeOperandInfo* fw_recurrent_to_input_weights_;
181     const RunTimeOperandInfo* fw_recurrent_to_forget_weights_;
182     const RunTimeOperandInfo* fw_recurrent_to_cell_weights_;
183     const RunTimeOperandInfo* fw_recurrent_to_output_weights_;
184 
185     const RunTimeOperandInfo* fw_cell_to_input_weights_;
186     const RunTimeOperandInfo* fw_cell_to_forget_weights_;
187     const RunTimeOperandInfo* fw_cell_to_output_weights_;
188 
189     const RunTimeOperandInfo* fw_input_gate_bias_;
190     const RunTimeOperandInfo* fw_forget_gate_bias_;
191     const RunTimeOperandInfo* fw_cell_bias_;
192     const RunTimeOperandInfo* fw_output_gate_bias_;
193 
194     const RunTimeOperandInfo* fw_projection_weights_;
195     const RunTimeOperandInfo* fw_projection_bias_;
196 
197     const RunTimeOperandInfo* fw_input_layer_norm_weights_;
198     const RunTimeOperandInfo* fw_forget_layer_norm_weights_;
199     const RunTimeOperandInfo* fw_cell_layer_norm_weights_;
200     const RunTimeOperandInfo* fw_output_layer_norm_weights_;
201 
202     const RunTimeOperandInfo* fw_activation_state_;
203     const RunTimeOperandInfo* fw_cell_state_;
204     RunTimeOperandInfo* fw_output_;
205 
206     const RunTimeOperandInfo* bw_input_to_input_weights_;
207     const RunTimeOperandInfo* bw_input_to_forget_weights_;
208     const RunTimeOperandInfo* bw_input_to_cell_weights_;
209     const RunTimeOperandInfo* bw_input_to_output_weights_;
210 
211     const RunTimeOperandInfo* bw_recurrent_to_input_weights_;
212     const RunTimeOperandInfo* bw_recurrent_to_forget_weights_;
213     const RunTimeOperandInfo* bw_recurrent_to_cell_weights_;
214     const RunTimeOperandInfo* bw_recurrent_to_output_weights_;
215 
216     const RunTimeOperandInfo* bw_cell_to_input_weights_;
217     const RunTimeOperandInfo* bw_cell_to_forget_weights_;
218     const RunTimeOperandInfo* bw_cell_to_output_weights_;
219 
220     const RunTimeOperandInfo* bw_input_gate_bias_;
221     const RunTimeOperandInfo* bw_forget_gate_bias_;
222     const RunTimeOperandInfo* bw_cell_bias_;
223     const RunTimeOperandInfo* bw_output_gate_bias_;
224 
225     const RunTimeOperandInfo* bw_projection_weights_;
226     const RunTimeOperandInfo* bw_projection_bias_;
227 
228     const RunTimeOperandInfo* bw_input_layer_norm_weights_;
229     const RunTimeOperandInfo* bw_forget_layer_norm_weights_;
230     const RunTimeOperandInfo* bw_cell_layer_norm_weights_;
231     const RunTimeOperandInfo* bw_output_layer_norm_weights_;
232 
233     const RunTimeOperandInfo* bw_activation_state_;
234     const RunTimeOperandInfo* bw_cell_state_;
235     RunTimeOperandInfo* bw_output_;
236 
237     RunTimeOperandInfo* fw_output_activation_state_;
238     RunTimeOperandInfo* fw_output_cell_state_;
239     RunTimeOperandInfo* bw_output_activation_state_;
240     RunTimeOperandInfo* bw_output_cell_state_;
241 };
242 
243 }  // namespace nn
244 }  // namespace android
245 
246 #endif  // ANDROID_FRAMEWORKS_ML_NN_COMMON_OPERATIONS_BIDIRECTIONAL_SEQUENCE_LSTM_H
247