1 /*
2 * Copyright (C) 2019 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *
16 */
17
18 #define LOG_TAG "resolv"
19
20 #include "DnsStats.h"
21
22 #include <android-base/logging.h>
23 #include <android-base/stringprintf.h>
24
25 namespace android::net {
26
27 using base::StringPrintf;
28 using netdutils::DumpWriter;
29 using netdutils::IPAddress;
30 using netdutils::IPSockAddr;
31 using netdutils::ScopedIndent;
32 using std::chrono::duration_cast;
33 using std::chrono::microseconds;
34 using std::chrono::milliseconds;
35 using std::chrono::seconds;
36
37 namespace {
38
39 static constexpr IPAddress INVALID_IPADDRESS = IPAddress();
40
rcodeToName(int rcode)41 std::string rcodeToName(int rcode) {
42 // clang-format off
43 switch (rcode) {
44 case NS_R_NO_ERROR: return "NOERROR";
45 case NS_R_FORMERR: return "FORMERR";
46 case NS_R_SERVFAIL: return "SERVFAIL";
47 case NS_R_NXDOMAIN: return "NXDOMAIN";
48 case NS_R_NOTIMPL: return "NOTIMP";
49 case NS_R_REFUSED: return "REFUSED";
50 case NS_R_YXDOMAIN: return "YXDOMAIN";
51 case NS_R_YXRRSET: return "YXRRSET";
52 case NS_R_NXRRSET: return "NXRRSET";
53 case NS_R_NOTAUTH: return "NOTAUTH";
54 case NS_R_NOTZONE: return "NOTZONE";
55 case NS_R_INTERNAL_ERROR: return "INTERNAL_ERROR";
56 case NS_R_TIMEOUT: return "TIMEOUT";
57 default: return StringPrintf("UNKNOWN(%d)", rcode);
58 }
59 // clang-format on
60 }
61
ensureNoInvalidIp(const std::vector<IPSockAddr> & servers)62 bool ensureNoInvalidIp(const std::vector<IPSockAddr>& servers) {
63 for (const auto& server : servers) {
64 if (server.ip() == INVALID_IPADDRESS || server.port() == 0) {
65 LOG(WARNING) << "Invalid server: " << server;
66 return false;
67 }
68 }
69 return true;
70 }
71
72 } // namespace
73
74 // The comparison ignores the last update time.
operator ==(const StatsData & o) const75 bool StatsData::operator==(const StatsData& o) const {
76 return std::tie(serverSockAddr, total, rcodeCounts, latencyUs) ==
77 std::tie(o.serverSockAddr, o.total, o.rcodeCounts, o.latencyUs);
78 }
79
averageLatencyMs() const80 int StatsData::averageLatencyMs() const {
81 return (total == 0) ? 0 : duration_cast<milliseconds>(latencyUs).count() / total;
82 }
83
toString() const84 std::string StatsData::toString() const {
85 if (total == 0) return StringPrintf("%s <no data>", serverSockAddr.ip().toString().c_str());
86
87 const auto now = std::chrono::steady_clock::now();
88 const int lastUpdateSec = duration_cast<seconds>(now - lastUpdate).count();
89 std::string buf;
90 for (const auto& [rcode, counts] : rcodeCounts) {
91 if (counts != 0) {
92 buf += StringPrintf("%s:%d ", rcodeToName(rcode).c_str(), counts);
93 }
94 }
95 return StringPrintf("%s (%d, %dms, [%s], %ds)", serverSockAddr.ip().toString().c_str(), total,
96 averageLatencyMs(), buf.c_str(), lastUpdateSec);
97 }
98
StatsRecords(const IPSockAddr & ipSockAddr,size_t size)99 StatsRecords::StatsRecords(const IPSockAddr& ipSockAddr, size_t size)
100 : mCapacity(size), mStatsData(ipSockAddr) {}
101
push(const Record & record)102 void StatsRecords::push(const Record& record) {
103 updateStatsData(record, true);
104 mRecords.push_back(record);
105
106 if (mRecords.size() > mCapacity) {
107 updateStatsData(mRecords.front(), false);
108 mRecords.pop_front();
109 }
110
111 // Update the quality factors.
112 mSkippedCount = 0;
113 updatePenalty(record);
114 }
115
updateStatsData(const Record & record,const bool add)116 void StatsRecords::updateStatsData(const Record& record, const bool add) {
117 const int rcode = record.rcode;
118 if (add) {
119 mStatsData.total += 1;
120 mStatsData.rcodeCounts[rcode] += 1;
121 mStatsData.latencyUs += record.latencyUs;
122 } else {
123 mStatsData.total -= 1;
124 mStatsData.rcodeCounts[rcode] -= 1;
125 mStatsData.latencyUs -= record.latencyUs;
126 }
127 mStatsData.lastUpdate = std::chrono::steady_clock::now();
128 }
129
updatePenalty(const Record & record)130 void StatsRecords::updatePenalty(const Record& record) {
131 switch (record.rcode) {
132 case NS_R_NO_ERROR:
133 case NS_R_NXDOMAIN:
134 case NS_R_NOTAUTH:
135 mPenalty = 0;
136 return;
137 default:
138 // NS_R_TIMEOUT and NS_R_INTERNAL_ERROR are in this case.
139 if (mPenalty == 0) {
140 mPenalty = 100;
141 } else {
142 // The evaluated quality drops more quickly when continuous failures happen.
143 mPenalty = std::min(mPenalty * 2, kMaxQuality);
144 }
145 return;
146 }
147 }
148
score() const149 double StatsRecords::score() const {
150 const int avgRtt = mStatsData.averageLatencyMs();
151
152 // Set the lower bound to -1 in case of "avgRtt + mPenalty < mSkippedCount"
153 // 1) when the server doesn't have any stats yet.
154 // 2) when the sorting has been disabled while it was enabled before.
155 int quality = std::clamp(avgRtt + mPenalty - mSkippedCount, -1, kMaxQuality);
156
157 // Normalization.
158 return static_cast<double>(kMaxQuality - quality) * 100 / kMaxQuality;
159 }
160
incrementSkippedCount()161 void StatsRecords::incrementSkippedCount() {
162 mSkippedCount = std::min(mSkippedCount + 1, kMaxQuality);
163 }
164
setServers(const std::vector<netdutils::IPSockAddr> & servers,Protocol protocol)165 bool DnsStats::setServers(const std::vector<netdutils::IPSockAddr>& servers, Protocol protocol) {
166 if (!ensureNoInvalidIp(servers)) return false;
167
168 ServerStatsMap& statsMap = mStats[protocol];
169 for (const auto& server : servers) {
170 statsMap.try_emplace(server, StatsRecords(server, kLogSize));
171 }
172
173 // Clean up the map to eliminate the nodes not belonging to the given list of servers.
174 const auto cleanup = [&](ServerStatsMap* statsMap) {
175 ServerStatsMap tmp;
176 for (const auto& server : servers) {
177 if (statsMap->find(server) != statsMap->end()) {
178 tmp.insert(statsMap->extract(server));
179 }
180 }
181 statsMap->swap(tmp);
182 };
183
184 cleanup(&statsMap);
185
186 return true;
187 }
188
addStats(const IPSockAddr & ipSockAddr,const DnsQueryEvent & record)189 bool DnsStats::addStats(const IPSockAddr& ipSockAddr, const DnsQueryEvent& record) {
190 if (ipSockAddr.ip() == INVALID_IPADDRESS) return false;
191
192 bool added = false;
193 for (auto& [serverSockAddr, statsRecords] : mStats[record.protocol()]) {
194 if (serverSockAddr == ipSockAddr) {
195 const StatsRecords::Record rec = {
196 .rcode = record.rcode(),
197 .latencyUs = microseconds(record.latency_micros()),
198 };
199 statsRecords.push(rec);
200 added = true;
201 } else {
202 statsRecords.incrementSkippedCount();
203 }
204 }
205
206 return added;
207 }
208
getSortedServers(Protocol protocol) const209 std::vector<IPSockAddr> DnsStats::getSortedServers(Protocol protocol) const {
210 // DoT unsupported. The handshake overhead is expensive, and the connection will hang for a
211 // while. Need to figure out if it is worth doing for DoT servers.
212 if (protocol == PROTO_DOT) return {};
213
214 auto it = mStats.find(protocol);
215 if (it == mStats.end()) return {};
216
217 // Sorting on insertion in decreasing order.
218 std::multimap<double, IPSockAddr, std::greater<double>> sortedData;
219 for (const auto& [ip, statsRecords] : it->second) {
220 sortedData.insert({statsRecords.score(), ip});
221 }
222
223 std::vector<IPSockAddr> ret;
224 ret.reserve(sortedData.size());
225 for (auto& [_, v] : sortedData) {
226 ret.push_back(v); // IPSockAddr is trivially-copyable.
227 }
228
229 return ret;
230 }
231
getStats(Protocol protocol) const232 std::vector<StatsData> DnsStats::getStats(Protocol protocol) const {
233 std::vector<StatsData> ret;
234
235 if (mStats.find(protocol) != mStats.end()) {
236 for (const auto& [_, statsRecords] : mStats.at(protocol)) {
237 ret.push_back(statsRecords.getStatsData());
238 }
239 }
240 return ret;
241 }
242
dump(DumpWriter & dw)243 void DnsStats::dump(DumpWriter& dw) {
244 const auto dumpStatsMap = [&](ServerStatsMap& statsMap) {
245 ScopedIndent indentLog(dw);
246 if (statsMap.size() == 0) {
247 dw.println("<no server>");
248 return;
249 }
250 for (const auto& [_, statsRecords] : statsMap) {
251 const StatsData& data = statsRecords.getStatsData();
252 std::string str = data.toString();
253 str += StringPrintf(" score{%.1f}", statsRecords.score());
254 dw.println("%s", str.c_str());
255 }
256 };
257
258 dw.println("Server statistics: (total, RTT avg, {rcode:counts}, last update)");
259 ScopedIndent indentStats(dw);
260
261 dw.println("over UDP");
262 dumpStatsMap(mStats[PROTO_UDP]);
263
264 dw.println("over TLS");
265 dumpStatsMap(mStats[PROTO_DOT]);
266
267 dw.println("over TCP");
268 dumpStatsMap(mStats[PROTO_TCP]);
269 }
270
271 } // namespace android::net
272