1 /*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "LSTM.h"
17
18 #include "NeuralNetworksWrapper.h"
19
20 #include <android-base/logging.h>
21 #include <gmock/gmock-matchers.h>
22 #include <gtest/gtest.h>
23 #include <sstream>
24 #include <string>
25 #include <vector>
26
27 namespace android {
28 namespace nn {
29 namespace wrapper {
30
31 using ::testing::Each;
32 using ::testing::FloatNear;
33 using ::testing::Matcher;
34
35 namespace {
36
ArrayFloatNear(const std::vector<float> & values,float max_abs_error=1.e-6)37 std::vector<Matcher<float>> ArrayFloatNear(const std::vector<float>& values,
38 float max_abs_error = 1.e-6) {
39 std::vector<Matcher<float>> matchers;
40 matchers.reserve(values.size());
41 for (const float& v : values) {
42 matchers.emplace_back(FloatNear(v, max_abs_error));
43 }
44 return matchers;
45 }
46
47 } // anonymous namespace
48
49 #define FOR_ALL_INPUT_AND_WEIGHT_TENSORS(ACTION) \
50 ACTION(Input) \
51 ACTION(InputToInputWeights) \
52 ACTION(InputToCellWeights) \
53 ACTION(InputToForgetWeights) \
54 ACTION(InputToOutputWeights) \
55 ACTION(RecurrentToInputWeights) \
56 ACTION(RecurrentToCellWeights) \
57 ACTION(RecurrentToForgetWeights) \
58 ACTION(RecurrentToOutputWeights) \
59 ACTION(CellToInputWeights) \
60 ACTION(CellToForgetWeights) \
61 ACTION(CellToOutputWeights) \
62 ACTION(InputGateBias) \
63 ACTION(CellGateBias) \
64 ACTION(ForgetGateBias) \
65 ACTION(OutputGateBias) \
66 ACTION(ProjectionWeights) \
67 ACTION(ProjectionBias) \
68 ACTION(OutputStateIn) \
69 ACTION(CellStateIn)
70
71 #define FOR_ALL_LAYER_NORM_WEIGHTS(ACTION) \
72 ACTION(InputLayerNormWeights) \
73 ACTION(ForgetLayerNormWeights) \
74 ACTION(CellLayerNormWeights) \
75 ACTION(OutputLayerNormWeights)
76
77 // For all output and intermediate states
78 #define FOR_ALL_OUTPUT_TENSORS(ACTION) \
79 ACTION(ScratchBuffer) \
80 ACTION(OutputStateOut) \
81 ACTION(CellStateOut) \
82 ACTION(Output)
83
84 class LayerNormLSTMOpModel {
85 public:
LayerNormLSTMOpModel(uint32_t n_batch,uint32_t n_input,uint32_t n_cell,uint32_t n_output,bool use_cifg,bool use_peephole,bool use_projection_weights,bool use_projection_bias,float cell_clip,float proj_clip,const std::vector<std::vector<uint32_t>> & input_shapes0)86 LayerNormLSTMOpModel(uint32_t n_batch, uint32_t n_input, uint32_t n_cell, uint32_t n_output,
87 bool use_cifg, bool use_peephole, bool use_projection_weights,
88 bool use_projection_bias, float cell_clip, float proj_clip,
89 const std::vector<std::vector<uint32_t>>& input_shapes0)
90 : n_input_(n_input),
91 n_output_(n_output),
92 use_cifg_(use_cifg),
93 use_peephole_(use_peephole),
94 use_projection_weights_(use_projection_weights),
95 use_projection_bias_(use_projection_bias),
96 activation_(ActivationFn::kActivationTanh),
97 cell_clip_(cell_clip),
98 proj_clip_(proj_clip) {
99 std::vector<uint32_t> inputs;
100 std::vector<std::vector<uint32_t>> input_shapes(input_shapes0);
101
102 auto it = input_shapes.begin();
103
104 // Input and weights
105 #define AddInput(X) \
106 CHECK(it != input_shapes.end()); \
107 OperandType X##OpndTy(Type::TENSOR_FLOAT32, *it++); \
108 inputs.push_back(model_.addOperand(&X##OpndTy));
109
110 FOR_ALL_INPUT_AND_WEIGHT_TENSORS(AddInput);
111
112 // Parameters
113 OperandType ActivationOpndTy(Type::INT32, {});
114 inputs.push_back(model_.addOperand(&ActivationOpndTy));
115 OperandType CellClipOpndTy(Type::FLOAT32, {});
116 inputs.push_back(model_.addOperand(&CellClipOpndTy));
117 OperandType ProjClipOpndTy(Type::FLOAT32, {});
118 inputs.push_back(model_.addOperand(&ProjClipOpndTy));
119
120 FOR_ALL_LAYER_NORM_WEIGHTS(AddInput);
121
122 #undef AddOperand
123
124 // Output and other intermediate state
125 std::vector<std::vector<uint32_t>> output_shapes{
126 {n_batch, n_cell * (use_cifg ? 3 : 4)},
127 {n_batch, n_output},
128 {n_batch, n_cell},
129 {n_batch, n_output},
130 };
131 std::vector<uint32_t> outputs;
132
133 auto it2 = output_shapes.begin();
134
135 #define AddOutput(X) \
136 CHECK(it2 != output_shapes.end()); \
137 OperandType X##OpndTy(Type::TENSOR_FLOAT32, *it2++); \
138 outputs.push_back(model_.addOperand(&X##OpndTy));
139
140 FOR_ALL_OUTPUT_TENSORS(AddOutput);
141
142 #undef AddOutput
143
144 model_.addOperation(ANEURALNETWORKS_LSTM, inputs, outputs);
145 model_.identifyInputsAndOutputs(inputs, outputs);
146
147 Input_.insert(Input_.end(), n_batch * n_input, 0.f);
148 OutputStateIn_.insert(OutputStateIn_.end(), n_batch * n_output, 0.f);
149 CellStateIn_.insert(CellStateIn_.end(), n_batch * n_cell, 0.f);
150
151 auto multiAll = [](const std::vector<uint32_t>& dims) -> uint32_t {
152 uint32_t sz = 1;
153 for (uint32_t d : dims) {
154 sz *= d;
155 }
156 return sz;
157 };
158
159 it2 = output_shapes.begin();
160
161 #define ReserveOutput(X) X##_.insert(X##_.end(), multiAll(*it2++), 0.f);
162
163 FOR_ALL_OUTPUT_TENSORS(ReserveOutput);
164
165 #undef ReserveOutput
166
167 model_.finish();
168 }
169
170 #define DefineSetter(X) \
171 void Set##X(const std::vector<float>& f) { X##_.insert(X##_.end(), f.begin(), f.end()); }
172
173 FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineSetter);
174 FOR_ALL_LAYER_NORM_WEIGHTS(DefineSetter);
175
176 #undef DefineSetter
177
ResetOutputState()178 void ResetOutputState() {
179 std::fill(OutputStateIn_.begin(), OutputStateIn_.end(), 0.f);
180 std::fill(OutputStateOut_.begin(), OutputStateOut_.end(), 0.f);
181 }
182
ResetCellState()183 void ResetCellState() {
184 std::fill(CellStateIn_.begin(), CellStateIn_.end(), 0.f);
185 std::fill(CellStateOut_.begin(), CellStateOut_.end(), 0.f);
186 }
187
SetInput(int offset,const float * begin,const float * end)188 void SetInput(int offset, const float* begin, const float* end) {
189 for (; begin != end; begin++, offset++) {
190 Input_[offset] = *begin;
191 }
192 }
193
num_inputs() const194 uint32_t num_inputs() const { return n_input_; }
num_outputs() const195 uint32_t num_outputs() const { return n_output_; }
196
GetOutput() const197 const std::vector<float>& GetOutput() const { return Output_; }
198
Invoke()199 void Invoke() {
200 ASSERT_TRUE(model_.isValid());
201
202 OutputStateIn_.swap(OutputStateOut_);
203 CellStateIn_.swap(CellStateOut_);
204
205 Compilation compilation(&model_);
206 compilation.finish();
207 Execution execution(&compilation);
208 #define SetInputOrWeight(X) \
209 ASSERT_EQ( \
210 execution.setInput(LSTMCell::k##X##Tensor, X##_.data(), sizeof(float) * X##_.size()), \
211 Result::NO_ERROR);
212
213 FOR_ALL_INPUT_AND_WEIGHT_TENSORS(SetInputOrWeight);
214 FOR_ALL_LAYER_NORM_WEIGHTS(SetInputOrWeight);
215
216 #undef SetInputOrWeight
217
218 #define SetOutput(X) \
219 ASSERT_EQ( \
220 execution.setOutput(LSTMCell::k##X##Tensor, X##_.data(), sizeof(float) * X##_.size()), \
221 Result::NO_ERROR);
222
223 FOR_ALL_OUTPUT_TENSORS(SetOutput);
224
225 #undef SetOutput
226
227 if (use_cifg_) {
228 execution.setInput(LSTMCell::kInputToInputWeightsTensor, nullptr, 0);
229 execution.setInput(LSTMCell::kRecurrentToInputWeightsTensor, nullptr, 0);
230 }
231
232 if (use_peephole_) {
233 if (use_cifg_) {
234 execution.setInput(LSTMCell::kCellToInputWeightsTensor, nullptr, 0);
235 }
236 } else {
237 execution.setInput(LSTMCell::kCellToInputWeightsTensor, nullptr, 0);
238 execution.setInput(LSTMCell::kCellToForgetWeightsTensor, nullptr, 0);
239 execution.setInput(LSTMCell::kCellToOutputWeightsTensor, nullptr, 0);
240 }
241
242 if (use_projection_weights_) {
243 if (!use_projection_bias_) {
244 execution.setInput(LSTMCell::kProjectionBiasTensor, nullptr, 0);
245 }
246 } else {
247 execution.setInput(LSTMCell::kProjectionWeightsTensor, nullptr, 0);
248 execution.setInput(LSTMCell::kProjectionBiasTensor, nullptr, 0);
249 }
250
251 ASSERT_EQ(execution.setInput(LSTMCell::kActivationParam, &activation_, sizeof(activation_)),
252 Result::NO_ERROR);
253 ASSERT_EQ(execution.setInput(LSTMCell::kCellClipParam, &cell_clip_, sizeof(cell_clip_)),
254 Result::NO_ERROR);
255 ASSERT_EQ(execution.setInput(LSTMCell::kProjClipParam, &proj_clip_, sizeof(proj_clip_)),
256 Result::NO_ERROR);
257
258 ASSERT_EQ(execution.compute(), Result::NO_ERROR);
259 }
260
261 private:
262 Model model_;
263 // Execution execution_;
264 const uint32_t n_input_;
265 const uint32_t n_output_;
266
267 const bool use_cifg_;
268 const bool use_peephole_;
269 const bool use_projection_weights_;
270 const bool use_projection_bias_;
271
272 const int activation_;
273 const float cell_clip_;
274 const float proj_clip_;
275
276 #define DefineTensor(X) std::vector<float> X##_;
277
278 FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineTensor);
279 FOR_ALL_LAYER_NORM_WEIGHTS(DefineTensor);
280 FOR_ALL_OUTPUT_TENSORS(DefineTensor);
281
282 #undef DefineTensor
283 };
284
TEST(LSTMOpTest,LayerNormNoCifgPeepholeProjectionNoClipping)285 TEST(LSTMOpTest, LayerNormNoCifgPeepholeProjectionNoClipping) {
286 const int n_batch = 2;
287 const int n_input = 5;
288 // n_cell and n_output have the same size when there is no projection.
289 const int n_cell = 4;
290 const int n_output = 3;
291
292 LayerNormLSTMOpModel lstm(n_batch, n_input, n_cell, n_output,
293 /*use_cifg=*/false, /*use_peephole=*/true,
294 /*use_projection_weights=*/true,
295 /*use_projection_bias=*/false,
296 /*cell_clip=*/0.0, /*proj_clip=*/0.0,
297 {
298 {n_batch, n_input}, // input tensor
299
300 {n_cell, n_input}, // input_to_input_weight tensor
301 {n_cell, n_input}, // input_to_forget_weight tensor
302 {n_cell, n_input}, // input_to_cell_weight tensor
303 {n_cell, n_input}, // input_to_output_weight tensor
304
305 {n_cell, n_output}, // recurrent_to_input_weight tensor
306 {n_cell, n_output}, // recurrent_to_forget_weight tensor
307 {n_cell, n_output}, // recurrent_to_cell_weight tensor
308 {n_cell, n_output}, // recurrent_to_output_weight tensor
309
310 {n_cell}, // cell_to_input_weight tensor
311 {n_cell}, // cell_to_forget_weight tensor
312 {n_cell}, // cell_to_output_weight tensor
313
314 {n_cell}, // input_gate_bias tensor
315 {n_cell}, // forget_gate_bias tensor
316 {n_cell}, // cell_bias tensor
317 {n_cell}, // output_gate_bias tensor
318
319 {n_output, n_cell}, // projection_weight tensor
320 {0}, // projection_bias tensor
321
322 {n_batch, n_output}, // output_state_in tensor
323 {n_batch, n_cell}, // cell_state_in tensor
324
325 {n_cell}, // input_layer_norm_weights tensor
326 {n_cell}, // forget_layer_norm_weights tensor
327 {n_cell}, // cell_layer_norm_weights tensor
328 {n_cell}, // output_layer_norm_weights tensor
329 });
330
331 lstm.SetInputToInputWeights({0.5, 0.6, 0.7, -0.8, -0.9, 0.1, 0.2, 0.3, -0.4, 0.5,
332 -0.8, 0.7, -0.6, 0.5, -0.4, -0.5, -0.4, -0.3, -0.2, -0.1});
333
334 lstm.SetInputToForgetWeights({-0.6, -0.1, 0.3, 0.2, 0.9, -0.5, -0.2, -0.4, 0.3, -0.8,
335 -0.4, 0.3, -0.5, -0.4, -0.6, 0.3, -0.4, -0.6, -0.5, -0.5});
336
337 lstm.SetInputToCellWeights({-0.4, -0.3, -0.2, -0.1, -0.5, 0.5, -0.2, -0.3, -0.2, -0.6,
338 0.6, -0.1, -0.4, -0.3, -0.7, 0.7, -0.9, -0.5, 0.8, 0.6});
339
340 lstm.SetInputToOutputWeights({-0.8, -0.4, -0.2, -0.9, -0.1, -0.7, 0.3, -0.3, -0.8, -0.2,
341 0.6, -0.2, 0.4, -0.7, -0.3, -0.5, 0.1, 0.5, -0.6, -0.4});
342
343 lstm.SetInputGateBias({0.03, 0.15, 0.22, 0.38});
344
345 lstm.SetForgetGateBias({0.1, -0.3, -0.2, 0.1});
346
347 lstm.SetCellGateBias({-0.05, 0.72, 0.25, 0.08});
348
349 lstm.SetOutputGateBias({0.05, -0.01, 0.2, 0.1});
350
351 lstm.SetRecurrentToInputWeights(
352 {-0.2, -0.3, 0.4, 0.1, -0.5, 0.9, -0.2, -0.3, -0.7, 0.05, -0.2, -0.6});
353
354 lstm.SetRecurrentToCellWeights(
355 {-0.3, 0.2, 0.1, -0.3, 0.8, -0.08, -0.2, 0.3, 0.8, -0.6, -0.1, 0.2});
356
357 lstm.SetRecurrentToForgetWeights(
358 {-0.5, -0.3, -0.5, -0.2, 0.6, 0.4, 0.9, 0.3, -0.1, 0.2, 0.5, 0.2});
359
360 lstm.SetRecurrentToOutputWeights(
361 {0.3, -0.1, 0.1, -0.2, -0.5, -0.7, -0.2, -0.6, -0.1, -0.4, -0.7, -0.2});
362
363 lstm.SetCellToInputWeights({0.05, 0.1, 0.25, 0.15});
364 lstm.SetCellToForgetWeights({-0.02, -0.15, -0.25, -0.03});
365 lstm.SetCellToOutputWeights({0.1, -0.1, -0.5, 0.05});
366
367 lstm.SetProjectionWeights({-0.1, 0.2, 0.01, -0.2, 0.1, 0.5, 0.3, 0.08, 0.07, 0.2, -0.4, 0.2});
368
369 lstm.SetInputLayerNormWeights({0.1, 0.2, 0.3, 0.5});
370 lstm.SetForgetLayerNormWeights({0.2, 0.2, 0.4, 0.3});
371 lstm.SetCellLayerNormWeights({0.7, 0.2, 0.3, 0.8});
372 lstm.SetOutputLayerNormWeights({0.6, 0.2, 0.2, 0.5});
373
374 const std::vector<std::vector<float>> lstm_input = {
375 { // Batch0: 3 (input_sequence_size) * 5 (n_input)
376 0.7, 0.8, 0.1, 0.2, 0.3, // seq 0
377 0.8, 0.1, 0.2, 0.4, 0.5, // seq 1
378 0.2, 0.7, 0.7, 0.1, 0.7}, // seq 2
379
380 { // Batch1: 3 (input_sequence_size) * 5 (n_input)
381 0.3, 0.2, 0.9, 0.8, 0.1, // seq 0
382 0.1, 0.5, 0.2, 0.4, 0.2, // seq 1
383 0.6, 0.9, 0.2, 0.5, 0.7}, // seq 2
384 };
385
386 const std::vector<std::vector<float>> lstm_golden_output = {
387 {
388 // Batch0: 3 (input_sequence_size) * 3 (n_output)
389 0.0244077, 0.128027, -0.00170918, // seq 0
390 0.0137642, 0.140751, 0.0395835, // seq 1
391 -0.00459231, 0.155278, 0.0837377, // seq 2
392 },
393 {
394 // Batch1: 3 (input_sequence_size) * 3 (n_output)
395 -0.00692428, 0.0848741, 0.063445, // seq 0
396 -0.00403912, 0.139963, 0.072681, // seq 1
397 0.00752706, 0.161903, 0.0561371, // seq 2
398 }};
399
400 // Resetting cell_state and output_state
401 lstm.ResetCellState();
402 lstm.ResetOutputState();
403
404 const int input_sequence_size = lstm_input[0].size() / n_input;
405 for (int i = 0; i < input_sequence_size; i++) {
406 for (int b = 0; b < n_batch; ++b) {
407 const float* batch_start = lstm_input[b].data() + i * n_input;
408 const float* batch_end = batch_start + n_input;
409
410 lstm.SetInput(b * n_input, batch_start, batch_end);
411 }
412
413 lstm.Invoke();
414
415 std::vector<float> expected;
416 for (int b = 0; b < n_batch; ++b) {
417 const float* golden_start = lstm_golden_output[b].data() + i * n_output;
418 const float* golden_end = golden_start + n_output;
419 expected.insert(expected.end(), golden_start, golden_end);
420 }
421 EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
422 }
423 }
424
425 } // namespace wrapper
426 } // namespace nn
427 } // namespace android
428