1 /*
2 * Copyright (C) 2019 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define LOG_TAG "Operations"
18
19 #include <algorithm>
20 #include <utility>
21 #include <vector>
22
23 #include "HalInterfaces.h"
24 #include "OperationResolver.h"
25 #include "RNN.h"
26
27 namespace android {
28 namespace nn {
29 namespace bidirectional_sequence_rnn {
30
31 constexpr uint32_t kNumInputs = 15;
32 constexpr uint32_t kInputTensor = 0;
33 // Forward cell tensors
34 constexpr uint32_t kFwWeightsTensor = 1;
35 constexpr uint32_t kFwRecurrentWeightsTensor = 2;
36 constexpr uint32_t kFwBiasTensor = 3;
37 constexpr uint32_t kFwHiddenStateTensor = 4;
38 // Backward cell tensors
39 constexpr uint32_t kBwWeightsTensor = 5;
40 constexpr uint32_t kBwRecurrentWeightsTensor = 6;
41 constexpr uint32_t kBwBiasTensor = 7;
42 constexpr uint32_t kBwHiddenStateTensor = 8;
43 // Auxiliary inputs
44 constexpr uint32_t kAuxInputTensor = 9; // optional
45 constexpr uint32_t kFwAuxWeightsTensor = 10; // optional
46 constexpr uint32_t kBwAuxWeightsTensor = 11; // optional
47 // Cell parameters
48 constexpr uint32_t kActivationParam = 12;
49 constexpr uint32_t kTimeMajorParam = 13;
50 constexpr uint32_t kMergeOutputsParam = 14;
51
52 constexpr uint32_t kNumOutputs = 2;
53 constexpr uint32_t kNumOutputsMerged = 1;
54 constexpr uint32_t kNumOutputsWithState = 4;
55 constexpr uint32_t kNumOutputsMergedWithState = 3;
56
57 constexpr uint32_t kFwOutputTensor = 0;
58 constexpr uint32_t kBwOutputTensor = 1; // Only if mergeOutputs parameter is false
59 constexpr uint32_t kFwOutputHiddenStateTensor = 2;
60 constexpr uint32_t kBwOutputHiddenStateTensor = 3;
61
62 namespace {
63
64 using namespace hal;
65
66 template <typename T>
transposeFirstTwoDims(const T * input,const Shape & inputShape,T * output)67 void transposeFirstTwoDims(const T* input, const Shape& inputShape, T* output) {
68 const uint32_t firstDimSize = getSizeOfDimension(inputShape, 0);
69 const uint32_t secondDimSize = getSizeOfDimension(inputShape, 1);
70 const uint32_t inputSize = getSizeOfDimension(inputShape, 2);
71 for (int f = 0; f < firstDimSize; ++f) {
72 for (int s = 0; s < secondDimSize; ++s) {
73 for (int i = 0; i < inputSize; ++i) {
74 const uint32_t inputIndex = f * secondDimSize * inputSize + s * inputSize + i;
75 const uint32_t outputIndex = s * firstDimSize * inputSize + f * inputSize + i;
76 output[outputIndex] = input[inputIndex];
77 }
78 }
79 }
80 }
81
removeFirstDim(const Shape & input)82 Shape removeFirstDim(const Shape& input) {
83 Shape output = input;
84 output.dimensions.resize(input.dimensions.size() - 1);
85 for (int i = 0; i < input.dimensions.size() - 1; ++i) {
86 output.dimensions[i] = input.dimensions[i + 1];
87 }
88 return output;
89 }
90
91 enum class LinkingMode {
92 NO_LINKING,
93 PARALLEL_LINKING,
94 CROSS_LINKING,
95 };
96
getLinkingMode(IOperationExecutionContext * context,LinkingMode * linkingMode)97 bool getLinkingMode(IOperationExecutionContext* context, LinkingMode* linkingMode) {
98 const bool hasAuxInput = !context->isOmittedInput(kAuxInputTensor);
99 const bool hasFwAuxWeights = !context->isOmittedInput(kFwAuxWeightsTensor);
100 const bool hasBwAuxWeights = !context->isOmittedInput(kBwAuxWeightsTensor);
101
102 // Three possible configurations for three possible linking modes:
103 // 1) NO_LINKING -- no auxiliary tensors at all
104 // 2) PARALLEL_LINKING -- auxiliary input is provided and used as a regular
105 // input to the backward network, so the auxiliary weights are omitted.
106 // 3) CROSS_LINKING -- auxiliary input is provided and multiplied by
107 // auxiliary weights.
108 if (!hasAuxInput && !hasFwAuxWeights && !hasBwAuxWeights) {
109 *linkingMode = LinkingMode::NO_LINKING;
110 } else if (hasAuxInput && !hasFwAuxWeights && !hasBwAuxWeights) {
111 *linkingMode = LinkingMode::PARALLEL_LINKING;
112 } else if (hasAuxInput && hasFwAuxWeights && hasBwAuxWeights) {
113 *linkingMode = LinkingMode::CROSS_LINKING;
114 } else {
115 NN_RET_CHECK_FAIL()
116 << "Unsupported auxiliary tensors configuration for BIDIRECTIONAL_SEQUENCE_RNN.";
117 }
118
119 return true;
120 }
121
122 template <typename T>
executeTyped(IOperationExecutionContext * context)123 bool executeTyped(IOperationExecutionContext* context) {
124 const T* input = context->getInputBuffer<T>(kInputTensor);
125 Shape inputShape = context->getInputShape(kInputTensor);
126
127 const T* fwWeights = context->getInputBuffer<T>(kFwWeightsTensor);
128 Shape fwWeightsShape = context->getInputShape(kFwWeightsTensor);
129 const T* fwRecurrentWeights = context->getInputBuffer<T>(kFwRecurrentWeightsTensor);
130 Shape fwRecurrentWeightsShape = context->getInputShape(kFwRecurrentWeightsTensor);
131 const T* fwBias = context->getInputBuffer<T>(kFwBiasTensor);
132 const T* fwHiddenState = context->getInputBuffer<T>(kFwHiddenStateTensor);
133
134 const T* bwWeights = context->getInputBuffer<T>(kBwWeightsTensor);
135 Shape bwWeightsShape = context->getInputShape(kBwWeightsTensor);
136 const T* bwRecurrentWeights = context->getInputBuffer<T>(kBwRecurrentWeightsTensor);
137 Shape bwRecurrentWeightsShape = context->getInputShape(kBwRecurrentWeightsTensor);
138 const T* bwBias = context->getInputBuffer<T>(kBwBiasTensor);
139 const T* bwHiddenState = context->getInputBuffer<T>(kBwHiddenStateTensor);
140
141 const T* auxInput = nullptr;
142 const T* fwAuxWeights = nullptr;
143 const T* bwAuxWeights = nullptr;
144 LinkingMode linkingMode;
145 NN_RET_CHECK(getLinkingMode(context, &linkingMode));
146 if (linkingMode == LinkingMode::CROSS_LINKING) {
147 auxInput = context->getInputBuffer<T>(kAuxInputTensor);
148 fwAuxWeights = context->getInputBuffer<T>(kFwAuxWeightsTensor);
149 bwAuxWeights = context->getInputBuffer<T>(kBwAuxWeightsTensor);
150 } else if (linkingMode == LinkingMode::PARALLEL_LINKING) {
151 auxInput = context->getInputBuffer<T>(kAuxInputTensor);
152 }
153 const bool hasAuxInput = (linkingMode == LinkingMode::CROSS_LINKING ||
154 linkingMode == LinkingMode::PARALLEL_LINKING);
155 const bool hasAuxWeights = (linkingMode == LinkingMode::CROSS_LINKING);
156 Shape auxInputShape = context->getInputShape(kAuxInputTensor);
157 Shape fwAuxWeightsShape = context->getInputShape(kFwAuxWeightsTensor);
158 Shape bwAuxWeightsShape = context->getInputShape(kBwAuxWeightsTensor);
159
160 const int32_t activation = context->getInputValue<int32_t>(kActivationParam);
161 const bool timeMajor = context->getInputValue<bool>(kTimeMajorParam);
162 const bool mergeOutputs = context->getInputValue<bool>(kMergeOutputsParam);
163
164 T* fwOutput = context->getOutputBuffer<T>(kFwOutputTensor);
165 Shape fwOutputShape = context->getOutputShape(kFwOutputTensor);
166 T* bwOutput = nullptr;
167 Shape bwOutputShape;
168 if (!mergeOutputs) {
169 bwOutputShape = context->getOutputShape(kBwOutputTensor);
170 bwOutput = context->getOutputBuffer<T>(kBwOutputTensor);
171 }
172
173 // If the input tensors are not in time major format, we transpose the first
174 // two dimensions, and set input and output pointers to temporary vectors
175 // which are transposed back after the RNN is applied.
176 std::vector<T> inputTransposed;
177 std::vector<T> auxInputTransposed;
178 std::vector<T> fwOutputTransposed;
179 std::vector<T> bwOutputTransposed;
180 if (!timeMajor) {
181 // First, resize temporary buffers to accommodate for transposed tensors.
182 inputTransposed.resize(getNumberOfElements(inputShape));
183 if (hasAuxInput) {
184 auxInputTransposed.resize(getNumberOfElements(auxInputShape));
185 }
186 fwOutputTransposed.resize(getNumberOfElements(fwOutputShape));
187 if (!mergeOutputs) {
188 bwOutputTransposed.resize(getNumberOfElements(bwOutputShape));
189 }
190
191 // Transpose the input tensors.
192 transposeFirstTwoDims(input, inputShape, inputTransposed.data());
193 if (hasAuxInput) {
194 transposeFirstTwoDims(auxInput, auxInputShape, auxInputTransposed.data());
195 }
196
197 // Change input and output pointers to the temporary buffers.
198 input = inputTransposed.data();
199 if (hasAuxInput) {
200 auxInput = auxInputTransposed.data();
201 }
202 fwOutput = fwOutputTransposed.data();
203 if (!mergeOutputs) {
204 bwOutput = bwOutputTransposed.data();
205 }
206
207 // Swap the first two dimensions in the Shapes to reflect the
208 // transposition.
209 std::swap(inputShape.dimensions[0], inputShape.dimensions[1]);
210 if (hasAuxInput) {
211 std::swap(auxInputShape.dimensions[0], auxInputShape.dimensions[1]);
212 }
213 std::swap(fwOutputShape.dimensions[0], fwOutputShape.dimensions[1]);
214 if (!mergeOutputs) {
215 std::swap(bwOutputShape.dimensions[0], bwOutputShape.dimensions[1]);
216 }
217 }
218
219 const uint32_t maxTime = getSizeOfDimension(inputShape, 0);
220 const uint32_t batchSize = getSizeOfDimension(inputShape, 1);
221 const uint32_t inputSize = getSizeOfDimension(inputShape, 2);
222 uint32_t auxInputSize = 0;
223 if (hasAuxInput) {
224 auxInputSize = getSizeOfDimension(auxInputShape, 2);
225 }
226 const uint32_t fwNumUnits = getSizeOfDimension(fwWeightsShape, 0);
227 const uint32_t bwNumUnits = getSizeOfDimension(bwWeightsShape, 0);
228
229 Shape fixedTimeInputShape = removeFirstDim(inputShape);
230 Shape fixedTimeAuxInputShape = auxInputShape;
231 if (hasAuxInput) {
232 fixedTimeAuxInputShape = removeFirstDim(auxInputShape);
233 }
234
235 const T* bwInput = input;
236 if (linkingMode == LinkingMode::PARALLEL_LINKING) {
237 bwInput = auxInput;
238 auxInput = nullptr;
239 }
240
241 const bool outputState = (context->getNumOutputs() == kNumOutputsWithState ||
242 context->getNumOutputs() == kNumOutputsMergedWithState);
243 T* fwOutputHiddenState = nullptr;
244 T* bwOutputHiddenState = nullptr;
245 // Create an additional buffer to store a hidden state between steps.
246 std::vector<T> tempHiddenState;
247 if (outputState) {
248 const int delta = mergeOutputs ? 1 : 0;
249 fwOutputHiddenState = context->getOutputBuffer<T>(kFwOutputHiddenStateTensor - delta);
250 bwOutputHiddenState = context->getOutputBuffer<T>(kBwOutputHiddenStateTensor - delta);
251 } else {
252 tempHiddenState.resize(std::max(batchSize * fwNumUnits, batchSize * bwNumUnits));
253 fwOutputHiddenState = tempHiddenState.data();
254 bwOutputHiddenState = tempHiddenState.data();
255 }
256
257 // Forward pass
258 for (int i = 0; i < maxTime; ++i) {
259 const T* inputBatchPtr = input + i * batchSize * inputSize;
260 const T* auxInputBatchPtr = nullptr;
261 if (hasAuxWeights) {
262 auxInputBatchPtr = auxInput + i * batchSize * auxInputSize;
263 }
264 const uint32_t fwOutputBatchStride = mergeOutputs ? (fwNumUnits + bwNumUnits) : fwNumUnits;
265 T* fwOutputBatchPtr = fwOutput + i * batchSize * fwOutputBatchStride;
266
267 RNN::RNNStep<T>(inputBatchPtr, fixedTimeInputShape, auxInputBatchPtr,
268 fixedTimeAuxInputShape, fwHiddenState, fwBias, fwWeights, fwWeightsShape,
269 fwAuxWeights, fwAuxWeightsShape, fwRecurrentWeights,
270 fwRecurrentWeightsShape, activation, fwOutputBatchStride,
271 /*outputBatchOffset=*/0, fwOutputBatchPtr, fwOutputHiddenState);
272
273 fwHiddenState = fwOutputHiddenState;
274 }
275
276 // Backward pass
277 for (int i = maxTime - 1; i >= 0; --i) {
278 const T* inputBatchPtr = bwInput + i * batchSize * inputSize;
279 const T* auxInputBatchPtr = nullptr;
280 if (hasAuxWeights) {
281 auxInputBatchPtr = auxInput + i * batchSize * auxInputSize;
282 }
283 T* bwOutputBatchPtr;
284 uint32_t bwOutputBatchOffset = 0;
285 uint32_t bwOutputBatchStride;
286 if (mergeOutputs) {
287 bwOutputBatchStride = fwNumUnits + bwNumUnits;
288 bwOutputBatchOffset = fwNumUnits;
289 bwOutputBatchPtr = fwOutput + i * batchSize * bwOutputBatchStride;
290 } else {
291 bwOutputBatchStride = bwNumUnits;
292 bwOutputBatchPtr = bwOutput + i * batchSize * bwOutputBatchStride;
293 }
294
295 RNN::RNNStep<T>(inputBatchPtr, fixedTimeInputShape, auxInputBatchPtr,
296 fixedTimeAuxInputShape, bwHiddenState, bwBias, bwWeights, bwWeightsShape,
297 bwAuxWeights, bwAuxWeightsShape, bwRecurrentWeights,
298 bwRecurrentWeightsShape, activation, bwOutputBatchStride,
299 bwOutputBatchOffset, bwOutputBatchPtr, bwOutputHiddenState);
300
301 bwHiddenState = bwOutputHiddenState;
302 }
303
304 // If the inputs were in batch major format, transpose data in temporary
305 // buffers and write to the output(s).
306 if (!timeMajor) {
307 transposeFirstTwoDims(fwOutputTransposed.data(), fwOutputShape,
308 context->getOutputBuffer<T>(kFwOutputTensor));
309 if (!mergeOutputs) {
310 transposeFirstTwoDims(bwOutputTransposed.data(), bwOutputShape,
311 context->getOutputBuffer<T>(kBwOutputTensor));
312 }
313 }
314 return true;
315 }
316
317 } // namespace
318
validate(const IOperationValidationContext * context)319 bool validate(const IOperationValidationContext* context) {
320 NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
321 // Exact number is dependent on the mergeOutputs parameter and checked
322 // during preparation.
323 const uint32_t numOutputs = context->getNumOutputs();
324 NN_RET_CHECK(numOutputs == kNumOutputs || numOutputs == kNumOutputsMerged ||
325 numOutputs == kNumOutputsWithState || numOutputs == kNumOutputsMergedWithState);
326
327 OperandType inputType = context->getInputType(kInputTensor);
328 if (inputType != OperandType::TENSOR_FLOAT16 && inputType != OperandType::TENSOR_FLOAT32) {
329 LOG(ERROR) << "Unsupported input operand type for UNIDIRECTIONAL_SEQUENCE_RNN op: "
330 << toString(inputType);
331 return false;
332 }
333 NN_RET_CHECK(validateInputTypes(
334 context, {inputType, inputType, inputType, inputType, inputType, inputType, inputType,
335 inputType, inputType, inputType, inputType, inputType, OperandType::INT32,
336 OperandType::BOOL, OperandType::BOOL}));
337
338 std::vector<OperandType> outExpectedTypes(numOutputs, inputType);
339 NN_RET_CHECK(validateOutputTypes(context, outExpectedTypes));
340
341 HalVersion minSupportedHalVersion = HalVersion::V1_2;
342 if (numOutputs == kNumOutputsWithState || numOutputs == kNumOutputsMergedWithState) {
343 minSupportedHalVersion = HalVersion::V1_3;
344 }
345 return validateHalVersion(context, minSupportedHalVersion);
346 }
347
prepare(IOperationExecutionContext * context)348 bool prepare(IOperationExecutionContext* context) {
349 const bool mergeOutputs = context->getInputValue<bool>(kMergeOutputsParam);
350 const int32_t numOutputs = context->getNumOutputs();
351 if (mergeOutputs) {
352 NN_RET_CHECK(numOutputs == kNumOutputsMerged || numOutputs == kNumOutputsMergedWithState);
353 } else {
354 NN_RET_CHECK(numOutputs == kNumOutputs || numOutputs == kNumOutputsWithState);
355 }
356
357 // Check that none of the required inputs are omitted.
358 const std::vector<int> requiredInputs = {
359 kInputTensor, kFwWeightsTensor, kFwRecurrentWeightsTensor, kFwBiasTensor,
360 kFwHiddenStateTensor, kBwWeightsTensor, kBwRecurrentWeightsTensor, kBwBiasTensor,
361 kBwHiddenStateTensor, kActivationParam, kTimeMajorParam, kMergeOutputsParam,
362 };
363 for (const int requiredInput : requiredInputs) {
364 NN_RET_CHECK(!context->isOmittedInput(requiredInput))
365 << "required input " << requiredInput << " is omitted";
366 }
367
368 Shape input = context->getInputShape(kInputTensor);
369 Shape fwWeights = context->getInputShape(kFwWeightsTensor);
370 Shape fwRecurrentWeights = context->getInputShape(kFwRecurrentWeightsTensor);
371 Shape fwBias = context->getInputShape(kFwBiasTensor);
372 Shape fwHiddenState = context->getInputShape(kFwHiddenStateTensor);
373 Shape bwWeights = context->getInputShape(kBwWeightsTensor);
374 Shape bwRecurrentWeights = context->getInputShape(kBwRecurrentWeightsTensor);
375 Shape bwBias = context->getInputShape(kBwBiasTensor);
376 Shape bwHiddenState = context->getInputShape(kBwHiddenStateTensor);
377
378 Shape auxInput = context->getInputShape(kAuxInputTensor);
379 Shape fwAuxWeights = context->getInputShape(kFwAuxWeightsTensor);
380 Shape bwAuxWeights = context->getInputShape(kBwAuxWeightsTensor);
381
382 LinkingMode linkingMode;
383 NN_RET_CHECK(getLinkingMode(context, &linkingMode));
384
385 bool timeMajor = context->getInputValue<bool>(kTimeMajorParam);
386 const uint32_t batchSize =
387 timeMajor ? getSizeOfDimension(input, 1) : getSizeOfDimension(input, 0);
388 const uint32_t maxTime =
389 timeMajor ? getSizeOfDimension(input, 0) : getSizeOfDimension(input, 1);
390 const uint32_t fwNumUnits = getSizeOfDimension(fwWeights, 0);
391 const uint32_t bwNumUnits = getSizeOfDimension(bwWeights, 0);
392 const uint32_t inputSize = getSizeOfDimension(input, 2);
393
394 NN_RET_CHECK_EQ(getNumberOfDimensions(input), 3);
395 NN_RET_CHECK_EQ(getNumberOfDimensions(fwWeights), 2);
396 NN_RET_CHECK_EQ(getNumberOfDimensions(fwRecurrentWeights), 2);
397 NN_RET_CHECK_EQ(getNumberOfDimensions(fwBias), 1);
398 NN_RET_CHECK_EQ(getNumberOfDimensions(fwHiddenState), 2);
399 NN_RET_CHECK_EQ(getNumberOfDimensions(bwWeights), 2);
400 NN_RET_CHECK_EQ(getNumberOfDimensions(bwRecurrentWeights), 2);
401 NN_RET_CHECK_EQ(getNumberOfDimensions(bwBias), 1);
402 NN_RET_CHECK_EQ(getNumberOfDimensions(bwHiddenState), 2);
403
404 NN_RET_CHECK_EQ(inputSize, getSizeOfDimension(fwWeights, 1));
405 NN_RET_CHECK_EQ(fwNumUnits, getSizeOfDimension(fwBias, 0));
406 NN_RET_CHECK_EQ(fwNumUnits, getSizeOfDimension(fwRecurrentWeights, 0));
407 NN_RET_CHECK_EQ(fwNumUnits, getSizeOfDimension(fwRecurrentWeights, 1));
408 NN_RET_CHECK_EQ(batchSize, getSizeOfDimension(fwHiddenState, 0));
409 NN_RET_CHECK_EQ(fwNumUnits, getSizeOfDimension(fwHiddenState, 1));
410
411 if (linkingMode != LinkingMode::PARALLEL_LINKING) {
412 NN_RET_CHECK_EQ(inputSize, getSizeOfDimension(bwWeights, 1));
413 }
414 NN_RET_CHECK_EQ(bwNumUnits, getSizeOfDimension(bwBias, 0));
415 NN_RET_CHECK_EQ(bwNumUnits, getSizeOfDimension(bwRecurrentWeights, 0));
416 NN_RET_CHECK_EQ(bwNumUnits, getSizeOfDimension(bwRecurrentWeights, 1));
417 NN_RET_CHECK_EQ(batchSize, getSizeOfDimension(bwHiddenState, 0));
418 NN_RET_CHECK_EQ(bwNumUnits, getSizeOfDimension(bwHiddenState, 1));
419
420 if (linkingMode == LinkingMode::CROSS_LINKING) {
421 NN_RET_CHECK_EQ(getNumberOfDimensions(auxInput), 3);
422 NN_RET_CHECK_EQ(getNumberOfDimensions(fwAuxWeights), 2);
423 NN_RET_CHECK_EQ(getNumberOfDimensions(bwAuxWeights), 2);
424
425 NN_RET_CHECK_EQ(getSizeOfDimension(auxInput, 0), getSizeOfDimension(input, 0));
426 NN_RET_CHECK_EQ(getSizeOfDimension(auxInput, 1), getSizeOfDimension(input, 1));
427 NN_RET_CHECK_EQ(getSizeOfDimension(fwAuxWeights, 0), fwNumUnits);
428 NN_RET_CHECK_EQ(getSizeOfDimension(fwAuxWeights, 1), getSizeOfDimension(auxInput, 2));
429 NN_RET_CHECK_EQ(getSizeOfDimension(bwAuxWeights, 0), bwNumUnits);
430 NN_RET_CHECK_EQ(getSizeOfDimension(bwAuxWeights, 1), getSizeOfDimension(auxInput, 2));
431 } else if (linkingMode == LinkingMode::PARALLEL_LINKING) {
432 NN_RET_CHECK_EQ(getNumberOfDimensions(auxInput), 3);
433
434 NN_RET_CHECK_EQ(getSizeOfDimension(auxInput, 0), getSizeOfDimension(input, 0));
435 NN_RET_CHECK_EQ(getSizeOfDimension(auxInput, 1), getSizeOfDimension(input, 1));
436 NN_RET_CHECK_EQ(getSizeOfDimension(auxInput, 2), getSizeOfDimension(bwWeights, 1));
437 }
438
439 Shape fwOutput = context->getOutputShape(kFwOutputTensor);
440 fwOutput.dimensions.resize(3);
441 fwOutput.dimensions[0] = timeMajor ? maxTime : batchSize;
442 fwOutput.dimensions[1] = timeMajor ? batchSize : maxTime;
443 fwOutput.dimensions[2] = mergeOutputs ? fwNumUnits + bwNumUnits : fwNumUnits;
444 NN_RET_CHECK(context->setOutputShape(kFwOutputTensor, fwOutput));
445 if (!mergeOutputs) {
446 Shape bwOutput = context->getOutputShape(kBwOutputTensor);
447 bwOutput.dimensions.resize(3);
448 bwOutput.dimensions[0] = timeMajor ? maxTime : batchSize;
449 bwOutput.dimensions[1] = timeMajor ? batchSize : maxTime;
450 bwOutput.dimensions[2] = bwNumUnits;
451 NN_RET_CHECK(context->setOutputShape(kBwOutputTensor, bwOutput));
452 }
453
454 const bool outputState =
455 (numOutputs == kNumOutputsWithState || numOutputs == kNumOutputsMergedWithState);
456 if (outputState) {
457 const int delta = mergeOutputs ? 1 : 0;
458 NN_RET_CHECK(context->setOutputShape(kFwOutputHiddenStateTensor - delta,
459 context->getInputShape(kFwHiddenStateTensor)));
460 NN_RET_CHECK(context->setOutputShape(kBwOutputHiddenStateTensor - delta,
461 context->getInputShape(kBwHiddenStateTensor)));
462 }
463
464 return true;
465 }
466
execute(IOperationExecutionContext * context)467 bool execute(IOperationExecutionContext* context) {
468 if (context->getInputType(kInputTensor) == OperandType::TENSOR_FLOAT16) {
469 executeTyped<_Float16>(context);
470 } else {
471 executeTyped<float>(context);
472 }
473 return true;
474 }
475
476 } // namespace bidirectional_sequence_rnn
477
478 NN_REGISTER_OPERATION(BIDIRECTIONAL_SEQUENCE_RNN, "BIDIRECTIONAL_SEQUENCE_RNN",
479 bidirectional_sequence_rnn::validate, bidirectional_sequence_rnn::prepare,
480 bidirectional_sequence_rnn::execute, .allowOmittedOperand = true);
481
482 } // namespace nn
483 } // namespace android
484