1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define LOG_TAG "Operations"
18
19 #include <algorithm>
20 #include <utility>
21 #include <vector>
22
23 #include "HalInterfaces.h"
24 #include "OperationResolver.h"
25 #include "RNN.h"
26
27 namespace android {
28 namespace nn {
29 namespace unidirectional_sequence_rnn {
30
31 constexpr uint32_t kNumInputs = 7;
32 constexpr uint32_t kInputTensor = 0;
33 constexpr uint32_t kWeightsTensor = 1;
34 constexpr uint32_t kRecurrentWeightsTensor = 2;
35 constexpr uint32_t kBiasTensor = 3;
36 constexpr uint32_t kHiddenStateTensor = 4;
37 constexpr uint32_t kActivationParam = 5;
38 constexpr uint32_t kTimeMajorParam = 6;
39
40 constexpr uint32_t kNumOutputs = 1;
41 constexpr uint32_t kNumOutputsWithState = 2;
42 constexpr uint32_t kOutputTensor = 0;
43 constexpr uint32_t kStateOutputTensor = 1;
44
45 namespace {
46
47 using namespace hal;
48
49 template <typename T>
transposeFirstTwoDims(const T * input,const Shape & inputShape,T * output)50 void transposeFirstTwoDims(const T* input, const Shape& inputShape, T* output) {
51 const uint32_t firstDimSize = getSizeOfDimension(inputShape, 0);
52 const uint32_t secondDimSize = getSizeOfDimension(inputShape, 1);
53 const uint32_t inputSize = getSizeOfDimension(inputShape, 2);
54 for (int f = 0; f < firstDimSize; ++f) {
55 for (int s = 0; s < secondDimSize; ++s) {
56 for (int i = 0; i < inputSize; ++i) {
57 const uint32_t inputIndex = f * secondDimSize * inputSize + s * inputSize + i;
58 const uint32_t outputIndex = s * firstDimSize * inputSize + f * inputSize + i;
59 output[outputIndex] = input[inputIndex];
60 }
61 }
62 }
63 }
64
65 template <typename T>
executeTyped(IOperationExecutionContext * context)66 bool executeTyped(IOperationExecutionContext* context) {
67 const T* input = context->getInputBuffer<T>(kInputTensor);
68 Shape inputShape = context->getInputShape(kInputTensor);
69 const T* weights = context->getInputBuffer<T>(kWeightsTensor);
70 Shape weightsShape = context->getInputShape(kWeightsTensor);
71 const T* recurrentWeights = context->getInputBuffer<T>(kRecurrentWeightsTensor);
72 Shape recurrentWeightsShape = context->getInputShape(kRecurrentWeightsTensor);
73 const T* bias = context->getInputBuffer<T>(kBiasTensor);
74 const T* hiddenState = context->getInputBuffer<T>(kHiddenStateTensor);
75 int32_t activation = context->getInputValue<int32_t>(kActivationParam);
76
77 T* output = context->getOutputBuffer<T>(kOutputTensor);
78 Shape outputShape = context->getOutputShape(kOutputTensor);
79
80 int32_t timeMajor = context->getInputValue<int32_t>(kTimeMajorParam);
81 // If the input tensors are not in time major format, we transpose the first
82 // two dimensions, and set input and output pointers to temporary vectors
83 // which are transposed back after the RNN is applied.
84 std::vector<T> inputTransposed;
85 std::vector<T> outputTransposed;
86 if (!timeMajor) {
87 // Convert input and output to time major format.
88 inputTransposed.resize(getNumberOfElements(inputShape));
89 outputTransposed.resize(getNumberOfElements(outputShape));
90 transposeFirstTwoDims(input, inputShape, inputTransposed.data());
91 input = inputTransposed.data();
92 output = outputTransposed.data();
93 std::swap(inputShape.dimensions[0], inputShape.dimensions[1]);
94 std::swap(outputShape.dimensions[0], outputShape.dimensions[1]);
95 }
96
97 const uint32_t maxTime = getSizeOfDimension(inputShape, 0);
98 const uint32_t batchSize = getSizeOfDimension(inputShape, 1);
99 const uint32_t inputSize = getSizeOfDimension(inputShape, 2);
100 const uint32_t numUnits = getSizeOfDimension(weightsShape, 0);
101
102 // A shape at a fixed step (removed time dimension).
103 Shape fixedTimeInputShape = inputShape;
104 fixedTimeInputShape.dimensions.resize(2);
105 fixedTimeInputShape.dimensions[0] = inputShape.dimensions[1];
106 fixedTimeInputShape.dimensions[1] = inputShape.dimensions[2];
107
108 for (int i = 0; i < maxTime; ++i) {
109 RNN::RNNStep<T>(input, fixedTimeInputShape, hiddenState, bias, weights, weightsShape,
110 recurrentWeights, recurrentWeightsShape, activation, output);
111 input += batchSize * inputSize;
112 hiddenState = output;
113 output += batchSize * numUnits;
114 }
115
116 if (!timeMajor) {
117 transposeFirstTwoDims(outputTransposed.data(), outputShape,
118 context->getOutputBuffer<T>(kOutputTensor));
119 }
120
121 if (context->getNumOutputs() == kNumOutputsWithState) {
122 // We checked that the state output is not omitted during preparation.
123 T* stateOutput = context->getOutputBuffer<T>(kStateOutputTensor);
124 std::copy(hiddenState, hiddenState + batchSize * numUnits, stateOutput);
125 }
126 return true;
127 }
128
129 } // namespace
130
validate(const IOperationValidationContext * context)131 bool validate(const IOperationValidationContext* context) {
132 NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
133 const int numOutputs = context->getNumOutputs();
134 NN_RET_CHECK(numOutputs == kNumOutputs || numOutputs == kNumOutputsWithState);
135 OperandType inputType = context->getInputType(kInputTensor);
136 if (inputType != OperandType::TENSOR_FLOAT16 && inputType != OperandType::TENSOR_FLOAT32) {
137 LOG(ERROR) << "Unsupported input operand type for UNIDIRECTIONAL_SEQUENCE_RNN op: "
138 << toString(inputType);
139 return false;
140 }
141 NN_RET_CHECK(validateInputTypes(context, {inputType, inputType, inputType, inputType, inputType,
142 OperandType::INT32, OperandType::INT32}));
143 std::vector<OperandType> outputTypes = {inputType};
144 HalVersion minHalVersionSupported = HalVersion::V1_2;
145 if (numOutputs == kNumOutputsWithState) {
146 minHalVersionSupported = HalVersion::V1_3;
147 outputTypes.push_back(inputType);
148 }
149 NN_RET_CHECK(validateOutputTypes(context, outputTypes));
150 return validateHalVersion(context, minHalVersionSupported);
151 }
152
prepare(IOperationExecutionContext * context)153 bool prepare(IOperationExecutionContext* context) {
154 Shape input = context->getInputShape(kInputTensor);
155 Shape weights = context->getInputShape(kWeightsTensor);
156 Shape recurrentWeights = context->getInputShape(kRecurrentWeightsTensor);
157 Shape bias = context->getInputShape(kBiasTensor);
158 Shape hiddenState = context->getInputShape(kHiddenStateTensor);
159
160 int32_t timeMajor = context->getInputValue<int32_t>(kTimeMajorParam);
161 NN_RET_CHECK(timeMajor == 0 || timeMajor == 1);
162 const uint32_t batchSize =
163 timeMajor ? getSizeOfDimension(input, 1) : getSizeOfDimension(input, 0);
164 const uint32_t maxTime =
165 timeMajor ? getSizeOfDimension(input, 0) : getSizeOfDimension(input, 1);
166 const uint32_t numUnits = getSizeOfDimension(weights, 0);
167 const uint32_t inputSize = getSizeOfDimension(input, 2);
168
169 NN_RET_CHECK_EQ(getNumberOfDimensions(input), 3);
170 NN_RET_CHECK_EQ(getNumberOfDimensions(weights), 2);
171 NN_RET_CHECK_EQ(getNumberOfDimensions(recurrentWeights), 2);
172 NN_RET_CHECK_EQ(getNumberOfDimensions(bias), 1);
173 NN_RET_CHECK_EQ(getNumberOfDimensions(hiddenState), 2);
174
175 NN_RET_CHECK_EQ(inputSize, getSizeOfDimension(weights, 1));
176 NN_RET_CHECK_EQ(numUnits, getSizeOfDimension(bias, 0));
177 NN_RET_CHECK_EQ(numUnits, getSizeOfDimension(recurrentWeights, 0));
178 NN_RET_CHECK_EQ(numUnits, getSizeOfDimension(recurrentWeights, 1));
179 NN_RET_CHECK_EQ(batchSize, getSizeOfDimension(hiddenState, 0));
180 NN_RET_CHECK_EQ(numUnits, getSizeOfDimension(hiddenState, 1));
181
182 Shape output = context->getOutputShape(kOutputTensor);
183 output.dimensions.resize(3);
184 output.dimensions[0] = timeMajor ? maxTime : batchSize;
185 output.dimensions[1] = timeMajor ? batchSize : maxTime;
186 output.dimensions[2] = numUnits;
187
188 if (context->getNumOutputs() == kNumOutputsWithState) {
189 NN_RET_CHECK(!context->isOmittedOutput(kStateOutputTensor));
190 Shape outputStateShape = context->getInputShape(kHiddenStateTensor);
191 outputStateShape.dimensions.resize(2);
192 outputStateShape.dimensions[0] = batchSize;
193 outputStateShape.dimensions[1] = numUnits;
194 NN_RET_CHECK(context->setOutputShape(kStateOutputTensor, outputStateShape));
195 }
196
197 return context->setOutputShape(kOutputTensor, output);
198 }
199
execute(IOperationExecutionContext * context)200 bool execute(IOperationExecutionContext* context) {
201 if (context->getInputType(kInputTensor) == OperandType::TENSOR_FLOAT16) {
202 executeTyped<_Float16>(context);
203 } else {
204 executeTyped<float>(context);
205 }
206 return true;
207 }
208
209 } // namespace unidirectional_sequence_rnn
210
211 NN_REGISTER_OPERATION(UNIDIRECTIONAL_SEQUENCE_RNN, "UNIDIRECTIONAL_SEQUENCE_RNN",
212 unidirectional_sequence_rnn::validate, unidirectional_sequence_rnn::prepare,
213 unidirectional_sequence_rnn::execute);
214
215 } // namespace nn
216 } // namespace android
217