1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "Operations"
18 
19 #include <vector>
20 
21 #include "RNN.h"
22 
23 #include "CpuExecutor.h"
24 #include "CpuOperationUtils.h"
25 #include "HalInterfaces.h"
26 
27 #include "Tracing.h"
28 
29 namespace android {
30 namespace nn {
31 
32 using namespace hal;
33 
RNN(const Operation & operation,RunTimeOperandInfo * operands)34 RNN::RNN(const Operation& operation, RunTimeOperandInfo* operands) {
35     NNTRACE_TRANS("RNN::RNN");
36     input_ = GetInput(operation, operands, kInputTensor);
37     weights_ = GetInput(operation, operands, kWeightsTensor);
38     recurrent_weights_ = GetInput(operation, operands, kRecurrentWeightsTensor);
39     hidden_state_in_ = GetInput(operation, operands, kHiddenStateInTensor);
40     bias_ = GetInput(operation, operands, kBiasTensor);
41 
42     activation_ = static_cast<ActivationFn>(
43             getScalarData<int32_t>(operands[operation.inputs[kActivationParam]]));
44 
45     hidden_state_out_ = GetOutput(operation, operands, kHiddenStateOutTensor);
46     output_ = GetOutput(operation, operands, kOutputTensor);
47 }
48 
Prepare(const Operation & operation,RunTimeOperandInfo * operands,Shape * hiddenStateShape,Shape * outputShape)49 bool RNN::Prepare(const Operation& operation, RunTimeOperandInfo* operands, Shape* hiddenStateShape,
50                   Shape* outputShape) {
51     NNTRACE_TRANS("RNN::Prepare");
52     // Check we have all the inputs and outputs we need.
53     const int num_inputs = NumInputsWithValues(operation, operands);
54     NN_CHECK(num_inputs == 6);
55     NN_CHECK_EQ(NumOutputs(operation), 2);
56 
57     const RunTimeOperandInfo* input = GetInput(operation, operands, kInputTensor);
58     const RunTimeOperandInfo* input_weights = GetInput(operation, operands, kWeightsTensor);
59     const RunTimeOperandInfo* recurrent_weights =
60             GetInput(operation, operands, kRecurrentWeightsTensor);
61     const RunTimeOperandInfo* bias = GetInput(operation, operands, kBiasTensor);
62 
63     // Check all the parameters of tensor match within themselves and match the
64     // input configuration.
65     const uint32_t batch_size = SizeOfDimension(input, 0);
66     const uint32_t num_units = SizeOfDimension(input_weights, 0);
67     NN_CHECK_EQ(SizeOfDimension(input, 1), SizeOfDimension(input_weights, 1));
68     NN_CHECK_EQ(SizeOfDimension(input_weights, 0), SizeOfDimension(bias, 0));
69     NN_CHECK_EQ(SizeOfDimension(recurrent_weights, 0), SizeOfDimension(bias, 0));
70     NN_CHECK_EQ(SizeOfDimension(recurrent_weights, 1), SizeOfDimension(bias, 0));
71 
72     const Shape& inputShape = input->shape();
73 
74     // Resize state.
75     hiddenStateShape->type = inputShape.type;
76     hiddenStateShape->dimensions = {batch_size, num_units};
77 
78     // Resize output.
79     outputShape->type = inputShape.type;
80     outputShape->dimensions = {batch_size, num_units};
81 
82     return true;
83 }
84 
Eval()85 bool RNN::Eval() {
86     switch (input_->type) {
87         case OperandType::TENSOR_FLOAT16: {
88             RNNStep<_Float16>(reinterpret_cast<_Float16*>(input_->buffer), input_->shape(),
89                               reinterpret_cast<_Float16*>(hidden_state_in_->buffer),
90                               reinterpret_cast<_Float16*>(bias_->buffer),
91                               reinterpret_cast<_Float16*>(weights_->buffer), weights_->shape(),
92                               reinterpret_cast<_Float16*>(recurrent_weights_->buffer),
93                               recurrent_weights_->shape(), activation_,
94                               reinterpret_cast<_Float16*>(output_->buffer));
95             memcpy(hidden_state_out_->buffer, output_->buffer,
96                    sizeof(_Float16) * getNumberOfElements(output_->shape()));
97             break;
98         }
99         case OperandType::TENSOR_FLOAT32: {
100             RNNStep<float>(reinterpret_cast<float*>(input_->buffer), input_->shape(),
101                            reinterpret_cast<float*>(hidden_state_in_->buffer),
102                            reinterpret_cast<float*>(bias_->buffer),
103                            reinterpret_cast<float*>(weights_->buffer), weights_->shape(),
104                            reinterpret_cast<float*>(recurrent_weights_->buffer),
105                            recurrent_weights_->shape(), activation_,
106                            reinterpret_cast<float*>(output_->buffer));
107             memcpy(hidden_state_out_->buffer, output_->buffer,
108                    sizeof(float) * getNumberOfElements(output_->shape()));
109             break;
110         }
111         default: {
112             LOG(ERROR) << "Unsupported data type: " << static_cast<int>(input_->type);
113             return false;
114         }
115     }
116     return true;
117 }
118 
119 template <typename T>
RNNStep(const T * inputData,const Shape & inputShape,const T * hiddenStateInputData,const T * biasData,const T * weightsData,const Shape & weightsShape,const T * recurrentWeightsData,const Shape & recurrentWeightsShape,const int32_t activation,T * outputData)120 bool RNN::RNNStep(const T* inputData, const Shape& inputShape, const T* hiddenStateInputData,
121                   const T* biasData, const T* weightsData, const Shape& weightsShape,
122                   const T* recurrentWeightsData, const Shape& recurrentWeightsShape,
123                   const int32_t activation, T* outputData) {
124     NNTRACE_COMP("RNN::Eval");
125 
126     Shape dummyShape;
127     uint32_t numUnits = weightsShape.dimensions[0];
128     return RNNStep<T>(inputData, inputShape, /*auxInputData=*/nullptr, /*auxInputShape=*/dummyShape,
129                       hiddenStateInputData, biasData, weightsData, weightsShape,
130                       /*auxWeightsData=*/nullptr, /*auxWeightsShape=*/dummyShape,
131                       recurrentWeightsData, recurrentWeightsShape, activation,
132                       /*outputBatchStride=*/numUnits, /*outputBatchOffset=*/0, outputData);
133 }
134 
135 // A more general version of the RNNStep function.
136 // Auxiliary input is treated as if it was concatenated to a regular input and
137 // the result was multiplied by the weights matrix which was also concatenated
138 // with auxiliary weights.
139 template <typename T>
RNNStep(const T * inputData,const Shape & inputShape,const T * auxInputData,const Shape & auxInputShape,const T * hiddenStateInputData,const T * biasData,const T * weightsData,const Shape & weightsShape,const T * auxWeightsData,const Shape & auxWeightsShape,const T * recurrentWeightsData,const Shape & recurrentWeightsShape,const int32_t activation,const uint32_t outputBatchStride,const uint32_t outputBatchOffset,T * outputData,T * hiddenStateOutput)140 bool RNN::RNNStep(const T* inputData, const Shape& inputShape, const T* auxInputData,
141                   const Shape& auxInputShape, const T* hiddenStateInputData, const T* biasData,
142                   const T* weightsData, const Shape& weightsShape, const T* auxWeightsData,
143                   const Shape& auxWeightsShape, const T* recurrentWeightsData,
144                   const Shape& recurrentWeightsShape, const int32_t activation,
145                   const uint32_t outputBatchStride, const uint32_t outputBatchOffset, T* outputData,
146                   T* hiddenStateOutput) {
147     NNTRACE_COMP("RNN::Eval");
148 
149     const uint32_t batch_size = inputShape.dimensions[0];
150     const uint32_t num_units = weightsShape.dimensions[0];
151     const uint32_t input_size = inputShape.dimensions[1];
152     const uint32_t input_weights_stride = weightsShape.dimensions[1];
153     const uint32_t recurrent_weights_stride = recurrentWeightsShape.dimensions[1];
154 
155     uint32_t aux_input_size = 0;
156     uint32_t aux_input_weights_stride = 0;
157     bool hasAuxInput = (auxInputData != nullptr);
158     if (hasAuxInput) {
159         aux_input_size = auxInputShape.dimensions[1];
160         aux_input_weights_stride = auxWeightsShape.dimensions[1];
161     }
162 
163     // For each batch
164     for (uint32_t b = 0; b < batch_size; b++) {
165         // Initialize the pointer to input, output and bias.
166         const T* input_ptr_batch = inputData + b * input_size;
167         const T* hidden_state_in_ptr_batch = hiddenStateInputData + b * num_units;
168         const T* aux_input_ptr_batch = nullptr;
169         if (hasAuxInput) {
170             aux_input_ptr_batch = auxInputData + b * aux_input_size;
171         }
172         T* output_ptr_batch = outputData + b * outputBatchStride + outputBatchOffset;
173 
174         // Initialize input_weights and recurrent_weights.
175         const T* input_weights_ptr = weightsData;
176         const T* recurrent_weights_ptr = recurrentWeightsData;
177         const T* aux_input_weights_ptr = nullptr;
178         if (hasAuxInput) {
179             aux_input_weights_ptr = auxWeightsData;
180         }
181 
182         // Output = bias
183         for (uint32_t o = 0; o < num_units; o++) {
184             output_ptr_batch[o] = biasData[o];
185         }
186 
187         // Output += input * input_weights
188         for (uint32_t o = 0; o < num_units; o++) {
189             for (uint32_t i = 0; i < input_size; i++) {
190                 output_ptr_batch[o] += input_ptr_batch[i] * input_weights_ptr[i];
191             }
192             input_weights_ptr += input_weights_stride;
193         }
194 
195         if (hasAuxInput) {
196             // Output += aux_input * aux_input_weights
197             for (uint32_t o = 0; o < num_units; o++) {
198                 for (uint32_t i = 0; i < input_size; i++) {
199                     output_ptr_batch[o] += aux_input_ptr_batch[i] * aux_input_weights_ptr[i];
200                 }
201                 aux_input_weights_ptr += aux_input_weights_stride;
202             }
203         }
204 
205         // Output += recurrent_weights * hidden_state
206         for (uint32_t o = 0; o < num_units; o++) {
207             for (uint32_t h = 0; h < num_units; h++) {
208                 output_ptr_batch[o] += hidden_state_in_ptr_batch[h] * recurrent_weights_ptr[h];
209             }
210             recurrent_weights_ptr += recurrent_weights_stride;
211         }
212 
213         // Output = activation(Output)
214         for (uint32_t o = 0; o < num_units; o++) {
215             output_ptr_batch[o] =
216                     (ActivationFunctor(static_cast<ActivationFn>(activation)))(output_ptr_batch[o]);
217             if (hiddenStateOutput != nullptr) {
218                 *hiddenStateOutput = output_ptr_batch[o];
219                 ++hiddenStateOutput;
220             }
221         }
222     }
223 
224     return true;
225 }
226 
227 template bool RNN::RNNStep<_Float16>(const _Float16* inputData, const Shape& inputShape,
228                                      const _Float16* hiddenStateInputData, const _Float16* biasData,
229                                      const _Float16* weightsData, const Shape& weightsShape,
230                                      const _Float16* recurrentWeightsData,
231                                      const Shape& recurrentWeightsShape, int32_t activation,
232                                      _Float16* outputData);
233 template bool RNN::RNNStep<_Float16>(const _Float16* inputData, const Shape& inputShape,
234                                      const _Float16* auxInputData, const Shape& auxInputShape,
235                                      const _Float16* hiddenStateInputData, const _Float16* biasData,
236                                      const _Float16* weightsData, const Shape& weightsShape,
237                                      const _Float16* auxWeightsData, const Shape& auxWeightsShape,
238                                      const _Float16* recurrentWeightsData,
239                                      const Shape& recurrentWeightsShape, const int32_t activation,
240                                      const uint32_t outputBatchStride,
241                                      const uint32_t outputBatchOffset, _Float16* outputData,
242                                      _Float16* hiddenStateOutput);
243 template bool RNN::RNNStep<float>(const float* inputData, const Shape& inputShape,
244                                   const float* hiddenStateInputData, const float* biasData,
245                                   const float* weightsData, const Shape& weightsShape,
246                                   const float* recurrentWeightsData,
247                                   const Shape& recurrentWeightsShape, int32_t activation,
248                                   float* outputData);
249 template bool RNN::RNNStep<float>(const float* inputData, const Shape& inputShape,
250                                   const float* auxInputData, const Shape& auxInputShape,
251                                   const float* hiddenStateInputData, const float* biasData,
252                                   const float* weightsData, const Shape& weightsShape,
253                                   const float* auxWeightsData, const Shape& auxWeightsShape,
254                                   const float* recurrentWeightsData,
255                                   const Shape& recurrentWeightsShape, int32_t activation,
256                                   uint32_t outputBatchStride, uint32_t outputBatchStep,
257                                   float* outputData, float* hiddenStateOutput);
258 
259 }  // namespace nn
260 }  // namespace android
261