1#!/usr/bin/env python3
2#
3#   Copyright 2020 - The Android Open Source Project
4#
5#   Licensed under the Apache License, Version 2.0 (the "License");
6#   you may not use this file except in compliance with the License.
7#   You may obtain a copy of the License at
8#
9#       http://www.apache.org/licenses/LICENSE-2.0
10#
11#   Unless required by applicable law or agreed to in writing, software
12#   distributed under the License is distributed on an "AS IS" BASIS,
13#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14#   See the License for the specific language governing permissions and
15#   limitations under the License.
16
17from cert.closable import Closable
18from cert.closable import safeClose
19from cert.py_le_acl_manager import PyLeAclManager
20from cert.truth import assertThat
21import bluetooth_packets_python3 as bt_packets
22from bluetooth_packets_python3 import l2cap_packets
23from bluetooth_packets_python3.l2cap_packets import LeCommandCode
24from bluetooth_packets_python3.l2cap_packets import LeCreditBasedConnectionResponseResult
25from cert.event_stream import FilteringEventStream
26from cert.event_stream import IEventStream
27from cert.matchers import L2capMatchers
28from cert.captures import L2capCaptures
29from mobly import asserts
30
31
32class CertLeL2capChannel(IEventStream):
33
34    def __init__(self, device, scid, dcid, acl_stream, acl, control_channel, initial_credits=0):
35        self._device = device
36        self._scid = scid
37        self._dcid = dcid
38        self._acl_stream = acl_stream
39        self._acl = acl
40        self._control_channel = control_channel
41        self._our_acl_view = FilteringEventStream(acl_stream, L2capMatchers.ExtractBasicFrame(scid))
42        self._credits_left = initial_credits
43
44    def get_event_queue(self):
45        return self._our_acl_view.get_event_queue()
46
47    def send(self, packet):
48        frame = l2cap_packets.BasicFrameBuilder(self._dcid, packet)
49        self._acl.send(frame.Serialize())
50        self._credits_left -= 1
51
52    def send_first_le_i_frame(self, sdu_size, packet):
53        frame = l2cap_packets.FirstLeInformationFrameBuilder(self._dcid, sdu_size, packet)
54        self._acl.send(frame.Serialize())
55        self._credits_left -= 1
56
57    def disconnect_and_verify(self):
58        assertThat(self._scid).isNotEqualTo(1)
59        self._control_channel.send(l2cap_packets.LeDisconnectionRequestBuilder(1, self._dcid, self._scid))
60
61        assertThat(self._control_channel).emits(L2capMatchers.LeDisconnectionResponse(self._scid, self._dcid))
62
63    def verify_disconnect_request(self):
64        assertThat(self._control_channel).emits(L2capMatchers.LeDisconnectionRequest(self._dcid, self._scid))
65
66    def send_credits(self, num_credits):
67        self._control_channel.send(l2cap_packets.LeFlowControlCreditBuilder(2, self._scid, num_credits))
68
69    def credits_left(self):
70        return self._credits_left
71
72
73class CertLeL2cap(Closable):
74
75    def __init__(self, device):
76        self._device = device
77        self._le_acl_manager = PyLeAclManager(device)
78        self._le_acl = None
79
80        self.control_table = {
81            LeCommandCode.DISCONNECTION_REQUEST: self._on_disconnection_request_default,
82            LeCommandCode.DISCONNECTION_RESPONSE: self._on_disconnection_response_default,
83            LeCommandCode.LE_FLOW_CONTROL_CREDIT: self._on_credit,
84        }
85
86        self._cid_to_cert_channels = {}
87
88    def close(self):
89        self._le_acl_manager.close()
90        safeClose(self._le_acl)
91
92    def connect_le_acl(self, remote_addr):
93        self._le_acl = self._le_acl_manager.connect_to_remote(remote_addr)
94        self.control_channel = CertLeL2capChannel(
95            self._device, 5, 5, self._get_acl_stream(), self._le_acl, control_channel=None)
96        self._get_acl_stream().register_callback(self._handle_control_packet)
97
98    def wait_for_connection(self):
99        self._le_acl = self._le_acl_manager.wait_for_connection()
100        self.control_channel = CertLeL2capChannel(
101            self._device, 5, 5, self._get_acl_stream(), self._le_acl, control_channel=None)
102        self._get_acl_stream().register_callback(self._handle_control_packet)
103
104    def open_fixed_channel(self, cid=4):
105        channel = CertLeL2capChannel(self._device, cid, cid, self._get_acl_stream(), self._le_acl, None, 0)
106        return channel
107
108    def open_channel(self, signal_id, psm, scid, mtu=1000, mps=100, initial_credit=6):
109        self.control_channel.send(
110            l2cap_packets.LeCreditBasedConnectionRequestBuilder(signal_id, psm, scid, mtu, mps, initial_credit))
111
112        response = L2capCaptures.CreditBasedConnectionResponse()
113        assertThat(self.control_channel).emits(response)
114        channel = CertLeL2capChannel(self._device, scid,
115                                     response.get().GetDestinationCid(), self._get_acl_stream(), self._le_acl,
116                                     self.control_channel,
117                                     response.get().GetInitialCredits())
118        self._cid_to_cert_channels[scid] = channel
119        return channel
120
121    def open_channel_with_expected_result(self, psm=0x33, result=LeCreditBasedConnectionResponseResult.SUCCESS):
122        self.control_channel.send(l2cap_packets.LeCreditBasedConnectionRequestBuilder(1, psm, 0x40, 1000, 100, 6))
123
124        response = L2capMatchers.CreditBasedConnectionResponse(result)
125        assertThat(self.control_channel).emits(response)
126
127    def verify_and_respond_open_channel_from_remote(self,
128                                                    psm=0x33,
129                                                    result=LeCreditBasedConnectionResponseResult.SUCCESS,
130                                                    our_scid=None):
131        request = L2capCaptures.CreditBasedConnectionRequest(psm)
132        assertThat(self.control_channel).emits(request)
133        (scid, dcid) = self._respond_connection_request_default(request.get(), result, our_scid)
134        channel = CertLeL2capChannel(self._device, scid, dcid, self._get_acl_stream(), self._le_acl,
135                                     self.control_channel,
136                                     request.get().GetInitialCredits())
137        self._cid_to_cert_channels[scid] = channel
138        return channel
139
140    def verify_and_reject_open_channel_from_remote(self, psm=0x33):
141        request = L2capCaptures.CreditBasedConnectionRequest(psm)
142        assertThat(self.control_channel).emits(request)
143        sid = request.get().GetIdentifier()
144        reject = l2cap_packets.LeCommandRejectNotUnderstoodBuilder(sid)
145        self.control_channel.send(reject)
146
147    def verify_le_flow_control_credit(self, channel):
148        assertThat(self.control_channel).emits(L2capMatchers.LeFlowControlCredit(channel._dcid))
149
150    def _respond_connection_request_default(self,
151                                            request,
152                                            result=LeCreditBasedConnectionResponseResult.SUCCESS,
153                                            our_scid=None):
154        sid = request.GetIdentifier()
155        their_scid = request.GetSourceCid()
156        mtu = request.GetMtu()
157        mps = request.GetMps()
158        initial_credits = request.GetInitialCredits()
159        # If our_scid is not specified, we use the same value - their scid as their scid
160        if our_scid is None:
161            our_scid = their_scid
162        our_dcid = their_scid
163        response = l2cap_packets.LeCreditBasedConnectionResponseBuilder(sid, our_scid, mtu, mps, initial_credits,
164                                                                        result)
165        self.control_channel.send(response)
166        return (our_scid, our_dcid)
167
168    def get_control_channel(self):
169        return self.control_channel
170
171    def _get_acl_stream(self):
172        return self._le_acl.acl_stream
173
174    def _on_disconnection_request_default(self, request):
175        disconnection_request = l2cap_packets.LeDisconnectionRequestView(request)
176        sid = disconnection_request.GetIdentifier()
177        scid = disconnection_request.GetSourceCid()
178        dcid = disconnection_request.GetDestinationCid()
179        response = l2cap_packets.LeDisconnectionResponseBuilder(sid, dcid, scid)
180        self.control_channel.send(response)
181
182    def _on_disconnection_response_default(self, request):
183        disconnection_response = l2cap_packets.LeDisconnectionResponseView(request)
184
185    def _on_credit(self, l2cap_le_control_view):
186        credit_view = l2cap_packets.LeFlowControlCreditView(l2cap_le_control_view)
187        cid = credit_view.GetCid()
188        if cid not in self._cid_to_cert_channels:
189            return
190        self._cid_to_cert_channels[cid]._credits_left += credit_view.GetCredits()
191
192    def _handle_control_packet(self, l2cap_packet):
193        packet_bytes = l2cap_packet.payload
194        l2cap_view = l2cap_packets.BasicFrameView(bt_packets.PacketViewLittleEndian(list(packet_bytes)))
195        if l2cap_view.GetChannelId() != 5:
196            return
197        request = l2cap_packets.LeControlView(l2cap_view.GetPayload())
198        fn = self.control_table.get(request.GetCode())
199        if fn is not None:
200            fn(request)
201        return
202