1#!/usr/bin/env python3 2# 3# Copyright 2019 - The Android Open Source Project 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16 17from abc import ABC, abstractmethod 18from concurrent.futures import ThreadPoolExecutor 19from datetime import datetime, timedelta 20import logging 21from queue import SimpleQueue, Empty 22 23from mobly import asserts 24 25from google.protobuf import text_format 26 27from grpc import RpcError 28 29from cert.closable import Closable 30 31 32class IEventStream(ABC): 33 34 @abstractmethod 35 def get_event_queue(self): 36 pass 37 38 39class FilteringEventStream(IEventStream): 40 41 def __init__(self, stream, filter_fn): 42 self.filter_fn = filter_fn if filter_fn else lambda x: x 43 self.event_queue = SimpleQueue() 44 self.stream = stream 45 46 self.stream.register_callback(self.__event_callback, lambda packet: self.filter_fn(packet) is not None) 47 48 def __event_callback(self, event): 49 self.event_queue.put(self.filter_fn(event)) 50 51 def get_event_queue(self): 52 return self.event_queue 53 54 def unregister(self): 55 self.stream.unregister(self.__event_callback) 56 57 58def pretty_print(proto_event): 59 return '{} {}'.format(type(proto_event).__name__, text_format.MessageToString(proto_event, as_one_line=True)) 60 61 62DEFAULT_TIMEOUT_SECONDS = 3 63 64 65class EventStream(IEventStream, Closable): 66 """ 67 A class that streams events from a gRPC stream, which you can assert on. 68 69 Don't use these asserts directly, use the ones from cert.truth. 70 """ 71 72 def __init__(self, server_stream_call): 73 if server_stream_call is None: 74 raise ValueError("server_stream_call cannot be None") 75 76 self.server_stream_call = server_stream_call 77 self.event_queue = SimpleQueue() 78 self.handlers = [] 79 self.executor = ThreadPoolExecutor() 80 self.future = self.executor.submit(EventStream.__event_loop, self) 81 82 def get_event_queue(self): 83 return self.event_queue 84 85 def close(self): 86 """ 87 Stop the gRPC lambda so that event_callback will not be invoked after 88 the method returns. 89 90 This object will be useless after this call as there is no way to 91 restart the gRPC callback. You would have to create a new EventStream 92 93 :raise None on success, or the same exception as __event_loop(), or 94 concurrent.futures.TimeoutError if underlying stream failed to 95 terminate within DEFAULT_TIMEOUT_SECONDS 96 """ 97 # Try to cancel the execution, don't care the result, non-blocking 98 self.server_stream_call.cancel() 99 try: 100 # cancelling gRPC stream should cause __event_loop() to quit 101 # same exception will be raised by future.result() or 102 # concurrent.futures.TimeoutError will be raised after timeout 103 self.future.result(timeout=DEFAULT_TIMEOUT_SECONDS) 104 finally: 105 # Make sure we force shutdown the executor regardless of the result 106 self.executor.shutdown(wait=False) 107 108 def register_callback(self, callback, matcher_fn=None): 109 """ 110 Register a callback to handle events. Event will be handled by callback 111 if matcher_fn(event) returns True 112 113 callback and matcher are registered as a tuple. Hence the same callback 114 with different matcher are considered two different handler units. Same 115 matcher, but different callback are also considered different handling 116 unit 117 118 Callback will be invoked on a ThreadPoolExecutor owned by this 119 EventStream 120 121 :param callback: Will be called as callback(event) 122 :param matcher_fn: A boolean function that returns True or False when 123 calling matcher_fn(event), if None, all event will 124 be matched 125 """ 126 if callback is None: 127 raise ValueError("callback must not be None") 128 self.handlers.append((callback, matcher_fn)) 129 130 def unregister_callback(self, callback, matcher_fn=None): 131 """ 132 Unregister callback and matcher_fn from the event stream. Both objects 133 must match exactly the ones when calling register_callback() 134 135 :param callback: callback used in register_callback() 136 :param matcher_fn: matcher_fn used in register_callback() 137 :raises ValueError when (callback, matcher_fn) tuple is not found 138 """ 139 if callback is None: 140 raise ValueError("callback must not be None") 141 self.handlers.remove((callback, matcher_fn)) 142 143 def __event_loop(self): 144 """ 145 Main loop for consuming the gRPC stream events. 146 Blocks until computation is cancelled 147 :raise grpc.Error on failure 148 """ 149 try: 150 for event in self.server_stream_call: 151 self.event_queue.put(event) 152 for (callback, matcher_fn) in self.handlers: 153 if not matcher_fn or matcher_fn(event): 154 callback(event) 155 except RpcError as exp: 156 # Underlying gRPC stream should run indefinitely until cancelled 157 # Hence any other reason besides CANCELLED is raised as an error 158 if self.server_stream_call.cancelled(): 159 logging.debug("Cancelled") 160 else: 161 raise exp 162 163 def assert_event_occurs(self, match_fn, at_least_times=1, timeout=timedelta(seconds=DEFAULT_TIMEOUT_SECONDS)): 164 """ 165 Assert at least |at_least_times| instances of events happen where 166 match_fn(event) returns True within timeout period 167 168 :param match_fn: returns True/False on match_fn(event) 169 :param timeout: a timedelta object 170 :param at_least_times: how many times at least a matching event should 171 happen 172 :return: 173 """ 174 NOT_FOR_YOU_assert_event_occurs(self, match_fn, at_least_times, timeout) 175 176 def assert_event_occurs_at_most(self, match_fn, at_most_times, timeout=timedelta(seconds=DEFAULT_TIMEOUT_SECONDS)): 177 """ 178 Assert at most |at_most_times| instances of events happen where 179 match_fn(event) returns True within timeout period 180 181 :param match_fn: returns True/False on match_fn(event) 182 :param at_most_times: how many times at most a matching event should 183 happen 184 :param timeout:a timedelta object 185 :return: 186 """ 187 logging.debug("assert_event_occurs_at_most") 188 event_list = [] 189 end_time = datetime.now() + timeout 190 while len(event_list) <= at_most_times and datetime.now() < end_time: 191 remaining = static_remaining_time_delta(end_time) 192 logging.debug("Waiting for event iteration (%fs remaining)" % (remaining.total_seconds())) 193 try: 194 current_event = self.event_queue.get(timeout=remaining.total_seconds()) 195 if match_fn(current_event): 196 event_list.append(current_event) 197 except Empty: 198 continue 199 logging.debug("Done waiting, got %d events" % len(event_list)) 200 asserts.assert_true( 201 len(event_list) <= at_most_times, 202 msg=("Expected at most %d events, but got %d" % (at_most_times, len(event_list)))) 203 204 205def static_remaining_time_delta(end_time): 206 remaining = end_time - datetime.now() 207 if remaining < timedelta(milliseconds=0): 208 remaining = timedelta(milliseconds=0) 209 return remaining 210 211 212def NOT_FOR_YOU_assert_event_occurs(istream, 213 match_fn, 214 at_least_times=1, 215 timeout=timedelta(seconds=DEFAULT_TIMEOUT_SECONDS)): 216 logging.debug("assert_event_occurs %d %fs" % (at_least_times, timeout.total_seconds())) 217 event_list = [] 218 end_time = datetime.now() + timeout 219 while len(event_list) < at_least_times and datetime.now() < end_time: 220 remaining = static_remaining_time_delta(end_time) 221 logging.debug("Waiting for event (%fs remaining)" % (remaining.total_seconds())) 222 try: 223 current_event = istream.get_event_queue().get(timeout=remaining.total_seconds()) 224 logging.debug("current_event: %s", current_event) 225 if match_fn(current_event): 226 event_list.append(current_event) 227 except Empty: 228 continue 229 logging.debug("Done waiting for event, received %d", len(event_list)) 230 asserts.assert_true( 231 len(event_list) >= at_least_times, 232 msg=("Expected at least %d events, but got %d" % (at_least_times, len(event_list)))) 233 234 235def NOT_FOR_YOU_assert_all_events_occur(istream, 236 match_fns, 237 order_matters, 238 timeout=timedelta(seconds=DEFAULT_TIMEOUT_SECONDS)): 239 logging.debug("assert_all_events_occur %fs" % timeout.total_seconds()) 240 pending_matches = list(match_fns) 241 matched_order = [] 242 end_time = datetime.now() + timeout 243 while len(pending_matches) > 0 and datetime.now() < end_time: 244 remaining = static_remaining_time_delta(end_time) 245 logging.debug("Waiting for event (%fs remaining)" % (remaining.total_seconds())) 246 try: 247 current_event = istream.get_event_queue().get(timeout=remaining.total_seconds()) 248 for match_fn in pending_matches: 249 if match_fn(current_event): 250 pending_matches.remove(match_fn) 251 matched_order.append(match_fn) 252 except Empty: 253 continue 254 logging.debug("Done waiting for event") 255 asserts.assert_true( 256 len(matched_order) == len(match_fns), 257 msg=("Expected at least %d events, but got %d" % (len(match_fns), len(matched_order)))) 258 if order_matters: 259 correct_order = True 260 i = 0 261 while i < len(match_fns): 262 if match_fns[i] is not matched_order[i]: 263 correct_order = False 264 break 265 i += 1 266 asserts.assert_true(correct_order, "Events not received in correct order %s %s" % (match_fns, matched_order)) 267 268 269def NOT_FOR_YOU_assert_none_matching(istream, match_fn, timeout=timedelta(seconds=DEFAULT_TIMEOUT_SECONDS)): 270 logging.debug("assert_none_matching %fs" % (timeout.total_seconds())) 271 event = None 272 end_time = datetime.now() + timeout 273 while event is None and datetime.now() < end_time: 274 remaining = static_remaining_time_delta(end_time) 275 logging.debug("Waiting for event (%fs remaining)" % (remaining.total_seconds())) 276 try: 277 current_event = istream.get_event_queue().get(timeout=remaining.total_seconds()) 278 if match_fn(current_event): 279 event = current_event 280 except Empty: 281 continue 282 logging.debug("Done waiting for an event") 283 if event is None: 284 return # Avoid an assert in MessageToString(None, ...) 285 asserts.assert_true(event is None, msg='Expected None matching, but got {}'.format(pretty_print(event))) 286 287 288def NOT_FOR_YOU_assert_none(istream, timeout=timedelta(seconds=DEFAULT_TIMEOUT_SECONDS)): 289 logging.debug("assert_none %fs" % (timeout.total_seconds())) 290 try: 291 event = istream.get_event_queue().get(timeout=timeout.total_seconds()) 292 asserts.assert_true(event is None, msg='Expected None, but got {}'.format(pretty_print(event))) 293 except Empty: 294 return 295