1 /*
2  * Copyright (C) 2020 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef ANDROID_FRAMEWORKS_ML_NN_RUNTIME_TEST_GENERATED_TEST_UTILS_H
18 #define ANDROID_FRAMEWORKS_ML_NN_RUNTIME_TEST_GENERATED_TEST_UTILS_H
19 
20 #include <gtest/gtest.h>
21 
22 #include <memory>
23 #include <string>
24 #include <utility>
25 #include <vector>
26 
27 #include "TestHarness.h"
28 #include "TestNeuralNetworksWrapper.h"
29 
30 namespace android::nn::generated_tests {
31 
32 class GeneratedTestBase
33     : public ::testing::TestWithParam<test_helper::TestModelManager::TestParam> {
34    protected:
35     const std::string& kTestName = GetParam().first;
36     const test_helper::TestModel& testModel = *GetParam().second;
37 };
38 
39 #define INSTANTIATE_GENERATED_TEST(TestSuite, filter)                                          \
40     INSTANTIATE_TEST_SUITE_P(                                                                  \
41             TestGenerated, TestSuite,                                                          \
42             ::testing::ValuesIn(::test_helper::TestModelManager::get().getTestModels(filter)), \
43             [](const auto& info) { return info.param.first; })
44 
45 // A generated NDK model.
46 class GeneratedModel : public test_wrapper::Model {
47    public:
48     // A helper method to simplify referenced model lifetime management.
49     //
50     // Usage:
51     //     GeneratedModel model;
52     //     std::vector<Model> refModels;
53     //     createModel(&model, &refModels);
54     //     model.setRefModels(std::move(refModels));
55     //
56     // This makes sure referenced models live as long as the main model.
57     //
setRefModels(std::vector<test_wrapper::Model> refModels)58     void setRefModels(std::vector<test_wrapper::Model> refModels) {
59         mRefModels = std::move(refModels);
60     }
61 
62     // A helper method to simplify CONSTANT_REFERENCE memory lifetime management.
setConstantReferenceMemory(std::unique_ptr<test_wrapper::Memory> memory)63     void setConstantReferenceMemory(std::unique_ptr<test_wrapper::Memory> memory) {
64         mConstantReferenceMemory = std::move(memory);
65     }
66 
67    private:
68     std::vector<test_wrapper::Model> mRefModels;
69     std::unique_ptr<test_wrapper::Memory> mConstantReferenceMemory;
70 };
71 
72 // Convert TestModel to NDK model.
73 void createModel(const test_helper::TestModel& testModel, bool testDynamicOutputShape,
74                  GeneratedModel* model);
createModel(const test_helper::TestModel & testModel,GeneratedModel * model)75 inline void createModel(const test_helper::TestModel& testModel, GeneratedModel* model) {
76     createModel(testModel, /*testDynamicOutputShape=*/false, model);
77 }
78 
79 void createRequest(const test_helper::TestModel& testModel, test_wrapper::Execution* execution,
80                    std::vector<test_helper::TestBuffer>* outputs);
81 
82 }  // namespace android::nn::generated_tests
83 
84 #endif  // ANDROID_FRAMEWORKS_ML_NN_RUNTIME_TEST_GENERATED_TEST_UTILS_H
85