1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "TypeManager"
18 
19 #include "TypeManager.h"
20 
21 #include <PackageInfo.h>
22 #include <android-base/file.h>
23 #include <android-base/properties.h>
24 #include <binder/IServiceManager.h>
25 #include <procpartition/procpartition.h>
26 
27 #include <algorithm>
28 #include <limits>
29 #include <map>
30 #include <memory>
31 #include <string>
32 #include <string_view>
33 #include <vector>
34 
35 #include "Utils.h"
36 
37 namespace android {
38 namespace nn {
39 
40 // Replacement function for std::string_view::starts_with()
41 // which shall be available in C++20.
42 #if __cplusplus >= 202000L
43 #error "When upgrading to C++20, remove this error and file a bug to remove this workaround."
44 #endif
StartsWith(std::string_view sv,std::string_view prefix)45 inline bool StartsWith(std::string_view sv, std::string_view prefix) {
46     return sv.substr(0u, prefix.size()) == prefix;
47 }
48 
49 namespace {
50 
51 using namespace hal;
52 
53 const uint8_t kLowBitsType = static_cast<uint8_t>(ExtensionTypeEncoding::LOW_BITS_TYPE);
54 const uint32_t kMaxPrefix =
55         (1 << static_cast<uint8_t>(ExtensionTypeEncoding::HIGH_BITS_PREFIX)) - 1;
56 
57 // Checks if the two structures contain the same information. The order of
58 // operand types within the structures does not matter.
equal(const Extension & a,const Extension & b)59 bool equal(const Extension& a, const Extension& b) {
60     NN_RET_CHECK_EQ(a.name, b.name);
61     // Relies on the fact that TypeManager sorts operandTypes.
62     NN_RET_CHECK(a.operandTypes == b.operandTypes);
63     return true;
64 }
65 
66 // Property for disabling NNAPI vendor extensions on product image (used on GSI /product image,
67 // which can't use NNAPI vendor extensions).
68 const char kVExtProductDeny[] = "ro.nnapi.extensions.deny_on_product";
isNNAPIVendorExtensionsUseAllowedInProductImage()69 bool isNNAPIVendorExtensionsUseAllowedInProductImage() {
70     const std::string vExtProductDeny = android::base::GetProperty(kVExtProductDeny, "");
71     return vExtProductDeny.empty();
72 }
73 
74 // The file containing the list of Android apps and binaries allowed to use vendor extensions.
75 // Each line of the file contains new entry. If entry is prefixed by
76 // '/' slash, then it's a native binary path (e.g. '/data/foo'). If not, it's a name
77 // of Android app package (e.g. 'com.foo.bar').
78 const char kAppAllowlistPath[] = "/vendor/etc/nnapi_extensions_app_allowlist";
79 const char kCtsAllowlist[] = "/data/local/tmp/CTSNNAPITestCases";
getVendorExtensionAllowlistedApps()80 std::vector<std::string> getVendorExtensionAllowlistedApps() {
81     std::string data;
82     // Allowlist CTS by default.
83     std::vector<std::string> allowlist = {kCtsAllowlist};
84 
85     if (!android::base::ReadFileToString(kAppAllowlistPath, &data)) {
86         // Return default allowlist (no app can use extensions).
87         LOG(INFO) << "Failed to read " << kAppAllowlistPath
88                   << " ; No app allowlisted for vendor extensions use.";
89         return allowlist;
90     }
91 
92     std::istringstream streamData(data);
93     std::string line;
94     while (std::getline(streamData, line)) {
95         // Do some basic validity check on entry, it's either
96         // fs path or package name.
97         if (StartsWith(line, "/") || line.find('.') != std::string::npos) {
98             allowlist.push_back(line);
99         } else {
100             LOG(ERROR) << kAppAllowlistPath << " - Invalid entry: " << line;
101         }
102     }
103     return allowlist;
104 }
105 
106 // Query PackageManagerNative service about Android app properties.
107 // On success, it will populate appPackageInfo->app* fields.
fetchAppPackageLocationInfo(uid_t uid,TypeManager::AppPackageInfo * appPackageInfo)108 bool fetchAppPackageLocationInfo(uid_t uid, TypeManager::AppPackageInfo* appPackageInfo) {
109     ANeuralNetworks_PackageInfo packageInfo;
110     if (!ANeuralNetworks_fetch_PackageInfo(uid, &packageInfo)) {
111         return false;
112     }
113     appPackageInfo->appPackageName = packageInfo.appPackageName;
114     appPackageInfo->appIsSystemApp = packageInfo.appIsSystemApp;
115     appPackageInfo->appIsOnVendorImage = packageInfo.appIsOnVendorImage;
116     appPackageInfo->appIsOnProductImage = packageInfo.appIsOnProductImage;
117 
118     ANeuralNetworks_free_PackageInfo(&packageInfo);
119     return true;
120 }
121 
122 // Check if this process is allowed to use NNAPI Vendor extensions.
isNNAPIVendorExtensionsUseAllowed(const std::vector<std::string> & allowlist)123 bool isNNAPIVendorExtensionsUseAllowed(const std::vector<std::string>& allowlist) {
124     TypeManager::AppPackageInfo appPackageInfo = {
125             .binaryPath = ::android::procpartition::getExe(getpid()),
126             .appPackageName = "",
127             .appIsSystemApp = false,
128             .appIsOnVendorImage = false,
129             .appIsOnProductImage = false};
130 
131     if (appPackageInfo.binaryPath == "/system/bin/app_process64" ||
132         appPackageInfo.binaryPath == "/system/bin/app_process32") {
133         if (!fetchAppPackageLocationInfo(getuid(), &appPackageInfo)) {
134             LOG(ERROR) << "Failed to get app information from package_manager_native";
135             return false;
136         }
137     }
138     return TypeManager::isExtensionsUseAllowed(
139             appPackageInfo, isNNAPIVendorExtensionsUseAllowedInProductImage(), allowlist);
140 }
141 
142 }  // namespace
143 
TypeManager()144 TypeManager::TypeManager() {
145     VLOG(MANAGER) << "TypeManager::TypeManager";
146     mExtensionsAllowed = isNNAPIVendorExtensionsUseAllowed(getVendorExtensionAllowlistedApps());
147     VLOG(MANAGER) << "NNAPI Vendor extensions enabled: " << mExtensionsAllowed;
148     findAvailableExtensions();
149 }
150 
isExtensionsUseAllowed(const AppPackageInfo & appPackageInfo,bool useOnProductImageEnabled,const std::vector<std::string> & allowlist)151 bool TypeManager::isExtensionsUseAllowed(const AppPackageInfo& appPackageInfo,
152                                          bool useOnProductImageEnabled,
153                                          const std::vector<std::string>& allowlist) {
154     // Only selected partitions and user-installed apps (/data)
155     // are allowed to use extensions.
156     if (StartsWith(appPackageInfo.binaryPath, "/vendor/") ||
157         StartsWith(appPackageInfo.binaryPath, "/odm/") ||
158         StartsWith(appPackageInfo.binaryPath, "/data/") ||
159         (StartsWith(appPackageInfo.binaryPath, "/product/") && useOnProductImageEnabled)) {
160 #ifdef NN_DEBUGGABLE
161         // Only on userdebug and eng builds.
162         // When running tests with mma and adb push.
163         if (StartsWith(appPackageInfo.binaryPath, "/data/nativetest") ||
164             // When running tests with Atest.
165             StartsWith(appPackageInfo.binaryPath, "/data/local/tmp/NeuralNetworksTest_")) {
166             return true;
167         }
168 #endif  // NN_DEBUGGABLE
169 
170         return std::find(allowlist.begin(), allowlist.end(), appPackageInfo.binaryPath) !=
171                allowlist.end();
172     } else if (appPackageInfo.binaryPath == "/system/bin/app_process64" ||
173                appPackageInfo.binaryPath == "/system/bin/app_process32") {
174         // App is not system app OR vendor app OR (product app AND product enabled)
175         // AND app is on allowlist.
176         return (!appPackageInfo.appIsSystemApp || appPackageInfo.appIsOnVendorImage ||
177                 (appPackageInfo.appIsOnProductImage && useOnProductImageEnabled)) &&
178                std::find(allowlist.begin(), allowlist.end(), appPackageInfo.appPackageName) !=
179                        allowlist.end();
180     }
181     return false;
182 }
183 
findAvailableExtensions()184 void TypeManager::findAvailableExtensions() {
185     for (const std::shared_ptr<Device>& device : mDeviceManager->getDrivers()) {
186         for (const Extension& extension : device->getSupportedExtensions()) {
187             registerExtension(extension, device->getName());
188         }
189     }
190 }
191 
registerExtension(Extension extension,const std::string & deviceName)192 bool TypeManager::registerExtension(Extension extension, const std::string& deviceName) {
193     if (mDisabledExtensions.find(extension.name) != mDisabledExtensions.end()) {
194         LOG(ERROR) << "Extension " << extension.name << " is disabled";
195         return false;
196     }
197 
198     std::sort(extension.operandTypes.begin(), extension.operandTypes.end(),
199               [](const Extension::OperandTypeInformation& a,
200                  const Extension::OperandTypeInformation& b) {
201                   return static_cast<uint16_t>(a.type) < static_cast<uint16_t>(b.type);
202               });
203 
204     std::map<std::string, Extension>::iterator it;
205     bool isNew;
206     std::tie(it, isNew) = mExtensionNameToExtension.emplace(extension.name, extension);
207     if (isNew) {
208         VLOG(MANAGER) << "Registered extension " << extension.name;
209         mExtensionNameToFirstDevice.emplace(extension.name, deviceName);
210     } else if (!equal(extension, it->second)) {
211         LOG(ERROR) << "Devices " << mExtensionNameToFirstDevice[extension.name] << " and "
212                    << deviceName << " provide inconsistent information for extension "
213                    << extension.name << ", which is therefore disabled";
214         mExtensionNameToExtension.erase(it);
215         mDisabledExtensions.insert(extension.name);
216         return false;
217     }
218     return true;
219 }
220 
getExtensionPrefix(const std::string & extensionName,uint16_t * prefix)221 bool TypeManager::getExtensionPrefix(const std::string& extensionName, uint16_t* prefix) {
222     auto it = mExtensionNameToPrefix.find(extensionName);
223     if (it != mExtensionNameToPrefix.end()) {
224         *prefix = it->second;
225     } else {
226         NN_RET_CHECK_LE(mPrefixToExtension.size(), kMaxPrefix) << "Too many extensions in use";
227         *prefix = mPrefixToExtension.size();
228         mExtensionNameToPrefix[extensionName] = *prefix;
229         mPrefixToExtension.push_back(&mExtensionNameToExtension[extensionName]);
230     }
231     return true;
232 }
233 
getExtensionType(const char * extensionName,uint16_t typeWithinExtension,int32_t * type)234 bool TypeManager::getExtensionType(const char* extensionName, uint16_t typeWithinExtension,
235                                    int32_t* type) {
236     uint16_t prefix;
237     NN_RET_CHECK(getExtensionPrefix(extensionName, &prefix));
238     *type = (prefix << kLowBitsType) | typeWithinExtension;
239     return true;
240 }
241 
getExtensionInfo(uint16_t prefix,const Extension ** extension) const242 bool TypeManager::getExtensionInfo(uint16_t prefix, const Extension** extension) const {
243     NN_RET_CHECK_NE(prefix, 0u) << "prefix=0 does not correspond to an extension";
244     NN_RET_CHECK_LT(prefix, mPrefixToExtension.size()) << "Unknown extension prefix";
245     *extension = mPrefixToExtension[prefix];
246     return true;
247 }
248 
getExtensionOperandTypeInfo(OperandType type,const Extension::OperandTypeInformation ** info) const249 bool TypeManager::getExtensionOperandTypeInfo(
250         OperandType type, const Extension::OperandTypeInformation** info) const {
251     uint32_t operandType = static_cast<uint32_t>(type);
252     uint16_t prefix = operandType >> kLowBitsType;
253     uint16_t typeWithinExtension = operandType & ((1 << kLowBitsType) - 1);
254     const Extension* extension;
255     NN_RET_CHECK(getExtensionInfo(prefix, &extension))
256             << "Cannot find extension corresponding to prefix " << prefix;
257     auto it = std::lower_bound(
258             extension->operandTypes.begin(), extension->operandTypes.end(), typeWithinExtension,
259             [](const Extension::OperandTypeInformation& info, uint32_t typeSought) {
260                 return static_cast<uint16_t>(info.type) < typeSought;
261             });
262     NN_RET_CHECK(it != extension->operandTypes.end() &&
263                  static_cast<uint16_t>(it->type) == typeWithinExtension)
264             << "Cannot find operand type " << typeWithinExtension << " in extension "
265             << extension->name;
266     *info = &*it;
267     return true;
268 }
269 
isTensorType(OperandType type) const270 bool TypeManager::isTensorType(OperandType type) const {
271     if (!isExtensionOperandType(type)) {
272         return !nonExtensionOperandTypeIsScalar(static_cast<int>(type));
273     }
274     const Extension::OperandTypeInformation* info;
275     CHECK(getExtensionOperandTypeInfo(type, &info));
276     return info->isTensor;
277 }
278 
getSizeOfData(OperandType type,const std::vector<uint32_t> & dimensions) const279 uint32_t TypeManager::getSizeOfData(OperandType type,
280                                     const std::vector<uint32_t>& dimensions) const {
281     if (!isExtensionOperandType(type)) {
282         return nonExtensionOperandSizeOfData(type, dimensions);
283     }
284     const Extension::OperandTypeInformation* info;
285     CHECK(getExtensionOperandTypeInfo(type, &info));
286     return info->isTensor ? sizeOfTensorData(info->byteSize, dimensions) : info->byteSize;
287 }
288 
sizeOfDataOverflowsUInt32(hal::OperandType type,const std::vector<uint32_t> & dimensions) const289 bool TypeManager::sizeOfDataOverflowsUInt32(hal::OperandType type,
290                                             const std::vector<uint32_t>& dimensions) const {
291     if (!isExtensionOperandType(type)) {
292         return nonExtensionOperandSizeOfDataOverflowsUInt32(type, dimensions);
293     }
294     const Extension::OperandTypeInformation* info;
295     CHECK(getExtensionOperandTypeInfo(type, &info));
296     return info->isTensor ? sizeOfTensorDataOverflowsUInt32(info->byteSize, dimensions) : false;
297 }
298 
299 }  // namespace nn
300 }  // namespace android
301