1 /*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define LOG_TAG "Operations"
18
19 #include <vector>
20
21 #include "RNN.h"
22
23 #include "CpuExecutor.h"
24 #include "CpuOperationUtils.h"
25 #include "HalInterfaces.h"
26
27 #include "Tracing.h"
28
29 namespace android {
30 namespace nn {
31
32 using namespace hal;
33
RNN(const Operation & operation,RunTimeOperandInfo * operands)34 RNN::RNN(const Operation& operation, RunTimeOperandInfo* operands) {
35 NNTRACE_TRANS("RNN::RNN");
36 input_ = GetInput(operation, operands, kInputTensor);
37 weights_ = GetInput(operation, operands, kWeightsTensor);
38 recurrent_weights_ = GetInput(operation, operands, kRecurrentWeightsTensor);
39 hidden_state_in_ = GetInput(operation, operands, kHiddenStateInTensor);
40 bias_ = GetInput(operation, operands, kBiasTensor);
41
42 activation_ = static_cast<ActivationFn>(
43 getScalarData<int32_t>(operands[operation.inputs[kActivationParam]]));
44
45 hidden_state_out_ = GetOutput(operation, operands, kHiddenStateOutTensor);
46 output_ = GetOutput(operation, operands, kOutputTensor);
47 }
48
Prepare(const Operation & operation,RunTimeOperandInfo * operands,Shape * hiddenStateShape,Shape * outputShape)49 bool RNN::Prepare(const Operation& operation, RunTimeOperandInfo* operands, Shape* hiddenStateShape,
50 Shape* outputShape) {
51 NNTRACE_TRANS("RNN::Prepare");
52 // Check we have all the inputs and outputs we need.
53 const int num_inputs = NumInputsWithValues(operation, operands);
54 NN_CHECK(num_inputs == 6);
55 NN_CHECK_EQ(NumOutputs(operation), 2);
56
57 const RunTimeOperandInfo* input = GetInput(operation, operands, kInputTensor);
58 const RunTimeOperandInfo* input_weights = GetInput(operation, operands, kWeightsTensor);
59 const RunTimeOperandInfo* recurrent_weights =
60 GetInput(operation, operands, kRecurrentWeightsTensor);
61 const RunTimeOperandInfo* bias = GetInput(operation, operands, kBiasTensor);
62
63 // Check all the parameters of tensor match within themselves and match the
64 // input configuration.
65 const uint32_t batch_size = SizeOfDimension(input, 0);
66 const uint32_t num_units = SizeOfDimension(input_weights, 0);
67 NN_CHECK_EQ(SizeOfDimension(input, 1), SizeOfDimension(input_weights, 1));
68 NN_CHECK_EQ(SizeOfDimension(input_weights, 0), SizeOfDimension(bias, 0));
69 NN_CHECK_EQ(SizeOfDimension(recurrent_weights, 0), SizeOfDimension(bias, 0));
70 NN_CHECK_EQ(SizeOfDimension(recurrent_weights, 1), SizeOfDimension(bias, 0));
71
72 const Shape& inputShape = input->shape();
73
74 // Resize state.
75 hiddenStateShape->type = inputShape.type;
76 hiddenStateShape->dimensions = {batch_size, num_units};
77
78 // Resize output.
79 outputShape->type = inputShape.type;
80 outputShape->dimensions = {batch_size, num_units};
81
82 return true;
83 }
84
Eval()85 bool RNN::Eval() {
86 switch (input_->type) {
87 case OperandType::TENSOR_FLOAT16: {
88 RNNStep<_Float16>(reinterpret_cast<_Float16*>(input_->buffer), input_->shape(),
89 reinterpret_cast<_Float16*>(hidden_state_in_->buffer),
90 reinterpret_cast<_Float16*>(bias_->buffer),
91 reinterpret_cast<_Float16*>(weights_->buffer), weights_->shape(),
92 reinterpret_cast<_Float16*>(recurrent_weights_->buffer),
93 recurrent_weights_->shape(), activation_,
94 reinterpret_cast<_Float16*>(output_->buffer));
95 memcpy(hidden_state_out_->buffer, output_->buffer,
96 sizeof(_Float16) * getNumberOfElements(output_->shape()));
97 break;
98 }
99 case OperandType::TENSOR_FLOAT32: {
100 RNNStep<float>(reinterpret_cast<float*>(input_->buffer), input_->shape(),
101 reinterpret_cast<float*>(hidden_state_in_->buffer),
102 reinterpret_cast<float*>(bias_->buffer),
103 reinterpret_cast<float*>(weights_->buffer), weights_->shape(),
104 reinterpret_cast<float*>(recurrent_weights_->buffer),
105 recurrent_weights_->shape(), activation_,
106 reinterpret_cast<float*>(output_->buffer));
107 memcpy(hidden_state_out_->buffer, output_->buffer,
108 sizeof(float) * getNumberOfElements(output_->shape()));
109 break;
110 }
111 default: {
112 LOG(ERROR) << "Unsupported data type: " << static_cast<int>(input_->type);
113 return false;
114 }
115 }
116 return true;
117 }
118
119 template <typename T>
RNNStep(const T * inputData,const Shape & inputShape,const T * hiddenStateInputData,const T * biasData,const T * weightsData,const Shape & weightsShape,const T * recurrentWeightsData,const Shape & recurrentWeightsShape,const int32_t activation,T * outputData)120 bool RNN::RNNStep(const T* inputData, const Shape& inputShape, const T* hiddenStateInputData,
121 const T* biasData, const T* weightsData, const Shape& weightsShape,
122 const T* recurrentWeightsData, const Shape& recurrentWeightsShape,
123 const int32_t activation, T* outputData) {
124 NNTRACE_COMP("RNN::Eval");
125
126 Shape dummyShape;
127 uint32_t numUnits = weightsShape.dimensions[0];
128 return RNNStep<T>(inputData, inputShape, /*auxInputData=*/nullptr, /*auxInputShape=*/dummyShape,
129 hiddenStateInputData, biasData, weightsData, weightsShape,
130 /*auxWeightsData=*/nullptr, /*auxWeightsShape=*/dummyShape,
131 recurrentWeightsData, recurrentWeightsShape, activation,
132 /*outputBatchStride=*/numUnits, /*outputBatchOffset=*/0, outputData);
133 }
134
135 // A more general version of the RNNStep function.
136 // Auxiliary input is treated as if it was concatenated to a regular input and
137 // the result was multiplied by the weights matrix which was also concatenated
138 // with auxiliary weights.
139 template <typename T>
RNNStep(const T * inputData,const Shape & inputShape,const T * auxInputData,const Shape & auxInputShape,const T * hiddenStateInputData,const T * biasData,const T * weightsData,const Shape & weightsShape,const T * auxWeightsData,const Shape & auxWeightsShape,const T * recurrentWeightsData,const Shape & recurrentWeightsShape,const int32_t activation,const uint32_t outputBatchStride,const uint32_t outputBatchOffset,T * outputData,T * hiddenStateOutput)140 bool RNN::RNNStep(const T* inputData, const Shape& inputShape, const T* auxInputData,
141 const Shape& auxInputShape, const T* hiddenStateInputData, const T* biasData,
142 const T* weightsData, const Shape& weightsShape, const T* auxWeightsData,
143 const Shape& auxWeightsShape, const T* recurrentWeightsData,
144 const Shape& recurrentWeightsShape, const int32_t activation,
145 const uint32_t outputBatchStride, const uint32_t outputBatchOffset, T* outputData,
146 T* hiddenStateOutput) {
147 NNTRACE_COMP("RNN::Eval");
148
149 const uint32_t batch_size = inputShape.dimensions[0];
150 const uint32_t num_units = weightsShape.dimensions[0];
151 const uint32_t input_size = inputShape.dimensions[1];
152 const uint32_t input_weights_stride = weightsShape.dimensions[1];
153 const uint32_t recurrent_weights_stride = recurrentWeightsShape.dimensions[1];
154
155 uint32_t aux_input_size = 0;
156 uint32_t aux_input_weights_stride = 0;
157 bool hasAuxInput = (auxInputData != nullptr);
158 if (hasAuxInput) {
159 aux_input_size = auxInputShape.dimensions[1];
160 aux_input_weights_stride = auxWeightsShape.dimensions[1];
161 }
162
163 // For each batch
164 for (uint32_t b = 0; b < batch_size; b++) {
165 // Initialize the pointer to input, output and bias.
166 const T* input_ptr_batch = inputData + b * input_size;
167 const T* hidden_state_in_ptr_batch = hiddenStateInputData + b * num_units;
168 const T* aux_input_ptr_batch = nullptr;
169 if (hasAuxInput) {
170 aux_input_ptr_batch = auxInputData + b * aux_input_size;
171 }
172 T* output_ptr_batch = outputData + b * outputBatchStride + outputBatchOffset;
173
174 // Initialize input_weights and recurrent_weights.
175 const T* input_weights_ptr = weightsData;
176 const T* recurrent_weights_ptr = recurrentWeightsData;
177 const T* aux_input_weights_ptr = nullptr;
178 if (hasAuxInput) {
179 aux_input_weights_ptr = auxWeightsData;
180 }
181
182 // Output = bias
183 for (uint32_t o = 0; o < num_units; o++) {
184 output_ptr_batch[o] = biasData[o];
185 }
186
187 // Output += input * input_weights
188 for (uint32_t o = 0; o < num_units; o++) {
189 for (uint32_t i = 0; i < input_size; i++) {
190 output_ptr_batch[o] += input_ptr_batch[i] * input_weights_ptr[i];
191 }
192 input_weights_ptr += input_weights_stride;
193 }
194
195 if (hasAuxInput) {
196 // Output += aux_input * aux_input_weights
197 for (uint32_t o = 0; o < num_units; o++) {
198 for (uint32_t i = 0; i < input_size; i++) {
199 output_ptr_batch[o] += aux_input_ptr_batch[i] * aux_input_weights_ptr[i];
200 }
201 aux_input_weights_ptr += aux_input_weights_stride;
202 }
203 }
204
205 // Output += recurrent_weights * hidden_state
206 for (uint32_t o = 0; o < num_units; o++) {
207 for (uint32_t h = 0; h < num_units; h++) {
208 output_ptr_batch[o] += hidden_state_in_ptr_batch[h] * recurrent_weights_ptr[h];
209 }
210 recurrent_weights_ptr += recurrent_weights_stride;
211 }
212
213 // Output = activation(Output)
214 for (uint32_t o = 0; o < num_units; o++) {
215 output_ptr_batch[o] =
216 (ActivationFunctor(static_cast<ActivationFn>(activation)))(output_ptr_batch[o]);
217 if (hiddenStateOutput != nullptr) {
218 *hiddenStateOutput = output_ptr_batch[o];
219 ++hiddenStateOutput;
220 }
221 }
222 }
223
224 return true;
225 }
226
227 template bool RNN::RNNStep<_Float16>(const _Float16* inputData, const Shape& inputShape,
228 const _Float16* hiddenStateInputData, const _Float16* biasData,
229 const _Float16* weightsData, const Shape& weightsShape,
230 const _Float16* recurrentWeightsData,
231 const Shape& recurrentWeightsShape, int32_t activation,
232 _Float16* outputData);
233 template bool RNN::RNNStep<_Float16>(const _Float16* inputData, const Shape& inputShape,
234 const _Float16* auxInputData, const Shape& auxInputShape,
235 const _Float16* hiddenStateInputData, const _Float16* biasData,
236 const _Float16* weightsData, const Shape& weightsShape,
237 const _Float16* auxWeightsData, const Shape& auxWeightsShape,
238 const _Float16* recurrentWeightsData,
239 const Shape& recurrentWeightsShape, const int32_t activation,
240 const uint32_t outputBatchStride,
241 const uint32_t outputBatchOffset, _Float16* outputData,
242 _Float16* hiddenStateOutput);
243 template bool RNN::RNNStep<float>(const float* inputData, const Shape& inputShape,
244 const float* hiddenStateInputData, const float* biasData,
245 const float* weightsData, const Shape& weightsShape,
246 const float* recurrentWeightsData,
247 const Shape& recurrentWeightsShape, int32_t activation,
248 float* outputData);
249 template bool RNN::RNNStep<float>(const float* inputData, const Shape& inputShape,
250 const float* auxInputData, const Shape& auxInputShape,
251 const float* hiddenStateInputData, const float* biasData,
252 const float* weightsData, const Shape& weightsShape,
253 const float* auxWeightsData, const Shape& auxWeightsShape,
254 const float* recurrentWeightsData,
255 const Shape& recurrentWeightsShape, int32_t activation,
256 uint32_t outputBatchStride, uint32_t outputBatchStep,
257 float* outputData, float* hiddenStateOutput);
258
259 } // namespace nn
260 } // namespace android
261