1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef ANDROID_FRAMEWORKS_ML_NN_DRIVER_SAMPLE_SAMPLE_DRIVER_H 18 #define ANDROID_FRAMEWORKS_ML_NN_DRIVER_SAMPLE_SAMPLE_DRIVER_H 19 20 #include <hwbinder/IPCThreadState.h> 21 22 #include <memory> 23 #include <string> 24 #include <utility> 25 #include <vector> 26 27 #include "BufferTracker.h" 28 #include "CpuExecutor.h" 29 #include "HalInterfaces.h" 30 #include "NeuralNetworks.h" 31 32 namespace android { 33 namespace nn { 34 namespace sample_driver { 35 36 using hardware::MQDescriptorSync; 37 38 // Manages the data buffer for an operand. 39 class SampleBuffer : public hal::IBuffer { 40 public: SampleBuffer(std::shared_ptr<ManagedBuffer> buffer,std::unique_ptr<BufferTracker::Token> token)41 SampleBuffer(std::shared_ptr<ManagedBuffer> buffer, std::unique_ptr<BufferTracker::Token> token) 42 : kBuffer(std::move(buffer)), kToken(std::move(token)) { 43 CHECK(kBuffer != nullptr); 44 CHECK(kToken != nullptr); 45 } 46 hal::Return<hal::ErrorStatus> copyTo(const hal::hidl_memory& dst) override; 47 hal::Return<hal::ErrorStatus> copyFrom(const hal::hidl_memory& src, 48 const hal::hidl_vec<uint32_t>& dimensions) override; 49 50 private: 51 const std::shared_ptr<ManagedBuffer> kBuffer; 52 const std::unique_ptr<BufferTracker::Token> kToken; 53 }; 54 55 // Base class used to create sample drivers for the NN HAL. This class 56 // provides some implementation of the more common functions. 57 // 58 // Since these drivers simulate hardware, they must run the computations 59 // on the CPU. An actual driver would not do that. 60 class SampleDriver : public hal::IDevice { 61 public: 62 SampleDriver(const char* name, 63 const IOperationResolver* operationResolver = BuiltinOperationResolver::get()) mName(name)64 : mName(name), 65 mOperationResolver(operationResolver), 66 mBufferTracker(BufferTracker::create()) { 67 android::nn::initVLogMask(); 68 } 69 hal::Return<void> getCapabilities(getCapabilities_cb cb) override; 70 hal::Return<void> getCapabilities_1_1(getCapabilities_1_1_cb cb) override; 71 hal::Return<void> getCapabilities_1_2(getCapabilities_1_2_cb cb) override; 72 hal::Return<void> getVersionString(getVersionString_cb cb) override; 73 hal::Return<void> getType(getType_cb cb) override; 74 hal::Return<void> getSupportedExtensions(getSupportedExtensions_cb) override; 75 hal::Return<void> getSupportedOperations(const hal::V1_0::Model& model, 76 getSupportedOperations_cb cb) override; 77 hal::Return<void> getSupportedOperations_1_1(const hal::V1_1::Model& model, 78 getSupportedOperations_1_1_cb cb) override; 79 hal::Return<void> getSupportedOperations_1_2(const hal::V1_2::Model& model, 80 getSupportedOperations_1_2_cb cb) override; 81 hal::Return<void> getNumberOfCacheFilesNeeded(getNumberOfCacheFilesNeeded_cb cb) override; 82 hal::Return<hal::V1_0::ErrorStatus> prepareModel( 83 const hal::V1_0::Model& model, 84 const sp<hal::V1_0::IPreparedModelCallback>& callback) override; 85 hal::Return<hal::V1_0::ErrorStatus> prepareModel_1_1( 86 const hal::V1_1::Model& model, hal::ExecutionPreference preference, 87 const sp<hal::V1_0::IPreparedModelCallback>& callback) override; 88 hal::Return<hal::V1_0::ErrorStatus> prepareModel_1_2( 89 const hal::V1_2::Model& model, hal::ExecutionPreference preference, 90 const hal::hidl_vec<hal::hidl_handle>& modelCache, 91 const hal::hidl_vec<hal::hidl_handle>& dataCache, const hal::CacheToken& token, 92 const sp<hal::V1_2::IPreparedModelCallback>& callback) override; 93 hal::Return<hal::V1_3::ErrorStatus> prepareModel_1_3( 94 const hal::V1_3::Model& model, hal::ExecutionPreference preference, 95 hal::Priority priority, const hal::OptionalTimePoint& deadline, 96 const hal::hidl_vec<hal::hidl_handle>& modelCache, 97 const hal::hidl_vec<hal::hidl_handle>& dataCache, const hal::CacheToken& token, 98 const sp<hal::V1_3::IPreparedModelCallback>& callback) override; 99 hal::Return<hal::V1_0::ErrorStatus> prepareModelFromCache( 100 const hal::hidl_vec<hal::hidl_handle>& modelCache, 101 const hal::hidl_vec<hal::hidl_handle>& dataCache, const hal::CacheToken& token, 102 const sp<hal::V1_2::IPreparedModelCallback>& callback) override; 103 hal::Return<hal::V1_3::ErrorStatus> prepareModelFromCache_1_3( 104 const hal::OptionalTimePoint& deadline, 105 const hal::hidl_vec<hal::hidl_handle>& modelCache, 106 const hal::hidl_vec<hal::hidl_handle>& dataCache, const hal::CacheToken& token, 107 const sp<hal::V1_3::IPreparedModelCallback>& callback) override; 108 hal::Return<hal::DeviceStatus> getStatus() override; 109 hal::Return<void> allocate(const hal::V1_3::BufferDesc& desc, 110 const hal::hidl_vec<sp<hal::V1_3::IPreparedModel>>& preparedModels, 111 const hal::hidl_vec<hal::V1_3::BufferRole>& inputRoles, 112 const hal::hidl_vec<hal::V1_3::BufferRole>& outputRoles, 113 allocate_cb cb) override; 114 115 // Starts and runs the driver service. Typically called from main(). 116 // This will return only once the service shuts down. 117 int run(); 118 getExecutor()119 CpuExecutor getExecutor() const { return CpuExecutor(mOperationResolver); } getBufferTracker()120 const std::shared_ptr<BufferTracker>& getBufferTracker() const { return mBufferTracker; } 121 122 protected: 123 std::string mName; 124 const IOperationResolver* mOperationResolver; 125 const std::shared_ptr<BufferTracker> mBufferTracker; 126 }; 127 128 class SamplePreparedModel : public hal::IPreparedModel { 129 public: SamplePreparedModel(const hal::Model & model,const SampleDriver * driver,hal::ExecutionPreference preference,uid_t userId,hal::Priority priority)130 SamplePreparedModel(const hal::Model& model, const SampleDriver* driver, 131 hal::ExecutionPreference preference, uid_t userId, hal::Priority priority) 132 : mModel(model), 133 mDriver(driver), 134 kPreference(preference), 135 kUserId(userId), 136 kPriority(priority) { 137 (void)kUserId; 138 (void)kPriority; 139 } 140 bool initialize(); 141 hal::Return<hal::V1_0::ErrorStatus> execute( 142 const hal::V1_0::Request& request, 143 const sp<hal::V1_0::IExecutionCallback>& callback) override; 144 hal::Return<hal::V1_0::ErrorStatus> execute_1_2( 145 const hal::V1_0::Request& request, hal::MeasureTiming measure, 146 const sp<hal::V1_2::IExecutionCallback>& callback) override; 147 hal::Return<hal::V1_3::ErrorStatus> execute_1_3( 148 const hal::V1_3::Request& request, hal::MeasureTiming measure, 149 const hal::OptionalTimePoint& deadline, 150 const hal::OptionalTimeoutDuration& loopTimeoutDuration, 151 const sp<hal::V1_3::IExecutionCallback>& callback) override; 152 hal::Return<void> executeSynchronously(const hal::V1_0::Request& request, 153 hal::MeasureTiming measure, 154 executeSynchronously_cb cb) override; 155 hal::Return<void> executeSynchronously_1_3( 156 const hal::V1_3::Request& request, hal::MeasureTiming measure, 157 const hal::OptionalTimePoint& deadline, 158 const hal::OptionalTimeoutDuration& loopTimeoutDuration, 159 executeSynchronously_1_3_cb cb) override; 160 hal::Return<void> configureExecutionBurst( 161 const sp<hal::V1_2::IBurstCallback>& callback, 162 const MQDescriptorSync<hal::V1_2::FmqRequestDatum>& requestChannel, 163 const MQDescriptorSync<hal::V1_2::FmqResultDatum>& resultChannel, 164 configureExecutionBurst_cb cb) override; 165 hal::Return<void> executeFenced(const hal::Request& request, 166 const hal::hidl_vec<hal::hidl_handle>& wait_for, 167 hal::MeasureTiming measure, 168 const hal::OptionalTimePoint& deadline, 169 const hal::OptionalTimeoutDuration& loopTimeoutDuration, 170 const hal::OptionalTimeoutDuration& duration, 171 executeFenced_cb callback) override; getModel()172 const hal::Model* getModel() const { return &mModel; } 173 174 private: 175 hal::Model mModel; 176 const SampleDriver* mDriver; 177 std::vector<RunTimePoolInfo> mPoolInfos; 178 const hal::ExecutionPreference kPreference; 179 const uid_t kUserId; 180 const hal::Priority kPriority; 181 }; 182 183 class SampleFencedExecutionCallback : public hal::IFencedExecutionCallback { 184 public: SampleFencedExecutionCallback(hal::Timing timingSinceLaunch,hal::Timing timingAfterFence,hal::ErrorStatus error)185 SampleFencedExecutionCallback(hal::Timing timingSinceLaunch, hal::Timing timingAfterFence, 186 hal::ErrorStatus error) 187 : kTimingSinceLaunch(timingSinceLaunch), 188 kTimingAfterFence(timingAfterFence), 189 kErrorStatus(error) {} getExecutionInfo(getExecutionInfo_cb callback)190 hal::Return<void> getExecutionInfo(getExecutionInfo_cb callback) override { 191 callback(kErrorStatus, kTimingSinceLaunch, kTimingAfterFence); 192 return hal::Void(); 193 } 194 195 private: 196 const hal::Timing kTimingSinceLaunch; 197 const hal::Timing kTimingAfterFence; 198 const hal::ErrorStatus kErrorStatus; 199 }; 200 201 } // namespace sample_driver 202 } // namespace nn 203 } // namespace android 204 205 #endif // ANDROID_FRAMEWORKS_ML_NN_DRIVER_SAMPLE_SAMPLE_DRIVER_H 206