1 #include <uds/client_channel.h>
2 
3 #include <sys/socket.h>
4 
5 #include <algorithm>
6 #include <limits>
7 #include <random>
8 #include <thread>
9 
10 #include <gmock/gmock.h>
11 #include <gtest/gtest.h>
12 
13 #include <pdx/client.h>
14 #include <pdx/rpc/remote_method.h>
15 #include <pdx/service.h>
16 #include <pdx/service_dispatcher.h>
17 
18 #include <uds/client_channel_factory.h>
19 #include <uds/service_endpoint.h>
20 
21 using testing::Return;
22 using testing::_;
23 
24 using android::pdx::ClientBase;
25 using android::pdx::LocalChannelHandle;
26 using android::pdx::LocalHandle;
27 using android::pdx::Message;
28 using android::pdx::ServiceBase;
29 using android::pdx::ServiceDispatcher;
30 using android::pdx::Status;
31 using android::pdx::rpc::DispatchRemoteMethod;
32 using android::pdx::uds::ClientChannel;
33 using android::pdx::uds::ClientChannelFactory;
34 using android::pdx::uds::Endpoint;
35 
36 namespace {
37 
38 struct TestProtocol {
39   using DataType = int8_t;
40   enum {
41     kOpSum = 0,
42   };
43   PDX_REMOTE_METHOD(Sum, kOpSum, int64_t(const std::vector<DataType>&));
44 };
45 
46 class TestService : public ServiceBase<TestService> {
47  public:
TestService(std::unique_ptr<Endpoint> endpoint)48   explicit TestService(std::unique_ptr<Endpoint> endpoint)
49       : ServiceBase{"TestService", std::move(endpoint)} {}
50 
HandleMessage(Message & message)51   Status<void> HandleMessage(Message& message) override {
52     switch (message.GetOp()) {
53       case TestProtocol::kOpSum:
54         DispatchRemoteMethod<TestProtocol::Sum>(*this, &TestService::OnSum,
55                                                 message);
56         return {};
57 
58       default:
59         return Service::HandleMessage(message);
60     }
61   }
62 
OnSum(Message &,const std::vector<TestProtocol::DataType> & data)63   int64_t OnSum(Message& /*message*/,
64                 const std::vector<TestProtocol::DataType>& data) {
65     return std::accumulate(data.begin(), data.end(), int64_t{0});
66   }
67 };
68 
69 class TestClient : public ClientBase<TestClient> {
70  public:
71   using ClientBase::ClientBase;
72 
Sum(const std::vector<TestProtocol::DataType> & data)73   int64_t Sum(const std::vector<TestProtocol::DataType>& data) {
74     auto status = InvokeRemoteMethod<TestProtocol::Sum>(data);
75     return status ? status.get() : -1;
76   }
77 };
78 
79 class TestServiceRunner {
80  public:
TestServiceRunner(LocalHandle channel_socket)81   explicit TestServiceRunner(LocalHandle channel_socket) {
82     auto endpoint = Endpoint::CreateFromSocketFd(LocalHandle{});
83     endpoint->RegisterNewChannelForTests(std::move(channel_socket));
84     service_ = TestService::Create(std::move(endpoint));
85     dispatcher_ = ServiceDispatcher::Create();
86     dispatcher_->AddService(service_);
87     dispatch_thread_ = std::thread(
88         std::bind(&ServiceDispatcher::EnterDispatchLoop, dispatcher_.get()));
89   }
90 
~TestServiceRunner()91   ~TestServiceRunner() {
92     dispatcher_->SetCanceled(true);
93     dispatch_thread_.join();
94     dispatcher_->RemoveService(service_);
95   }
96 
97  private:
98   std::shared_ptr<TestService> service_;
99   std::unique_ptr<ServiceDispatcher> dispatcher_;
100   std::thread dispatch_thread_;
101 };
102 
103 class ClientChannelTest : public testing::Test {
104  public:
SetUp()105   void SetUp() override {
106     int channel_sockets[2] = {};
107     ASSERT_EQ(
108         0, socketpair(AF_UNIX, SOCK_STREAM | SOCK_CLOEXEC, 0, channel_sockets));
109     LocalHandle service_channel{channel_sockets[0]};
110     LocalHandle client_channel{channel_sockets[1]};
111 
112     service_runner_.reset(new TestServiceRunner{std::move(service_channel)});
113     auto factory = ClientChannelFactory::Create(std::move(client_channel));
114     auto status = factory->Connect(android::pdx::Client::kInfiniteTimeout);
115     ASSERT_TRUE(status);
116     client_ = TestClient::Create(status.take());
117   }
118 
TearDown()119   void TearDown() override {
120     service_runner_.reset();
121     client_.reset();
122   }
123 
124  protected:
125   std::unique_ptr<TestServiceRunner> service_runner_;
126   std::shared_ptr<TestClient> client_;
127 };
128 
TEST_F(ClientChannelTest,MultithreadedClient)129 TEST_F(ClientChannelTest, MultithreadedClient) {
130   constexpr int kNumTestThreads = 8;
131   constexpr size_t kDataSize = 1000;  // Try to keep RPC buffer size below 4K.
132 
133   std::random_device rd;
134   std::mt19937 gen{rd()};
135   std::uniform_int_distribution<TestProtocol::DataType> dist{
136       std::numeric_limits<TestProtocol::DataType>::min(),
137       std::numeric_limits<TestProtocol::DataType>::max()};
138 
139   auto worker = [](std::shared_ptr<TestClient> client,
140                    std::vector<TestProtocol::DataType> data) {
141     constexpr int kMaxIterations = 500;
142     int64_t expected = std::accumulate(data.begin(), data.end(), int64_t{0});
143     for (int i = 0; i < kMaxIterations; i++) {
144       ASSERT_EQ(expected, client->Sum(data));
145     }
146   };
147 
148   // Start client threads.
149   std::vector<TestProtocol::DataType> data;
150   data.resize(kDataSize);
151   std::vector<std::thread> threads;
152   for (int i = 0; i < kNumTestThreads; i++) {
153     std::generate(data.begin(), data.end(),
154                   [&dist, &gen]() { return dist(gen); });
155     threads.emplace_back(worker, client_, data);
156   }
157 
158   // Wait for threads to finish.
159   for (auto& thread : threads)
160     thread.join();
161 }
162 
163 }  // namespace
164