1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef ANDROID_FRAMEWORKS_ML_NN_COMMON_OPERATIONS_QUANTIZED_LSTM_H
18 #define ANDROID_FRAMEWORKS_ML_NN_COMMON_OPERATIONS_QUANTIZED_LSTM_H
19 
20 #include <vector>
21 
22 #include "OperationsUtils.h"
23 
24 namespace android {
25 namespace nn {
26 
27 struct RunTimeOperandInfo;
28 
29 class QuantizedLSTMCell {
30    public:
31     QuantizedLSTMCell(const hal::Operation& operation, RunTimeOperandInfo* operands);
32 
33     static bool prepare(const hal::Operation& operation, RunTimeOperandInfo* operands,
34                         Shape* cellStateShape, Shape* outputShape);
35     bool eval();
36 
37     // Inputs:
38     static constexpr int kInputTensor = 0;
39     // Input weight tensors of size: {n_cell, n_input}
40     static constexpr int kInputToInputWeightsTensor = 1;
41     static constexpr int kInputToForgetWeightsTensor = 2;
42     static constexpr int kInputToCellWeightsTensor = 3;
43     static constexpr int kInputToOutputWeightsTensor = 4;
44 
45     // Recurrent weight tensors of size {n_cell, n_output}
46     static constexpr int kRecurrentToInputWeightsTensor = 5;
47     static constexpr int kRecurrentToForgetWeightsTensor = 6;
48     static constexpr int kRecurrentToCellWeightsTensor = 7;
49     static constexpr int kRecurrentToOutputWeightsTensor = 8;
50 
51     // Gates bias tensors of size {n_cell}
52     static constexpr int kInputGateBiasTensor = 9;
53     static constexpr int kForgetGateBiasTensor = 10;
54     static constexpr int kCellGateBiasTensor = 11;
55     static constexpr int kOutputGateBiasTensor = 12;
56 
57     static constexpr int kPrevCellStateTensor = 13;
58     static constexpr int kPrevOutputTensor = 14;
59 
60     // Outputs:
61     static constexpr int kCellStateOutTensor = 0;
62     static constexpr int kOutputTensor = 1;
63 
64    private:
65     const RunTimeOperandInfo* input_;
66 
67     const RunTimeOperandInfo* inputToInputWeights_;
68     const RunTimeOperandInfo* inputToForgetWeights_;
69     const RunTimeOperandInfo* inputToCellWeights_;
70     const RunTimeOperandInfo* inputToOutputWeights_;
71 
72     const RunTimeOperandInfo* recurrentToInputWeights_;
73     const RunTimeOperandInfo* recurrentToForgetWeights_;
74     const RunTimeOperandInfo* recurrentToCellWeights_;
75     const RunTimeOperandInfo* recurrentToOutputWeights_;
76 
77     const RunTimeOperandInfo* inputGateBias_;
78     const RunTimeOperandInfo* forgetGateBias_;
79     const RunTimeOperandInfo* cellGateBias_;
80     const RunTimeOperandInfo* outputGateBias_;
81 
82     const RunTimeOperandInfo* prevCellState_;
83     const RunTimeOperandInfo* prevOutput_;
84 
85     RunTimeOperandInfo* cellStateOut_;
86     RunTimeOperandInfo* output_;
87 
88     void concatenateWeights(const std::vector<uint32_t>& weightsDims, uint8_t* weights);
89     void concatenateBiases(uint32_t outputSize, int32_t* bias);
90 };
91 
92 }  // namespace nn
93 }  // namespace android
94 
95 #endif  // ANDROID_FRAMEWORKS_ML_NN_COMMON_OPERATIONS_QUANTIZED_LSTM_H
96