1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "Operations"
18 
19 #include "LSTM.h"
20 
21 #include <vector>
22 
23 #include "CpuExecutor.h"
24 #include "CpuOperationUtils.h"
25 #include "HalInterfaces.h"
26 #include "OperationsUtils.h"
27 #include "Tracing.h"
28 #include "Utils.h"
29 
30 namespace android {
31 namespace nn {
32 
33 namespace {
34 
35 using namespace hal;
36 
37 template <typename T>
GetBuffer(RunTimeOperandInfo * operand)38 inline T* GetBuffer(RunTimeOperandInfo* operand) {
39     return reinterpret_cast<T*>(operand->buffer);
40 }
41 
42 template <typename T>
GetBuffer(const RunTimeOperandInfo * operand)43 inline const T* GetBuffer(const RunTimeOperandInfo* operand) {
44     return reinterpret_cast<const T*>(operand->buffer);
45 }
46 
47 template <typename T>
GetOptionalBuffer(const RunTimeOperandInfo * operand)48 inline const T* GetOptionalBuffer(const RunTimeOperandInfo* operand) {
49     return !IsNullInput(operand) ? reinterpret_cast<const T*>(operand->buffer) : nullptr;
50 }
51 
52 }  // anonymous namespace
53 
LSTMCell(const Operation & operation,RunTimeOperandInfo * operands)54 LSTMCell::LSTMCell(const Operation& operation, RunTimeOperandInfo* operands) {
55     input_ = GetInput(operation, operands, kInputTensor);
56 
57     input_to_input_weights_ =
58             GetInput(operation, operands, kInputToInputWeightsTensor);  // optional
59     input_to_forget_weights_ = GetInput(operation, operands, kInputToForgetWeightsTensor);
60     input_to_cell_weights_ = GetInput(operation, operands, kInputToCellWeightsTensor);
61     input_to_output_weights_ = GetInput(operation, operands, kInputToOutputWeightsTensor);
62 
63     recurrent_to_input_weights_ =
64             GetInput(operation, operands, kRecurrentToInputWeightsTensor);  // optional
65     recurrent_to_forget_weights_ = GetInput(operation, operands, kRecurrentToForgetWeightsTensor);
66     recurrent_to_cell_weights_ = GetInput(operation, operands, kRecurrentToCellWeightsTensor);
67     recurrent_to_output_weights_ = GetInput(operation, operands, kRecurrentToOutputWeightsTensor);
68 
69     cell_to_input_weights_ = GetInput(operation, operands, kCellToInputWeightsTensor);  // optional
70     cell_to_forget_weights_ =
71             GetInput(operation, operands, kCellToForgetWeightsTensor);  // optional
72     cell_to_output_weights_ =
73             GetInput(operation, operands, kCellToOutputWeightsTensor);  // optional
74 
75     input_gate_bias_ = GetInput(operation, operands, kInputGateBiasTensor);
76     forget_gate_bias_ = GetInput(operation, operands, kForgetGateBiasTensor);
77     cell_bias_ = GetInput(operation, operands, kCellGateBiasTensor);
78     output_gate_bias_ = GetInput(operation, operands, kOutputGateBiasTensor);
79 
80     projection_weights_ = GetInput(operation, operands, kProjectionWeightsTensor);  // optional
81     projection_bias_ = GetInput(operation, operands, kProjectionBiasTensor);        // optional
82 
83     output_state_in_ = GetInput(operation, operands, kOutputStateInTensor);
84     cell_state_in_ = GetInput(operation, operands, kCellStateInTensor);
85 
86     const auto& activationOperand = *GetInput(operation, operands, kActivationParam);
87     params_.activation = static_cast<TfLiteFusedActivation>(getScalarDataWithDefault<int32_t>(
88             activationOperand, TfLiteFusedActivation::kTfLiteActNone));
89 
90     const auto& cellClipOperand = *GetInput(operation, operands, kCellClipParam);
91     const auto& projClipOperand = *GetInput(operation, operands, kProjClipParam);
92     if (input_->type == OperandType::TENSOR_FLOAT32) {
93         params_.cell_clip = getScalarDataWithDefault<float>(cellClipOperand, 0.0f);
94         params_.proj_clip = getScalarDataWithDefault<float>(projClipOperand, 0.0f);
95     } else {
96         params_.cell_clip =
97                 static_cast<float>(getScalarDataWithDefault<_Float16>(cellClipOperand, 0.0f));
98         params_.proj_clip =
99                 static_cast<float>(getScalarDataWithDefault<_Float16>(projClipOperand, 0.0f));
100     }
101 
102     // We check the version of LSTM by checking the number of the inputs to the
103     // op. For LSTM version 1.0 there were 23 inputs and for 1.2 there are 27.
104     if (operation.inputs.size() == 27) {
105         input_layer_norm_weights_ =
106                 GetInput(operation, operands, kInputLayerNormWeightsTensor);  // optional
107         forget_layer_norm_weights_ =
108                 GetInput(operation, operands, kForgetLayerNormWeightsTensor);  // optional
109         cell_layer_norm_weights_ =
110                 GetInput(operation, operands, kCellLayerNormWeightsTensor);  // optional
111         output_layer_norm_weights_ =
112                 GetInput(operation, operands, kOutputLayerNormWeightsTensor);  // optional
113     } else {
114         // For LSTM from HAL v1.0 assign operands with no values
115         static RunTimeOperandInfo no_value;
116         no_value.lifetime = OperandLifeTime::NO_VALUE;
117 
118         input_layer_norm_weights_ = &no_value;
119         forget_layer_norm_weights_ = &no_value;
120         cell_layer_norm_weights_ = &no_value;
121         output_layer_norm_weights_ = &no_value;
122     }
123 
124     output_state_out_ = GetOutput(operation, operands, kOutputStateOutTensor);
125     cell_state_out_ = GetOutput(operation, operands, kCellStateOutTensor);
126     output_ = GetOutput(operation, operands, kOutputTensor);
127 
128     scratch_buffer_ = GetOutput(operation, operands, kScratchBufferTensor);
129 }
130 
131 // static
CheckInputTensorDimensions(const RunTimeOperandInfo * input_,const RunTimeOperandInfo * input_to_input_weights,const RunTimeOperandInfo * input_to_forget_weights,const RunTimeOperandInfo * input_to_cell_weights,const RunTimeOperandInfo * input_to_output_weights,const RunTimeOperandInfo * recurrent_to_input_weights,const RunTimeOperandInfo * recurrent_to_forget_weights,const RunTimeOperandInfo * recurrent_to_cell_weights,const RunTimeOperandInfo * recurrent_to_output_weights,const RunTimeOperandInfo * cell_to_input_weights,const RunTimeOperandInfo * cell_to_forget_weights,const RunTimeOperandInfo * cell_to_output_weights,const RunTimeOperandInfo * input_gate_bias,const RunTimeOperandInfo * forget_gate_bias,const RunTimeOperandInfo * cell_bias,const RunTimeOperandInfo * output_gate_bias,const RunTimeOperandInfo * projection_weights,const RunTimeOperandInfo * projection_bias,const RunTimeOperandInfo * input_layer_norm_weights,const RunTimeOperandInfo * forget_layer_norm_weights,const RunTimeOperandInfo * cell_layer_norm_weights,const RunTimeOperandInfo * output_layer_norm_weights,uint32_t n_input,uint32_t n_output,uint32_t n_cell,LSTMParams * params)132 bool LSTMCell::CheckInputTensorDimensions(
133         const RunTimeOperandInfo* input_, const RunTimeOperandInfo* input_to_input_weights,
134         const RunTimeOperandInfo* input_to_forget_weights,
135         const RunTimeOperandInfo* input_to_cell_weights,
136         const RunTimeOperandInfo* input_to_output_weights,
137         const RunTimeOperandInfo* recurrent_to_input_weights,
138         const RunTimeOperandInfo* recurrent_to_forget_weights,
139         const RunTimeOperandInfo* recurrent_to_cell_weights,
140         const RunTimeOperandInfo* recurrent_to_output_weights,
141         const RunTimeOperandInfo* cell_to_input_weights,
142         const RunTimeOperandInfo* cell_to_forget_weights,
143         const RunTimeOperandInfo* cell_to_output_weights, const RunTimeOperandInfo* input_gate_bias,
144         const RunTimeOperandInfo* forget_gate_bias, const RunTimeOperandInfo* cell_bias,
145         const RunTimeOperandInfo* output_gate_bias, const RunTimeOperandInfo* projection_weights,
146         const RunTimeOperandInfo* projection_bias,
147         const RunTimeOperandInfo* input_layer_norm_weights,
148         const RunTimeOperandInfo* forget_layer_norm_weights,
149         const RunTimeOperandInfo* cell_layer_norm_weights,
150         const RunTimeOperandInfo* output_layer_norm_weights, uint32_t n_input, uint32_t n_output,
151         uint32_t n_cell, LSTMParams* params) {
152     // Making sure clipping parameters have valid values.
153     // == 0 means no clipping
154     //  > 0 means clipping
155     NN_CHECK(params->cell_clip >= 0);
156     NN_CHECK(params->proj_clip >= 0);
157 
158     if (!IsNullInput(input_to_input_weights)) {
159         NN_CHECK_EQ(NumDimensions(input_to_input_weights), 2);
160         NN_CHECK_EQ(SizeOfDimension(input_to_input_weights, 0), n_cell);
161         NN_CHECK_EQ(SizeOfDimension(input_to_input_weights, 1), n_input);
162     }
163 
164     NN_CHECK_EQ(NumDimensions(input_to_forget_weights), 2);
165     NN_CHECK_EQ(SizeOfDimension(input_to_forget_weights, 0), n_cell);
166     NN_CHECK_EQ(SizeOfDimension(input_to_forget_weights, 1), n_input);
167 
168     NN_CHECK_EQ(NumDimensions(input_to_cell_weights), 2);
169     NN_CHECK_EQ(SizeOfDimension(input_to_cell_weights, 0), n_cell);
170     NN_CHECK_EQ(SizeOfDimension(input_to_cell_weights, 1), n_input);
171 
172     if (!IsNullInput(recurrent_to_input_weights)) {
173         NN_CHECK_EQ(NumDimensions(recurrent_to_input_weights), 2);
174         NN_CHECK_EQ(SizeOfDimension(recurrent_to_input_weights, 0), n_cell);
175         NN_CHECK_EQ(SizeOfDimension(recurrent_to_input_weights, 1), n_output);
176     }
177 
178     NN_CHECK_EQ(NumDimensions(recurrent_to_forget_weights), 2);
179     NN_CHECK_EQ(SizeOfDimension(recurrent_to_forget_weights, 0), n_cell);
180     NN_CHECK_EQ(SizeOfDimension(recurrent_to_forget_weights, 1), n_output);
181 
182     NN_CHECK_EQ(NumDimensions(recurrent_to_cell_weights), 2);
183     NN_CHECK_EQ(SizeOfDimension(recurrent_to_cell_weights, 0), n_cell);
184     NN_CHECK_EQ(SizeOfDimension(recurrent_to_cell_weights, 1), n_output);
185 
186     // We make sure the input-gate's parameters are either both present (regular
187     // LSTM) or not at all (CIFG-LSTM).
188     const bool cifg_weights_all_or_none =
189             (!IsNullInput(input_to_input_weights) && !IsNullInput(recurrent_to_input_weights)) ||
190             (IsNullInput(input_to_input_weights) && IsNullInput(recurrent_to_input_weights));
191     NN_CHECK(cifg_weights_all_or_none);
192 
193     if (!IsNullInput(cell_to_input_weights)) {
194         NN_CHECK_EQ(NumDimensions(cell_to_input_weights), 1);
195         NN_CHECK_EQ(SizeOfDimension(cell_to_input_weights, 0), n_cell);
196     }
197 
198     if (!IsNullInput(cell_to_forget_weights)) {
199         NN_CHECK_EQ(NumDimensions(cell_to_forget_weights), 1);
200         NN_CHECK_EQ(SizeOfDimension(cell_to_forget_weights, 0), n_cell);
201     }
202 
203     if (!IsNullInput(cell_to_output_weights)) {
204         NN_CHECK_EQ(NumDimensions(cell_to_output_weights), 1);
205         NN_CHECK_EQ(SizeOfDimension(cell_to_output_weights, 0), n_cell);
206     }
207 
208     // Making sure the peephole weights are there all or none.
209     params->use_cifg = IsNullInput(input_to_input_weights);
210     const bool peephole_weights_all_or_none =
211             ((!IsNullInput(cell_to_input_weights) || params->use_cifg) &&
212              !IsNullInput(cell_to_forget_weights) && !IsNullInput(cell_to_output_weights)) ||
213             (IsNullInput(cell_to_input_weights) && IsNullInput(cell_to_forget_weights) &&
214              IsNullInput(cell_to_output_weights));
215     NN_CHECK(peephole_weights_all_or_none);
216 
217     // Since we have already checked that weights are all there or none, we can
218     // check the existence of only one to the get the condition.
219     params->use_peephole = !IsNullInput(cell_to_output_weights);
220     // Checking output instead of input layer norm weights because input can be
221     // omitted ones can be omited in case CIFG LSTM is used.
222     params->use_layer_norm = !IsNullInput(output_layer_norm_weights);
223 
224     params->use_projection_weight = (projection_weights->lifetime != OperandLifeTime::NO_VALUE);
225     params->use_projection_bias = (projection_bias->lifetime != OperandLifeTime::NO_VALUE);
226 
227     // Make sure the input gate bias is present only when not a CIFG-LSTM.
228     if (params->use_cifg) {
229         NN_CHECK(IsNullInput(input_gate_bias));
230     } else {
231         NN_CHECK_EQ(NumDimensions(input_gate_bias), 1);
232         NN_CHECK_EQ(SizeOfDimension(input_gate_bias, 0), n_cell);
233     }
234 
235     NN_CHECK_EQ(NumDimensions(forget_gate_bias), 1);
236     NN_CHECK_EQ(SizeOfDimension(forget_gate_bias, 0), n_cell);
237 
238     NN_CHECK_EQ(NumDimensions(cell_bias), 1);
239     NN_CHECK_EQ(SizeOfDimension(cell_bias, 0), n_cell);
240 
241     NN_CHECK_EQ(NumDimensions(output_gate_bias), 1);
242     NN_CHECK_EQ(SizeOfDimension(output_gate_bias, 0), n_cell);
243 
244     if (!IsNullInput(projection_weights)) {
245         NN_CHECK_EQ(NumDimensions(projection_weights), 2);
246         NN_CHECK_EQ(SizeOfDimension(projection_weights, 0), n_output);
247         NN_CHECK_EQ(SizeOfDimension(projection_weights, 1), n_cell);
248     }
249 
250     if (!IsNullInput(projection_bias)) {
251         NN_CHECK_EQ(NumDimensions(projection_bias), 1);
252         NN_CHECK_EQ(SizeOfDimension(projection_bias, 0), n_output);
253     }
254 
255     // Making sure the projection tensors are consistent:
256     // 1) If projection weight is not present, then projection bias should not be
257     // present.
258     // 2) If projection weight is present, then projection bias is optional.
259     // TODO: make sure this is correct.
260     const bool projecton_tensors_consistent =
261             (!IsNullInput(projection_weights) || IsNullInput(projection_bias));
262     NN_CHECK(projecton_tensors_consistent == true);
263 
264     if (!IsNullInput(input_layer_norm_weights)) {
265         NN_CHECK_EQ(NumDimensions(input_layer_norm_weights), 1);
266         NN_CHECK_EQ(SizeOfDimension(input_layer_norm_weights, 0), n_cell);
267     }
268     if (!IsNullInput(forget_layer_norm_weights)) {
269         NN_CHECK_EQ(NumDimensions(forget_layer_norm_weights), 1);
270         NN_CHECK_EQ(SizeOfDimension(forget_layer_norm_weights, 0), n_cell);
271     }
272     if (!IsNullInput(cell_layer_norm_weights)) {
273         NN_CHECK_EQ(NumDimensions(cell_layer_norm_weights), 1);
274         NN_CHECK_EQ(SizeOfDimension(cell_layer_norm_weights, 0), n_cell);
275     }
276     if (!IsNullInput(output_layer_norm_weights)) {
277         NN_CHECK_EQ(NumDimensions(output_layer_norm_weights), 1);
278         NN_CHECK_EQ(SizeOfDimension(output_layer_norm_weights, 0), n_cell);
279     }
280 
281     if (params->use_cifg) {
282         NN_RET_CHECK(IsNullInput(input_layer_norm_weights))
283                 << "input_layer_norm_weights are provided while CIFG is used";
284         const bool layer_norm_weights_all_or_none_cifg =
285                 (IsNullInput(forget_layer_norm_weights) && IsNullInput(cell_layer_norm_weights) &&
286                  IsNullInput(output_layer_norm_weights)) ||
287                 (!IsNullInput(forget_layer_norm_weights) && !IsNullInput(cell_layer_norm_weights) &&
288                  !IsNullInput(output_layer_norm_weights));
289         NN_RET_CHECK(layer_norm_weights_all_or_none_cifg);
290     } else {
291         const bool layer_norm_weights_all_or_none =
292                 (IsNullInput(input_layer_norm_weights) && IsNullInput(forget_layer_norm_weights) &&
293                  IsNullInput(cell_layer_norm_weights) && IsNullInput(output_layer_norm_weights)) ||
294                 (!IsNullInput(input_layer_norm_weights) &&
295                  !IsNullInput(forget_layer_norm_weights) && !IsNullInput(cell_layer_norm_weights) &&
296                  !IsNullInput(output_layer_norm_weights));
297         NN_RET_CHECK(layer_norm_weights_all_or_none);
298     }
299 
300     return true;
301 }
302 
Prepare(const Operation & operation,RunTimeOperandInfo * operands,Shape * scratchShape,Shape * outputStateShape,Shape * cellStateShape,Shape * outputShape)303 bool LSTMCell::Prepare(const Operation& operation, RunTimeOperandInfo* operands,
304                        Shape* scratchShape, Shape* outputStateShape, Shape* cellStateShape,
305                        Shape* outputShape) {
306     // Check we have all the inputs and outputs we need.
307     NN_CHECK(NumInputsWithValues(operation, operands) >= 15 &&
308              NumInputsWithValues(operation, operands) <= 27);
309     constexpr int requiredInputs[] = {
310             kInputTensor,
311             kInputToForgetWeightsTensor,
312             kInputToCellWeightsTensor,
313             kInputToOutputWeightsTensor,
314             kRecurrentToForgetWeightsTensor,
315             kRecurrentToCellWeightsTensor,
316             kRecurrentToOutputWeightsTensor,
317             kForgetGateBiasTensor,
318             kCellGateBiasTensor,
319             kOutputGateBiasTensor,
320             kOutputStateInTensor,
321             kCellStateInTensor,
322             kActivationParam,
323             kCellClipParam,
324             kProjClipParam,
325     };
326     for (const int requiredInput : requiredInputs) {
327         NN_RET_CHECK(!IsNullInput(GetInput(operation, operands, requiredInput)))
328                 << "required input " << requiredInput << " is omitted";
329     }
330     NN_CHECK_EQ(NumOutputs(operation), 4);
331 
332     // Check that the scalar operands' buffers are large enough.
333     const auto& activationOperand = *GetInput(operation, operands, kActivationParam);
334     NN_RET_CHECK(activationOperand.length >= sizeof(int32_t));
335     const auto& cellClipOperand = *GetInput(operation, operands, kCellClipParam);
336     const auto& projClipOperand = *GetInput(operation, operands, kProjClipParam);
337     if (input_->type == OperandType::TENSOR_FLOAT32) {
338         NN_RET_CHECK(cellClipOperand.length >= sizeof(float));
339         NN_RET_CHECK(projClipOperand.length >= sizeof(float));
340     } else {
341         NN_RET_CHECK(cellClipOperand.length >= sizeof(_Float16));
342         NN_RET_CHECK(projClipOperand.length >= sizeof(_Float16));
343     }
344 
345     // Inferring batch size, number of outputs and number of cells from the
346     // input tensors.
347     NN_CHECK(NumDimensions(input_) > 1);
348     const uint32_t n_batch = SizeOfDimension(input_, 0);
349     const uint32_t n_input = SizeOfDimension(input_, 1);
350 
351     const uint32_t n_cell = SizeOfDimension(input_to_output_weights_, 0);
352     NN_CHECK_EQ(NumDimensions(input_to_output_weights_), 2);
353     NN_CHECK_EQ(SizeOfDimension(input_to_output_weights_, 1), n_input);
354 
355     NN_CHECK_EQ(NumDimensions(recurrent_to_output_weights_), 2);
356     NN_CHECK_EQ(SizeOfDimension(recurrent_to_output_weights_, 0), n_cell);
357     const uint32_t n_output = SizeOfDimension(recurrent_to_output_weights_, 1);
358 
359     // Check that input tensor dimensions matches with each other.
360     if (!CheckInputTensorDimensions(
361                 input_, input_to_input_weights_, input_to_forget_weights_, input_to_cell_weights_,
362                 input_to_output_weights_, recurrent_to_input_weights_, recurrent_to_forget_weights_,
363                 recurrent_to_cell_weights_, recurrent_to_output_weights_, cell_to_input_weights_,
364                 cell_to_forget_weights_, cell_to_output_weights_, input_gate_bias_,
365                 forget_gate_bias_, cell_bias_, output_gate_bias_, projection_weights_,
366                 projection_bias_, input_layer_norm_weights_, forget_layer_norm_weights_,
367                 cell_layer_norm_weights_, output_layer_norm_weights_, n_input, n_output, n_cell,
368                 &params_)) {
369         return false;
370     }
371 
372     // Resize the output and output_state tensors.
373     const Shape& inputShape = input_->shape();
374 
375     outputShape->type = inputShape.type;
376     outputShape->dimensions = {n_batch, n_output};
377     outputShape->offset = inputShape.offset;
378     outputShape->scale = inputShape.scale;
379 
380     outputStateShape->type = inputShape.type;
381     outputStateShape->dimensions = {n_batch, n_output};
382     outputStateShape->offset = inputShape.offset;
383     outputStateShape->scale = inputShape.scale;
384 
385     cellStateShape->type = inputShape.type;
386     cellStateShape->dimensions = {n_batch, n_cell};
387     cellStateShape->offset = inputShape.offset;
388     cellStateShape->scale = inputShape.scale;
389 
390     if (params_.use_cifg) {
391         // Reserving space for Cell, Forget, Output gates
392         scratchShape->dimensions = {n_batch, n_cell * 3};
393     } else {
394         // Reserving space for Input, Cell, Forget, Output gates
395         scratchShape->dimensions = {n_batch, n_cell * 4};
396     }
397     scratchShape->type = inputShape.type;
398     scratchShape->offset = inputShape.offset;
399     scratchShape->scale = inputShape.scale;
400 
401     return true;
402 }
403 
404 // static
LSTMEvalFloat32(const LSTMParams & params,const float * input_buffer,const Shape & input_shape,const float * input_to_input_weights_buffer,const float * input_to_forget_weights_buffer,const float * input_to_cell_weights_buffer,const float * input_to_output_weights_buffer,const Shape & input_to_output_weights_shape,const float * recurrent_to_input_weights_buffer,const float * recurrent_to_forget_weights_buffer,const float * recurrent_to_cell_weights_buffer,const float * recurrent_to_output_weights_buffer,const Shape & recurrent_to_output_weights_shape,const float * cell_to_input_weights_buffer,const float * cell_to_forget_weights_buffer,const float * cell_to_output_weights_buffer,const float * aux_input_buffer,const float * aux_input_to_input_weights_buffer,const float * aux_input_to_forget_weights_buffer,const float * aux_input_to_cell_weights_buffer,const float * aux_input_to_output_weights_buffer,const float * input_gate_bias_buffer,const float * forget_gate_bias_buffer,const float * cell_bias_buffer,const float * output_gate_bias_buffer,const float * projection_weights_buffer,const float * projection_bias_buffer,const float * output_state_in_buffer,const float * cell_state_in_buffer,const float * input_layer_norm_weights_buffer,const float * forget_layer_norm_weights_buffer,const float * cell_layer_norm_weights_buffer,const float * output_layer_norm_weights_buffer,float * output_state_out_buffer,float * cell_state_out_buffer,float * output_buffer,float * scratch_buffer_buffer,bool timeMajor,bool forwardSequence)405 bool LSTMCell::LSTMEvalFloat32(
406         const LSTMParams& params, const float* input_buffer, const Shape& input_shape,
407         const float* input_to_input_weights_buffer, const float* input_to_forget_weights_buffer,
408         const float* input_to_cell_weights_buffer, const float* input_to_output_weights_buffer,
409         const Shape& input_to_output_weights_shape, const float* recurrent_to_input_weights_buffer,
410         const float* recurrent_to_forget_weights_buffer,
411         const float* recurrent_to_cell_weights_buffer,
412         const float* recurrent_to_output_weights_buffer,
413         const Shape& recurrent_to_output_weights_shape, const float* cell_to_input_weights_buffer,
414         const float* cell_to_forget_weights_buffer, const float* cell_to_output_weights_buffer,
415         const float* aux_input_buffer, const float* aux_input_to_input_weights_buffer,
416         const float* aux_input_to_forget_weights_buffer,
417         const float* aux_input_to_cell_weights_buffer,
418         const float* aux_input_to_output_weights_buffer, const float* input_gate_bias_buffer,
419         const float* forget_gate_bias_buffer, const float* cell_bias_buffer,
420         const float* output_gate_bias_buffer, const float* projection_weights_buffer,
421         const float* projection_bias_buffer, const float* output_state_in_buffer,
422         const float* cell_state_in_buffer, const float* input_layer_norm_weights_buffer,
423         const float* forget_layer_norm_weights_buffer, const float* cell_layer_norm_weights_buffer,
424         const float* output_layer_norm_weights_buffer, float* output_state_out_buffer,
425         float* cell_state_out_buffer, float* output_buffer, float* scratch_buffer_buffer,
426         bool timeMajor, bool forwardSequence) {
427     NNTRACE_COMP("LSTMCell::LSTMEvalFloat32");
428 
429     const uint32_t inputRank = getNumberOfDimensions(input_shape);
430     NN_CHECK(inputRank == 2 || inputRank == 3);
431 
432     const uint32_t maxTime =
433             (inputRank == 3) ? getSizeOfDimension(input_shape, timeMajor ? 0 : 1) : 1;
434     const uint32_t batchSize = (inputRank == 3) ? getSizeOfDimension(input_shape, timeMajor ? 1 : 0)
435                                                 : getSizeOfDimension(input_shape, 0);
436     const uint32_t inputSize = getSizeOfDimension(input_shape, inputRank - 1);
437     const uint32_t numCells = getSizeOfDimension(input_to_output_weights_shape, 0);
438     const uint32_t outputSize = getSizeOfDimension(recurrent_to_output_weights_shape, 1);
439 
440     Shape batchInputShape = input_shape;
441     batchInputShape.dimensions = {batchSize, inputSize};
442     const uint32_t batchInputSize = batchSize * inputSize;
443     const uint32_t batchOutputSize = batchSize * outputSize;
444 
445     std::vector<float> transposedInput;
446     const bool hasAuxInput = (aux_input_buffer != nullptr);
447     std::vector<float> transposedAuxInput;
448     std::vector<float> transposedOutput;
449     Shape transposedInputShape;
450     Shape transposedOutputShape;
451     if (!timeMajor) {
452         transposedInput.resize(maxTime * batchInputSize);
453         transposeFirstTwoDimensions<float>(input_buffer, input_shape, transposedInput.data());
454         if (hasAuxInput) {
455             transposedAuxInput.resize(maxTime * batchInputSize);
456             transposeFirstTwoDimensions<float>(aux_input_buffer, input_shape,
457                                                transposedAuxInput.data());
458         }
459         transposeFirstTwoDimensions(input_shape, &transposedInputShape);
460         transposedOutput.resize(maxTime * batchOutputSize);
461         transposedOutputShape = transposedInputShape;
462         transposedOutputShape.dimensions[2] = outputSize;
463     }
464     const float* inputData = timeMajor ? input_buffer : transposedInput.data();
465     const float* auxInputData =
466             hasAuxInput ? (timeMajor ? aux_input_buffer : transposedAuxInput.data()) : nullptr;
467     float* outputData = timeMajor ? output_buffer : transposedOutput.data();
468 
469     std::vector<float> outputStateInCurrentTimeStep(
470             output_state_in_buffer, output_state_in_buffer + batchSize * outputSize);
471     std::vector<float> cellStateInCurrentTimeStep(cell_state_in_buffer,
472                                                   cell_state_in_buffer + batchSize * numCells);
473     const float* inputCurrentTimeStep =
474             inputData + (forwardSequence ? 0 : batchInputSize * (maxTime - 1));
475     const float* auxInputCurrentTimeStep =
476             hasAuxInput ? (auxInputData + (forwardSequence ? 0 : batchInputSize * (maxTime - 1)))
477                         : nullptr;
478     float* outputCurrentTimeStep =
479             outputData + (forwardSequence ? 0 : batchOutputSize * (maxTime - 1));
480     const int batchInputDelta = (forwardSequence ? 1 : -1) * static_cast<int>(batchInputSize);
481     const int batchOutputDelta = (forwardSequence ? 1 : -1) * static_cast<int>(batchOutputSize);
482 
483     for (int t = 0; t < maxTime; ++t) {
484         LSTMStep(params, inputCurrentTimeStep, batchInputShape, input_to_input_weights_buffer,
485                  input_to_forget_weights_buffer, input_to_cell_weights_buffer,
486                  input_to_output_weights_buffer, input_to_output_weights_shape,
487                  recurrent_to_input_weights_buffer, recurrent_to_forget_weights_buffer,
488                  recurrent_to_cell_weights_buffer, recurrent_to_output_weights_buffer,
489                  recurrent_to_output_weights_shape, cell_to_input_weights_buffer,
490                  cell_to_forget_weights_buffer, cell_to_output_weights_buffer,
491                  auxInputCurrentTimeStep, aux_input_to_input_weights_buffer,
492                  aux_input_to_forget_weights_buffer, aux_input_to_cell_weights_buffer,
493                  aux_input_to_output_weights_buffer, input_gate_bias_buffer,
494                  forget_gate_bias_buffer, cell_bias_buffer, output_gate_bias_buffer,
495                  projection_weights_buffer, projection_bias_buffer,
496                  outputStateInCurrentTimeStep.data(), cellStateInCurrentTimeStep.data(),
497                  input_layer_norm_weights_buffer, forget_layer_norm_weights_buffer,
498                  cell_layer_norm_weights_buffer, output_layer_norm_weights_buffer,
499                  output_state_out_buffer, cell_state_out_buffer, outputCurrentTimeStep,
500                  scratch_buffer_buffer);
501         inputCurrentTimeStep += batchInputDelta;
502         if (hasAuxInput) {
503             auxInputCurrentTimeStep += batchInputDelta;
504         }
505         outputCurrentTimeStep += batchOutputDelta;
506         outputStateInCurrentTimeStep.assign(output_state_out_buffer,
507                                             output_state_out_buffer + batchSize * outputSize);
508         cellStateInCurrentTimeStep.assign(cell_state_out_buffer,
509                                           cell_state_out_buffer + batchSize * numCells);
510     }
511 
512     if (!timeMajor) {
513         transposeFirstTwoDimensions<float>(transposedOutput.data(), transposedOutputShape,
514                                            output_buffer);
515     }
516 
517     return true;
518 }
519 
520 // static
LSTMEvalFloat16(const LSTMParams & params,const _Float16 * input_buffer,const Shape & input_shape,const _Float16 * input_to_input_weights_buffer,const _Float16 * input_to_forget_weights_buffer,const _Float16 * input_to_cell_weights_buffer,const _Float16 * input_to_output_weights_buffer,const Shape & input_to_output_weights_shape,const _Float16 * recurrent_to_input_weights_buffer,const _Float16 * recurrent_to_forget_weights_buffer,const _Float16 * recurrent_to_cell_weights_buffer,const _Float16 * recurrent_to_output_weights_buffer,const Shape & recurrent_to_output_weights_shape,const _Float16 * cell_to_input_weights_buffer,const _Float16 * cell_to_forget_weights_buffer,const _Float16 * cell_to_output_weights_buffer,const _Float16 * aux_input_buffer,const _Float16 * aux_input_to_input_weights_buffer,const _Float16 * aux_input_to_forget_weights_buffer,const _Float16 * aux_input_to_cell_weights_buffer,const _Float16 * aux_input_to_output_weights_buffer,const _Float16 * input_gate_bias_buffer,const _Float16 * forget_gate_bias_buffer,const _Float16 * cell_bias_buffer,const _Float16 * output_gate_bias_buffer,const _Float16 * projection_weights_buffer,const _Float16 * projection_bias_buffer,const _Float16 * output_state_in_buffer,const _Float16 * cell_state_in_buffer,const _Float16 * input_layer_norm_weights_buffer,const _Float16 * forget_layer_norm_weights_buffer,const _Float16 * cell_layer_norm_weights_buffer,const _Float16 * output_layer_norm_weights_buffer,_Float16 * output_state_out_buffer,_Float16 * cell_state_out_buffer,_Float16 * output_buffer,_Float16 * scratch_buffer_buffer,bool timeMajor,bool forwardSequence)521 bool LSTMCell::LSTMEvalFloat16(
522         const LSTMParams& params, const _Float16* input_buffer, const Shape& input_shape,
523         const _Float16* input_to_input_weights_buffer,
524         const _Float16* input_to_forget_weights_buffer,
525         const _Float16* input_to_cell_weights_buffer,
526         const _Float16* input_to_output_weights_buffer, const Shape& input_to_output_weights_shape,
527         const _Float16* recurrent_to_input_weights_buffer,
528         const _Float16* recurrent_to_forget_weights_buffer,
529         const _Float16* recurrent_to_cell_weights_buffer,
530         const _Float16* recurrent_to_output_weights_buffer,
531         const Shape& recurrent_to_output_weights_shape,
532         const _Float16* cell_to_input_weights_buffer, const _Float16* cell_to_forget_weights_buffer,
533         const _Float16* cell_to_output_weights_buffer, const _Float16* aux_input_buffer,
534         const _Float16* aux_input_to_input_weights_buffer,
535         const _Float16* aux_input_to_forget_weights_buffer,
536         const _Float16* aux_input_to_cell_weights_buffer,
537         const _Float16* aux_input_to_output_weights_buffer, const _Float16* input_gate_bias_buffer,
538         const _Float16* forget_gate_bias_buffer, const _Float16* cell_bias_buffer,
539         const _Float16* output_gate_bias_buffer, const _Float16* projection_weights_buffer,
540         const _Float16* projection_bias_buffer, const _Float16* output_state_in_buffer,
541         const _Float16* cell_state_in_buffer, const _Float16* input_layer_norm_weights_buffer,
542         const _Float16* forget_layer_norm_weights_buffer,
543         const _Float16* cell_layer_norm_weights_buffer,
544         const _Float16* output_layer_norm_weights_buffer, _Float16* output_state_out_buffer,
545         _Float16* cell_state_out_buffer, _Float16* output_buffer, _Float16* scratch_buffer_buffer,
546         bool timeMajor, bool forwardSequence) {
547     NNTRACE_COMP("LSTMCell::LSTMEvalFloat16");
548 
549     const uint32_t inputRank = getNumberOfDimensions(input_shape);
550     NN_CHECK(inputRank == 2 || inputRank == 3);
551 
552     const uint32_t maxTime =
553             (inputRank == 3) ? getSizeOfDimension(input_shape, timeMajor ? 0 : 1) : 1;
554     const uint32_t batchSize = (inputRank == 3) ? getSizeOfDimension(input_shape, timeMajor ? 1 : 0)
555                                                 : getSizeOfDimension(input_shape, 0);
556     const uint32_t inputSize = getSizeOfDimension(input_shape, inputRank - 1);
557     const uint32_t numCells = getSizeOfDimension(input_to_output_weights_shape, 0);
558     const uint32_t outputSize = getSizeOfDimension(recurrent_to_output_weights_shape, 1);
559 
560     Shape batchInputShape = input_shape;
561     batchInputShape.dimensions = {batchSize, inputSize};
562     const uint32_t batchInputSize = batchSize * inputSize;
563     const uint32_t batchOutputSize = batchSize * outputSize;
564 
565     std::vector<float> input_float32(maxTime * batchInputSize);
566     convertFloat16ToFloat32(input_buffer, &input_float32);
567     std::vector<float> input_to_input_weights_float32(numCells * inputSize);
568     if (input_to_input_weights_buffer != nullptr) {
569         convertFloat16ToFloat32(input_to_input_weights_buffer, &input_to_input_weights_float32);
570     }
571     std::vector<float> input_to_forget_weights_float32(numCells * inputSize);
572     convertFloat16ToFloat32(input_to_forget_weights_buffer, &input_to_forget_weights_float32);
573     std::vector<float> input_to_cell_weights_float32(numCells * inputSize);
574     convertFloat16ToFloat32(input_to_cell_weights_buffer, &input_to_cell_weights_float32);
575     std::vector<float> input_to_output_weights_float32(numCells * inputSize);
576     convertFloat16ToFloat32(input_to_output_weights_buffer, &input_to_output_weights_float32);
577 
578     std::vector<float> recurrent_to_input_weights_float32(numCells * outputSize);
579     if (recurrent_to_input_weights_buffer != nullptr) {
580         convertFloat16ToFloat32(recurrent_to_input_weights_buffer,
581                                 &recurrent_to_input_weights_float32);
582     }
583     std::vector<float> recurrent_to_forget_weights_float32(numCells * outputSize);
584     convertFloat16ToFloat32(recurrent_to_forget_weights_buffer,
585                             &recurrent_to_forget_weights_float32);
586     std::vector<float> recurrent_to_cell_weights_float32(numCells * outputSize);
587     convertFloat16ToFloat32(recurrent_to_cell_weights_buffer, &recurrent_to_cell_weights_float32);
588     std::vector<float> recurrent_to_output_weights_float32(numCells * outputSize);
589     convertFloat16ToFloat32(recurrent_to_output_weights_buffer,
590                             &recurrent_to_output_weights_float32);
591 
592     std::vector<float> cell_to_input_weights_float32(numCells);
593     if (cell_to_input_weights_buffer != nullptr) {
594         convertFloat16ToFloat32(cell_to_input_weights_buffer, &cell_to_input_weights_float32);
595     }
596     std::vector<float> cell_to_forget_weights_float32(numCells);
597     if (cell_to_forget_weights_buffer != nullptr) {
598         convertFloat16ToFloat32(cell_to_forget_weights_buffer, &cell_to_forget_weights_float32);
599     }
600     std::vector<float> cell_to_output_weights_float32(numCells);
601     if (cell_to_output_weights_buffer != nullptr) {
602         convertFloat16ToFloat32(cell_to_output_weights_buffer, &cell_to_output_weights_float32);
603     }
604 
605     std::vector<float> aux_input_float32(maxTime * batchInputSize);
606     if (aux_input_buffer != nullptr) {
607         convertFloat16ToFloat32(aux_input_buffer, &aux_input_float32);
608     }
609     std::vector<float> aux_input_to_input_weights_float32(numCells * inputSize);
610     if (aux_input_to_input_weights_buffer != nullptr) {
611         convertFloat16ToFloat32(aux_input_to_input_weights_buffer,
612                                 &aux_input_to_input_weights_float32);
613     }
614     std::vector<float> aux_input_to_forget_weights_float32(numCells * inputSize);
615     if (aux_input_to_forget_weights_buffer != nullptr) {
616         convertFloat16ToFloat32(aux_input_to_forget_weights_buffer,
617                                 &aux_input_to_forget_weights_float32);
618     }
619     std::vector<float> aux_input_to_cell_weights_float32(numCells * inputSize);
620     if (aux_input_to_cell_weights_buffer != nullptr) {
621         convertFloat16ToFloat32(aux_input_to_cell_weights_buffer,
622                                 &aux_input_to_cell_weights_float32);
623     }
624     std::vector<float> aux_input_to_output_weights_float32(numCells * inputSize);
625     if (aux_input_to_output_weights_buffer != nullptr) {
626         convertFloat16ToFloat32(aux_input_to_output_weights_buffer,
627                                 &aux_input_to_output_weights_float32);
628     }
629 
630     std::vector<float> input_gate_bias_float32(numCells);
631     if (input_gate_bias_buffer != nullptr) {
632         convertFloat16ToFloat32(input_gate_bias_buffer, &input_gate_bias_float32);
633     }
634     std::vector<float> forget_gate_bias_float32(numCells);
635     convertFloat16ToFloat32(forget_gate_bias_buffer, &forget_gate_bias_float32);
636     std::vector<float> cell_bias_float32(numCells);
637     convertFloat16ToFloat32(cell_bias_buffer, &cell_bias_float32);
638     std::vector<float> output_gate_bias_float32(numCells);
639     convertFloat16ToFloat32(output_gate_bias_buffer, &output_gate_bias_float32);
640 
641     std::vector<float> projection_weights_float32(numCells * outputSize);
642     if (projection_weights_buffer != nullptr) {
643         convertFloat16ToFloat32(projection_weights_buffer, &projection_weights_float32);
644     }
645     std::vector<float> projection_bias_float32(outputSize);
646     if (projection_bias_buffer != nullptr) {
647         convertFloat16ToFloat32(projection_bias_buffer, &projection_bias_float32);
648     }
649 
650     std::vector<float> input_layer_norm_weights_float32(numCells);
651     if (input_layer_norm_weights_buffer != nullptr) {
652         convertFloat16ToFloat32(input_layer_norm_weights_buffer, &input_layer_norm_weights_float32);
653     }
654     std::vector<float> forget_layer_norm_weights_float32(numCells);
655     if (forget_layer_norm_weights_buffer != nullptr) {
656         convertFloat16ToFloat32(forget_layer_norm_weights_buffer,
657                                 &forget_layer_norm_weights_float32);
658     }
659     std::vector<float> cell_layer_norm_weights_float32(numCells);
660     if (cell_layer_norm_weights_buffer != nullptr) {
661         convertFloat16ToFloat32(cell_layer_norm_weights_buffer, &cell_layer_norm_weights_float32);
662     }
663     std::vector<float> output_layer_norm_weights_float32(numCells);
664     if (output_layer_norm_weights_buffer != nullptr) {
665         convertFloat16ToFloat32(output_layer_norm_weights_buffer,
666                                 &output_layer_norm_weights_float32);
667     }
668 
669     std::vector<float> output_state_out_float32(batchOutputSize);
670     convertFloat16ToFloat32(output_state_out_buffer, &output_state_out_float32);
671     std::vector<float> cell_state_out_float32(batchSize * numCells);
672     convertFloat16ToFloat32(cell_state_out_buffer, &cell_state_out_float32);
673 
674     std::vector<float> output_float32(maxTime * batchOutputSize);
675     convertFloat16ToFloat32(output_buffer, &output_float32);
676     std::vector<float> scratch_buffer_float32(params.use_cifg ? 3 * batchSize * numCells
677                                                               : 4 * batchSize * numCells);
678     convertFloat16ToFloat32(scratch_buffer_buffer, &scratch_buffer_float32);
679 
680     std::vector<float> transposedInput;
681     const bool hasAuxInput = (aux_input_buffer != nullptr);
682     std::vector<float> transposedAuxInput;
683     std::vector<float> transposedOutput;
684     Shape transposedInputShape;
685     Shape transposedOutputShape;
686     if (!timeMajor) {
687         transposedInput.resize(maxTime * batchInputSize);
688         transposeFirstTwoDimensions<float>(input_float32.data(), input_shape,
689                                            transposedInput.data());
690         if (hasAuxInput) {
691             transposedAuxInput.resize(maxTime * batchInputSize);
692             transposeFirstTwoDimensions<float>(aux_input_float32.data(), input_shape,
693                                                transposedAuxInput.data());
694         }
695         transposeFirstTwoDimensions(input_shape, &transposedInputShape);
696         transposedOutput.resize(maxTime * batchOutputSize);
697         transposedOutputShape = transposedInputShape;
698         transposedOutputShape.dimensions[2] = outputSize;
699     }
700     const float* inputData = timeMajor ? input_float32.data() : transposedInput.data();
701     const float* auxInputData =
702             hasAuxInput ? (timeMajor ? aux_input_float32.data() : transposedAuxInput.data())
703                         : nullptr;
704     float* outputData = timeMajor ? output_float32.data() : transposedOutput.data();
705 
706     std::vector<float> outputStateInCurrentTimeStep(batchSize * outputSize);
707     convertFloat16ToFloat32(output_state_in_buffer, &outputStateInCurrentTimeStep);
708     std::vector<float> cellStateInCurrentTimeStep(batchSize * numCells);
709     convertFloat16ToFloat32(cell_state_in_buffer, &cellStateInCurrentTimeStep);
710 
711     const float* inputCurrentTimeStep =
712             inputData + (forwardSequence ? 0 : batchInputSize * (maxTime - 1));
713     const float* auxInputCurrentTimeStep =
714             hasAuxInput ? (auxInputData + (forwardSequence ? 0 : batchInputSize * (maxTime - 1)))
715                         : nullptr;
716     float* outputCurrentTimeStep =
717             outputData + (forwardSequence ? 0 : batchOutputSize * (maxTime - 1));
718     const int batchInputDelta = (forwardSequence ? 1 : -1) * static_cast<int>(batchInputSize);
719     const int batchOutputDelta = (forwardSequence ? 1 : -1) * static_cast<int>(batchOutputSize);
720 
721     for (int t = 0; t < maxTime; ++t) {
722         LSTMStep(params, inputCurrentTimeStep, batchInputShape,
723                  input_to_input_weights_float32.data(), input_to_forget_weights_float32.data(),
724                  input_to_cell_weights_float32.data(), input_to_output_weights_float32.data(),
725                  input_to_output_weights_shape, recurrent_to_input_weights_float32.data(),
726                  recurrent_to_forget_weights_float32.data(),
727                  recurrent_to_cell_weights_float32.data(),
728                  recurrent_to_output_weights_float32.data(), recurrent_to_output_weights_shape,
729                  cell_to_input_weights_float32.data(), cell_to_forget_weights_float32.data(),
730                  cell_to_output_weights_float32.data(), auxInputCurrentTimeStep,
731                  aux_input_to_input_weights_float32.data(),
732                  aux_input_to_forget_weights_float32.data(),
733                  aux_input_to_cell_weights_float32.data(),
734                  aux_input_to_output_weights_float32.data(), input_gate_bias_float32.data(),
735                  forget_gate_bias_float32.data(), cell_bias_float32.data(),
736                  output_gate_bias_float32.data(), projection_weights_float32.data(),
737                  projection_bias_float32.data(), outputStateInCurrentTimeStep.data(),
738                  cellStateInCurrentTimeStep.data(), input_layer_norm_weights_float32.data(),
739                  forget_layer_norm_weights_float32.data(), cell_layer_norm_weights_float32.data(),
740                  output_layer_norm_weights_float32.data(), output_state_out_float32.data(),
741                  cell_state_out_float32.data(), outputCurrentTimeStep,
742                  scratch_buffer_float32.data());
743         inputCurrentTimeStep += batchInputDelta;
744         if (hasAuxInput) {
745             auxInputCurrentTimeStep += batchInputDelta;
746         }
747         outputCurrentTimeStep += batchOutputDelta;
748         outputStateInCurrentTimeStep = output_state_out_float32;
749         cellStateInCurrentTimeStep = cell_state_out_float32;
750     }
751 
752     if (!timeMajor) {
753         transposeFirstTwoDimensions<float>(transposedOutput.data(), transposedOutputShape,
754                                            output_float32.data());
755     }
756 
757     convertFloat32ToFloat16(output_state_out_float32, output_state_out_buffer);
758     convertFloat32ToFloat16(cell_state_out_float32, cell_state_out_buffer);
759     convertFloat32ToFloat16(output_float32, output_buffer);
760     convertFloat32ToFloat16(scratch_buffer_float32, scratch_buffer_buffer);
761     return true;
762 }
763 
764 // static
LSTMStep(const LSTMParams & params,const float * input_buffer,const Shape & input_shape,const float * input_to_input_weights_buffer,const float * input_to_forget_weights_buffer,const float * input_to_cell_weights_buffer,const float * input_to_output_weights_buffer,const Shape & input_to_output_weights_shape,const float * recurrent_to_input_weights_buffer,const float * recurrent_to_forget_weights_buffer,const float * recurrent_to_cell_weights_buffer,const float * recurrent_to_output_weights_buffer,const Shape & recurrent_to_output_weights_shape,const float * cell_to_input_weights_buffer,const float * cell_to_forget_weights_buffer,const float * cell_to_output_weights_buffer,const float * aux_input_buffer,const float * aux_input_to_input_weights_buffer,const float * aux_input_to_forget_weights_buffer,const float * aux_input_to_cell_weights_buffer,const float * aux_input_to_output_weights_buffer,const float * input_gate_bias_buffer,const float * forget_gate_bias_buffer,const float * cell_bias_buffer,const float * output_gate_bias_buffer,const float * projection_weights_buffer,const float * projection_bias_buffer,const float * output_state_in_buffer,const float * cell_state_in_buffer,const float * input_layer_norm_weights_buffer,const float * forget_layer_norm_weights_buffer,const float * cell_layer_norm_weights_buffer,const float * output_layer_norm_weights_buffer,float * output_state_out_buffer,float * cell_state_out_buffer,float * output_buffer,float * scratch_buffer_buffer)765 bool LSTMCell::LSTMStep(
766         const LSTMParams& params, const float* input_buffer, const Shape& input_shape,
767         const float* input_to_input_weights_buffer, const float* input_to_forget_weights_buffer,
768         const float* input_to_cell_weights_buffer, const float* input_to_output_weights_buffer,
769         const Shape& input_to_output_weights_shape, const float* recurrent_to_input_weights_buffer,
770         const float* recurrent_to_forget_weights_buffer,
771         const float* recurrent_to_cell_weights_buffer,
772         const float* recurrent_to_output_weights_buffer,
773         const Shape& recurrent_to_output_weights_shape, const float* cell_to_input_weights_buffer,
774         const float* cell_to_forget_weights_buffer, const float* cell_to_output_weights_buffer,
775         const float* aux_input_buffer, const float* aux_input_to_input_weights_buffer,
776         const float* aux_input_to_forget_weights_buffer,
777         const float* aux_input_to_cell_weights_buffer,
778         const float* aux_input_to_output_weights_buffer, const float* input_gate_bias_buffer,
779         const float* forget_gate_bias_buffer, const float* cell_bias_buffer,
780         const float* output_gate_bias_buffer, const float* projection_weights_buffer,
781         const float* projection_bias_buffer, const float* output_state_in_buffer,
782         const float* cell_state_in_buffer, const float* input_layer_norm_weights_buffer,
783         const float* forget_layer_norm_weights_buffer, const float* cell_layer_norm_weights_buffer,
784         const float* output_layer_norm_weights_buffer, float* output_state_out_buffer,
785         float* cell_state_out_buffer, float* output_buffer, float* scratch_buffer_buffer) {
786     NNTRACE_COMP("LSTMCell::LSTMStep");
787 
788     const uint32_t n_batch = input_shape.dimensions[0];
789     const uint32_t n_input = input_shape.dimensions[1];
790     // n_cell and n_output will be the same size when there is no projection.
791     const uint32_t n_cell = input_to_output_weights_shape.dimensions[0];
792     const uint32_t n_output = recurrent_to_output_weights_shape.dimensions[1];
793     const uint32_t n_aux_input = aux_input_buffer == nullptr ? 0 : n_input;
794 
795     // Index the scratch buffers pointers to the global scratch buffer.
796     float* input_gate_scratch = nullptr;
797     float* cell_scratch = nullptr;
798     float* forget_gate_scratch = nullptr;
799     float* output_gate_scratch = nullptr;
800     if (params.use_cifg) {
801         cell_scratch = scratch_buffer_buffer;
802         forget_gate_scratch = cell_scratch + n_cell * n_batch;
803         output_gate_scratch = cell_scratch + 2 * n_cell * n_batch;
804     } else {
805         input_gate_scratch = scratch_buffer_buffer;
806         cell_scratch = input_gate_scratch + n_cell * n_batch;
807         forget_gate_scratch = input_gate_scratch + 2 * n_cell * n_batch;
808         output_gate_scratch = input_gate_scratch + 3 * n_cell * n_batch;
809     }
810 
811     if (!params.use_layer_norm) {
812         // Initialize scratch buffers with bias.
813         if (!params.use_cifg) {
814             tflite::tensor_utils::VectorBatchVectorAssign(input_gate_bias_buffer, n_cell, n_batch,
815                                                           input_gate_scratch);
816         }
817         tflite::tensor_utils::VectorBatchVectorAssign(forget_gate_bias_buffer, n_cell, n_batch,
818                                                       forget_gate_scratch);
819         tflite::tensor_utils::VectorBatchVectorAssign(cell_bias_buffer, n_cell, n_batch,
820                                                       cell_scratch);
821         tflite::tensor_utils::VectorBatchVectorAssign(output_gate_bias_buffer, n_cell, n_batch,
822                                                       output_gate_scratch);
823     } else {
824         // Initialize scratch buffers with zeroes.
825         if (!params.use_cifg) {
826             std::fill_n(input_gate_scratch, n_cell * n_batch, 0.0f);
827         }
828         std::fill_n(forget_gate_scratch, n_cell * n_batch, 0.0f);
829         std::fill_n(cell_scratch, n_cell * n_batch, 0.0f);
830         std::fill_n(output_gate_scratch, n_cell * n_batch, 0.0f);
831     }
832 
833     // For each batch and cell: compute input_weight * input.
834     if (!params.use_cifg) {
835         tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
836                 input_to_input_weights_buffer, n_cell, n_input, input_buffer, n_batch,
837                 input_gate_scratch, /*result_stride*/ 1);
838     }
839     tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
840             input_to_forget_weights_buffer, n_cell, n_input, input_buffer, n_batch,
841             forget_gate_scratch, /*result_stride*/ 1);
842     tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(input_to_cell_weights_buffer, n_cell,
843                                                               n_input, input_buffer, n_batch,
844                                                               cell_scratch, /*result_stride*/ 1);
845     tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
846             input_to_output_weights_buffer, n_cell, n_input, input_buffer, n_batch,
847             output_gate_scratch, /*result_stride*/ 1);
848 
849     // If auxiliary input is available then compute aux_input_weight * aux_input
850     if (aux_input_buffer != nullptr) {
851         if (!params.use_cifg) {
852             tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
853                     aux_input_to_input_weights_buffer, n_cell, n_aux_input, aux_input_buffer,
854                     n_batch, input_gate_scratch,
855                     /*result_stride=*/1);
856         }
857 
858         tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
859                 aux_input_to_forget_weights_buffer, n_cell, n_aux_input, aux_input_buffer, n_batch,
860                 forget_gate_scratch, /*result_stride=*/1);
861         tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
862                 aux_input_to_cell_weights_buffer, n_cell, n_aux_input, aux_input_buffer, n_batch,
863                 cell_scratch, /*result_stride=*/1);
864         tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
865                 aux_input_to_output_weights_buffer, n_cell, n_aux_input, aux_input_buffer, n_batch,
866                 output_gate_scratch, /*result_stride=*/1);
867     }
868 
869     // For each batch and cell: compute recurrent_weight * output_state.
870     if (!params.use_cifg) {
871         tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
872                 recurrent_to_input_weights_buffer, n_cell, n_output, output_state_in_buffer,
873                 n_batch, input_gate_scratch,
874                 /*result_stride*/ 1);
875     }
876     tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
877             recurrent_to_forget_weights_buffer, n_cell, n_output, output_state_in_buffer, n_batch,
878             forget_gate_scratch, /*result_stride*/ 1);
879     tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
880             recurrent_to_cell_weights_buffer, n_cell, n_output, output_state_in_buffer, n_batch,
881             cell_scratch, /*result_stride*/ 1);
882     tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
883             recurrent_to_output_weights_buffer, n_cell, n_output, output_state_in_buffer, n_batch,
884             output_gate_scratch, /*result_stride*/ 1);
885 
886     // For each batch and cell: update input gate.
887     if (!params.use_cifg) {
888         if (params.use_peephole) {
889             tflite::tensor_utils::VectorBatchVectorCwiseProductAccumulate(
890                     cell_to_input_weights_buffer, n_cell, cell_state_in_buffer, n_batch,
891                     input_gate_scratch);
892         }
893         if (params.use_layer_norm) {
894             tflite::tensor_utils::MeanStddevNormalization(input_gate_scratch, input_gate_scratch,
895                                                           n_cell, n_batch);
896             tflite::tensor_utils::VectorBatchVectorCwiseProduct(input_layer_norm_weights_buffer,
897                                                                 n_cell, input_gate_scratch, n_batch,
898                                                                 input_gate_scratch);
899             tflite::tensor_utils::VectorBatchVectorAdd(input_gate_bias_buffer, n_cell, n_batch,
900                                                        input_gate_scratch);
901         }
902         tflite::tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch,
903                                                    input_gate_scratch);
904     }
905 
906     // For each batch and cell: update forget gate.
907     if (params.use_peephole) {
908         tflite::tensor_utils::VectorBatchVectorCwiseProductAccumulate(cell_to_forget_weights_buffer,
909                                                                       n_cell, cell_state_in_buffer,
910                                                                       n_batch, forget_gate_scratch);
911     }
912     if (params.use_layer_norm) {
913         tflite::tensor_utils::MeanStddevNormalization(forget_gate_scratch, forget_gate_scratch,
914                                                       n_cell, n_batch);
915         tflite::tensor_utils::VectorBatchVectorCwiseProduct(forget_layer_norm_weights_buffer,
916                                                             n_cell, forget_gate_scratch, n_batch,
917                                                             forget_gate_scratch);
918         tflite::tensor_utils::VectorBatchVectorAdd(forget_gate_bias_buffer, n_cell, n_batch,
919                                                    forget_gate_scratch);
920     }
921     tflite::tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch,
922                                                forget_gate_scratch);
923 
924     // For each batch and cell: update the cell.
925     if (params.use_layer_norm) {
926         tflite::tensor_utils::MeanStddevNormalization(cell_scratch, cell_scratch, n_cell, n_batch);
927         tflite::tensor_utils::VectorBatchVectorCwiseProduct(cell_layer_norm_weights_buffer, n_cell,
928                                                             cell_scratch, n_batch, cell_scratch);
929         tflite::tensor_utils::VectorBatchVectorAdd(cell_bias_buffer, n_cell, n_batch, cell_scratch);
930     }
931     tflite::tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_in_buffer,
932                                                    n_batch * n_cell, cell_state_out_buffer);
933     tflite::tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell, params.activation,
934                                                   cell_scratch);
935     if (params.use_cifg) {
936         tflite::tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell,
937                                          forget_gate_scratch);
938         tflite::tensor_utils::VectorVectorCwiseProductAccumulate(
939                 cell_scratch, forget_gate_scratch, n_batch * n_cell, cell_state_out_buffer);
940     } else {
941         tflite::tensor_utils::VectorVectorCwiseProductAccumulate(
942                 cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state_out_buffer);
943     }
944     if (params.cell_clip > 0.0) {
945         tflite::tensor_utils::ClipVector(cell_state_out_buffer, n_batch * n_cell, params.cell_clip,
946                                          cell_state_out_buffer);
947     }
948 
949     // For each batch and cell: update the output gate.
950     if (params.use_peephole) {
951         tflite::tensor_utils::VectorBatchVectorCwiseProductAccumulate(cell_to_output_weights_buffer,
952                                                                       n_cell, cell_state_out_buffer,
953                                                                       n_batch, output_gate_scratch);
954     }
955     if (params.use_layer_norm) {
956         tflite::tensor_utils::MeanStddevNormalization(output_gate_scratch, output_gate_scratch,
957                                                       n_cell, n_batch);
958         tflite::tensor_utils::VectorBatchVectorCwiseProduct(output_layer_norm_weights_buffer,
959                                                             n_cell, output_gate_scratch, n_batch,
960                                                             output_gate_scratch);
961         tflite::tensor_utils::VectorBatchVectorAdd(output_gate_bias_buffer, n_cell, n_batch,
962                                                    output_gate_scratch);
963     }
964     tflite::tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell,
965                                                output_gate_scratch);
966     tflite::tensor_utils::ApplyActivationToVector(cell_state_out_buffer, n_batch * n_cell,
967                                                   params.activation, cell_scratch);
968     tflite::tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch,
969                                                    n_batch * n_cell, output_gate_scratch);
970 
971     // For each batch: update the projection and output_state.
972     if (params.use_projection_weight) {
973         if (params.use_projection_bias) {
974             tflite::tensor_utils::VectorBatchVectorAssign(projection_bias_buffer, n_output, n_batch,
975                                                           output_buffer);
976         } else {
977             std::fill_n(output_buffer, n_batch * n_output, 0.0f);
978         }
979         tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
980                 projection_weights_buffer, n_output, n_cell, output_gate_scratch, n_batch,
981                 output_buffer,
982                 /*result_stride*/ 1);
983         if (params.proj_clip > 0.0) {
984             tflite::tensor_utils::ClipVector(output_buffer, n_batch * n_output, params.proj_clip,
985                                              output_buffer);
986         }
987     } else {
988         std::copy_n(output_gate_scratch, n_batch * n_output, output_buffer);
989     }
990     std::copy_n(output_buffer, n_batch * n_output, output_state_out_buffer);
991     return true;
992 }
993 
Eval()994 bool LSTMCell::Eval() {
995     switch (input_->type) {
996         case OperandType::TENSOR_FLOAT32: {
997             LSTMEvalFloat32(params_, GetBuffer<const float>(input_), input_->shape(),
998                             GetBuffer<const float>(input_to_input_weights_),
999                             GetBuffer<const float>(input_to_forget_weights_),
1000                             GetBuffer<const float>(input_to_cell_weights_),
1001                             GetBuffer<const float>(input_to_output_weights_),
1002                             input_to_output_weights_->shape(),
1003                             GetBuffer<const float>(recurrent_to_input_weights_),
1004                             GetBuffer<const float>(recurrent_to_forget_weights_),
1005                             GetBuffer<const float>(recurrent_to_cell_weights_),
1006                             GetBuffer<const float>(recurrent_to_output_weights_),
1007                             recurrent_to_output_weights_->shape(),
1008                             GetBuffer<const float>(cell_to_input_weights_),
1009                             GetBuffer<const float>(cell_to_forget_weights_),
1010                             GetBuffer<const float>(cell_to_output_weights_),
1011                             /*aux_input_buffer=*/nullptr,
1012                             /*aux_input_to_input_weights_buffer=*/nullptr,
1013                             /*aux_input_to_forget_weights_buffer=*/nullptr,
1014                             /*aux_input_to_cell_weights_buffer=*/nullptr,
1015                             /*aux_input_to_output_weights_buffer=*/nullptr,
1016                             GetBuffer<const float>(input_gate_bias_),
1017                             GetBuffer<const float>(forget_gate_bias_),
1018                             GetBuffer<const float>(cell_bias_),
1019                             GetBuffer<const float>(output_gate_bias_),
1020                             GetBuffer<const float>(projection_weights_),
1021                             GetBuffer<const float>(projection_bias_),
1022                             GetBuffer<const float>(output_state_in_),
1023                             GetBuffer<const float>(cell_state_in_),
1024                             GetBuffer<const float>(input_layer_norm_weights_),
1025                             GetBuffer<const float>(forget_layer_norm_weights_),
1026                             GetBuffer<const float>(cell_layer_norm_weights_),
1027                             GetBuffer<const float>(output_layer_norm_weights_),
1028                             GetBuffer<float>(output_state_out_), GetBuffer<float>(cell_state_out_),
1029                             GetBuffer<float>(output_), GetBuffer<float>(scratch_buffer_));
1030         } break;
1031         case OperandType::TENSOR_FLOAT16: {
1032             LSTMEvalFloat16(params_, GetBuffer<const _Float16>(input_), input_->shape(),
1033                             GetOptionalBuffer<const _Float16>(input_to_input_weights_),
1034                             GetBuffer<const _Float16>(input_to_forget_weights_),
1035                             GetBuffer<const _Float16>(input_to_cell_weights_),
1036                             GetBuffer<const _Float16>(input_to_output_weights_),
1037                             input_to_output_weights_->shape(),
1038                             GetOptionalBuffer<const _Float16>(recurrent_to_input_weights_),
1039                             GetBuffer<const _Float16>(recurrent_to_forget_weights_),
1040                             GetBuffer<const _Float16>(recurrent_to_cell_weights_),
1041                             GetBuffer<const _Float16>(recurrent_to_output_weights_),
1042                             recurrent_to_output_weights_->shape(),
1043                             GetOptionalBuffer<const _Float16>(cell_to_input_weights_),
1044                             GetOptionalBuffer<const _Float16>(cell_to_forget_weights_),
1045                             GetOptionalBuffer<const _Float16>(cell_to_output_weights_),
1046                             /*aux_input_buffer=*/nullptr,
1047                             /*aux_input_to_input_weights_buffer=*/nullptr,
1048                             /*aux_input_to_forget_weights_buffer=*/nullptr,
1049                             /*aux_input_to_cell_weights_buffer=*/nullptr,
1050                             /*aux_input_to_output_weights_buffer=*/nullptr,
1051                             GetOptionalBuffer<const _Float16>(input_gate_bias_),
1052                             GetBuffer<const _Float16>(forget_gate_bias_),
1053                             GetBuffer<const _Float16>(cell_bias_),
1054                             GetBuffer<const _Float16>(output_gate_bias_),
1055                             GetOptionalBuffer<const _Float16>(projection_weights_),
1056                             GetOptionalBuffer<const _Float16>(projection_bias_),
1057                             GetBuffer<const _Float16>(output_state_in_),
1058                             GetBuffer<const _Float16>(cell_state_in_),
1059                             GetOptionalBuffer<const _Float16>(input_layer_norm_weights_),
1060                             GetOptionalBuffer<const _Float16>(forget_layer_norm_weights_),
1061                             GetOptionalBuffer<const _Float16>(cell_layer_norm_weights_),
1062                             GetOptionalBuffer<const _Float16>(output_layer_norm_weights_),
1063                             GetBuffer<_Float16>(output_state_out_),
1064                             GetBuffer<_Float16>(cell_state_out_), GetBuffer<_Float16>(output_),
1065                             GetBuffer<_Float16>(scratch_buffer_));
1066         } break;
1067         default: {
1068             LOG(ERROR) << "Unsupported data type: " << static_cast<int>(input_->type);
1069             return false;
1070         }
1071     }
1072     return true;
1073 }
1074 
1075 }  // namespace nn
1076 }  // namespace android
1077