1#!/usr/bin/env python3
2#
3#   Copyright 2018 - Google, Inc.
4#
5#   Licensed under the Apache License, Version 2.0 (the "License");
6#   you may not use this file except in compliance with the License.
7#   You may obtain a copy of the License at
8#
9#       http://www.apache.org/licenses/LICENSE-2.0
10#
11#   Unless required by applicable law or agreed to in writing, software
12#   distributed under the License is distributed on an "AS IS" BASIS,
13#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14#   See the License for the specific language governing permissions and
15#   limitations under the License.
16import json
17import socket
18import threading
19import time
20from concurrent import futures
21
22from acts import error
23from acts import logger
24from acts.metrics.loggers import usage_metadata_logger
25
26# The default timeout value when no timeout is set.
27SOCKET_TIMEOUT = 60
28
29# The Session UID when a UID has not been received yet.
30UNKNOWN_UID = -1
31
32class Sl4aException(error.ActsError):
33    """The base class for all SL4A exceptions."""
34
35
36class Sl4aStartError(Sl4aException):
37    """Raised when sl4a is not able to be started."""
38
39
40class Sl4aApiError(Sl4aException):
41    """Raised when remote API reports an error.
42
43    This error mirrors the JSON-RPC 2.0 spec for Error Response objects.
44
45    Attributes:
46        code: The error code returned by SL4A. Not to be confused with
47            ActsError's error_code.
48        message: The error message returned by SL4A.
49        data: The extra data, if any, returned by SL4A.
50    """
51
52    def __init__(self, message, code=-1, data=None, rpc_name=''):
53        super().__init__()
54        self.message = message
55        self.code = code
56        if data is None:
57            self.data = {}
58        else:
59            self.data = data
60        self.rpc_name = rpc_name
61
62    def __str__(self):
63        if self.data:
64            return 'Error in RPC %s %s:%s:%s' % (self.rpc_name, self.code,
65                                                 self.message, self.data)
66        else:
67            return 'Error in RPC %s %s:%s' % (self.rpc_name, self.code,
68                                              self.message)
69
70
71class Sl4aConnectionError(Sl4aException):
72    """An error raised upon failure to connect to SL4A."""
73
74
75class Sl4aProtocolError(Sl4aException):
76    """Raised when there an error in exchanging data with server on device."""
77    NO_RESPONSE_FROM_HANDSHAKE = 'No response from handshake.'
78    NO_RESPONSE_FROM_SERVER = 'No response from server.'
79    MISMATCHED_API_ID = 'Mismatched API id.'
80
81
82class Sl4aNotInstalledError(Sl4aException):
83    """An error raised when an Sl4aClient is created without SL4A installed."""
84
85
86class Sl4aRpcTimeoutError(Sl4aException):
87    """An error raised when an SL4A RPC has timed out."""
88
89
90class RpcClient(object):
91    """An RPC client capable of processing multiple RPCs concurrently.
92
93    Attributes:
94        _free_connections: A list of all idle RpcConnections.
95        _working_connections: A list of all working RpcConnections.
96        _lock: A lock used for accessing critical memory.
97        max_connections: The maximum number of RpcConnections at a time.
98            Increasing or decreasing the number of max connections does NOT
99            modify the thread pool size being used for self.future RPC calls.
100        _log: The logger for this RpcClient.
101    """
102    """The default value for the maximum amount of connections for a client."""
103    DEFAULT_MAX_CONNECTION = 15
104
105    class AsyncClient(object):
106        """An object that allows RPC calls to be called asynchronously.
107
108        Attributes:
109            _rpc_client: The RpcClient to use when making calls.
110            _executor: The ThreadPoolExecutor used to keep track of workers
111        """
112
113        def __init__(self, rpc_client):
114            self._rpc_client = rpc_client
115            self._executor = futures.ThreadPoolExecutor(
116                max_workers=max(rpc_client.max_connections - 2, 1))
117
118        def rpc(self, name, *args, **kwargs):
119            future = self._executor.submit(name, *args, **kwargs)
120            return future
121
122        def __getattr__(self, name):
123            """Wrapper for python magic to turn method calls into RPC calls."""
124
125            def rpc_call(*args, **kwargs):
126                future = self._executor.submit(
127                    self._rpc_client.__getattr__(name), *args, **kwargs)
128                return future
129
130            return rpc_call
131
132    def __init__(self,
133                 uid,
134                 serial,
135                 on_error_callback,
136                 _create_connection_func,
137                 max_connections=None):
138        """Creates a new RpcClient object.
139
140        Args:
141            uid: The session uid this client is a part of.
142            serial: The serial of the Android device. Used for logging.
143            on_error_callback: A callback for when a connection error is raised.
144            _create_connection_func: A reference to the function that creates a
145                new session.
146            max_connections: The maximum number of connections the RpcClient
147                can have.
148        """
149        self._serial = serial
150        self.on_error = on_error_callback
151        self._create_connection_func = _create_connection_func
152        self._free_connections = [self._create_connection_func(uid)]
153
154        self.uid = self._free_connections[0].uid
155        self._lock = threading.Lock()
156
157        def _log_formatter(message):
158            """Formats the message to be logged."""
159            return '[RPC Service|%s|%s] %s' % (self._serial, self.uid, message)
160
161        self._log = logger.create_logger(_log_formatter)
162
163        self._working_connections = []
164        if max_connections is None:
165            self.max_connections = RpcClient.DEFAULT_MAX_CONNECTION
166        else:
167            self.max_connections = max_connections
168
169        self._async_client = RpcClient.AsyncClient(self)
170        self.is_alive = True
171
172    def terminate(self):
173        """Terminates all connections to the SL4A server."""
174        if len(self._working_connections) > 0:
175            self._log.warning(
176                '%s connections are still active, and waiting on '
177                'responses.Closing these connections now.' % len(
178                    self._working_connections))
179        connections = self._free_connections + self._working_connections
180        for connection in connections:
181            self._log.debug(
182                'Closing connection over ports %s' % connection.ports)
183            connection.close()
184        self._free_connections = []
185        self._working_connections = []
186        self.is_alive = False
187
188    def _get_free_connection(self):
189        """Returns a free connection to be used for an RPC call.
190
191        This function also adds the client to the working set to prevent
192        multiple users from obtaining the same client.
193        """
194        while True:
195            if len(self._free_connections) > 0:
196                with self._lock:
197                    # Check if another thread grabbed the remaining connection.
198                    # while we were waiting for the lock.
199                    if len(self._free_connections) == 0:
200                        continue
201                    client = self._free_connections.pop()
202                    self._working_connections.append(client)
203                    return client
204
205            client_count = (len(self._free_connections) +
206                            len(self._working_connections))
207            if client_count < self.max_connections:
208                with self._lock:
209                    client_count = (len(self._free_connections) +
210                                    len(self._working_connections))
211                    if client_count < self.max_connections:
212                        client = self._create_connection_func(self.uid)
213                        self._working_connections.append(client)
214                        return client
215            time.sleep(.01)
216
217    def _release_working_connection(self, connection):
218        """Marks a working client as free.
219
220        Args:
221            connection: The client to mark as free.
222        Raises:
223            A ValueError if the client is not a known working connection.
224        """
225        # We need to keep this code atomic because the client count is based on
226        # the length of the free and working connection list lengths.
227        with self._lock:
228            self._working_connections.remove(connection)
229            self._free_connections.append(connection)
230
231    def rpc(self, method, *args, timeout=None, retries=3):
232        """Sends an rpc to sl4a.
233
234        Sends an rpc call to sl4a over this RpcClient's corresponding session.
235
236        Args:
237            method: str, The name of the method to execute.
238            args: any, The args to send to sl4a.
239            timeout: The amount of time to wait for a response.
240            retries: Misnomer, is actually the number of tries.
241
242        Returns:
243            The result of the rpc.
244
245        Raises:
246            Sl4aProtocolError: Something went wrong with the sl4a protocol.
247            Sl4aApiError: The rpc went through, however executed with errors.
248        """
249        connection = self._get_free_connection()
250        ticket = connection.get_new_ticket()
251        timed_out = False
252        if timeout:
253            connection.set_timeout(timeout)
254        data = {'id': ticket, 'method': method, 'params': args}
255        request = json.dumps(data)
256        response = ''
257        try:
258            for i in range(1, retries + 1):
259                connection.send_request(request)
260
261                response = connection.get_response()
262                if not response:
263                    if i < retries:
264                        self._log.warning(
265                            'No response for RPC method %s on iteration %s',
266                            method, i)
267                        continue
268                    else:
269                        self._log.exception(
270                            'No response for RPC method %s on iteration %s',
271                            method, i)
272                        self.on_error(connection)
273                        raise Sl4aProtocolError(
274                            Sl4aProtocolError.NO_RESPONSE_FROM_SERVER)
275                else:
276                    break
277        except BrokenPipeError as e:
278            if self.is_alive:
279                self._log.exception('The device disconnected during RPC call '
280                                    '%s. Please check the logcat for a crash '
281                                    'or disconnect.', method)
282                self.on_error(connection)
283            else:
284                self._log.warning('The connection was killed during cleanup:')
285                self._log.warning(e)
286            raise Sl4aConnectionError(e)
287        except socket.timeout as err:
288            # If a socket connection has timed out, the socket can no longer be
289            # used. Close it out and remove the socket from the connection pool.
290            timed_out = True
291            self._log.warning('RPC "%s" (id: %s) timed out after %s seconds.',
292                              method, ticket, timeout or SOCKET_TIMEOUT)
293            self._log.debug(
294                'Closing timed out connection over %s' % connection.ports)
295            connection.close()
296            self._working_connections.remove(connection)
297            # Re-raise the error as an SL4A Error so end users can process it.
298            raise Sl4aRpcTimeoutError(err)
299        finally:
300            if not timed_out:
301                if timeout:
302                    connection.set_timeout(SOCKET_TIMEOUT)
303                self._release_working_connection(connection)
304        result = json.loads(str(response, encoding='utf8'))
305
306        if result['error']:
307            error_object = result['error']
308            if isinstance(error_object, dict):
309                # Uses JSON-RPC 2.0 Format
310                sl4a_api_error = Sl4aApiError(error_object.get('message', None),
311                                              error_object.get('code', -1),
312                                              error_object.get('data', {}),
313                                              rpc_name=method)
314            else:
315                # Fallback on JSON-RPC 1.0 Format
316                sl4a_api_error = Sl4aApiError(error_object, rpc_name=method)
317            self._log.warning(sl4a_api_error)
318            raise sl4a_api_error
319        if result['id'] != ticket:
320            self._log.error('RPC method %s with mismatched api id %s', method,
321                            result['id'])
322            raise Sl4aProtocolError(Sl4aProtocolError.MISMATCHED_API_ID)
323        return result['result']
324
325    @property
326    def future(self):
327        """Returns a magic function that returns a future running an RPC call.
328
329        This function effectively allows the idiom:
330
331        >>> rpc_client = RpcClient(...)
332        >>> # returns after call finishes
333        >>> rpc_client.someRpcCall()
334        >>> # Immediately returns a reference to the RPC's future, running
335        >>> # the lengthy RPC call on another thread.
336        >>> future = rpc_client.future.someLengthyRpcCall()
337        >>> rpc_client.doOtherThings()
338        >>> ...
339        >>> # Wait for and get the returned value of the lengthy RPC.
340        >>> # Can specify a timeout as well.
341        >>> value = future.result()
342
343        The number of concurrent calls to this method is limited to
344        (max_connections - 2), to prevent future calls from exhausting all free
345        connections.
346        """
347        return self._async_client
348
349    def __getattr__(self, name):
350        """Wrapper for python magic to turn method calls into RPC calls."""
351
352        def rpc_call(*args, **kwargs):
353            usage_metadata_logger.log_usage(self.__module__, name)
354            return self.rpc(name, *args, **kwargs)
355
356        if not self.is_alive:
357            raise Sl4aStartError(
358                'This SL4A session has already been terminated. You must '
359                'create a new session to continue.')
360        return rpc_call
361