1 /*
2  * Copyright (C) 2016 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.server.net;
18 
19 import static android.app.usage.NetworkStatsManager.MIN_THRESHOLD_BYTES;
20 
21 import static com.android.internal.util.Preconditions.checkArgument;
22 
23 import android.app.usage.NetworkStatsManager;
24 import android.net.DataUsageRequest;
25 import android.net.NetworkStats;
26 import android.net.NetworkStatsHistory;
27 import android.net.NetworkTemplate;
28 import android.os.Bundle;
29 import android.os.Handler;
30 import android.os.HandlerThread;
31 import android.os.IBinder;
32 import android.os.Looper;
33 import android.os.Message;
34 import android.os.Messenger;
35 import android.os.Process;
36 import android.os.RemoteException;
37 import android.util.ArrayMap;
38 import android.util.Slog;
39 import android.util.SparseArray;
40 
41 import com.android.internal.annotations.VisibleForTesting;
42 
43 import java.util.concurrent.atomic.AtomicInteger;
44 
45 /**
46  * Manages observers of {@link NetworkStats}. Allows observers to be notified when
47  * data usage has been reported in {@link NetworkStatsService}. An observer can set
48  * a threshold of how much data it cares about to be notified.
49  */
50 class NetworkStatsObservers {
51     private static final String TAG = "NetworkStatsObservers";
52     private static final boolean LOGV = false;
53 
54     private static final int MSG_REGISTER = 1;
55     private static final int MSG_UNREGISTER = 2;
56     private static final int MSG_UPDATE_STATS = 3;
57 
58     // All access to this map must be done from the handler thread.
59     // indexed by DataUsageRequest#requestId
60     private final SparseArray<RequestInfo> mDataUsageRequests = new SparseArray<>();
61 
62     // Sequence number of DataUsageRequests
63     private final AtomicInteger mNextDataUsageRequestId = new AtomicInteger();
64 
65     // Lazily instantiated when an observer is registered.
66     private volatile Handler mHandler;
67 
68     /**
69      * Creates a wrapper that contains the caller context and a normalized request.
70      * The request should be returned to the caller app, and the wrapper should be sent to this
71      * object through #addObserver by the service handler.
72      *
73      * <p>It will register the observer asynchronously, so it is safe to call from any thread.
74      *
75      * @return the normalized request wrapped within {@link RequestInfo}.
76      */
register(DataUsageRequest inputRequest, Messenger messenger, IBinder binder, int callingUid, @NetworkStatsAccess.Level int accessLevel)77     public DataUsageRequest register(DataUsageRequest inputRequest, Messenger messenger,
78                 IBinder binder, int callingUid, @NetworkStatsAccess.Level int accessLevel) {
79         DataUsageRequest request = buildRequest(inputRequest);
80         RequestInfo requestInfo = buildRequestInfo(request, messenger, binder, callingUid,
81                 accessLevel);
82 
83         if (LOGV) Slog.v(TAG, "Registering observer for " + request);
84         getHandler().sendMessage(mHandler.obtainMessage(MSG_REGISTER, requestInfo));
85         return request;
86     }
87 
88     /**
89      * Unregister a data usage observer.
90      *
91      * <p>It will unregister the observer asynchronously, so it is safe to call from any thread.
92      */
unregister(DataUsageRequest request, int callingUid)93     public void unregister(DataUsageRequest request, int callingUid) {
94         getHandler().sendMessage(mHandler.obtainMessage(MSG_UNREGISTER, callingUid, 0 /* ignore */,
95                 request));
96     }
97 
98     /**
99      * Updates data usage statistics of registered observers and notifies if limits are reached.
100      *
101      * <p>It will update stats asynchronously, so it is safe to call from any thread.
102      */
updateStats(NetworkStats xtSnapshot, NetworkStats uidSnapshot, ArrayMap<String, NetworkIdentitySet> activeIfaces, ArrayMap<String, NetworkIdentitySet> activeUidIfaces, long currentTime)103     public void updateStats(NetworkStats xtSnapshot, NetworkStats uidSnapshot,
104                 ArrayMap<String, NetworkIdentitySet> activeIfaces,
105                 ArrayMap<String, NetworkIdentitySet> activeUidIfaces,
106                 long currentTime) {
107         StatsContext statsContext = new StatsContext(xtSnapshot, uidSnapshot, activeIfaces,
108                 activeUidIfaces, currentTime);
109         getHandler().sendMessage(mHandler.obtainMessage(MSG_UPDATE_STATS, statsContext));
110     }
111 
getHandler()112     private Handler getHandler() {
113         if (mHandler == null) {
114             synchronized (this) {
115                 if (mHandler == null) {
116                     if (LOGV) Slog.v(TAG, "Creating handler");
117                     mHandler = new Handler(getHandlerLooperLocked(), mHandlerCallback);
118                 }
119             }
120         }
121         return mHandler;
122     }
123 
124     @VisibleForTesting
getHandlerLooperLocked()125     protected Looper getHandlerLooperLocked() {
126         HandlerThread handlerThread = new HandlerThread(TAG);
127         handlerThread.start();
128         return handlerThread.getLooper();
129     }
130 
131     private Handler.Callback mHandlerCallback = new Handler.Callback() {
132         @Override
133         public boolean handleMessage(Message msg) {
134             switch (msg.what) {
135                 case MSG_REGISTER: {
136                     handleRegister((RequestInfo) msg.obj);
137                     return true;
138                 }
139                 case MSG_UNREGISTER: {
140                     handleUnregister((DataUsageRequest) msg.obj, msg.arg1 /* callingUid */);
141                     return true;
142                 }
143                 case MSG_UPDATE_STATS: {
144                     handleUpdateStats((StatsContext) msg.obj);
145                     return true;
146                 }
147                 default: {
148                     return false;
149                 }
150             }
151         }
152     };
153 
154     /**
155      * Adds a {@link RequestInfo} as an observer.
156      * Should only be called from the handler thread otherwise there will be a race condition
157      * on mDataUsageRequests.
158      */
handleRegister(RequestInfo requestInfo)159     private void handleRegister(RequestInfo requestInfo) {
160         mDataUsageRequests.put(requestInfo.mRequest.requestId, requestInfo);
161     }
162 
163     /**
164      * Removes a {@link DataUsageRequest} if the calling uid is authorized.
165      * Should only be called from the handler thread otherwise there will be a race condition
166      * on mDataUsageRequests.
167      */
handleUnregister(DataUsageRequest request, int callingUid)168     private void handleUnregister(DataUsageRequest request, int callingUid) {
169         RequestInfo requestInfo;
170         requestInfo = mDataUsageRequests.get(request.requestId);
171         if (requestInfo == null) {
172             if (LOGV) Slog.v(TAG, "Trying to unregister unknown request " + request);
173             return;
174         }
175         if (Process.SYSTEM_UID != callingUid && requestInfo.mCallingUid != callingUid) {
176             Slog.w(TAG, "Caller uid " + callingUid + " is not owner of " + request);
177             return;
178         }
179 
180         if (LOGV) Slog.v(TAG, "Unregistering " + request);
181         mDataUsageRequests.remove(request.requestId);
182         requestInfo.unlinkDeathRecipient();
183         requestInfo.callCallback(NetworkStatsManager.CALLBACK_RELEASED);
184     }
185 
handleUpdateStats(StatsContext statsContext)186     private void handleUpdateStats(StatsContext statsContext) {
187         if (mDataUsageRequests.size() == 0) {
188             return;
189         }
190 
191         for (int i = 0; i < mDataUsageRequests.size(); i++) {
192             RequestInfo requestInfo = mDataUsageRequests.valueAt(i);
193             requestInfo.updateStats(statsContext);
194         }
195     }
196 
buildRequest(DataUsageRequest request)197     private DataUsageRequest buildRequest(DataUsageRequest request) {
198         // Cap the minimum threshold to a safe default to avoid too many callbacks
199         long thresholdInBytes = Math.max(MIN_THRESHOLD_BYTES, request.thresholdInBytes);
200         if (thresholdInBytes < request.thresholdInBytes) {
201             Slog.w(TAG, "Threshold was too low for " + request
202                     + ". Overriding to a safer default of " + thresholdInBytes + " bytes");
203         }
204         return new DataUsageRequest(mNextDataUsageRequestId.incrementAndGet(),
205                 request.template, thresholdInBytes);
206     }
207 
buildRequestInfo(DataUsageRequest request, Messenger messenger, IBinder binder, int callingUid, @NetworkStatsAccess.Level int accessLevel)208     private RequestInfo buildRequestInfo(DataUsageRequest request,
209                 Messenger messenger, IBinder binder, int callingUid,
210                 @NetworkStatsAccess.Level int accessLevel) {
211         if (accessLevel <= NetworkStatsAccess.Level.USER) {
212             return new UserUsageRequestInfo(this, request, messenger, binder, callingUid,
213                     accessLevel);
214         } else {
215             // Safety check in case a new access level is added and we forgot to update this
216             checkArgument(accessLevel >= NetworkStatsAccess.Level.DEVICESUMMARY);
217             return new NetworkUsageRequestInfo(this, request, messenger, binder, callingUid,
218                     accessLevel);
219         }
220     }
221 
222     /**
223      * Tracks information relevant to a data usage observer.
224      * It will notice when the calling process dies so we can self-expire.
225      */
226     private abstract static class RequestInfo implements IBinder.DeathRecipient {
227         private final NetworkStatsObservers mStatsObserver;
228         protected final DataUsageRequest mRequest;
229         private final Messenger mMessenger;
230         private final IBinder mBinder;
231         protected final int mCallingUid;
232         protected final @NetworkStatsAccess.Level int mAccessLevel;
233         protected NetworkStatsRecorder mRecorder;
234         protected NetworkStatsCollection mCollection;
235 
RequestInfo(NetworkStatsObservers statsObserver, DataUsageRequest request, Messenger messenger, IBinder binder, int callingUid, @NetworkStatsAccess.Level int accessLevel)236         RequestInfo(NetworkStatsObservers statsObserver, DataUsageRequest request,
237                     Messenger messenger, IBinder binder, int callingUid,
238                     @NetworkStatsAccess.Level int accessLevel) {
239             mStatsObserver = statsObserver;
240             mRequest = request;
241             mMessenger = messenger;
242             mBinder = binder;
243             mCallingUid = callingUid;
244             mAccessLevel = accessLevel;
245 
246             try {
247                 mBinder.linkToDeath(this, 0);
248             } catch (RemoteException e) {
249                 binderDied();
250             }
251         }
252 
253         @Override
binderDied()254         public void binderDied() {
255             if (LOGV) Slog.v(TAG, "RequestInfo binderDied("
256                     + mRequest + ", " + mBinder + ")");
257             mStatsObserver.unregister(mRequest, Process.SYSTEM_UID);
258             callCallback(NetworkStatsManager.CALLBACK_RELEASED);
259         }
260 
261         @Override
toString()262         public String toString() {
263             return "RequestInfo from uid:" + mCallingUid
264                     + " for " + mRequest + " accessLevel:" + mAccessLevel;
265         }
266 
unlinkDeathRecipient()267         private void unlinkDeathRecipient() {
268             if (mBinder != null) {
269                 mBinder.unlinkToDeath(this, 0);
270             }
271         }
272 
273         /**
274          * Update stats given the samples and interface to identity mappings.
275          */
updateStats(StatsContext statsContext)276         private void updateStats(StatsContext statsContext) {
277             if (mRecorder == null) {
278                 // First run; establish baseline stats
279                 resetRecorder();
280                 recordSample(statsContext);
281                 return;
282             }
283             recordSample(statsContext);
284 
285             if (checkStats()) {
286                 resetRecorder();
287                 callCallback(NetworkStatsManager.CALLBACK_LIMIT_REACHED);
288             }
289         }
290 
callCallback(int callbackType)291         private void callCallback(int callbackType) {
292             Bundle bundle = new Bundle();
293             bundle.putParcelable(DataUsageRequest.PARCELABLE_KEY, mRequest);
294             Message msg = Message.obtain();
295             msg.what = callbackType;
296             msg.setData(bundle);
297             try {
298                 if (LOGV) {
299                     Slog.v(TAG, "sending notification " + callbackTypeToName(callbackType)
300                             + " for " + mRequest);
301                 }
302                 mMessenger.send(msg);
303             } catch (RemoteException e) {
304                 // May occur naturally in the race of binder death.
305                 Slog.w(TAG, "RemoteException caught trying to send a callback msg for " + mRequest);
306             }
307         }
308 
resetRecorder()309         private void resetRecorder() {
310             mRecorder = new NetworkStatsRecorder();
311             mCollection = mRecorder.getSinceBoot();
312         }
313 
checkStats()314         protected abstract boolean checkStats();
315 
recordSample(StatsContext statsContext)316         protected abstract void recordSample(StatsContext statsContext);
317 
callbackTypeToName(int callbackType)318         private String callbackTypeToName(int callbackType) {
319             switch (callbackType) {
320                 case NetworkStatsManager.CALLBACK_LIMIT_REACHED:
321                     return "LIMIT_REACHED";
322                 case NetworkStatsManager.CALLBACK_RELEASED:
323                     return "RELEASED";
324                 default:
325                     return "UNKNOWN";
326             }
327         }
328     }
329 
330     private static class NetworkUsageRequestInfo extends RequestInfo {
NetworkUsageRequestInfo(NetworkStatsObservers statsObserver, DataUsageRequest request, Messenger messenger, IBinder binder, int callingUid, @NetworkStatsAccess.Level int accessLevel)331         NetworkUsageRequestInfo(NetworkStatsObservers statsObserver, DataUsageRequest request,
332                     Messenger messenger, IBinder binder, int callingUid,
333                     @NetworkStatsAccess.Level int accessLevel) {
334             super(statsObserver, request, messenger, binder, callingUid, accessLevel);
335         }
336 
337         @Override
checkStats()338         protected boolean checkStats() {
339             long bytesSoFar = getTotalBytesForNetwork(mRequest.template);
340             if (LOGV) {
341                 Slog.v(TAG, bytesSoFar + " bytes so far since notification for "
342                         + mRequest.template);
343             }
344             if (bytesSoFar > mRequest.thresholdInBytes) {
345                 return true;
346             }
347             return false;
348         }
349 
350         @Override
recordSample(StatsContext statsContext)351         protected void recordSample(StatsContext statsContext) {
352             // Recorder does not need to be locked in this context since only the handler
353             // thread will update it. We pass a null VPN array because usage is aggregated by uid
354             // for this snapshot, so VPN traffic can't be reattributed to responsible apps.
355             mRecorder.recordSnapshotLocked(statsContext.mXtSnapshot, statsContext.mActiveIfaces,
356                     statsContext.mCurrentTime);
357         }
358 
359         /**
360          * Reads stats matching the given template. {@link NetworkStatsCollection} will aggregate
361          * over all buckets, which in this case should be only one since we built it big enough
362          * that it will outlive the caller. If it doesn't, then there will be multiple buckets.
363          */
getTotalBytesForNetwork(NetworkTemplate template)364         private long getTotalBytesForNetwork(NetworkTemplate template) {
365             NetworkStats stats = mCollection.getSummary(template,
366                     Long.MIN_VALUE /* start */, Long.MAX_VALUE /* end */,
367                     mAccessLevel, mCallingUid);
368             return stats.getTotalBytes();
369         }
370     }
371 
372     private static class UserUsageRequestInfo extends RequestInfo {
UserUsageRequestInfo(NetworkStatsObservers statsObserver, DataUsageRequest request, Messenger messenger, IBinder binder, int callingUid, @NetworkStatsAccess.Level int accessLevel)373         UserUsageRequestInfo(NetworkStatsObservers statsObserver, DataUsageRequest request,
374                     Messenger messenger, IBinder binder, int callingUid,
375                     @NetworkStatsAccess.Level int accessLevel) {
376             super(statsObserver, request, messenger, binder, callingUid, accessLevel);
377         }
378 
379         @Override
checkStats()380         protected boolean checkStats() {
381             int[] uidsToMonitor = mCollection.getRelevantUids(mAccessLevel, mCallingUid);
382 
383             for (int i = 0; i < uidsToMonitor.length; i++) {
384                 long bytesSoFar = getTotalBytesForNetworkUid(mRequest.template, uidsToMonitor[i]);
385                 if (bytesSoFar > mRequest.thresholdInBytes) {
386                     return true;
387                 }
388             }
389             return false;
390         }
391 
392         @Override
recordSample(StatsContext statsContext)393         protected void recordSample(StatsContext statsContext) {
394             // Recorder does not need to be locked in this context since only the handler
395             // thread will update it. We pass the VPN info so VPN traffic is reattributed to
396             // responsible apps.
397             mRecorder.recordSnapshotLocked(statsContext.mUidSnapshot, statsContext.mActiveUidIfaces,
398                     statsContext.mCurrentTime);
399         }
400 
401         /**
402          * Reads all stats matching the given template and uid. Ther history will likely only
403          * contain one bucket per ident since we build it big enough that it will outlive the
404          * caller lifetime.
405          */
getTotalBytesForNetworkUid(NetworkTemplate template, int uid)406         private long getTotalBytesForNetworkUid(NetworkTemplate template, int uid) {
407             try {
408                 NetworkStatsHistory history = mCollection.getHistory(template, null, uid,
409                         NetworkStats.SET_ALL, NetworkStats.TAG_NONE,
410                         NetworkStatsHistory.FIELD_ALL,
411                         Long.MIN_VALUE /* start */, Long.MAX_VALUE /* end */,
412                         mAccessLevel, mCallingUid);
413                 return history.getTotalBytes();
414             } catch (SecurityException e) {
415                 if (LOGV) {
416                     Slog.w(TAG, "CallerUid " + mCallingUid + " may have lost access to uid "
417                             + uid);
418                 }
419                 return 0;
420             }
421         }
422     }
423 
424     private static class StatsContext {
425         NetworkStats mXtSnapshot;
426         NetworkStats mUidSnapshot;
427         ArrayMap<String, NetworkIdentitySet> mActiveIfaces;
428         ArrayMap<String, NetworkIdentitySet> mActiveUidIfaces;
429         long mCurrentTime;
430 
StatsContext(NetworkStats xtSnapshot, NetworkStats uidSnapshot, ArrayMap<String, NetworkIdentitySet> activeIfaces, ArrayMap<String, NetworkIdentitySet> activeUidIfaces, long currentTime)431         StatsContext(NetworkStats xtSnapshot, NetworkStats uidSnapshot,
432                 ArrayMap<String, NetworkIdentitySet> activeIfaces,
433                 ArrayMap<String, NetworkIdentitySet> activeUidIfaces,
434                 long currentTime) {
435             mXtSnapshot = xtSnapshot;
436             mUidSnapshot = uidSnapshot;
437             mActiveIfaces = activeIfaces;
438             mActiveUidIfaces = activeUidIfaces;
439             mCurrentTime = currentTime;
440         }
441     }
442 }
443