1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "Operations"
18 
19 #include <algorithm>
20 #include <utility>
21 #include <vector>
22 
23 #include "HalInterfaces.h"
24 #include "OperationResolver.h"
25 #include "RNN.h"
26 
27 namespace android {
28 namespace nn {
29 namespace bidirectional_sequence_rnn {
30 
31 constexpr uint32_t kNumInputs = 15;
32 constexpr uint32_t kInputTensor = 0;
33 // Forward cell tensors
34 constexpr uint32_t kFwWeightsTensor = 1;
35 constexpr uint32_t kFwRecurrentWeightsTensor = 2;
36 constexpr uint32_t kFwBiasTensor = 3;
37 constexpr uint32_t kFwHiddenStateTensor = 4;
38 // Backward cell tensors
39 constexpr uint32_t kBwWeightsTensor = 5;
40 constexpr uint32_t kBwRecurrentWeightsTensor = 6;
41 constexpr uint32_t kBwBiasTensor = 7;
42 constexpr uint32_t kBwHiddenStateTensor = 8;
43 // Auxiliary inputs
44 constexpr uint32_t kAuxInputTensor = 9;       // optional
45 constexpr uint32_t kFwAuxWeightsTensor = 10;  // optional
46 constexpr uint32_t kBwAuxWeightsTensor = 11;  // optional
47 // Cell parameters
48 constexpr uint32_t kActivationParam = 12;
49 constexpr uint32_t kTimeMajorParam = 13;
50 constexpr uint32_t kMergeOutputsParam = 14;
51 
52 constexpr uint32_t kNumOutputs = 2;
53 constexpr uint32_t kNumOutputsMerged = 1;
54 constexpr uint32_t kNumOutputsWithState = 4;
55 constexpr uint32_t kNumOutputsMergedWithState = 3;
56 
57 constexpr uint32_t kFwOutputTensor = 0;
58 constexpr uint32_t kBwOutputTensor = 1;  // Only if mergeOutputs parameter is false
59 constexpr uint32_t kFwOutputHiddenStateTensor = 2;
60 constexpr uint32_t kBwOutputHiddenStateTensor = 3;
61 
62 namespace {
63 
64 using namespace hal;
65 
66 template <typename T>
transposeFirstTwoDims(const T * input,const Shape & inputShape,T * output)67 void transposeFirstTwoDims(const T* input, const Shape& inputShape, T* output) {
68     const uint32_t firstDimSize = getSizeOfDimension(inputShape, 0);
69     const uint32_t secondDimSize = getSizeOfDimension(inputShape, 1);
70     const uint32_t inputSize = getSizeOfDimension(inputShape, 2);
71     for (int f = 0; f < firstDimSize; ++f) {
72         for (int s = 0; s < secondDimSize; ++s) {
73             for (int i = 0; i < inputSize; ++i) {
74                 const uint32_t inputIndex = f * secondDimSize * inputSize + s * inputSize + i;
75                 const uint32_t outputIndex = s * firstDimSize * inputSize + f * inputSize + i;
76                 output[outputIndex] = input[inputIndex];
77             }
78         }
79     }
80 }
81 
removeFirstDim(const Shape & input)82 Shape removeFirstDim(const Shape& input) {
83     Shape output = input;
84     output.dimensions.resize(input.dimensions.size() - 1);
85     for (int i = 0; i < input.dimensions.size() - 1; ++i) {
86         output.dimensions[i] = input.dimensions[i + 1];
87     }
88     return output;
89 }
90 
91 enum class LinkingMode {
92     NO_LINKING,
93     PARALLEL_LINKING,
94     CROSS_LINKING,
95 };
96 
getLinkingMode(IOperationExecutionContext * context,LinkingMode * linkingMode)97 bool getLinkingMode(IOperationExecutionContext* context, LinkingMode* linkingMode) {
98     const bool hasAuxInput = !context->isOmittedInput(kAuxInputTensor);
99     const bool hasFwAuxWeights = !context->isOmittedInput(kFwAuxWeightsTensor);
100     const bool hasBwAuxWeights = !context->isOmittedInput(kBwAuxWeightsTensor);
101 
102     // Three possible configurations for three possible linking modes:
103     // 1) NO_LINKING -- no auxiliary tensors at all
104     // 2) PARALLEL_LINKING -- auxiliary input is provided and used as a regular
105     //    input to the backward network, so the auxiliary weights are omitted.
106     // 3) CROSS_LINKING -- auxiliary input is provided and multiplied by
107     //    auxiliary weights.
108     if (!hasAuxInput && !hasFwAuxWeights && !hasBwAuxWeights) {
109         *linkingMode = LinkingMode::NO_LINKING;
110     } else if (hasAuxInput && !hasFwAuxWeights && !hasBwAuxWeights) {
111         *linkingMode = LinkingMode::PARALLEL_LINKING;
112     } else if (hasAuxInput && hasFwAuxWeights && hasBwAuxWeights) {
113         *linkingMode = LinkingMode::CROSS_LINKING;
114     } else {
115         NN_RET_CHECK_FAIL()
116                 << "Unsupported auxiliary tensors configuration for BIDIRECTIONAL_SEQUENCE_RNN.";
117     }
118 
119     return true;
120 }
121 
122 template <typename T>
executeTyped(IOperationExecutionContext * context)123 bool executeTyped(IOperationExecutionContext* context) {
124     const T* input = context->getInputBuffer<T>(kInputTensor);
125     Shape inputShape = context->getInputShape(kInputTensor);
126 
127     const T* fwWeights = context->getInputBuffer<T>(kFwWeightsTensor);
128     Shape fwWeightsShape = context->getInputShape(kFwWeightsTensor);
129     const T* fwRecurrentWeights = context->getInputBuffer<T>(kFwRecurrentWeightsTensor);
130     Shape fwRecurrentWeightsShape = context->getInputShape(kFwRecurrentWeightsTensor);
131     const T* fwBias = context->getInputBuffer<T>(kFwBiasTensor);
132     const T* fwHiddenState = context->getInputBuffer<T>(kFwHiddenStateTensor);
133 
134     const T* bwWeights = context->getInputBuffer<T>(kBwWeightsTensor);
135     Shape bwWeightsShape = context->getInputShape(kBwWeightsTensor);
136     const T* bwRecurrentWeights = context->getInputBuffer<T>(kBwRecurrentWeightsTensor);
137     Shape bwRecurrentWeightsShape = context->getInputShape(kBwRecurrentWeightsTensor);
138     const T* bwBias = context->getInputBuffer<T>(kBwBiasTensor);
139     const T* bwHiddenState = context->getInputBuffer<T>(kBwHiddenStateTensor);
140 
141     const T* auxInput = nullptr;
142     const T* fwAuxWeights = nullptr;
143     const T* bwAuxWeights = nullptr;
144     LinkingMode linkingMode;
145     NN_RET_CHECK(getLinkingMode(context, &linkingMode));
146     if (linkingMode == LinkingMode::CROSS_LINKING) {
147         auxInput = context->getInputBuffer<T>(kAuxInputTensor);
148         fwAuxWeights = context->getInputBuffer<T>(kFwAuxWeightsTensor);
149         bwAuxWeights = context->getInputBuffer<T>(kBwAuxWeightsTensor);
150     } else if (linkingMode == LinkingMode::PARALLEL_LINKING) {
151         auxInput = context->getInputBuffer<T>(kAuxInputTensor);
152     }
153     const bool hasAuxInput = (linkingMode == LinkingMode::CROSS_LINKING ||
154                               linkingMode == LinkingMode::PARALLEL_LINKING);
155     const bool hasAuxWeights = (linkingMode == LinkingMode::CROSS_LINKING);
156     Shape auxInputShape = context->getInputShape(kAuxInputTensor);
157     Shape fwAuxWeightsShape = context->getInputShape(kFwAuxWeightsTensor);
158     Shape bwAuxWeightsShape = context->getInputShape(kBwAuxWeightsTensor);
159 
160     const int32_t activation = context->getInputValue<int32_t>(kActivationParam);
161     const bool timeMajor = context->getInputValue<bool>(kTimeMajorParam);
162     const bool mergeOutputs = context->getInputValue<bool>(kMergeOutputsParam);
163 
164     T* fwOutput = context->getOutputBuffer<T>(kFwOutputTensor);
165     Shape fwOutputShape = context->getOutputShape(kFwOutputTensor);
166     T* bwOutput = nullptr;
167     Shape bwOutputShape;
168     if (!mergeOutputs) {
169         bwOutputShape = context->getOutputShape(kBwOutputTensor);
170         bwOutput = context->getOutputBuffer<T>(kBwOutputTensor);
171     }
172 
173     // If the input tensors are not in time major format, we transpose the first
174     // two dimensions, and set input and output pointers to temporary vectors
175     // which are transposed back after the RNN is applied.
176     std::vector<T> inputTransposed;
177     std::vector<T> auxInputTransposed;
178     std::vector<T> fwOutputTransposed;
179     std::vector<T> bwOutputTransposed;
180     if (!timeMajor) {
181         // First, resize temporary buffers to accommodate for transposed tensors.
182         inputTransposed.resize(getNumberOfElements(inputShape));
183         if (hasAuxInput) {
184             auxInputTransposed.resize(getNumberOfElements(auxInputShape));
185         }
186         fwOutputTransposed.resize(getNumberOfElements(fwOutputShape));
187         if (!mergeOutputs) {
188             bwOutputTransposed.resize(getNumberOfElements(bwOutputShape));
189         }
190 
191         // Transpose the input tensors.
192         transposeFirstTwoDims(input, inputShape, inputTransposed.data());
193         if (hasAuxInput) {
194             transposeFirstTwoDims(auxInput, auxInputShape, auxInputTransposed.data());
195         }
196 
197         // Change input and output pointers to the temporary buffers.
198         input = inputTransposed.data();
199         if (hasAuxInput) {
200             auxInput = auxInputTransposed.data();
201         }
202         fwOutput = fwOutputTransposed.data();
203         if (!mergeOutputs) {
204             bwOutput = bwOutputTransposed.data();
205         }
206 
207         // Swap the first two dimensions in the Shapes to reflect the
208         // transposition.
209         std::swap(inputShape.dimensions[0], inputShape.dimensions[1]);
210         if (hasAuxInput) {
211             std::swap(auxInputShape.dimensions[0], auxInputShape.dimensions[1]);
212         }
213         std::swap(fwOutputShape.dimensions[0], fwOutputShape.dimensions[1]);
214         if (!mergeOutputs) {
215             std::swap(bwOutputShape.dimensions[0], bwOutputShape.dimensions[1]);
216         }
217     }
218 
219     const uint32_t maxTime = getSizeOfDimension(inputShape, 0);
220     const uint32_t batchSize = getSizeOfDimension(inputShape, 1);
221     const uint32_t inputSize = getSizeOfDimension(inputShape, 2);
222     uint32_t auxInputSize = 0;
223     if (hasAuxInput) {
224         auxInputSize = getSizeOfDimension(auxInputShape, 2);
225     }
226     const uint32_t fwNumUnits = getSizeOfDimension(fwWeightsShape, 0);
227     const uint32_t bwNumUnits = getSizeOfDimension(bwWeightsShape, 0);
228 
229     Shape fixedTimeInputShape = removeFirstDim(inputShape);
230     Shape fixedTimeAuxInputShape = auxInputShape;
231     if (hasAuxInput) {
232         fixedTimeAuxInputShape = removeFirstDim(auxInputShape);
233     }
234 
235     const T* bwInput = input;
236     if (linkingMode == LinkingMode::PARALLEL_LINKING) {
237         bwInput = auxInput;
238         auxInput = nullptr;
239     }
240 
241     const bool outputState = (context->getNumOutputs() == kNumOutputsWithState ||
242                               context->getNumOutputs() == kNumOutputsMergedWithState);
243     T* fwOutputHiddenState = nullptr;
244     T* bwOutputHiddenState = nullptr;
245     // Create an additional buffer to store a hidden state between steps.
246     std::vector<T> tempHiddenState;
247     if (outputState) {
248         const int delta = mergeOutputs ? 1 : 0;
249         fwOutputHiddenState = context->getOutputBuffer<T>(kFwOutputHiddenStateTensor - delta);
250         bwOutputHiddenState = context->getOutputBuffer<T>(kBwOutputHiddenStateTensor - delta);
251     } else {
252         tempHiddenState.resize(std::max(batchSize * fwNumUnits, batchSize * bwNumUnits));
253         fwOutputHiddenState = tempHiddenState.data();
254         bwOutputHiddenState = tempHiddenState.data();
255     }
256 
257     // Forward pass
258     for (int i = 0; i < maxTime; ++i) {
259         const T* inputBatchPtr = input + i * batchSize * inputSize;
260         const T* auxInputBatchPtr = nullptr;
261         if (hasAuxWeights) {
262             auxInputBatchPtr = auxInput + i * batchSize * auxInputSize;
263         }
264         const uint32_t fwOutputBatchStride = mergeOutputs ? (fwNumUnits + bwNumUnits) : fwNumUnits;
265         T* fwOutputBatchPtr = fwOutput + i * batchSize * fwOutputBatchStride;
266 
267         RNN::RNNStep<T>(inputBatchPtr, fixedTimeInputShape, auxInputBatchPtr,
268                         fixedTimeAuxInputShape, fwHiddenState, fwBias, fwWeights, fwWeightsShape,
269                         fwAuxWeights, fwAuxWeightsShape, fwRecurrentWeights,
270                         fwRecurrentWeightsShape, activation, fwOutputBatchStride,
271                         /*outputBatchOffset=*/0, fwOutputBatchPtr, fwOutputHiddenState);
272 
273         fwHiddenState = fwOutputHiddenState;
274     }
275 
276     // Backward pass
277     for (int i = maxTime - 1; i >= 0; --i) {
278         const T* inputBatchPtr = bwInput + i * batchSize * inputSize;
279         const T* auxInputBatchPtr = nullptr;
280         if (hasAuxWeights) {
281             auxInputBatchPtr = auxInput + i * batchSize * auxInputSize;
282         }
283         T* bwOutputBatchPtr;
284         uint32_t bwOutputBatchOffset = 0;
285         uint32_t bwOutputBatchStride;
286         if (mergeOutputs) {
287             bwOutputBatchStride = fwNumUnits + bwNumUnits;
288             bwOutputBatchOffset = fwNumUnits;
289             bwOutputBatchPtr = fwOutput + i * batchSize * bwOutputBatchStride;
290         } else {
291             bwOutputBatchStride = bwNumUnits;
292             bwOutputBatchPtr = bwOutput + i * batchSize * bwOutputBatchStride;
293         }
294 
295         RNN::RNNStep<T>(inputBatchPtr, fixedTimeInputShape, auxInputBatchPtr,
296                         fixedTimeAuxInputShape, bwHiddenState, bwBias, bwWeights, bwWeightsShape,
297                         bwAuxWeights, bwAuxWeightsShape, bwRecurrentWeights,
298                         bwRecurrentWeightsShape, activation, bwOutputBatchStride,
299                         bwOutputBatchOffset, bwOutputBatchPtr, bwOutputHiddenState);
300 
301         bwHiddenState = bwOutputHiddenState;
302     }
303 
304     // If the inputs were in batch major format, transpose data in temporary
305     // buffers and write to the output(s).
306     if (!timeMajor) {
307         transposeFirstTwoDims(fwOutputTransposed.data(), fwOutputShape,
308                               context->getOutputBuffer<T>(kFwOutputTensor));
309         if (!mergeOutputs) {
310             transposeFirstTwoDims(bwOutputTransposed.data(), bwOutputShape,
311                                   context->getOutputBuffer<T>(kBwOutputTensor));
312         }
313     }
314     return true;
315 }
316 
317 }  // namespace
318 
validate(const IOperationValidationContext * context)319 bool validate(const IOperationValidationContext* context) {
320     NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
321     // Exact number is dependent on the mergeOutputs parameter and checked
322     // during preparation.
323     const uint32_t numOutputs = context->getNumOutputs();
324     NN_RET_CHECK(numOutputs == kNumOutputs || numOutputs == kNumOutputsMerged ||
325                  numOutputs == kNumOutputsWithState || numOutputs == kNumOutputsMergedWithState);
326 
327     OperandType inputType = context->getInputType(kInputTensor);
328     if (inputType != OperandType::TENSOR_FLOAT16 && inputType != OperandType::TENSOR_FLOAT32) {
329         LOG(ERROR) << "Unsupported input operand type for UNIDIRECTIONAL_SEQUENCE_RNN op: "
330                    << toString(inputType);
331         return false;
332     }
333     NN_RET_CHECK(validateInputTypes(
334             context, {inputType, inputType, inputType, inputType, inputType, inputType, inputType,
335                       inputType, inputType, inputType, inputType, inputType, OperandType::INT32,
336                       OperandType::BOOL, OperandType::BOOL}));
337 
338     std::vector<OperandType> outExpectedTypes(numOutputs, inputType);
339     NN_RET_CHECK(validateOutputTypes(context, outExpectedTypes));
340 
341     HalVersion minSupportedHalVersion = HalVersion::V1_2;
342     if (numOutputs == kNumOutputsWithState || numOutputs == kNumOutputsMergedWithState) {
343         minSupportedHalVersion = HalVersion::V1_3;
344     }
345     return validateHalVersion(context, minSupportedHalVersion);
346 }
347 
prepare(IOperationExecutionContext * context)348 bool prepare(IOperationExecutionContext* context) {
349     const bool mergeOutputs = context->getInputValue<bool>(kMergeOutputsParam);
350     const int32_t numOutputs = context->getNumOutputs();
351     if (mergeOutputs) {
352         NN_RET_CHECK(numOutputs == kNumOutputsMerged || numOutputs == kNumOutputsMergedWithState);
353     } else {
354         NN_RET_CHECK(numOutputs == kNumOutputs || numOutputs == kNumOutputsWithState);
355     }
356 
357     // Check that none of the required inputs are omitted.
358     const std::vector<int> requiredInputs = {
359             kInputTensor,         kFwWeightsTensor, kFwRecurrentWeightsTensor, kFwBiasTensor,
360             kFwHiddenStateTensor, kBwWeightsTensor, kBwRecurrentWeightsTensor, kBwBiasTensor,
361             kBwHiddenStateTensor, kActivationParam, kTimeMajorParam,           kMergeOutputsParam,
362     };
363     for (const int requiredInput : requiredInputs) {
364         NN_RET_CHECK(!context->isOmittedInput(requiredInput))
365                 << "required input " << requiredInput << " is omitted";
366     }
367 
368     Shape input = context->getInputShape(kInputTensor);
369     Shape fwWeights = context->getInputShape(kFwWeightsTensor);
370     Shape fwRecurrentWeights = context->getInputShape(kFwRecurrentWeightsTensor);
371     Shape fwBias = context->getInputShape(kFwBiasTensor);
372     Shape fwHiddenState = context->getInputShape(kFwHiddenStateTensor);
373     Shape bwWeights = context->getInputShape(kBwWeightsTensor);
374     Shape bwRecurrentWeights = context->getInputShape(kBwRecurrentWeightsTensor);
375     Shape bwBias = context->getInputShape(kBwBiasTensor);
376     Shape bwHiddenState = context->getInputShape(kBwHiddenStateTensor);
377 
378     Shape auxInput = context->getInputShape(kAuxInputTensor);
379     Shape fwAuxWeights = context->getInputShape(kFwAuxWeightsTensor);
380     Shape bwAuxWeights = context->getInputShape(kBwAuxWeightsTensor);
381 
382     LinkingMode linkingMode;
383     NN_RET_CHECK(getLinkingMode(context, &linkingMode));
384 
385     bool timeMajor = context->getInputValue<bool>(kTimeMajorParam);
386     const uint32_t batchSize =
387             timeMajor ? getSizeOfDimension(input, 1) : getSizeOfDimension(input, 0);
388     const uint32_t maxTime =
389             timeMajor ? getSizeOfDimension(input, 0) : getSizeOfDimension(input, 1);
390     const uint32_t fwNumUnits = getSizeOfDimension(fwWeights, 0);
391     const uint32_t bwNumUnits = getSizeOfDimension(bwWeights, 0);
392     const uint32_t inputSize = getSizeOfDimension(input, 2);
393 
394     NN_RET_CHECK_EQ(getNumberOfDimensions(input), 3);
395     NN_RET_CHECK_EQ(getNumberOfDimensions(fwWeights), 2);
396     NN_RET_CHECK_EQ(getNumberOfDimensions(fwRecurrentWeights), 2);
397     NN_RET_CHECK_EQ(getNumberOfDimensions(fwBias), 1);
398     NN_RET_CHECK_EQ(getNumberOfDimensions(fwHiddenState), 2);
399     NN_RET_CHECK_EQ(getNumberOfDimensions(bwWeights), 2);
400     NN_RET_CHECK_EQ(getNumberOfDimensions(bwRecurrentWeights), 2);
401     NN_RET_CHECK_EQ(getNumberOfDimensions(bwBias), 1);
402     NN_RET_CHECK_EQ(getNumberOfDimensions(bwHiddenState), 2);
403 
404     NN_RET_CHECK_EQ(inputSize, getSizeOfDimension(fwWeights, 1));
405     NN_RET_CHECK_EQ(fwNumUnits, getSizeOfDimension(fwBias, 0));
406     NN_RET_CHECK_EQ(fwNumUnits, getSizeOfDimension(fwRecurrentWeights, 0));
407     NN_RET_CHECK_EQ(fwNumUnits, getSizeOfDimension(fwRecurrentWeights, 1));
408     NN_RET_CHECK_EQ(batchSize, getSizeOfDimension(fwHiddenState, 0));
409     NN_RET_CHECK_EQ(fwNumUnits, getSizeOfDimension(fwHiddenState, 1));
410 
411     if (linkingMode != LinkingMode::PARALLEL_LINKING) {
412         NN_RET_CHECK_EQ(inputSize, getSizeOfDimension(bwWeights, 1));
413     }
414     NN_RET_CHECK_EQ(bwNumUnits, getSizeOfDimension(bwBias, 0));
415     NN_RET_CHECK_EQ(bwNumUnits, getSizeOfDimension(bwRecurrentWeights, 0));
416     NN_RET_CHECK_EQ(bwNumUnits, getSizeOfDimension(bwRecurrentWeights, 1));
417     NN_RET_CHECK_EQ(batchSize, getSizeOfDimension(bwHiddenState, 0));
418     NN_RET_CHECK_EQ(bwNumUnits, getSizeOfDimension(bwHiddenState, 1));
419 
420     if (linkingMode == LinkingMode::CROSS_LINKING) {
421         NN_RET_CHECK_EQ(getNumberOfDimensions(auxInput), 3);
422         NN_RET_CHECK_EQ(getNumberOfDimensions(fwAuxWeights), 2);
423         NN_RET_CHECK_EQ(getNumberOfDimensions(bwAuxWeights), 2);
424 
425         NN_RET_CHECK_EQ(getSizeOfDimension(auxInput, 0), getSizeOfDimension(input, 0));
426         NN_RET_CHECK_EQ(getSizeOfDimension(auxInput, 1), getSizeOfDimension(input, 1));
427         NN_RET_CHECK_EQ(getSizeOfDimension(fwAuxWeights, 0), fwNumUnits);
428         NN_RET_CHECK_EQ(getSizeOfDimension(fwAuxWeights, 1), getSizeOfDimension(auxInput, 2));
429         NN_RET_CHECK_EQ(getSizeOfDimension(bwAuxWeights, 0), bwNumUnits);
430         NN_RET_CHECK_EQ(getSizeOfDimension(bwAuxWeights, 1), getSizeOfDimension(auxInput, 2));
431     } else if (linkingMode == LinkingMode::PARALLEL_LINKING) {
432         NN_RET_CHECK_EQ(getNumberOfDimensions(auxInput), 3);
433 
434         NN_RET_CHECK_EQ(getSizeOfDimension(auxInput, 0), getSizeOfDimension(input, 0));
435         NN_RET_CHECK_EQ(getSizeOfDimension(auxInput, 1), getSizeOfDimension(input, 1));
436         NN_RET_CHECK_EQ(getSizeOfDimension(auxInput, 2), getSizeOfDimension(bwWeights, 1));
437     }
438 
439     Shape fwOutput = context->getOutputShape(kFwOutputTensor);
440     fwOutput.dimensions.resize(3);
441     fwOutput.dimensions[0] = timeMajor ? maxTime : batchSize;
442     fwOutput.dimensions[1] = timeMajor ? batchSize : maxTime;
443     fwOutput.dimensions[2] = mergeOutputs ? fwNumUnits + bwNumUnits : fwNumUnits;
444     NN_RET_CHECK(context->setOutputShape(kFwOutputTensor, fwOutput));
445     if (!mergeOutputs) {
446         Shape bwOutput = context->getOutputShape(kBwOutputTensor);
447         bwOutput.dimensions.resize(3);
448         bwOutput.dimensions[0] = timeMajor ? maxTime : batchSize;
449         bwOutput.dimensions[1] = timeMajor ? batchSize : maxTime;
450         bwOutput.dimensions[2] = bwNumUnits;
451         NN_RET_CHECK(context->setOutputShape(kBwOutputTensor, bwOutput));
452     }
453 
454     const bool outputState =
455             (numOutputs == kNumOutputsWithState || numOutputs == kNumOutputsMergedWithState);
456     if (outputState) {
457         const int delta = mergeOutputs ? 1 : 0;
458         NN_RET_CHECK(context->setOutputShape(kFwOutputHiddenStateTensor - delta,
459                                              context->getInputShape(kFwHiddenStateTensor)));
460         NN_RET_CHECK(context->setOutputShape(kBwOutputHiddenStateTensor - delta,
461                                              context->getInputShape(kBwHiddenStateTensor)));
462     }
463 
464     return true;
465 }
466 
execute(IOperationExecutionContext * context)467 bool execute(IOperationExecutionContext* context) {
468     if (context->getInputType(kInputTensor) == OperandType::TENSOR_FLOAT16) {
469         executeTyped<_Float16>(context);
470     } else {
471         executeTyped<float>(context);
472     }
473     return true;
474 }
475 
476 }  // namespace bidirectional_sequence_rnn
477 
478 NN_REGISTER_OPERATION(BIDIRECTIONAL_SEQUENCE_RNN, "BIDIRECTIONAL_SEQUENCE_RNN",
479                       bidirectional_sequence_rnn::validate, bidirectional_sequence_rnn::prepare,
480                       bidirectional_sequence_rnn::execute, .allowOmittedOperand = true);
481 
482 }  // namespace nn
483 }  // namespace android
484