1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "LSTM.h"
17 
18 #include "NeuralNetworksWrapper.h"
19 
20 #include <android-base/logging.h>
21 #include <gmock/gmock-matchers.h>
22 #include <gtest/gtest.h>
23 #include <sstream>
24 #include <string>
25 #include <vector>
26 
27 namespace android {
28 namespace nn {
29 namespace wrapper {
30 
31 using ::testing::Each;
32 using ::testing::FloatNear;
33 using ::testing::Matcher;
34 
35 namespace {
36 
ArrayFloatNear(const std::vector<float> & values,float max_abs_error=1.e-6)37 std::vector<Matcher<float>> ArrayFloatNear(const std::vector<float>& values,
38                                            float max_abs_error = 1.e-6) {
39     std::vector<Matcher<float>> matchers;
40     matchers.reserve(values.size());
41     for (const float& v : values) {
42         matchers.emplace_back(FloatNear(v, max_abs_error));
43     }
44     return matchers;
45 }
46 
47 }  // anonymous namespace
48 
49 #define FOR_ALL_INPUT_AND_WEIGHT_TENSORS(ACTION) \
50     ACTION(Input)                                \
51     ACTION(InputToInputWeights)                  \
52     ACTION(InputToCellWeights)                   \
53     ACTION(InputToForgetWeights)                 \
54     ACTION(InputToOutputWeights)                 \
55     ACTION(RecurrentToInputWeights)              \
56     ACTION(RecurrentToCellWeights)               \
57     ACTION(RecurrentToForgetWeights)             \
58     ACTION(RecurrentToOutputWeights)             \
59     ACTION(CellToInputWeights)                   \
60     ACTION(CellToForgetWeights)                  \
61     ACTION(CellToOutputWeights)                  \
62     ACTION(InputGateBias)                        \
63     ACTION(CellGateBias)                         \
64     ACTION(ForgetGateBias)                       \
65     ACTION(OutputGateBias)                       \
66     ACTION(ProjectionWeights)                    \
67     ACTION(ProjectionBias)                       \
68     ACTION(OutputStateIn)                        \
69     ACTION(CellStateIn)
70 
71 #define FOR_ALL_LAYER_NORM_WEIGHTS(ACTION) \
72     ACTION(InputLayerNormWeights)          \
73     ACTION(ForgetLayerNormWeights)         \
74     ACTION(CellLayerNormWeights)           \
75     ACTION(OutputLayerNormWeights)
76 
77 // For all output and intermediate states
78 #define FOR_ALL_OUTPUT_TENSORS(ACTION) \
79     ACTION(ScratchBuffer)              \
80     ACTION(OutputStateOut)             \
81     ACTION(CellStateOut)               \
82     ACTION(Output)
83 
84 class LayerNormLSTMOpModel {
85    public:
LayerNormLSTMOpModel(uint32_t n_batch,uint32_t n_input,uint32_t n_cell,uint32_t n_output,bool use_cifg,bool use_peephole,bool use_projection_weights,bool use_projection_bias,float cell_clip,float proj_clip,const std::vector<std::vector<uint32_t>> & input_shapes0)86     LayerNormLSTMOpModel(uint32_t n_batch, uint32_t n_input, uint32_t n_cell, uint32_t n_output,
87                          bool use_cifg, bool use_peephole, bool use_projection_weights,
88                          bool use_projection_bias, float cell_clip, float proj_clip,
89                          const std::vector<std::vector<uint32_t>>& input_shapes0)
90         : n_input_(n_input),
91           n_output_(n_output),
92           use_cifg_(use_cifg),
93           use_peephole_(use_peephole),
94           use_projection_weights_(use_projection_weights),
95           use_projection_bias_(use_projection_bias),
96           activation_(ActivationFn::kActivationTanh),
97           cell_clip_(cell_clip),
98           proj_clip_(proj_clip) {
99         std::vector<uint32_t> inputs;
100         std::vector<std::vector<uint32_t>> input_shapes(input_shapes0);
101 
102         auto it = input_shapes.begin();
103 
104         // Input and weights
105 #define AddInput(X)                                     \
106     CHECK(it != input_shapes.end());                    \
107     OperandType X##OpndTy(Type::TENSOR_FLOAT32, *it++); \
108     inputs.push_back(model_.addOperand(&X##OpndTy));
109 
110         FOR_ALL_INPUT_AND_WEIGHT_TENSORS(AddInput);
111 
112         // Parameters
113         OperandType ActivationOpndTy(Type::INT32, {});
114         inputs.push_back(model_.addOperand(&ActivationOpndTy));
115         OperandType CellClipOpndTy(Type::FLOAT32, {});
116         inputs.push_back(model_.addOperand(&CellClipOpndTy));
117         OperandType ProjClipOpndTy(Type::FLOAT32, {});
118         inputs.push_back(model_.addOperand(&ProjClipOpndTy));
119 
120         FOR_ALL_LAYER_NORM_WEIGHTS(AddInput);
121 
122 #undef AddOperand
123 
124         // Output and other intermediate state
125         std::vector<std::vector<uint32_t>> output_shapes{
126                 {n_batch, n_cell * (use_cifg ? 3 : 4)},
127                 {n_batch, n_output},
128                 {n_batch, n_cell},
129                 {n_batch, n_output},
130         };
131         std::vector<uint32_t> outputs;
132 
133         auto it2 = output_shapes.begin();
134 
135 #define AddOutput(X)                                     \
136     CHECK(it2 != output_shapes.end());                   \
137     OperandType X##OpndTy(Type::TENSOR_FLOAT32, *it2++); \
138     outputs.push_back(model_.addOperand(&X##OpndTy));
139 
140         FOR_ALL_OUTPUT_TENSORS(AddOutput);
141 
142 #undef AddOutput
143 
144         model_.addOperation(ANEURALNETWORKS_LSTM, inputs, outputs);
145         model_.identifyInputsAndOutputs(inputs, outputs);
146 
147         Input_.insert(Input_.end(), n_batch * n_input, 0.f);
148         OutputStateIn_.insert(OutputStateIn_.end(), n_batch * n_output, 0.f);
149         CellStateIn_.insert(CellStateIn_.end(), n_batch * n_cell, 0.f);
150 
151         auto multiAll = [](const std::vector<uint32_t>& dims) -> uint32_t {
152             uint32_t sz = 1;
153             for (uint32_t d : dims) {
154                 sz *= d;
155             }
156             return sz;
157         };
158 
159         it2 = output_shapes.begin();
160 
161 #define ReserveOutput(X) X##_.insert(X##_.end(), multiAll(*it2++), 0.f);
162 
163         FOR_ALL_OUTPUT_TENSORS(ReserveOutput);
164 
165 #undef ReserveOutput
166 
167         model_.finish();
168     }
169 
170 #define DefineSetter(X) \
171     void Set##X(const std::vector<float>& f) { X##_.insert(X##_.end(), f.begin(), f.end()); }
172 
173     FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineSetter);
174     FOR_ALL_LAYER_NORM_WEIGHTS(DefineSetter);
175 
176 #undef DefineSetter
177 
ResetOutputState()178     void ResetOutputState() {
179         std::fill(OutputStateIn_.begin(), OutputStateIn_.end(), 0.f);
180         std::fill(OutputStateOut_.begin(), OutputStateOut_.end(), 0.f);
181     }
182 
ResetCellState()183     void ResetCellState() {
184         std::fill(CellStateIn_.begin(), CellStateIn_.end(), 0.f);
185         std::fill(CellStateOut_.begin(), CellStateOut_.end(), 0.f);
186     }
187 
SetInput(int offset,const float * begin,const float * end)188     void SetInput(int offset, const float* begin, const float* end) {
189         for (; begin != end; begin++, offset++) {
190             Input_[offset] = *begin;
191         }
192     }
193 
num_inputs() const194     uint32_t num_inputs() const { return n_input_; }
num_outputs() const195     uint32_t num_outputs() const { return n_output_; }
196 
GetOutput() const197     const std::vector<float>& GetOutput() const { return Output_; }
198 
Invoke()199     void Invoke() {
200         ASSERT_TRUE(model_.isValid());
201 
202         OutputStateIn_.swap(OutputStateOut_);
203         CellStateIn_.swap(CellStateOut_);
204 
205         Compilation compilation(&model_);
206         compilation.finish();
207         Execution execution(&compilation);
208 #define SetInputOrWeight(X)                                                                       \
209     ASSERT_EQ(                                                                                    \
210             execution.setInput(LSTMCell::k##X##Tensor, X##_.data(), sizeof(float) * X##_.size()), \
211             Result::NO_ERROR);
212 
213         FOR_ALL_INPUT_AND_WEIGHT_TENSORS(SetInputOrWeight);
214         FOR_ALL_LAYER_NORM_WEIGHTS(SetInputOrWeight);
215 
216 #undef SetInputOrWeight
217 
218 #define SetOutput(X)                                                                               \
219     ASSERT_EQ(                                                                                     \
220             execution.setOutput(LSTMCell::k##X##Tensor, X##_.data(), sizeof(float) * X##_.size()), \
221             Result::NO_ERROR);
222 
223         FOR_ALL_OUTPUT_TENSORS(SetOutput);
224 
225 #undef SetOutput
226 
227         if (use_cifg_) {
228             execution.setInput(LSTMCell::kInputToInputWeightsTensor, nullptr, 0);
229             execution.setInput(LSTMCell::kRecurrentToInputWeightsTensor, nullptr, 0);
230         }
231 
232         if (use_peephole_) {
233             if (use_cifg_) {
234                 execution.setInput(LSTMCell::kCellToInputWeightsTensor, nullptr, 0);
235             }
236         } else {
237             execution.setInput(LSTMCell::kCellToInputWeightsTensor, nullptr, 0);
238             execution.setInput(LSTMCell::kCellToForgetWeightsTensor, nullptr, 0);
239             execution.setInput(LSTMCell::kCellToOutputWeightsTensor, nullptr, 0);
240         }
241 
242         if (use_projection_weights_) {
243             if (!use_projection_bias_) {
244                 execution.setInput(LSTMCell::kProjectionBiasTensor, nullptr, 0);
245             }
246         } else {
247             execution.setInput(LSTMCell::kProjectionWeightsTensor, nullptr, 0);
248             execution.setInput(LSTMCell::kProjectionBiasTensor, nullptr, 0);
249         }
250 
251         ASSERT_EQ(execution.setInput(LSTMCell::kActivationParam, &activation_, sizeof(activation_)),
252                   Result::NO_ERROR);
253         ASSERT_EQ(execution.setInput(LSTMCell::kCellClipParam, &cell_clip_, sizeof(cell_clip_)),
254                   Result::NO_ERROR);
255         ASSERT_EQ(execution.setInput(LSTMCell::kProjClipParam, &proj_clip_, sizeof(proj_clip_)),
256                   Result::NO_ERROR);
257 
258         ASSERT_EQ(execution.compute(), Result::NO_ERROR);
259     }
260 
261    private:
262     Model model_;
263     // Execution execution_;
264     const uint32_t n_input_;
265     const uint32_t n_output_;
266 
267     const bool use_cifg_;
268     const bool use_peephole_;
269     const bool use_projection_weights_;
270     const bool use_projection_bias_;
271 
272     const int activation_;
273     const float cell_clip_;
274     const float proj_clip_;
275 
276 #define DefineTensor(X) std::vector<float> X##_;
277 
278     FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineTensor);
279     FOR_ALL_LAYER_NORM_WEIGHTS(DefineTensor);
280     FOR_ALL_OUTPUT_TENSORS(DefineTensor);
281 
282 #undef DefineTensor
283 };
284 
TEST(LSTMOpTest,LayerNormNoCifgPeepholeProjectionNoClipping)285 TEST(LSTMOpTest, LayerNormNoCifgPeepholeProjectionNoClipping) {
286     const int n_batch = 2;
287     const int n_input = 5;
288     // n_cell and n_output have the same size when there is no projection.
289     const int n_cell = 4;
290     const int n_output = 3;
291 
292     LayerNormLSTMOpModel lstm(n_batch, n_input, n_cell, n_output,
293                               /*use_cifg=*/false, /*use_peephole=*/true,
294                               /*use_projection_weights=*/true,
295                               /*use_projection_bias=*/false,
296                               /*cell_clip=*/0.0, /*proj_clip=*/0.0,
297                               {
298                                       {n_batch, n_input},  // input tensor
299 
300                                       {n_cell, n_input},  // input_to_input_weight tensor
301                                       {n_cell, n_input},  // input_to_forget_weight tensor
302                                       {n_cell, n_input},  // input_to_cell_weight tensor
303                                       {n_cell, n_input},  // input_to_output_weight tensor
304 
305                                       {n_cell, n_output},  // recurrent_to_input_weight tensor
306                                       {n_cell, n_output},  // recurrent_to_forget_weight tensor
307                                       {n_cell, n_output},  // recurrent_to_cell_weight tensor
308                                       {n_cell, n_output},  // recurrent_to_output_weight tensor
309 
310                                       {n_cell},  // cell_to_input_weight tensor
311                                       {n_cell},  // cell_to_forget_weight tensor
312                                       {n_cell},  // cell_to_output_weight tensor
313 
314                                       {n_cell},  // input_gate_bias tensor
315                                       {n_cell},  // forget_gate_bias tensor
316                                       {n_cell},  // cell_bias tensor
317                                       {n_cell},  // output_gate_bias tensor
318 
319                                       {n_output, n_cell},  // projection_weight tensor
320                                       {0},                 // projection_bias tensor
321 
322                                       {n_batch, n_output},  // output_state_in tensor
323                                       {n_batch, n_cell},    // cell_state_in tensor
324 
325                                       {n_cell},  // input_layer_norm_weights tensor
326                                       {n_cell},  // forget_layer_norm_weights tensor
327                                       {n_cell},  // cell_layer_norm_weights tensor
328                                       {n_cell},  // output_layer_norm_weights tensor
329                               });
330 
331     lstm.SetInputToInputWeights({0.5,  0.6, 0.7,  -0.8, -0.9, 0.1,  0.2,  0.3,  -0.4, 0.5,
332                                  -0.8, 0.7, -0.6, 0.5,  -0.4, -0.5, -0.4, -0.3, -0.2, -0.1});
333 
334     lstm.SetInputToForgetWeights({-0.6, -0.1, 0.3,  0.2,  0.9,  -0.5, -0.2, -0.4, 0.3,  -0.8,
335                                   -0.4, 0.3,  -0.5, -0.4, -0.6, 0.3,  -0.4, -0.6, -0.5, -0.5});
336 
337     lstm.SetInputToCellWeights({-0.4, -0.3, -0.2, -0.1, -0.5, 0.5, -0.2, -0.3, -0.2, -0.6,
338                                 0.6,  -0.1, -0.4, -0.3, -0.7, 0.7, -0.9, -0.5, 0.8,  0.6});
339 
340     lstm.SetInputToOutputWeights({-0.8, -0.4, -0.2, -0.9, -0.1, -0.7, 0.3, -0.3, -0.8, -0.2,
341                                   0.6,  -0.2, 0.4,  -0.7, -0.3, -0.5, 0.1, 0.5,  -0.6, -0.4});
342 
343     lstm.SetInputGateBias({0.03, 0.15, 0.22, 0.38});
344 
345     lstm.SetForgetGateBias({0.1, -0.3, -0.2, 0.1});
346 
347     lstm.SetCellGateBias({-0.05, 0.72, 0.25, 0.08});
348 
349     lstm.SetOutputGateBias({0.05, -0.01, 0.2, 0.1});
350 
351     lstm.SetRecurrentToInputWeights(
352             {-0.2, -0.3, 0.4, 0.1, -0.5, 0.9, -0.2, -0.3, -0.7, 0.05, -0.2, -0.6});
353 
354     lstm.SetRecurrentToCellWeights(
355             {-0.3, 0.2, 0.1, -0.3, 0.8, -0.08, -0.2, 0.3, 0.8, -0.6, -0.1, 0.2});
356 
357     lstm.SetRecurrentToForgetWeights(
358             {-0.5, -0.3, -0.5, -0.2, 0.6, 0.4, 0.9, 0.3, -0.1, 0.2, 0.5, 0.2});
359 
360     lstm.SetRecurrentToOutputWeights(
361             {0.3, -0.1, 0.1, -0.2, -0.5, -0.7, -0.2, -0.6, -0.1, -0.4, -0.7, -0.2});
362 
363     lstm.SetCellToInputWeights({0.05, 0.1, 0.25, 0.15});
364     lstm.SetCellToForgetWeights({-0.02, -0.15, -0.25, -0.03});
365     lstm.SetCellToOutputWeights({0.1, -0.1, -0.5, 0.05});
366 
367     lstm.SetProjectionWeights({-0.1, 0.2, 0.01, -0.2, 0.1, 0.5, 0.3, 0.08, 0.07, 0.2, -0.4, 0.2});
368 
369     lstm.SetInputLayerNormWeights({0.1, 0.2, 0.3, 0.5});
370     lstm.SetForgetLayerNormWeights({0.2, 0.2, 0.4, 0.3});
371     lstm.SetCellLayerNormWeights({0.7, 0.2, 0.3, 0.8});
372     lstm.SetOutputLayerNormWeights({0.6, 0.2, 0.2, 0.5});
373 
374     const std::vector<std::vector<float>> lstm_input = {
375             {                           // Batch0: 3 (input_sequence_size) * 5 (n_input)
376              0.7, 0.8, 0.1, 0.2, 0.3,   // seq 0
377              0.8, 0.1, 0.2, 0.4, 0.5,   // seq 1
378              0.2, 0.7, 0.7, 0.1, 0.7},  // seq 2
379 
380             {                           // Batch1: 3 (input_sequence_size) * 5 (n_input)
381              0.3, 0.2, 0.9, 0.8, 0.1,   // seq 0
382              0.1, 0.5, 0.2, 0.4, 0.2,   // seq 1
383              0.6, 0.9, 0.2, 0.5, 0.7},  // seq 2
384     };
385 
386     const std::vector<std::vector<float>> lstm_golden_output = {
387             {
388                     // Batch0: 3 (input_sequence_size) * 3 (n_output)
389                     0.0244077, 0.128027, -0.00170918,  // seq 0
390                     0.0137642, 0.140751, 0.0395835,    // seq 1
391                     -0.00459231, 0.155278, 0.0837377,  // seq 2
392             },
393             {
394                     // Batch1: 3 (input_sequence_size) * 3 (n_output)
395                     -0.00692428, 0.0848741, 0.063445,  // seq 0
396                     -0.00403912, 0.139963, 0.072681,   // seq 1
397                     0.00752706, 0.161903, 0.0561371,   // seq 2
398             }};
399 
400     // Resetting cell_state and output_state
401     lstm.ResetCellState();
402     lstm.ResetOutputState();
403 
404     const int input_sequence_size = lstm_input[0].size() / n_input;
405     for (int i = 0; i < input_sequence_size; i++) {
406         for (int b = 0; b < n_batch; ++b) {
407             const float* batch_start = lstm_input[b].data() + i * n_input;
408             const float* batch_end = batch_start + n_input;
409 
410             lstm.SetInput(b * n_input, batch_start, batch_end);
411         }
412 
413         lstm.Invoke();
414 
415         std::vector<float> expected;
416         for (int b = 0; b < n_batch; ++b) {
417             const float* golden_start = lstm_golden_output[b].data() + i * n_output;
418             const float* golden_end = golden_start + n_output;
419             expected.insert(expected.end(), golden_start, golden_end);
420         }
421         EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
422     }
423 }
424 
425 }  // namespace wrapper
426 }  // namespace nn
427 }  // namespace android
428