1#!/usr/bin/env python3
2#
3#   Copyright 2019 - The Android Open Source Project
4#
5#   Licensed under the Apache License, Version 2.0 (the "License");
6#   you may not use this file except in compliance with the License.
7#   You may obtain a copy of the License at
8#
9#       http://www.apache.org/licenses/LICENSE-2.0
10#
11#   Unless required by applicable law or agreed to in writing, software
12#   distributed under the License is distributed on an "AS IS" BASIS,
13#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14#   See the License for the specific language governing permissions and
15#   limitations under the License.
16
17from abc import ABC, abstractmethod
18from concurrent.futures import ThreadPoolExecutor
19from datetime import datetime, timedelta
20import logging
21from queue import SimpleQueue, Empty
22
23from mobly import asserts
24
25from google.protobuf import text_format
26
27from grpc import RpcError
28
29from cert.closable import Closable
30
31
32class IEventStream(ABC):
33
34    @abstractmethod
35    def get_event_queue(self):
36        pass
37
38
39class FilteringEventStream(IEventStream):
40
41    def __init__(self, stream, filter_fn):
42        self.filter_fn = filter_fn if filter_fn else lambda x: x
43        self.event_queue = SimpleQueue()
44        self.stream = stream
45
46        self.stream.register_callback(self.__event_callback, lambda packet: self.filter_fn(packet) is not None)
47
48    def __event_callback(self, event):
49        self.event_queue.put(self.filter_fn(event))
50
51    def get_event_queue(self):
52        return self.event_queue
53
54    def unregister(self):
55        self.stream.unregister(self.__event_callback)
56
57
58def pretty_print(proto_event):
59    return '{} {}'.format(type(proto_event).__name__, text_format.MessageToString(proto_event, as_one_line=True))
60
61
62DEFAULT_TIMEOUT_SECONDS = 3
63
64
65class EventStream(IEventStream, Closable):
66    """
67    A class that streams events from a gRPC stream, which you can assert on.
68
69    Don't use these asserts directly, use the ones from cert.truth.
70    """
71
72    def __init__(self, server_stream_call):
73        if server_stream_call is None:
74            raise ValueError("server_stream_call cannot be None")
75
76        self.server_stream_call = server_stream_call
77        self.event_queue = SimpleQueue()
78        self.handlers = []
79        self.executor = ThreadPoolExecutor()
80        self.future = self.executor.submit(EventStream.__event_loop, self)
81
82    def get_event_queue(self):
83        return self.event_queue
84
85    def close(self):
86        """
87        Stop the gRPC lambda so that event_callback will not be invoked after
88        the method returns.
89
90        This object will be useless after this call as there is no way to
91        restart the gRPC callback. You would have to create a new EventStream
92
93        :raise None on success, or the same exception as __event_loop(), or
94               concurrent.futures.TimeoutError if underlying stream failed to
95               terminate within DEFAULT_TIMEOUT_SECONDS
96        """
97        # Try to cancel the execution, don't care the result, non-blocking
98        self.server_stream_call.cancel()
99        try:
100            # cancelling gRPC stream should cause __event_loop() to quit
101            # same exception will be raised by future.result() or
102            # concurrent.futures.TimeoutError will be raised after timeout
103            self.future.result(timeout=DEFAULT_TIMEOUT_SECONDS)
104        finally:
105            # Make sure we force shutdown the executor regardless of the result
106            self.executor.shutdown(wait=False)
107
108    def register_callback(self, callback, matcher_fn=None):
109        """
110        Register a callback to handle events. Event will be handled by callback
111        if matcher_fn(event) returns True
112
113        callback and matcher are registered as a tuple. Hence the same callback
114        with different matcher are considered two different handler units. Same
115        matcher, but different callback are also considered different handling
116        unit
117
118        Callback will be invoked on a ThreadPoolExecutor owned by this
119        EventStream
120
121        :param callback: Will be called as callback(event)
122        :param matcher_fn: A boolean function that returns True or False when
123                           calling matcher_fn(event), if None, all event will
124                           be matched
125        """
126        if callback is None:
127            raise ValueError("callback must not be None")
128        self.handlers.append((callback, matcher_fn))
129
130    def unregister_callback(self, callback, matcher_fn=None):
131        """
132        Unregister callback and matcher_fn from the event stream. Both objects
133        must match exactly the ones when calling register_callback()
134
135        :param callback: callback used in register_callback()
136        :param matcher_fn: matcher_fn used in register_callback()
137        :raises ValueError when (callback, matcher_fn) tuple is not found
138        """
139        if callback is None:
140            raise ValueError("callback must not be None")
141        self.handlers.remove((callback, matcher_fn))
142
143    def __event_loop(self):
144        """
145        Main loop for consuming the gRPC stream events.
146        Blocks until computation is cancelled
147        :raise grpc.Error on failure
148        """
149        try:
150            for event in self.server_stream_call:
151                self.event_queue.put(event)
152                for (callback, matcher_fn) in self.handlers:
153                    if not matcher_fn or matcher_fn(event):
154                        callback(event)
155        except RpcError as exp:
156            # Underlying gRPC stream should run indefinitely until cancelled
157            # Hence any other reason besides CANCELLED is raised as an error
158            if self.server_stream_call.cancelled():
159                logging.debug("Cancelled")
160            else:
161                raise exp
162
163    def assert_event_occurs(self, match_fn, at_least_times=1, timeout=timedelta(seconds=DEFAULT_TIMEOUT_SECONDS)):
164        """
165        Assert at least |at_least_times| instances of events happen where
166        match_fn(event) returns True within timeout period
167
168        :param match_fn: returns True/False on match_fn(event)
169        :param timeout: a timedelta object
170        :param at_least_times: how many times at least a matching event should
171                               happen
172        :return:
173        """
174        NOT_FOR_YOU_assert_event_occurs(self, match_fn, at_least_times, timeout)
175
176    def assert_event_occurs_at_most(self, match_fn, at_most_times, timeout=timedelta(seconds=DEFAULT_TIMEOUT_SECONDS)):
177        """
178        Assert at most |at_most_times| instances of events happen where
179        match_fn(event) returns True within timeout period
180
181        :param match_fn: returns True/False on match_fn(event)
182        :param at_most_times: how many times at most a matching event should
183                               happen
184        :param timeout:a timedelta object
185        :return:
186        """
187        logging.debug("assert_event_occurs_at_most")
188        event_list = []
189        end_time = datetime.now() + timeout
190        while len(event_list) <= at_most_times and datetime.now() < end_time:
191            remaining = static_remaining_time_delta(end_time)
192            logging.debug("Waiting for event iteration (%fs remaining)" % (remaining.total_seconds()))
193            try:
194                current_event = self.event_queue.get(timeout=remaining.total_seconds())
195                if match_fn(current_event):
196                    event_list.append(current_event)
197            except Empty:
198                continue
199        logging.debug("Done waiting, got %d events" % len(event_list))
200        asserts.assert_true(
201            len(event_list) <= at_most_times,
202            msg=("Expected at most %d events, but got %d" % (at_most_times, len(event_list))))
203
204
205def static_remaining_time_delta(end_time):
206    remaining = end_time - datetime.now()
207    if remaining < timedelta(milliseconds=0):
208        remaining = timedelta(milliseconds=0)
209    return remaining
210
211
212def NOT_FOR_YOU_assert_event_occurs(istream,
213                                    match_fn,
214                                    at_least_times=1,
215                                    timeout=timedelta(seconds=DEFAULT_TIMEOUT_SECONDS)):
216    logging.debug("assert_event_occurs %d %fs" % (at_least_times, timeout.total_seconds()))
217    event_list = []
218    end_time = datetime.now() + timeout
219    while len(event_list) < at_least_times and datetime.now() < end_time:
220        remaining = static_remaining_time_delta(end_time)
221        logging.debug("Waiting for event (%fs remaining)" % (remaining.total_seconds()))
222        try:
223            current_event = istream.get_event_queue().get(timeout=remaining.total_seconds())
224            logging.debug("current_event: %s", current_event)
225            if match_fn(current_event):
226                event_list.append(current_event)
227        except Empty:
228            continue
229    logging.debug("Done waiting for event, received %d", len(event_list))
230    asserts.assert_true(
231        len(event_list) >= at_least_times,
232        msg=("Expected at least %d events, but got %d" % (at_least_times, len(event_list))))
233
234
235def NOT_FOR_YOU_assert_all_events_occur(istream,
236                                        match_fns,
237                                        order_matters,
238                                        timeout=timedelta(seconds=DEFAULT_TIMEOUT_SECONDS)):
239    logging.debug("assert_all_events_occur %fs" % timeout.total_seconds())
240    pending_matches = list(match_fns)
241    matched_order = []
242    end_time = datetime.now() + timeout
243    while len(pending_matches) > 0 and datetime.now() < end_time:
244        remaining = static_remaining_time_delta(end_time)
245        logging.debug("Waiting for event (%fs remaining)" % (remaining.total_seconds()))
246        try:
247            current_event = istream.get_event_queue().get(timeout=remaining.total_seconds())
248            for match_fn in pending_matches:
249                if match_fn(current_event):
250                    pending_matches.remove(match_fn)
251                    matched_order.append(match_fn)
252        except Empty:
253            continue
254    logging.debug("Done waiting for event")
255    asserts.assert_true(
256        len(matched_order) == len(match_fns),
257        msg=("Expected at least %d events, but got %d" % (len(match_fns), len(matched_order))))
258    if order_matters:
259        correct_order = True
260        i = 0
261        while i < len(match_fns):
262            if match_fns[i] is not matched_order[i]:
263                correct_order = False
264                break
265            i += 1
266        asserts.assert_true(correct_order, "Events not received in correct order %s %s" % (match_fns, matched_order))
267
268
269def NOT_FOR_YOU_assert_none_matching(istream, match_fn, timeout=timedelta(seconds=DEFAULT_TIMEOUT_SECONDS)):
270    logging.debug("assert_none_matching %fs" % (timeout.total_seconds()))
271    event = None
272    end_time = datetime.now() + timeout
273    while event is None and datetime.now() < end_time:
274        remaining = static_remaining_time_delta(end_time)
275        logging.debug("Waiting for event (%fs remaining)" % (remaining.total_seconds()))
276        try:
277            current_event = istream.get_event_queue().get(timeout=remaining.total_seconds())
278            if match_fn(current_event):
279                event = current_event
280        except Empty:
281            continue
282    logging.debug("Done waiting for an event")
283    if event is None:
284        return  # Avoid an assert in MessageToString(None, ...)
285    asserts.assert_true(event is None, msg='Expected None matching, but got {}'.format(pretty_print(event)))
286
287
288def NOT_FOR_YOU_assert_none(istream, timeout=timedelta(seconds=DEFAULT_TIMEOUT_SECONDS)):
289    logging.debug("assert_none %fs" % (timeout.total_seconds()))
290    try:
291        event = istream.get_event_queue().get(timeout=timeout.total_seconds())
292        asserts.assert_true(event is None, msg='Expected None, but got {}'.format(pretty_print(event)))
293    except Empty:
294        return
295