1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "dns_responder_client"
18 
19 #include "dns_responder_client_ndk.h"
20 
21 #include <android-base/logging.h>
22 #include <android-base/stringprintf.h>
23 
24 #include <android/binder_manager.h>
25 #include "NetdClient.h"
26 
27 // TODO: make this dynamic and stop depending on implementation details.
28 #define TEST_OEM_NETWORK "oem29"
29 #define TEST_NETID 30
30 
31 // TODO: move this somewhere shared.
32 static const char* ANDROID_DNS_MODE = "ANDROID_DNS_MODE";
33 
34 using aidl::android::net::IDnsResolver;
35 using aidl::android::net::INetd;
36 using aidl::android::net::ResolverParamsParcel;
37 using android::base::StringPrintf;
38 using android::net::ResolverStats;
39 
SetupMappings(unsigned numHosts,const std::vector<std::string> & domains,std::vector<Mapping> * mappings)40 void DnsResponderClient::SetupMappings(unsigned numHosts, const std::vector<std::string>& domains,
41                                        std::vector<Mapping>* mappings) {
42     mappings->resize(numHosts * domains.size());
43     auto mappingsIt = mappings->begin();
44     for (unsigned i = 0; i < numHosts; ++i) {
45         for (const auto& domain : domains) {
46             mappingsIt->host = StringPrintf("host%u", i);
47             mappingsIt->entry = StringPrintf("%s.%s.", mappingsIt->host.c_str(), domain.c_str());
48             mappingsIt->ip4 = StringPrintf("192.0.2.%u", i % 253 + 1);
49             mappingsIt->ip6 = StringPrintf("2001:db8::%x", i % 65534 + 1);
50             ++mappingsIt;
51         }
52     }
53 }
54 
55 // TODO: Use SetResolverConfiguration() with ResolverParamsParcel struct directly.
56 // DEPRECATED: Use SetResolverConfiguration() in new code
makeResolverParamsParcel(int netId,const std::vector<int> & params,const std::vector<std::string> & servers,const std::vector<std::string> & domains,const std::string & tlsHostname,const std::vector<std::string> & tlsServers,const std::string & caCert)57 ResolverParamsParcel DnsResponderClient::makeResolverParamsParcel(
58         int netId, const std::vector<int>& params, const std::vector<std::string>& servers,
59         const std::vector<std::string>& domains, const std::string& tlsHostname,
60         const std::vector<std::string>& tlsServers, const std::string& caCert) {
61     ResolverParamsParcel paramsParcel;
62 
63     paramsParcel.netId = netId;
64     paramsParcel.sampleValiditySeconds = params[IDnsResolver::RESOLVER_PARAMS_SAMPLE_VALIDITY];
65     paramsParcel.successThreshold = params[IDnsResolver::RESOLVER_PARAMS_SUCCESS_THRESHOLD];
66     paramsParcel.minSamples = params[IDnsResolver::RESOLVER_PARAMS_MIN_SAMPLES];
67     paramsParcel.maxSamples = params[IDnsResolver::RESOLVER_PARAMS_MAX_SAMPLES];
68     if (params.size() > IDnsResolver::RESOLVER_PARAMS_BASE_TIMEOUT_MSEC) {
69         paramsParcel.baseTimeoutMsec = params[IDnsResolver::RESOLVER_PARAMS_BASE_TIMEOUT_MSEC];
70     } else {
71         paramsParcel.baseTimeoutMsec = 0;
72     }
73     if (params.size() > IDnsResolver::RESOLVER_PARAMS_RETRY_COUNT) {
74         paramsParcel.retryCount = params[IDnsResolver::RESOLVER_PARAMS_RETRY_COUNT];
75     } else {
76         paramsParcel.retryCount = 0;
77     }
78     paramsParcel.servers = servers;
79     paramsParcel.domains = domains;
80     paramsParcel.tlsName = tlsHostname;
81     paramsParcel.tlsServers = tlsServers;
82     paramsParcel.tlsFingerprints = {};
83     paramsParcel.caCertificate = caCert;
84 
85     // Note, do not remove this otherwise the ResolverTest#ConnectTlsServerTimeout won't pass in M4
86     // module.
87     // TODO: remove after 2020-01 rolls out.
88     paramsParcel.tlsConnectTimeoutMs = 1000;
89 
90     return paramsParcel;
91 }
92 
GetResolverInfo(aidl::android::net::IDnsResolver * dnsResolverService,unsigned netId,std::vector<std::string> * servers,std::vector<std::string> * domains,std::vector<std::string> * tlsServers,res_params * params,std::vector<ResolverStats> * stats,int * waitForPendingReqTimeoutCount)93 bool DnsResponderClient::GetResolverInfo(aidl::android::net::IDnsResolver* dnsResolverService,
94                                          unsigned netId, std::vector<std::string>* servers,
95                                          std::vector<std::string>* domains,
96                                          std::vector<std::string>* tlsServers, res_params* params,
97                                          std::vector<ResolverStats>* stats,
98                                          int* waitForPendingReqTimeoutCount) {
99     using aidl::android::net::IDnsResolver;
100     std::vector<int32_t> params32;
101     std::vector<int32_t> stats32;
102     std::vector<int32_t> waitForPendingReqTimeoutCount32{0};
103     auto rv = dnsResolverService->getResolverInfo(netId, servers, domains, tlsServers, &params32,
104                                                   &stats32, &waitForPendingReqTimeoutCount32);
105 
106     if (!rv.isOk() || params32.size() != static_cast<size_t>(IDnsResolver::RESOLVER_PARAMS_COUNT)) {
107         return false;
108     }
109     *params = res_params{
110             .sample_validity =
111                     static_cast<uint16_t>(params32[IDnsResolver::RESOLVER_PARAMS_SAMPLE_VALIDITY]),
112             .success_threshold =
113                     static_cast<uint8_t>(params32[IDnsResolver::RESOLVER_PARAMS_SUCCESS_THRESHOLD]),
114             .min_samples =
115                     static_cast<uint8_t>(params32[IDnsResolver::RESOLVER_PARAMS_MIN_SAMPLES]),
116             .max_samples =
117                     static_cast<uint8_t>(params32[IDnsResolver::RESOLVER_PARAMS_MAX_SAMPLES]),
118             .base_timeout_msec = params32[IDnsResolver::RESOLVER_PARAMS_BASE_TIMEOUT_MSEC],
119             .retry_count = params32[IDnsResolver::RESOLVER_PARAMS_RETRY_COUNT],
120     };
121     *waitForPendingReqTimeoutCount = waitForPendingReqTimeoutCount32[0];
122     return ResolverStats::decodeAll(stats32, stats);
123 }
124 
isRemoteVersionSupported(aidl::android::net::IDnsResolver * dnsResolverService,int requiredVersion)125 bool DnsResponderClient::isRemoteVersionSupported(
126         aidl::android::net::IDnsResolver* dnsResolverService, int requiredVersion) {
127     int remoteVersion = 0;
128     if (!dnsResolverService->getInterfaceVersion(&remoteVersion).isOk()) {
129         LOG(FATAL) << "Can't get 'dnsresolver' remote version";
130     }
131     if (remoteVersion < requiredVersion) {
132         LOG(WARNING) << StringPrintf("Remote version: %d < Required version: %d", remoteVersion,
133                                      requiredVersion);
134         return false;
135     }
136     return true;
137 }
138 
SetResolversForNetwork(const std::vector<std::string> & servers,const std::vector<std::string> & domains,const std::vector<int> & params)139 bool DnsResponderClient::SetResolversForNetwork(const std::vector<std::string>& servers,
140                                                 const std::vector<std::string>& domains,
141                                                 const std::vector<int>& params) {
142     const auto& resolverParams =
143             makeResolverParamsParcel(TEST_NETID, params, servers, domains, "", {}, "");
144     const auto rv = mDnsResolvSrv->setResolverConfiguration(resolverParams);
145     return rv.isOk();
146 }
147 
SetResolversWithTls(const std::vector<std::string> & servers,const std::vector<std::string> & domains,const std::vector<int> & params,const std::vector<std::string> & tlsServers,const std::string & name)148 bool DnsResponderClient::SetResolversWithTls(const std::vector<std::string>& servers,
149                                              const std::vector<std::string>& domains,
150                                              const std::vector<int>& params,
151                                              const std::vector<std::string>& tlsServers,
152                                              const std::string& name) {
153     const auto& resolverParams = makeResolverParamsParcel(TEST_NETID, params, servers, domains,
154                                                           name, tlsServers, kCaCert);
155     const auto rv = mDnsResolvSrv->setResolverConfiguration(resolverParams);
156     if (!rv.isOk()) LOG(ERROR) << "SetResolversWithTls() -> " << rv.getMessage();
157     return rv.isOk();
158 }
159 
SetResolversFromParcel(const ResolverParamsParcel & resolverParams)160 bool DnsResponderClient::SetResolversFromParcel(const ResolverParamsParcel& resolverParams) {
161     const auto rv = mDnsResolvSrv->setResolverConfiguration(resolverParams);
162     if (!rv.isOk()) LOG(ERROR) << "SetResolversFromParcel() -> " << rv.getMessage();
163     return rv.isOk();
164 }
165 
GetDefaultResolverParamsParcel()166 ResolverParamsParcel DnsResponderClient::GetDefaultResolverParamsParcel() {
167     return makeResolverParamsParcel(TEST_NETID, kDefaultParams, kDefaultServers,
168                                     kDefaultSearchDomains, {} /* tlsHostname */, kDefaultServers,
169                                     kCaCert);
170 }
171 
SetupDNSServers(unsigned numServers,const std::vector<Mapping> & mappings,std::vector<std::unique_ptr<test::DNSResponder>> * dns,std::vector<std::string> * servers)172 void DnsResponderClient::SetupDNSServers(unsigned numServers, const std::vector<Mapping>& mappings,
173                                          std::vector<std::unique_ptr<test::DNSResponder>>* dns,
174                                          std::vector<std::string>* servers) {
175     const char* listenSrv = "53";
176     dns->resize(numServers);
177     servers->resize(numServers);
178     for (unsigned i = 0; i < numServers; ++i) {
179         auto& server = (*servers)[i];
180         auto& d = (*dns)[i];
181         server = StringPrintf("127.0.0.%u", i + 100);
182         d = std::make_unique<test::DNSResponder>(server, listenSrv, ns_rcode::ns_r_servfail);
183         for (const auto& mapping : mappings) {
184             d->addMapping(mapping.entry.c_str(), ns_type::ns_t_a, mapping.ip4.c_str());
185             d->addMapping(mapping.entry.c_str(), ns_type::ns_t_aaaa, mapping.ip6.c_str());
186         }
187         d->startServer();
188     }
189 }
190 
SetupOemNetwork()191 int DnsResponderClient::SetupOemNetwork() {
192     mNetdSrv->networkDestroy(TEST_NETID);
193     mDnsResolvSrv->destroyNetworkCache(TEST_NETID);
194     auto ret = mNetdSrv->networkCreatePhysical(TEST_NETID, INetd::PERMISSION_NONE);
195     if (!ret.isOk()) {
196         fprintf(stderr, "Creating physical network %d failed, %s\n", TEST_NETID, ret.getMessage());
197         return -1;
198     }
199     ret = mDnsResolvSrv->createNetworkCache(TEST_NETID);
200     if (!ret.isOk()) {
201         fprintf(stderr, "Creating network cache %d failed, %s\n", TEST_NETID, ret.getMessage());
202         return -1;
203     }
204     setNetworkForProcess(TEST_NETID);
205     if ((unsigned)TEST_NETID != getNetworkForProcess()) {
206         return -1;
207     }
208     return TEST_NETID;
209 }
210 
TearDownOemNetwork(int oemNetId)211 void DnsResponderClient::TearDownOemNetwork(int oemNetId) {
212     if (oemNetId != -1) {
213         mNetdSrv->networkDestroy(oemNetId);
214         mDnsResolvSrv->destroyNetworkCache(oemNetId);
215     }
216 }
217 
SetUp()218 void DnsResponderClient::SetUp() {
219     // binder setup
220     ndk::SpAIBinder netdBinder = ndk::SpAIBinder(AServiceManager_getService("netd"));
221     mNetdSrv = INetd::fromBinder(netdBinder);
222     if (mNetdSrv.get() == nullptr) {
223         LOG(FATAL) << "Can't connect to service 'netd'. Missing root privileges? uid=" << getuid();
224     }
225 
226     ndk::SpAIBinder resolvBinder = ndk::SpAIBinder(AServiceManager_getService("dnsresolver"));
227     mDnsResolvSrv = IDnsResolver::fromBinder(resolvBinder);
228     if (mDnsResolvSrv.get() == nullptr) {
229         LOG(FATAL) << "Can't connect to service 'dnsresolver'. Missing root privileges? uid="
230                    << getuid();
231     }
232 
233     // Ensure resolutions go via proxy.
234     setenv(ANDROID_DNS_MODE, "", 1);
235     mOemNetId = SetupOemNetwork();
236 }
237 
TearDown()238 void DnsResponderClient::TearDown() {
239     TearDownOemNetwork(mOemNetId);
240 }
241