1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  *
16  */
17 
18 #define LOG_TAG "resolv"
19 
20 #include "DnsStats.h"
21 
22 #include <android-base/logging.h>
23 #include <android-base/stringprintf.h>
24 
25 namespace android::net {
26 
27 using base::StringPrintf;
28 using netdutils::DumpWriter;
29 using netdutils::IPAddress;
30 using netdutils::IPSockAddr;
31 using netdutils::ScopedIndent;
32 using std::chrono::duration_cast;
33 using std::chrono::microseconds;
34 using std::chrono::milliseconds;
35 using std::chrono::seconds;
36 
37 namespace {
38 
39 static constexpr IPAddress INVALID_IPADDRESS = IPAddress();
40 
rcodeToName(int rcode)41 std::string rcodeToName(int rcode) {
42     // clang-format off
43     switch (rcode) {
44         case NS_R_NO_ERROR: return "NOERROR";
45         case NS_R_FORMERR: return "FORMERR";
46         case NS_R_SERVFAIL: return "SERVFAIL";
47         case NS_R_NXDOMAIN: return "NXDOMAIN";
48         case NS_R_NOTIMPL: return "NOTIMP";
49         case NS_R_REFUSED: return "REFUSED";
50         case NS_R_YXDOMAIN: return "YXDOMAIN";
51         case NS_R_YXRRSET: return "YXRRSET";
52         case NS_R_NXRRSET: return "NXRRSET";
53         case NS_R_NOTAUTH: return "NOTAUTH";
54         case NS_R_NOTZONE: return "NOTZONE";
55         case NS_R_INTERNAL_ERROR: return "INTERNAL_ERROR";
56         case NS_R_TIMEOUT: return "TIMEOUT";
57         default: return StringPrintf("UNKNOWN(%d)", rcode);
58     }
59     // clang-format on
60 }
61 
ensureNoInvalidIp(const std::vector<IPSockAddr> & servers)62 bool ensureNoInvalidIp(const std::vector<IPSockAddr>& servers) {
63     for (const auto& server : servers) {
64         if (server.ip() == INVALID_IPADDRESS || server.port() == 0) {
65             LOG(WARNING) << "Invalid server: " << server;
66             return false;
67         }
68     }
69     return true;
70 }
71 
72 }  // namespace
73 
74 // The comparison ignores the last update time.
operator ==(const StatsData & o) const75 bool StatsData::operator==(const StatsData& o) const {
76     return std::tie(serverSockAddr, total, rcodeCounts, latencyUs) ==
77            std::tie(o.serverSockAddr, o.total, o.rcodeCounts, o.latencyUs);
78 }
79 
averageLatencyMs() const80 int StatsData::averageLatencyMs() const {
81     return (total == 0) ? 0 : duration_cast<milliseconds>(latencyUs).count() / total;
82 }
83 
toString() const84 std::string StatsData::toString() const {
85     if (total == 0) return StringPrintf("%s <no data>", serverSockAddr.ip().toString().c_str());
86 
87     const auto now = std::chrono::steady_clock::now();
88     const int lastUpdateSec = duration_cast<seconds>(now - lastUpdate).count();
89     std::string buf;
90     for (const auto& [rcode, counts] : rcodeCounts) {
91         if (counts != 0) {
92             buf += StringPrintf("%s:%d ", rcodeToName(rcode).c_str(), counts);
93         }
94     }
95     return StringPrintf("%s (%d, %dms, [%s], %ds)", serverSockAddr.ip().toString().c_str(), total,
96                         averageLatencyMs(), buf.c_str(), lastUpdateSec);
97 }
98 
StatsRecords(const IPSockAddr & ipSockAddr,size_t size)99 StatsRecords::StatsRecords(const IPSockAddr& ipSockAddr, size_t size)
100     : mCapacity(size), mStatsData(ipSockAddr) {}
101 
push(const Record & record)102 void StatsRecords::push(const Record& record) {
103     updateStatsData(record, true);
104     mRecords.push_back(record);
105 
106     if (mRecords.size() > mCapacity) {
107         updateStatsData(mRecords.front(), false);
108         mRecords.pop_front();
109     }
110 
111     // Update the quality factors.
112     mSkippedCount = 0;
113     updatePenalty(record);
114 }
115 
updateStatsData(const Record & record,const bool add)116 void StatsRecords::updateStatsData(const Record& record, const bool add) {
117     const int rcode = record.rcode;
118     if (add) {
119         mStatsData.total += 1;
120         mStatsData.rcodeCounts[rcode] += 1;
121         mStatsData.latencyUs += record.latencyUs;
122     } else {
123         mStatsData.total -= 1;
124         mStatsData.rcodeCounts[rcode] -= 1;
125         mStatsData.latencyUs -= record.latencyUs;
126     }
127     mStatsData.lastUpdate = std::chrono::steady_clock::now();
128 }
129 
updatePenalty(const Record & record)130 void StatsRecords::updatePenalty(const Record& record) {
131     switch (record.rcode) {
132         case NS_R_NO_ERROR:
133         case NS_R_NXDOMAIN:
134         case NS_R_NOTAUTH:
135             mPenalty = 0;
136             return;
137         default:
138             // NS_R_TIMEOUT and NS_R_INTERNAL_ERROR are in this case.
139             if (mPenalty == 0) {
140                 mPenalty = 100;
141             } else {
142                 // The evaluated quality drops more quickly when continuous failures happen.
143                 mPenalty = std::min(mPenalty * 2, kMaxQuality);
144             }
145             return;
146     }
147 }
148 
score() const149 double StatsRecords::score() const {
150     const int avgRtt = mStatsData.averageLatencyMs();
151 
152     // Set the lower bound to -1 in case of "avgRtt + mPenalty < mSkippedCount"
153     //   1) when the server doesn't have any stats yet.
154     //   2) when the sorting has been disabled while it was enabled before.
155     int quality = std::clamp(avgRtt + mPenalty - mSkippedCount, -1, kMaxQuality);
156 
157     // Normalization.
158     return static_cast<double>(kMaxQuality - quality) * 100 / kMaxQuality;
159 }
160 
incrementSkippedCount()161 void StatsRecords::incrementSkippedCount() {
162     mSkippedCount = std::min(mSkippedCount + 1, kMaxQuality);
163 }
164 
setServers(const std::vector<netdutils::IPSockAddr> & servers,Protocol protocol)165 bool DnsStats::setServers(const std::vector<netdutils::IPSockAddr>& servers, Protocol protocol) {
166     if (!ensureNoInvalidIp(servers)) return false;
167 
168     ServerStatsMap& statsMap = mStats[protocol];
169     for (const auto& server : servers) {
170         statsMap.try_emplace(server, StatsRecords(server, kLogSize));
171     }
172 
173     // Clean up the map to eliminate the nodes not belonging to the given list of servers.
174     const auto cleanup = [&](ServerStatsMap* statsMap) {
175         ServerStatsMap tmp;
176         for (const auto& server : servers) {
177             if (statsMap->find(server) != statsMap->end()) {
178                 tmp.insert(statsMap->extract(server));
179             }
180         }
181         statsMap->swap(tmp);
182     };
183 
184     cleanup(&statsMap);
185 
186     return true;
187 }
188 
addStats(const IPSockAddr & ipSockAddr,const DnsQueryEvent & record)189 bool DnsStats::addStats(const IPSockAddr& ipSockAddr, const DnsQueryEvent& record) {
190     if (ipSockAddr.ip() == INVALID_IPADDRESS) return false;
191 
192     bool added = false;
193     for (auto& [serverSockAddr, statsRecords] : mStats[record.protocol()]) {
194         if (serverSockAddr == ipSockAddr) {
195             const StatsRecords::Record rec = {
196                     .rcode = record.rcode(),
197                     .latencyUs = microseconds(record.latency_micros()),
198             };
199             statsRecords.push(rec);
200             added = true;
201         } else {
202             statsRecords.incrementSkippedCount();
203         }
204     }
205 
206     return added;
207 }
208 
getSortedServers(Protocol protocol) const209 std::vector<IPSockAddr> DnsStats::getSortedServers(Protocol protocol) const {
210     // DoT unsupported. The handshake overhead is expensive, and the connection will hang for a
211     // while. Need to figure out if it is worth doing for DoT servers.
212     if (protocol == PROTO_DOT) return {};
213 
214     auto it = mStats.find(protocol);
215     if (it == mStats.end()) return {};
216 
217     // Sorting on insertion in decreasing order.
218     std::multimap<double, IPSockAddr, std::greater<double>> sortedData;
219     for (const auto& [ip, statsRecords] : it->second) {
220         sortedData.insert({statsRecords.score(), ip});
221     }
222 
223     std::vector<IPSockAddr> ret;
224     ret.reserve(sortedData.size());
225     for (auto& [_, v] : sortedData) {
226         ret.push_back(v);  // IPSockAddr is trivially-copyable.
227     }
228 
229     return ret;
230 }
231 
getStats(Protocol protocol) const232 std::vector<StatsData> DnsStats::getStats(Protocol protocol) const {
233     std::vector<StatsData> ret;
234 
235     if (mStats.find(protocol) != mStats.end()) {
236         for (const auto& [_, statsRecords] : mStats.at(protocol)) {
237             ret.push_back(statsRecords.getStatsData());
238         }
239     }
240     return ret;
241 }
242 
dump(DumpWriter & dw)243 void DnsStats::dump(DumpWriter& dw) {
244     const auto dumpStatsMap = [&](ServerStatsMap& statsMap) {
245         ScopedIndent indentLog(dw);
246         if (statsMap.size() == 0) {
247             dw.println("<no server>");
248             return;
249         }
250         for (const auto& [_, statsRecords] : statsMap) {
251             const StatsData& data = statsRecords.getStatsData();
252             std::string str = data.toString();
253             str += StringPrintf(" score{%.1f}", statsRecords.score());
254             dw.println("%s", str.c_str());
255         }
256     };
257 
258     dw.println("Server statistics: (total, RTT avg, {rcode:counts}, last update)");
259     ScopedIndent indentStats(dw);
260 
261     dw.println("over UDP");
262     dumpStatsMap(mStats[PROTO_UDP]);
263 
264     dw.println("over TLS");
265     dumpStatsMap(mStats[PROTO_DOT]);
266 
267     dw.println("over TCP");
268     dumpStatsMap(mStats[PROTO_TCP]);
269 }
270 
271 }  // namespace android::net
272