1 /*
2 * Copyright (C) 2019 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define LOG_TAG "ModelArgumentInfo"
18
19 #include "ModelArgumentInfo.h"
20
21 #include <algorithm>
22 #include <utility>
23 #include <vector>
24
25 #include "HalInterfaces.h"
26 #include "NeuralNetworks.h"
27 #include "TypeManager.h"
28 #include "Utils.h"
29
30 namespace android {
31 namespace nn {
32
33 using namespace hal;
34
35 static const std::pair<int, ModelArgumentInfo> kBadDataModelArgumentInfo{ANEURALNETWORKS_BAD_DATA,
36 {}};
37
createFromPointer(const Operand & operand,const ANeuralNetworksOperandType * type,void * data,uint32_t length)38 std::pair<int, ModelArgumentInfo> ModelArgumentInfo::createFromPointer(
39 const Operand& operand, const ANeuralNetworksOperandType* type, void* data,
40 uint32_t length) {
41 if ((data == nullptr) != (length == 0)) {
42 const char* dataPtrMsg = data ? "NOT_NULLPTR" : "NULLPTR";
43 LOG(ERROR) << "Data pointer must be nullptr if and only if length is zero (data = "
44 << dataPtrMsg << ", length = " << length << ")";
45 return kBadDataModelArgumentInfo;
46 }
47
48 ModelArgumentInfo ret;
49 if (data == nullptr) {
50 ret.mState = ModelArgumentInfo::HAS_NO_VALUE;
51 } else {
52 if (int n = ret.updateDimensionInfo(operand, type)) {
53 return {n, ModelArgumentInfo()};
54 }
55 if (operand.type != OperandType::OEM) {
56 uint32_t neededLength =
57 TypeManager::get()->getSizeOfData(operand.type, ret.mDimensions);
58 if (neededLength != length && neededLength != 0) {
59 LOG(ERROR) << "Setting argument with invalid length: " << length
60 << ", expected length: " << neededLength;
61 return kBadDataModelArgumentInfo;
62 }
63 }
64 ret.mState = ModelArgumentInfo::POINTER;
65 }
66 ret.mBuffer = data;
67 ret.mLocationAndLength = {.poolIndex = 0, .offset = 0, .length = length};
68 return {ANEURALNETWORKS_NO_ERROR, ret};
69 }
70
createFromMemory(const Operand & operand,const ANeuralNetworksOperandType * type,uint32_t poolIndex,uint32_t offset,uint32_t length)71 std::pair<int, ModelArgumentInfo> ModelArgumentInfo::createFromMemory(
72 const Operand& operand, const ANeuralNetworksOperandType* type, uint32_t poolIndex,
73 uint32_t offset, uint32_t length) {
74 ModelArgumentInfo ret;
75 if (int n = ret.updateDimensionInfo(operand, type)) {
76 return {n, ModelArgumentInfo()};
77 }
78 const bool isMemorySizeKnown = offset != 0 || length != 0;
79 if (isMemorySizeKnown && operand.type != OperandType::OEM) {
80 const uint32_t neededLength =
81 TypeManager::get()->getSizeOfData(operand.type, ret.mDimensions);
82 if (neededLength != length && neededLength != 0) {
83 LOG(ERROR) << "Setting argument with invalid length: " << length
84 << " (offset: " << offset << "), expected length: " << neededLength;
85 return kBadDataModelArgumentInfo;
86 }
87 }
88
89 ret.mState = ModelArgumentInfo::MEMORY;
90 ret.mLocationAndLength = {.poolIndex = poolIndex, .offset = offset, .length = length};
91 ret.mBuffer = nullptr;
92 return {ANEURALNETWORKS_NO_ERROR, ret};
93 }
94
updateDimensionInfo(const Operand & operand,const ANeuralNetworksOperandType * newType)95 int ModelArgumentInfo::updateDimensionInfo(const Operand& operand,
96 const ANeuralNetworksOperandType* newType) {
97 if (newType == nullptr) {
98 mDimensions = operand.dimensions;
99 } else {
100 const uint32_t count = newType->dimensionCount;
101 mDimensions = hidl_vec<uint32_t>(count);
102 std::copy(&newType->dimensions[0], &newType->dimensions[count], mDimensions.begin());
103 }
104 return ANEURALNETWORKS_NO_ERROR;
105 }
106
createRequestArguments(const std::vector<ModelArgumentInfo> & argumentInfos,const std::vector<DataLocation> & ptrArgsLocations)107 hidl_vec<RequestArgument> createRequestArguments(
108 const std::vector<ModelArgumentInfo>& argumentInfos,
109 const std::vector<DataLocation>& ptrArgsLocations) {
110 const size_t count = argumentInfos.size();
111 hidl_vec<RequestArgument> ioInfos(count);
112 uint32_t ptrArgsIndex = 0;
113 for (size_t i = 0; i < count; i++) {
114 const auto& info = argumentInfos[i];
115 switch (info.state()) {
116 case ModelArgumentInfo::POINTER:
117 ioInfos[i] = {.hasNoValue = false,
118 .location = ptrArgsLocations[ptrArgsIndex++],
119 .dimensions = info.dimensions()};
120 break;
121 case ModelArgumentInfo::MEMORY:
122 ioInfos[i] = {.hasNoValue = false,
123 .location = info.locationAndLength(),
124 .dimensions = info.dimensions()};
125 break;
126 case ModelArgumentInfo::HAS_NO_VALUE:
127 ioInfos[i] = {.hasNoValue = true};
128 break;
129 default:
130 CHECK(false);
131 };
132 }
133 return ioInfos;
134 }
135
136 } // namespace nn
137 } // namespace android
138