1 /*
2  * Copyright 2016 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <binder/SafeInterface.h>
18 
19 #include <binder/IInterface.h>
20 #include <binder/IPCThreadState.h>
21 #include <binder/IServiceManager.h>
22 #include <binder/Parcel.h>
23 #include <binder/Parcelable.h>
24 #include <binder/ProcessState.h>
25 
26 #pragma clang diagnostic push
27 #pragma clang diagnostic ignored "-Weverything"
28 #include <gtest/gtest.h>
29 #pragma clang diagnostic pop
30 
31 #include <utils/LightRefBase.h>
32 #include <utils/NativeHandle.h>
33 
34 #include <cutils/native_handle.h>
35 
36 #include <optional>
37 
38 #include <sys/eventfd.h>
39 
40 using namespace std::chrono_literals; // NOLINT - google-build-using-namespace
41 
42 namespace android {
43 namespace tests {
44 
45 enum class TestEnum : uint32_t {
46     INVALID = 0,
47     INITIAL = 1,
48     FINAL = 2,
49 };
50 
51 // This class serves two purposes:
52 //   1) It ensures that the implementation doesn't require copying or moving the data (for
53 //      efficiency purposes)
54 //   2) It tests that Parcelables can be passed correctly
55 class NoCopyNoMove : public Parcelable {
56 public:
57     NoCopyNoMove() = default;
NoCopyNoMove(int32_t value)58     explicit NoCopyNoMove(int32_t value) : mValue(value) {}
59     ~NoCopyNoMove() override = default;
60 
61     // Not copyable
62     NoCopyNoMove(const NoCopyNoMove&) = delete;
63     NoCopyNoMove& operator=(const NoCopyNoMove&) = delete;
64 
65     // Not movable
66     NoCopyNoMove(NoCopyNoMove&&) = delete;
67     NoCopyNoMove& operator=(NoCopyNoMove&&) = delete;
68 
69     // Parcelable interface
writeToParcel(Parcel * parcel) const70     status_t writeToParcel(Parcel* parcel) const override { return parcel->writeInt32(mValue); }
readFromParcel(const Parcel * parcel)71     status_t readFromParcel(const Parcel* parcel) override { return parcel->readInt32(&mValue); }
72 
getValue() const73     int32_t getValue() const { return mValue; }
setValue(int32_t value)74     void setValue(int32_t value) { mValue = value; }
75 
76 private:
77     int32_t mValue = 0;
78     __attribute__((unused)) uint8_t mPadding[4] = {}; // Avoids a warning from -Wpadded
79 };
80 
81 struct TestFlattenable : Flattenable<TestFlattenable> {
82     TestFlattenable() = default;
TestFlattenableandroid::tests::TestFlattenable83     explicit TestFlattenable(int32_t v) : value(v) {}
84 
85     // Flattenable protocol
getFlattenedSizeandroid::tests::TestFlattenable86     size_t getFlattenedSize() const { return sizeof(value); }
getFdCountandroid::tests::TestFlattenable87     size_t getFdCount() const { return 0; }
flattenandroid::tests::TestFlattenable88     status_t flatten(void*& buffer, size_t& size, int*& /*fds*/, size_t& /*count*/) const {
89         FlattenableUtils::write(buffer, size, value);
90         return NO_ERROR;
91     }
unflattenandroid::tests::TestFlattenable92     status_t unflatten(void const*& buffer, size_t& size, int const*& /*fds*/, size_t& /*count*/) {
93         FlattenableUtils::read(buffer, size, value);
94         return NO_ERROR;
95     }
96 
97     int32_t value = 0;
98 };
99 
100 struct TestLightFlattenable : LightFlattenablePod<TestLightFlattenable> {
101     TestLightFlattenable() = default;
TestLightFlattenableandroid::tests::TestLightFlattenable102     explicit TestLightFlattenable(int32_t v) : value(v) {}
103     int32_t value = 0;
104 };
105 
106 // It seems like this should be able to inherit from TestFlattenable (to avoid duplicating code),
107 // but the SafeInterface logic can't easily be extended to find an indirect Flattenable<T>
108 // base class
109 class TestLightRefBaseFlattenable : public Flattenable<TestLightRefBaseFlattenable>,
110                                     public LightRefBase<TestLightRefBaseFlattenable> {
111 public:
112     TestLightRefBaseFlattenable() = default;
TestLightRefBaseFlattenable(int32_t v)113     explicit TestLightRefBaseFlattenable(int32_t v) : value(v) {}
114 
115     // Flattenable protocol
getFlattenedSize() const116     size_t getFlattenedSize() const { return sizeof(value); }
getFdCount() const117     size_t getFdCount() const { return 0; }
flatten(void * & buffer,size_t & size,int * &,size_t &) const118     status_t flatten(void*& buffer, size_t& size, int*& /*fds*/, size_t& /*count*/) const {
119         FlattenableUtils::write(buffer, size, value);
120         return NO_ERROR;
121     }
unflatten(void const * & buffer,size_t & size,int const * &,size_t &)122     status_t unflatten(void const*& buffer, size_t& size, int const*& /*fds*/, size_t& /*count*/) {
123         FlattenableUtils::read(buffer, size, value);
124         return NO_ERROR;
125     }
126 
127     int32_t value = 0;
128 };
129 
130 class TestParcelable : public Parcelable {
131 public:
132     TestParcelable() = default;
TestParcelable(int32_t value)133     explicit TestParcelable(int32_t value) : mValue(value) {}
TestParcelable(const TestParcelable & other)134     TestParcelable(const TestParcelable& other) : TestParcelable(other.mValue) {}
TestParcelable(TestParcelable && other)135     TestParcelable(TestParcelable&& other) : TestParcelable(other.mValue) {}
136 
137     // Parcelable interface
writeToParcel(Parcel * parcel) const138     status_t writeToParcel(Parcel* parcel) const override { return parcel->writeInt32(mValue); }
readFromParcel(const Parcel * parcel)139     status_t readFromParcel(const Parcel* parcel) override { return parcel->readInt32(&mValue); }
140 
getValue() const141     int32_t getValue() const { return mValue; }
setValue(int32_t value)142     void setValue(int32_t value) { mValue = value; }
143 
144 private:
145     int32_t mValue = 0;
146 };
147 
148 class ExitOnDeath : public IBinder::DeathRecipient {
149 public:
150     ~ExitOnDeath() override = default;
151 
binderDied(const wp<IBinder> &)152     void binderDied(const wp<IBinder>& /*who*/) override {
153         ALOG(LOG_INFO, "ExitOnDeath", "Exiting");
154         exit(0);
155     }
156 };
157 
158 // This callback class is used to test both one-way transactions and that sp<IInterface> can be
159 // passed correctly
160 class ICallback : public IInterface {
161 public:
162     DECLARE_META_INTERFACE(Callback)
163 
164     enum class Tag : uint32_t {
165         OnCallback = IBinder::FIRST_CALL_TRANSACTION,
166         Last,
167     };
168 
169     virtual void onCallback(int32_t aPlusOne) = 0;
170 };
171 
172 class BpCallback : public SafeBpInterface<ICallback> {
173 public:
BpCallback(const sp<IBinder> & impl)174     explicit BpCallback(const sp<IBinder>& impl) : SafeBpInterface<ICallback>(impl, getLogTag()) {}
175 
onCallback(int32_t aPlusOne)176     void onCallback(int32_t aPlusOne) override {
177         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
178         return callRemoteAsync<decltype(&ICallback::onCallback)>(Tag::OnCallback, aPlusOne);
179     }
180 
181 private:
getLogTag()182     static constexpr const char* getLogTag() { return "BpCallback"; }
183 };
184 
185 #pragma clang diagnostic push
186 #pragma clang diagnostic ignored "-Wexit-time-destructors"
187 IMPLEMENT_META_INTERFACE(Callback, "android.gfx.tests.ICallback");
188 #pragma clang diagnostic pop
189 
190 class BnCallback : public SafeBnInterface<ICallback> {
191 public:
BnCallback()192     BnCallback() : SafeBnInterface("BnCallback") {}
193 
onTransact(uint32_t code,const Parcel & data,Parcel * reply,uint32_t)194     status_t onTransact(uint32_t code, const Parcel& data, Parcel* reply,
195                         uint32_t /*flags*/) override {
196         EXPECT_GE(code, IBinder::FIRST_CALL_TRANSACTION);
197         EXPECT_LT(code, static_cast<uint32_t>(ICallback::Tag::Last));
198         ICallback::Tag tag = static_cast<ICallback::Tag>(code);
199         switch (tag) {
200             case ICallback::Tag::OnCallback: {
201                 return callLocalAsync(data, reply, &ICallback::onCallback);
202             }
203             case ICallback::Tag::Last:
204                 // Should not be possible because of the asserts at the beginning of the method
205                 [&]() { FAIL(); }();
206                 return UNKNOWN_ERROR;
207         }
208     }
209 };
210 
211 class ISafeInterfaceTest : public IInterface {
212 public:
213     DECLARE_META_INTERFACE(SafeInterfaceTest)
214 
215     enum class Tag : uint32_t {
216         SetDeathToken = IBinder::FIRST_CALL_TRANSACTION,
217         ReturnsNoMemory,
218         LogicalNot,
219         ModifyEnum,
220         IncrementFlattenable,
221         IncrementLightFlattenable,
222         IncrementLightRefBaseFlattenable,
223         IncrementNativeHandle,
224         IncrementNoCopyNoMove,
225         IncrementParcelableVector,
226         ToUpper,
227         CallMeBack,
228         IncrementInt32,
229         IncrementUint32,
230         IncrementInt64,
231         IncrementUint64,
232         IncrementFloat,
233         IncrementTwo,
234         Last,
235     };
236 
237     // This is primarily so that the remote service dies when the test does, but it also serves to
238     // test the handling of sp<IBinder> and non-const methods
239     virtual status_t setDeathToken(const sp<IBinder>& token) = 0;
240 
241     // This is the most basic test since it doesn't require parceling any arguments
242     virtual status_t returnsNoMemory() const = 0;
243 
244     // These are ordered according to their corresponding methods in SafeInterface::ParcelHandler
245     virtual status_t logicalNot(bool a, bool* notA) const = 0;
246     virtual status_t modifyEnum(TestEnum a, TestEnum* b) const = 0;
247     virtual status_t increment(const TestFlattenable& a, TestFlattenable* aPlusOne) const = 0;
248     virtual status_t increment(const TestLightFlattenable& a,
249                                TestLightFlattenable* aPlusOne) const = 0;
250     virtual status_t increment(const sp<TestLightRefBaseFlattenable>& a,
251                                sp<TestLightRefBaseFlattenable>* aPlusOne) const = 0;
252     virtual status_t increment(const sp<NativeHandle>& a, sp<NativeHandle>* aPlusOne) const = 0;
253     virtual status_t increment(const NoCopyNoMove& a, NoCopyNoMove* aPlusOne) const = 0;
254     virtual status_t increment(const std::vector<TestParcelable>& a,
255                                std::vector<TestParcelable>* aPlusOne) const = 0;
256     virtual status_t toUpper(const String8& str, String8* upperStr) const = 0;
257     // As mentioned above, sp<IBinder> is already tested by setDeathToken
258     virtual void callMeBack(const sp<ICallback>& callback, int32_t a) const = 0;
259     virtual status_t increment(int32_t a, int32_t* aPlusOne) const = 0;
260     virtual status_t increment(uint32_t a, uint32_t* aPlusOne) const = 0;
261     virtual status_t increment(int64_t a, int64_t* aPlusOne) const = 0;
262     virtual status_t increment(uint64_t a, uint64_t* aPlusOne) const = 0;
263     virtual status_t increment(float a, float* aPlusOne) const = 0;
264 
265     // This tests that input/output parameter interleaving works correctly
266     virtual status_t increment(int32_t a, int32_t* aPlusOne, int32_t b,
267                                int32_t* bPlusOne) const = 0;
268 };
269 
270 class BpSafeInterfaceTest : public SafeBpInterface<ISafeInterfaceTest> {
271 public:
BpSafeInterfaceTest(const sp<IBinder> & impl)272     explicit BpSafeInterfaceTest(const sp<IBinder>& impl)
273           : SafeBpInterface<ISafeInterfaceTest>(impl, getLogTag()) {}
274 
setDeathToken(const sp<IBinder> & token)275     status_t setDeathToken(const sp<IBinder>& token) override {
276         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
277         return callRemote<decltype(&ISafeInterfaceTest::setDeathToken)>(Tag::SetDeathToken, token);
278     }
returnsNoMemory() const279     status_t returnsNoMemory() const override {
280         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
281         return callRemote<decltype(&ISafeInterfaceTest::returnsNoMemory)>(Tag::ReturnsNoMemory);
282     }
logicalNot(bool a,bool * notA) const283     status_t logicalNot(bool a, bool* notA) const override {
284         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
285         return callRemote<decltype(&ISafeInterfaceTest::logicalNot)>(Tag::LogicalNot, a, notA);
286     }
modifyEnum(TestEnum a,TestEnum * b) const287     status_t modifyEnum(TestEnum a, TestEnum* b) const override {
288         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
289         return callRemote<decltype(&ISafeInterfaceTest::modifyEnum)>(Tag::ModifyEnum, a, b);
290     }
increment(const TestFlattenable & a,TestFlattenable * aPlusOne) const291     status_t increment(const TestFlattenable& a, TestFlattenable* aPlusOne) const override {
292         using Signature =
293                 status_t (ISafeInterfaceTest::*)(const TestFlattenable&, TestFlattenable*) const;
294         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
295         return callRemote<Signature>(Tag::IncrementFlattenable, a, aPlusOne);
296     }
increment(const TestLightFlattenable & a,TestLightFlattenable * aPlusOne) const297     status_t increment(const TestLightFlattenable& a,
298                        TestLightFlattenable* aPlusOne) const override {
299         using Signature = status_t (ISafeInterfaceTest::*)(const TestLightFlattenable&,
300                                                            TestLightFlattenable*) const;
301         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
302         return callRemote<Signature>(Tag::IncrementLightFlattenable, a, aPlusOne);
303     }
increment(const sp<TestLightRefBaseFlattenable> & a,sp<TestLightRefBaseFlattenable> * aPlusOne) const304     status_t increment(const sp<TestLightRefBaseFlattenable>& a,
305                        sp<TestLightRefBaseFlattenable>* aPlusOne) const override {
306         using Signature = status_t (ISafeInterfaceTest::*)(const sp<TestLightRefBaseFlattenable>&,
307                                                            sp<TestLightRefBaseFlattenable>*) const;
308         return callRemote<Signature>(Tag::IncrementLightRefBaseFlattenable, a, aPlusOne);
309     }
increment(const sp<NativeHandle> & a,sp<NativeHandle> * aPlusOne) const310     status_t increment(const sp<NativeHandle>& a, sp<NativeHandle>* aPlusOne) const override {
311         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
312         using Signature =
313                 status_t (ISafeInterfaceTest::*)(const sp<NativeHandle>&, sp<NativeHandle>*) const;
314         return callRemote<Signature>(Tag::IncrementNativeHandle, a, aPlusOne);
315     }
increment(const NoCopyNoMove & a,NoCopyNoMove * aPlusOne) const316     status_t increment(const NoCopyNoMove& a, NoCopyNoMove* aPlusOne) const override {
317         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
318         using Signature = status_t (ISafeInterfaceTest::*)(const NoCopyNoMove& a,
319                                                            NoCopyNoMove* aPlusOne) const;
320         return callRemote<Signature>(Tag::IncrementNoCopyNoMove, a, aPlusOne);
321     }
increment(const std::vector<TestParcelable> & a,std::vector<TestParcelable> * aPlusOne) const322     status_t increment(const std::vector<TestParcelable>& a,
323                        std::vector<TestParcelable>* aPlusOne) const override {
324         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
325         using Signature = status_t (ISafeInterfaceTest::*)(const std::vector<TestParcelable>&,
326                                                            std::vector<TestParcelable>*);
327         return callRemote<Signature>(Tag::IncrementParcelableVector, a, aPlusOne);
328     }
toUpper(const String8 & str,String8 * upperStr) const329     status_t toUpper(const String8& str, String8* upperStr) const override {
330         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
331         return callRemote<decltype(&ISafeInterfaceTest::toUpper)>(Tag::ToUpper, str, upperStr);
332     }
callMeBack(const sp<ICallback> & callback,int32_t a) const333     void callMeBack(const sp<ICallback>& callback, int32_t a) const override {
334         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
335         return callRemoteAsync<decltype(&ISafeInterfaceTest::callMeBack)>(Tag::CallMeBack, callback,
336                                                                           a);
337     }
increment(int32_t a,int32_t * aPlusOne) const338     status_t increment(int32_t a, int32_t* aPlusOne) const override {
339         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
340         using Signature = status_t (ISafeInterfaceTest::*)(int32_t, int32_t*) const;
341         return callRemote<Signature>(Tag::IncrementInt32, a, aPlusOne);
342     }
increment(uint32_t a,uint32_t * aPlusOne) const343     status_t increment(uint32_t a, uint32_t* aPlusOne) const override {
344         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
345         using Signature = status_t (ISafeInterfaceTest::*)(uint32_t, uint32_t*) const;
346         return callRemote<Signature>(Tag::IncrementUint32, a, aPlusOne);
347     }
increment(int64_t a,int64_t * aPlusOne) const348     status_t increment(int64_t a, int64_t* aPlusOne) const override {
349         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
350         using Signature = status_t (ISafeInterfaceTest::*)(int64_t, int64_t*) const;
351         return callRemote<Signature>(Tag::IncrementInt64, a, aPlusOne);
352     }
increment(uint64_t a,uint64_t * aPlusOne) const353     status_t increment(uint64_t a, uint64_t* aPlusOne) const override {
354         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
355         using Signature = status_t (ISafeInterfaceTest::*)(uint64_t, uint64_t*) const;
356         return callRemote<Signature>(Tag::IncrementUint64, a, aPlusOne);
357     }
increment(float a,float * aPlusOne) const358     status_t increment(float a, float* aPlusOne) const override {
359         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
360         using Signature = status_t (ISafeInterfaceTest::*)(float, float*) const;
361         return callRemote<Signature>(Tag::IncrementFloat, a, aPlusOne);
362     }
increment(int32_t a,int32_t * aPlusOne,int32_t b,int32_t * bPlusOne) const363     status_t increment(int32_t a, int32_t* aPlusOne, int32_t b, int32_t* bPlusOne) const override {
364         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
365         using Signature =
366                 status_t (ISafeInterfaceTest::*)(int32_t, int32_t*, int32_t, int32_t*) const;
367         return callRemote<Signature>(Tag::IncrementTwo, a, aPlusOne, b, bPlusOne);
368     }
369 
370 private:
getLogTag()371     static constexpr const char* getLogTag() { return "BpSafeInterfaceTest"; }
372 };
373 
374 #pragma clang diagnostic push
375 #pragma clang diagnostic ignored "-Wexit-time-destructors"
376 IMPLEMENT_META_INTERFACE(SafeInterfaceTest, "android.gfx.tests.ISafeInterfaceTest");
377 
getDeathRecipient()378 static sp<IBinder::DeathRecipient> getDeathRecipient() {
379     static sp<IBinder::DeathRecipient> recipient = new ExitOnDeath;
380     return recipient;
381 }
382 #pragma clang diagnostic pop
383 
384 class BnSafeInterfaceTest : public SafeBnInterface<ISafeInterfaceTest> {
385 public:
BnSafeInterfaceTest()386     BnSafeInterfaceTest() : SafeBnInterface(getLogTag()) {}
387 
setDeathToken(const sp<IBinder> & token)388     status_t setDeathToken(const sp<IBinder>& token) override {
389         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
390         token->linkToDeath(getDeathRecipient());
391         return NO_ERROR;
392     }
returnsNoMemory() const393     status_t returnsNoMemory() const override {
394         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
395         return NO_MEMORY;
396     }
logicalNot(bool a,bool * notA) const397     status_t logicalNot(bool a, bool* notA) const override {
398         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
399         *notA = !a;
400         return NO_ERROR;
401     }
modifyEnum(TestEnum a,TestEnum * b) const402     status_t modifyEnum(TestEnum a, TestEnum* b) const override {
403         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
404         *b = (a == TestEnum::INITIAL) ? TestEnum::FINAL : TestEnum::INVALID;
405         return NO_ERROR;
406     }
increment(const TestFlattenable & a,TestFlattenable * aPlusOne) const407     status_t increment(const TestFlattenable& a, TestFlattenable* aPlusOne) const override {
408         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
409         aPlusOne->value = a.value + 1;
410         return NO_ERROR;
411     }
increment(const TestLightFlattenable & a,TestLightFlattenable * aPlusOne) const412     status_t increment(const TestLightFlattenable& a,
413                        TestLightFlattenable* aPlusOne) const override {
414         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
415         aPlusOne->value = a.value + 1;
416         return NO_ERROR;
417     }
increment(const sp<TestLightRefBaseFlattenable> & a,sp<TestLightRefBaseFlattenable> * aPlusOne) const418     status_t increment(const sp<TestLightRefBaseFlattenable>& a,
419                        sp<TestLightRefBaseFlattenable>* aPlusOne) const override {
420         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
421         *aPlusOne = new TestLightRefBaseFlattenable(a->value + 1);
422         return NO_ERROR;
423     }
increment(const sp<NativeHandle> & a,sp<NativeHandle> * aPlusOne) const424     status_t increment(const sp<NativeHandle>& a, sp<NativeHandle>* aPlusOne) const override {
425         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
426         native_handle* rawHandle = native_handle_create(1 /*numFds*/, 1 /*numInts*/);
427         if (rawHandle == nullptr) return NO_MEMORY;
428 
429         // Copy the fd over directly
430         rawHandle->data[0] = dup(a->handle()->data[0]);
431 
432         // Increment the int
433         rawHandle->data[1] = a->handle()->data[1] + 1;
434 
435         // This cannot fail, as it is just the sp<NativeHandle> taking responsibility for closing
436         // the native_handle when it goes out of scope
437         *aPlusOne = NativeHandle::create(rawHandle, true);
438         return NO_ERROR;
439     }
increment(const NoCopyNoMove & a,NoCopyNoMove * aPlusOne) const440     status_t increment(const NoCopyNoMove& a, NoCopyNoMove* aPlusOne) const override {
441         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
442         aPlusOne->setValue(a.getValue() + 1);
443         return NO_ERROR;
444     }
increment(const std::vector<TestParcelable> & a,std::vector<TestParcelable> * aPlusOne) const445     status_t increment(const std::vector<TestParcelable>& a,
446                        std::vector<TestParcelable>* aPlusOne) const override {
447         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
448         aPlusOne->resize(a.size());
449         for (size_t i = 0; i < a.size(); ++i) {
450             (*aPlusOne)[i].setValue(a[i].getValue() + 1);
451         }
452         return NO_ERROR;
453     }
toUpper(const String8 & str,String8 * upperStr) const454     status_t toUpper(const String8& str, String8* upperStr) const override {
455         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
456         *upperStr = str;
457         upperStr->toUpper();
458         return NO_ERROR;
459     }
callMeBack(const sp<ICallback> & callback,int32_t a) const460     void callMeBack(const sp<ICallback>& callback, int32_t a) const override {
461         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
462         callback->onCallback(a + 1);
463     }
increment(int32_t a,int32_t * aPlusOne) const464     status_t increment(int32_t a, int32_t* aPlusOne) const override {
465         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
466         *aPlusOne = a + 1;
467         return NO_ERROR;
468     }
increment(uint32_t a,uint32_t * aPlusOne) const469     status_t increment(uint32_t a, uint32_t* aPlusOne) const override {
470         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
471         *aPlusOne = a + 1;
472         return NO_ERROR;
473     }
increment(int64_t a,int64_t * aPlusOne) const474     status_t increment(int64_t a, int64_t* aPlusOne) const override {
475         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
476         *aPlusOne = a + 1;
477         return NO_ERROR;
478     }
increment(uint64_t a,uint64_t * aPlusOne) const479     status_t increment(uint64_t a, uint64_t* aPlusOne) const override {
480         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
481         *aPlusOne = a + 1;
482         return NO_ERROR;
483     }
increment(float a,float * aPlusOne) const484     status_t increment(float a, float* aPlusOne) const override {
485         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
486         *aPlusOne = a + 1.0f;
487         return NO_ERROR;
488     }
increment(int32_t a,int32_t * aPlusOne,int32_t b,int32_t * bPlusOne) const489     status_t increment(int32_t a, int32_t* aPlusOne, int32_t b, int32_t* bPlusOne) const override {
490         ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
491         *aPlusOne = a + 1;
492         *bPlusOne = b + 1;
493         return NO_ERROR;
494     }
495 
496     // BnInterface
onTransact(uint32_t code,const Parcel & data,Parcel * reply,uint32_t)497     status_t onTransact(uint32_t code, const Parcel& data, Parcel* reply,
498                         uint32_t /*flags*/) override {
499         EXPECT_GE(code, IBinder::FIRST_CALL_TRANSACTION);
500         EXPECT_LT(code, static_cast<uint32_t>(Tag::Last));
501         ISafeInterfaceTest::Tag tag = static_cast<ISafeInterfaceTest::Tag>(code);
502         switch (tag) {
503             case ISafeInterfaceTest::Tag::SetDeathToken: {
504                 return callLocal(data, reply, &ISafeInterfaceTest::setDeathToken);
505             }
506             case ISafeInterfaceTest::Tag::ReturnsNoMemory: {
507                 return callLocal(data, reply, &ISafeInterfaceTest::returnsNoMemory);
508             }
509             case ISafeInterfaceTest::Tag::LogicalNot: {
510                 return callLocal(data, reply, &ISafeInterfaceTest::logicalNot);
511             }
512             case ISafeInterfaceTest::Tag::ModifyEnum: {
513                 return callLocal(data, reply, &ISafeInterfaceTest::modifyEnum);
514             }
515             case ISafeInterfaceTest::Tag::IncrementFlattenable: {
516                 using Signature = status_t (ISafeInterfaceTest::*)(const TestFlattenable& a,
517                                                                    TestFlattenable* aPlusOne) const;
518                 return callLocal<Signature>(data, reply, &ISafeInterfaceTest::increment);
519             }
520             case ISafeInterfaceTest::Tag::IncrementLightFlattenable: {
521                 using Signature =
522                         status_t (ISafeInterfaceTest::*)(const TestLightFlattenable& a,
523                                                          TestLightFlattenable* aPlusOne) const;
524                 return callLocal<Signature>(data, reply, &ISafeInterfaceTest::increment);
525             }
526             case ISafeInterfaceTest::Tag::IncrementLightRefBaseFlattenable: {
527                 using Signature =
528                         status_t (ISafeInterfaceTest::*)(const sp<TestLightRefBaseFlattenable>&,
529                                                          sp<TestLightRefBaseFlattenable>*) const;
530                 return callLocal<Signature>(data, reply, &ISafeInterfaceTest::increment);
531             }
532             case ISafeInterfaceTest::Tag::IncrementNativeHandle: {
533                 using Signature = status_t (ISafeInterfaceTest::*)(const sp<NativeHandle>&,
534                                                                    sp<NativeHandle>*) const;
535                 return callLocal<Signature>(data, reply, &ISafeInterfaceTest::increment);
536             }
537             case ISafeInterfaceTest::Tag::IncrementNoCopyNoMove: {
538                 using Signature = status_t (ISafeInterfaceTest::*)(const NoCopyNoMove& a,
539                                                                    NoCopyNoMove* aPlusOne) const;
540                 return callLocal<Signature>(data, reply, &ISafeInterfaceTest::increment);
541             }
542             case ISafeInterfaceTest::Tag::IncrementParcelableVector: {
543                 using Signature =
544                         status_t (ISafeInterfaceTest::*)(const std::vector<TestParcelable>&,
545                                                          std::vector<TestParcelable>*) const;
546                 return callLocal<Signature>(data, reply, &ISafeInterfaceTest::increment);
547             }
548             case ISafeInterfaceTest::Tag::ToUpper: {
549                 return callLocal(data, reply, &ISafeInterfaceTest::toUpper);
550             }
551             case ISafeInterfaceTest::Tag::CallMeBack: {
552                 return callLocalAsync(data, reply, &ISafeInterfaceTest::callMeBack);
553             }
554             case ISafeInterfaceTest::Tag::IncrementInt32: {
555                 using Signature = status_t (ISafeInterfaceTest::*)(int32_t, int32_t*) const;
556                 return callLocal<Signature>(data, reply, &ISafeInterfaceTest::increment);
557             }
558             case ISafeInterfaceTest::Tag::IncrementUint32: {
559                 using Signature = status_t (ISafeInterfaceTest::*)(uint32_t, uint32_t*) const;
560                 return callLocal<Signature>(data, reply, &ISafeInterfaceTest::increment);
561             }
562             case ISafeInterfaceTest::Tag::IncrementInt64: {
563                 using Signature = status_t (ISafeInterfaceTest::*)(int64_t, int64_t*) const;
564                 return callLocal<Signature>(data, reply, &ISafeInterfaceTest::increment);
565             }
566             case ISafeInterfaceTest::Tag::IncrementUint64: {
567                 using Signature = status_t (ISafeInterfaceTest::*)(uint64_t, uint64_t*) const;
568                 return callLocal<Signature>(data, reply, &ISafeInterfaceTest::increment);
569             }
570             case ISafeInterfaceTest::Tag::IncrementFloat: {
571                 using Signature = status_t (ISafeInterfaceTest::*)(float, float*) const;
572                 return callLocal<Signature>(data, reply, &ISafeInterfaceTest::increment);
573             }
574             case ISafeInterfaceTest::Tag::IncrementTwo: {
575                 using Signature = status_t (ISafeInterfaceTest::*)(int32_t, int32_t*, int32_t,
576                                                                    int32_t*) const;
577                 return callLocal<Signature>(data, reply, &ISafeInterfaceTest::increment);
578             }
579             case ISafeInterfaceTest::Tag::Last:
580                 // Should not be possible because of the asserts at the beginning of the method
581                 [&]() { FAIL(); }();
582                 return UNKNOWN_ERROR;
583         }
584     }
585 
586 private:
getLogTag()587     static constexpr const char* getLogTag() { return "BnSafeInterfaceTest"; }
588 };
589 
590 class SafeInterfaceTest : public ::testing::Test {
591 public:
SafeInterfaceTest()592     SafeInterfaceTest() : mSafeInterfaceTest(getRemoteService()) {
593         ProcessState::self()->startThreadPool();
594     }
595     ~SafeInterfaceTest() override = default;
596 
597 protected:
598     sp<ISafeInterfaceTest> mSafeInterfaceTest;
599 
600 private:
getLogTag()601     static constexpr const char* getLogTag() { return "SafeInterfaceTest"; }
602 
getRemoteService()603     sp<ISafeInterfaceTest> getRemoteService() {
604 #pragma clang diagnostic push
605 #pragma clang diagnostic ignored "-Wexit-time-destructors"
606         static std::mutex sMutex;
607         static sp<ISafeInterfaceTest> sService;
608         static sp<IBinder> sDeathToken = new BBinder;
609 #pragma clang diagnostic pop
610 
611         std::unique_lock<decltype(sMutex)> lock;
612         if (sService == nullptr) {
613             ALOG(LOG_INFO, getLogTag(), "Forking remote process");
614             pid_t forkPid = fork();
615             EXPECT_NE(forkPid, -1);
616 
617             const String16 serviceName("SafeInterfaceTest");
618 
619             if (forkPid == 0) {
620                 ALOG(LOG_INFO, getLogTag(), "Remote process checking in");
621                 sp<ISafeInterfaceTest> nativeService = new BnSafeInterfaceTest;
622                 defaultServiceManager()->addService(serviceName,
623                                                     IInterface::asBinder(nativeService));
624                 ProcessState::self()->startThreadPool();
625                 IPCThreadState::self()->joinThreadPool();
626                 // We shouldn't get to this point
627                 [&]() { FAIL(); }();
628             }
629 
630             sp<IBinder> binder = defaultServiceManager()->getService(serviceName);
631             sService = interface_cast<ISafeInterfaceTest>(binder);
632             EXPECT_TRUE(sService != nullptr);
633 
634             sService->setDeathToken(sDeathToken);
635         }
636 
637         return sService;
638     }
639 };
640 
TEST_F(SafeInterfaceTest,TestReturnsNoMemory)641 TEST_F(SafeInterfaceTest, TestReturnsNoMemory) {
642     status_t result = mSafeInterfaceTest->returnsNoMemory();
643     ASSERT_EQ(NO_MEMORY, result);
644 }
645 
TEST_F(SafeInterfaceTest,TestLogicalNot)646 TEST_F(SafeInterfaceTest, TestLogicalNot) {
647     const bool a = true;
648     bool notA = true;
649     status_t result = mSafeInterfaceTest->logicalNot(a, &notA);
650     ASSERT_EQ(NO_ERROR, result);
651     ASSERT_EQ(!a, notA);
652     // Test both since we don't want to accidentally catch a default false somewhere
653     const bool b = false;
654     bool notB = false;
655     result = mSafeInterfaceTest->logicalNot(b, &notB);
656     ASSERT_EQ(NO_ERROR, result);
657     ASSERT_EQ(!b, notB);
658 }
659 
TEST_F(SafeInterfaceTest,TestModifyEnum)660 TEST_F(SafeInterfaceTest, TestModifyEnum) {
661     const TestEnum a = TestEnum::INITIAL;
662     TestEnum b = TestEnum::INVALID;
663     status_t result = mSafeInterfaceTest->modifyEnum(a, &b);
664     ASSERT_EQ(NO_ERROR, result);
665     ASSERT_EQ(TestEnum::FINAL, b);
666 }
667 
TEST_F(SafeInterfaceTest,TestIncrementFlattenable)668 TEST_F(SafeInterfaceTest, TestIncrementFlattenable) {
669     const TestFlattenable a{1};
670     TestFlattenable aPlusOne{0};
671     status_t result = mSafeInterfaceTest->increment(a, &aPlusOne);
672     ASSERT_EQ(NO_ERROR, result);
673     ASSERT_EQ(a.value + 1, aPlusOne.value);
674 }
675 
TEST_F(SafeInterfaceTest,TestIncrementLightFlattenable)676 TEST_F(SafeInterfaceTest, TestIncrementLightFlattenable) {
677     const TestLightFlattenable a{1};
678     TestLightFlattenable aPlusOne{0};
679     status_t result = mSafeInterfaceTest->increment(a, &aPlusOne);
680     ASSERT_EQ(NO_ERROR, result);
681     ASSERT_EQ(a.value + 1, aPlusOne.value);
682 }
683 
TEST_F(SafeInterfaceTest,TestIncrementLightRefBaseFlattenable)684 TEST_F(SafeInterfaceTest, TestIncrementLightRefBaseFlattenable) {
685     sp<TestLightRefBaseFlattenable> a = new TestLightRefBaseFlattenable{1};
686     sp<TestLightRefBaseFlattenable> aPlusOne;
687     status_t result = mSafeInterfaceTest->increment(a, &aPlusOne);
688     ASSERT_EQ(NO_ERROR, result);
689     ASSERT_NE(nullptr, aPlusOne.get());
690     ASSERT_EQ(a->value + 1, aPlusOne->value);
691 }
692 
693 namespace { // Anonymous namespace
694 
fdsAreEquivalent(int a,int b)695 bool fdsAreEquivalent(int a, int b) {
696     struct stat statA {};
697     struct stat statB {};
698     if (fstat(a, &statA) != 0) return false;
699     if (fstat(b, &statB) != 0) return false;
700     return (statA.st_dev == statB.st_dev) && (statA.st_ino == statB.st_ino);
701 }
702 
703 } // Anonymous namespace
704 
TEST_F(SafeInterfaceTest,TestIncrementNativeHandle)705 TEST_F(SafeInterfaceTest, TestIncrementNativeHandle) {
706     // Create an fd we can use to send and receive from the remote process
707     base::unique_fd eventFd{eventfd(0 /*initval*/, 0 /*flags*/)};
708     ASSERT_NE(-1, eventFd);
709 
710     // Determine the maximum number of fds this process can have open
711     struct rlimit limit {};
712     ASSERT_EQ(0, getrlimit(RLIMIT_NOFILE, &limit));
713     uint32_t maxFds = static_cast<uint32_t>(limit.rlim_cur);
714 
715     // Perform this test enough times to rule out fd leaks
716     for (uint32_t iter = 0; iter < (2 * maxFds); ++iter) {
717         native_handle* handle = native_handle_create(1 /*numFds*/, 1 /*numInts*/);
718         ASSERT_NE(nullptr, handle);
719         handle->data[0] = dup(eventFd.get());
720         handle->data[1] = 1;
721 
722         // This cannot fail, as it is just the sp<NativeHandle> taking responsibility for closing
723         // the native_handle when it goes out of scope
724         sp<NativeHandle> a = NativeHandle::create(handle, true);
725 
726         sp<NativeHandle> aPlusOne;
727         status_t result = mSafeInterfaceTest->increment(a, &aPlusOne);
728         ASSERT_EQ(NO_ERROR, result);
729         ASSERT_TRUE(fdsAreEquivalent(a->handle()->data[0], aPlusOne->handle()->data[0]));
730         ASSERT_EQ(a->handle()->data[1] + 1, aPlusOne->handle()->data[1]);
731     }
732 }
733 
TEST_F(SafeInterfaceTest,TestIncrementNoCopyNoMove)734 TEST_F(SafeInterfaceTest, TestIncrementNoCopyNoMove) {
735     const NoCopyNoMove a{1};
736     NoCopyNoMove aPlusOne{0};
737     status_t result = mSafeInterfaceTest->increment(a, &aPlusOne);
738     ASSERT_EQ(NO_ERROR, result);
739     ASSERT_EQ(a.getValue() + 1, aPlusOne.getValue());
740 }
741 
TEST_F(SafeInterfaceTest,TestIncremementParcelableVector)742 TEST_F(SafeInterfaceTest, TestIncremementParcelableVector) {
743     const std::vector<TestParcelable> a{TestParcelable{1}, TestParcelable{2}};
744     std::vector<TestParcelable> aPlusOne;
745     status_t result = mSafeInterfaceTest->increment(a, &aPlusOne);
746     ASSERT_EQ(NO_ERROR, result);
747     ASSERT_EQ(a.size(), aPlusOne.size());
748     for (size_t i = 0; i < a.size(); ++i) {
749         ASSERT_EQ(a[i].getValue() + 1, aPlusOne[i].getValue());
750     }
751 }
752 
TEST_F(SafeInterfaceTest,TestToUpper)753 TEST_F(SafeInterfaceTest, TestToUpper) {
754     const String8 str{"Hello, world!"};
755     String8 upperStr;
756     status_t result = mSafeInterfaceTest->toUpper(str, &upperStr);
757     ASSERT_EQ(NO_ERROR, result);
758     ASSERT_TRUE(upperStr == String8{"HELLO, WORLD!"});
759 }
760 
TEST_F(SafeInterfaceTest,TestCallMeBack)761 TEST_F(SafeInterfaceTest, TestCallMeBack) {
762     class CallbackReceiver : public BnCallback {
763     public:
764         void onCallback(int32_t aPlusOne) override {
765             ALOG(LOG_INFO, "CallbackReceiver", "%s", __PRETTY_FUNCTION__);
766             std::unique_lock<decltype(mMutex)> lock(mMutex);
767             mValue = aPlusOne;
768             mCondition.notify_one();
769         }
770 
771         std::optional<int32_t> waitForCallback() {
772             std::unique_lock<decltype(mMutex)> lock(mMutex);
773             bool success =
774                     mCondition.wait_for(lock, 100ms, [&]() { return static_cast<bool>(mValue); });
775             return success ? mValue : std::nullopt;
776         }
777 
778     private:
779         std::mutex mMutex;
780         std::condition_variable mCondition;
781         std::optional<int32_t> mValue;
782     };
783 
784     sp<CallbackReceiver> receiver = new CallbackReceiver;
785     const int32_t a = 1;
786     mSafeInterfaceTest->callMeBack(receiver, a);
787     auto result = receiver->waitForCallback();
788     ASSERT_TRUE(result);
789     ASSERT_EQ(a + 1, *result);
790 }
791 
TEST_F(SafeInterfaceTest,TestIncrementInt32)792 TEST_F(SafeInterfaceTest, TestIncrementInt32) {
793     const int32_t a = 1;
794     int32_t aPlusOne = 0;
795     status_t result = mSafeInterfaceTest->increment(a, &aPlusOne);
796     ASSERT_EQ(NO_ERROR, result);
797     ASSERT_EQ(a + 1, aPlusOne);
798 }
799 
TEST_F(SafeInterfaceTest,TestIncrementUint32)800 TEST_F(SafeInterfaceTest, TestIncrementUint32) {
801     const uint32_t a = 1;
802     uint32_t aPlusOne = 0;
803     status_t result = mSafeInterfaceTest->increment(a, &aPlusOne);
804     ASSERT_EQ(NO_ERROR, result);
805     ASSERT_EQ(a + 1, aPlusOne);
806 }
807 
TEST_F(SafeInterfaceTest,TestIncrementInt64)808 TEST_F(SafeInterfaceTest, TestIncrementInt64) {
809     const int64_t a = 1;
810     int64_t aPlusOne = 0;
811     status_t result = mSafeInterfaceTest->increment(a, &aPlusOne);
812     ASSERT_EQ(NO_ERROR, result);
813     ASSERT_EQ(a + 1, aPlusOne);
814 }
815 
TEST_F(SafeInterfaceTest,TestIncrementUint64)816 TEST_F(SafeInterfaceTest, TestIncrementUint64) {
817     const uint64_t a = 1;
818     uint64_t aPlusOne = 0;
819     status_t result = mSafeInterfaceTest->increment(a, &aPlusOne);
820     ASSERT_EQ(NO_ERROR, result);
821     ASSERT_EQ(a + 1, aPlusOne);
822 }
823 
TEST_F(SafeInterfaceTest,TestIncrementFloat)824 TEST_F(SafeInterfaceTest, TestIncrementFloat) {
825     const float a = 1.0f;
826     float aPlusOne = 0.0f;
827     status_t result = mSafeInterfaceTest->increment(a, &aPlusOne);
828     ASSERT_EQ(NO_ERROR, result);
829     ASSERT_EQ(a + 1.0f, aPlusOne);
830 }
831 
TEST_F(SafeInterfaceTest,TestIncrementTwo)832 TEST_F(SafeInterfaceTest, TestIncrementTwo) {
833     const int32_t a = 1;
834     int32_t aPlusOne = 0;
835     const int32_t b = 2;
836     int32_t bPlusOne = 0;
837     status_t result = mSafeInterfaceTest->increment(1, &aPlusOne, 2, &bPlusOne);
838     ASSERT_EQ(NO_ERROR, result);
839     ASSERT_EQ(a + 1, aPlusOne);
840     ASSERT_EQ(b + 1, bPlusOne);
841 }
842 
843 } // namespace tests
844 } // namespace android
845