1#!/usr/bin/env python3
2#
3#   Copyright 2020 - The Android Open Source Project
4#
5#   Licensed under the Apache License, Version 2.0 (the "License");
6#   you may not use this file except in compliance with the License.
7#   You may obtain a copy of the License at
8#
9#       http://www.apache.org/licenses/LICENSE-2.0
10#
11#   Unless required by applicable law or agreed to in writing, software
12#   distributed under the License is distributed on an "AS IS" BASIS,
13#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14#   See the License for the specific language governing permissions and
15#   limitations under the License.
16
17from abc import ABC, abstractmethod
18from datetime import datetime, timedelta
19from mobly import signals
20from threading import Condition
21
22from cert.event_stream import static_remaining_time_delta
23from cert.truth import assertThat
24
25
26class IHasBehaviors(ABC):
27
28    @abstractmethod
29    def get_behaviors(self):
30        pass
31
32
33def anything():
34    return lambda obj: True
35
36
37def when(has_behaviors):
38    assertThat(isinstance(has_behaviors, IHasBehaviors)).isTrue()
39    return has_behaviors.get_behaviors()
40
41
42def IGNORE_UNHANDLED(obj):
43    pass
44
45
46class SingleArgumentBehavior(object):
47
48    def __init__(self, reply_stage_factory):
49        self._reply_stage_factory = reply_stage_factory
50        self._instances = []
51        self._invoked_obj = []
52        self._invoked_condition = Condition()
53        self.set_default_to_crash()
54
55    def begin(self, matcher):
56        return PersistenceStage(self, matcher, self._reply_stage_factory)
57
58    def append(self, behavior_instance):
59        self._instances.append(behavior_instance)
60
61    def set_default(self, fn):
62        assertThat(fn).isNotNone()
63        self._default_fn = fn
64
65    def set_default_to_crash(self):
66        self._default_fn = None
67
68    def set_default_to_ignore(self):
69        self._default_fn = IGNORE_UNHANDLED
70
71    def run(self, obj):
72        for instance in self._instances:
73            if instance.try_run(obj):
74                self.__obj_invoked(obj)
75                return
76        if self._default_fn is not None:
77            # IGNORE_UNHANDLED is also a default fn
78            self._default_fn(obj)
79            self.__obj_invoked(obj)
80        else:
81            raise signals.TestFailure(
82                "%s: behavior for %s went unhandled" % (self._reply_stage_factory().__class__.__name__, obj),
83                extras=None)
84
85    def __obj_invoked(self, obj):
86        self._invoked_condition.acquire()
87        self._invoked_obj.append(obj)
88        self._invoked_condition.notify()
89        self._invoked_condition.release()
90
91    def wait_until_invoked(self, matcher, times, timeout):
92        end_time = datetime.now() + timeout
93        invoked_times = 0
94        while datetime.now() < end_time and invoked_times < times:
95            remaining = static_remaining_time_delta(end_time)
96            invoked_times = sum((matcher(i) for i in self._invoked_obj))
97            self._invoked_condition.acquire()
98            self._invoked_condition.wait(remaining.total_seconds())
99            self._invoked_condition.release()
100        return invoked_times == times
101
102
103class PersistenceStage(object):
104
105    def __init__(self, behavior, matcher, reply_stage_factory):
106        self._behavior = behavior
107        self._matcher = matcher
108        self._reply_stage_factory = reply_stage_factory
109
110    def then(self, times=1):
111        reply_stage = self._reply_stage_factory()
112        reply_stage.init(self._behavior, self._matcher, times)
113        return reply_stage
114
115    def always(self):
116        return self.then(times=-1)
117
118
119class ReplyStage(object):
120
121    def init(self, behavior, matcher, persistence):
122        self._behavior = behavior
123        self._matcher = matcher
124        self._persistence = persistence
125
126    def _commit(self, fn):
127        self._behavior.append(BehaviorInstance(self._matcher, self._persistence, fn))
128
129
130class BehaviorInstance(object):
131
132    def __init__(self, matcher, persistence, fn):
133        self._matcher = matcher
134        self._persistence = persistence
135        self._fn = fn
136        self._called_count = 0
137
138    def try_run(self, obj):
139        if not self._matcher(obj):
140            return False
141        if self._persistence >= 0:
142            if self._called_count >= self._persistence:
143                return False
144        self._called_count += 1
145        self._fn(obj)
146        return True
147
148
149class BoundVerificationStage(object):
150
151    def __init__(self, behavior, matcher, timeout):
152        self._behavior = behavior
153        self._matcher = matcher
154        self._timeout = timeout
155
156    def times(self, times=1):
157        return self._behavior.wait_until_invoked(self._matcher, times, self._timeout)
158
159
160class WaitForBehaviorSubject(object):
161
162    def __init__(self, behaviors, timeout):
163        self._behaviors = behaviors
164        self._timeout = timeout
165
166    def __getattr__(self, item):
167        behavior = getattr(self._behaviors, item + "_behavior")
168        t = self._timeout
169        return lambda matcher: BoundVerificationStage(behavior, matcher, t)
170
171
172def wait_until(i_has_behaviors, timeout=timedelta(seconds=3)):
173    return WaitForBehaviorSubject(i_has_behaviors.get_behaviors(), timeout)
174