1#!/usr/bin/env python3
2#
3#   Copyright 2018 - The Android Open Source Project
4#
5#   Licensed under the Apache License, Version 2.0 (the "License");
6#   you may not use this file except in compliance with the License.
7#   You may obtain a copy of the License at
8#
9#       http://www.apache.org/licenses/LICENSE-2.0
10#
11#   Unless required by applicable law or agreed to in writing, software
12#   distributed under the License is distributed on an "AS IS" BASIS,
13#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14#   See the License for the specific language governing permissions and
15#   limitations under the License.
16
17import os
18import re
19import select
20import subprocess
21import sys
22import time
23import uuid
24from threading import Thread
25
26import serial
27from serial.tools import list_ports
28
29from acts import tracelogger
30from logging import Logger
31
32logging = tracelogger.TakoTraceLogger(Logger(__file__))
33
34RETRIES = 0
35
36
37class LogSerialException(Exception):
38    """LogSerial Exception."""
39
40
41class PortCheck(object):
42    def get_serial_ports(self):
43        """Gets the computer available serial ports.
44
45        Returns:
46            Dictionary object with all the serial port names.
47        """
48        result = {}
49        ports = list_ports.comports()
50        for port_name, description, address in ports:
51            result[port_name] = (description, address)
52        return result
53
54    # TODO: Clean up this function. The boolean logic can be simplified.
55    def search_port_by_property(self, search_params):
56        """Search ports by a dictionary of the search parameters.
57
58        Args:
59            search_params: Dictionary object with the parameters
60                           to search. i.e:
61                           {'ID_SERIAL_SHORT':'213213',
62                           'ID_USB_INTERFACE_NUM': '01'}
63        Returns:
64            Array with the ports found
65        """
66        ports_result = []
67        for port in self.get_serial_ports():
68            properties = self.get_port_properties(port=port)
69            if properties:
70                properties_exists = True
71                for port_property in search_params:
72                    properties_exists *= (port_property in properties)
73                properties_exists = True if properties_exists == 1 else False
74                if properties_exists:
75                    found = True
76                    for port_property in search_params.keys():
77                        search_value = search_params[port_property]
78                        if properties[port_property] == search_value:
79                            found *= True
80                        else:
81                            found = False
82                            break
83                    found = True if found == 1 else False
84                    if found:
85                        ports_result.append(port)
86        return ports_result
87
88    def get_port_properties(self, port):
89        """Get all the properties from a given port.
90
91        Args:
92            port: String object with the port name. i.e. '/dev/ttyACM1'
93
94        Returns:
95            dictionary object with all the properties.
96        """
97        ports = self.get_serial_ports()
98        if port in ports:
99            result = {}
100            port_address = ports[port][1]
101            property_list = None
102            if sys.platform.startswith('linux') or sys.platform.startswith(
103                    'cygwin'):
104                try:
105                    command = 'udevadm info -q property -n {}'.format(port)
106                    property_list = subprocess.check_output(command, shell=True)
107                    property_list = property_list.decode(errors='replace')
108                except subprocess.CalledProcessError as error:
109                    logging.error(error)
110                if property_list:
111                    properties = filter(None, property_list.split('\n'))
112                    for prop in properties:
113                        p = prop.split('=')
114                        result[p[0]] = p[1]
115            elif sys.platform.startswith('win'):
116                regex = ('(?P<type>[A-Z]*)\sVID\:PID\=(?P<vid>\w*)'
117                         '\:(?P<pid>\w*)\s+(?P<adprop>.*$)')
118                m = re.search(regex, port_address)
119                if m:
120                    result['type'] = m.group('type')
121                    result['vid'] = m.group('vid')
122                    result['pid'] = m.group('pid')
123                    adprop = m.group('adprop').strip()
124                    if adprop:
125                        prop_array = adprop.split(' ')
126                        for prop in prop_array:
127                            p = prop.split('=')
128                            result[p[0]] = p[1]
129                    if 'LOCATION' in result:
130                        interface = int(result['LOCATION'].split('.')[1])
131                        if interface < 10:
132                            result['ID_USB_INTERFACE_NUM'] = '0{}'.format(
133                                interface)
134                        else:
135                            result['ID_USB_INTERFACE_NUM'] = '{}'.format(
136                                interface)
137                    win_vid_pid = '*VID_{}*PID_{}*'.format(result['vid'],
138                                                           result['pid'])
139                    command = (
140                            'powershell gwmi "Win32_USBControllerDevice |' +
141                            ' %{[wmi]($_.Dependent)} |' +
142                            ' Where-Object -Property PNPDeviceID -Like "' +
143                            win_vid_pid + '" |' +
144                            ' Where-Object -Property Service -Eq "usbccgp" |' +
145                            ' Select-Object -Property PNPDeviceID"')
146                    res = subprocess.check_output(command, shell=True)
147                    r = res.decode('ascii')
148                    m = re.search('USB\\\\.*', r)
149                    if m:
150                        result['ID_SERIAL_SHORT'] = (
151                            m.group().strip().split('\\')[2])
152            return result
153
154    def port_exists(self, port):
155        """Check if a serial port exists in the computer by the port name.
156
157        Args:
158            port: String object with the port name. i.e. '/dev/ttyACM1'
159
160        Returns:
161            True if it was found, False if not.
162        """
163        exists = port in self.get_serial_ports()
164        return exists
165
166
167class LogSerial(object):
168    def __init__(self,
169                 port,
170                 baudrate,
171                 bytesize=8,
172                 parity='N',
173                 stopbits=1,
174                 timeout=0.15,
175                 retries=0,
176                 flush_output=True,
177                 terminator='\n',
178                 output_path=None,
179                 serial_logger=None):
180        global RETRIES
181        self.set_log = False
182        self.output_path = None
183        self.set_output_path(output_path)
184        if serial_logger:
185            self.set_logger(serial_logger)
186        self.monitor_port = PortCheck()
187        if self.monitor_port.port_exists(port=port):
188            self.connection_handle = serial.Serial()
189            RETRIES = retries
190            self.reading = True
191            self.log = []
192            self.log_thread = Thread()
193            self.command_ini_index = None
194            self.is_logging = False
195            self.flush_output = flush_output
196            self.terminator = terminator
197            if port:
198                self.connection_handle.port = port
199            if baudrate:
200                self.connection_handle.baudrate = baudrate
201            if bytesize:
202                self.connection_handle.bytesize = bytesize
203            if parity:
204                self.connection_handle.parity = parity
205            if stopbits:
206                self.connection_handle.stopbits = stopbits
207            if timeout:
208                self.connection_handle.timeout = timeout
209            try:
210                self.open()
211            except Exception as e:
212                self.close()
213                logging.error(e)
214        else:
215            raise LogSerialException(
216                'The port {} does not exist'.format(port))
217
218    def set_logger(self, serial_logger):
219        global logging
220        logging = serial_logger
221        self.set_output_path(getattr(logging, 'output_path', '/tmp'))
222        self.set_log = True
223
224    def set_output_path(self, output_path):
225        """Set the output path for the flushed log.
226
227        Args:
228            output_path: String object with the path
229        """
230        if output_path:
231            if os.path.exists(output_path):
232                self.output_path = output_path
233            else:
234                raise LogSerialException('The output path does not exist.')
235
236    def refresh_port_connection(self, port):
237        """Will update the port connection without closing the read thread.
238
239        Args:
240            port: String object with the new port name. i.e. '/dev/ttyACM1'
241
242        Raises:
243            LogSerialException if the port is not alive.
244        """
245        if self.monitor_port.port_exists(port=port):
246            self.connection_handle.port = port
247            self.open()
248        else:
249            raise LogSerialException(
250                'The port {} does not exist'.format(port))
251
252    def is_port_alive(self):
253        """Verify if the current port is alive in the computer.
254
255        Returns:
256            True if its alive, False if its missing.
257        """
258        alive = self.monitor_port.port_exists(port=self.connection_handle.port)
259        return alive
260
261    # @retry(Exception, tries=RETRIES, delay=1, backoff=2)
262    def open(self):
263        """Will open the connection with the current port settings."""
264        while self.connection_handle.isOpen():
265            self.connection_handle.close()
266            time.sleep(0.5)
267        self.connection_handle.open()
268        if self.flush_output:
269            self.flush()
270        self.start_reading()
271        logging.info('Connection Open')
272
273    def close(self):
274        """Will close the connection and the read thread."""
275        self.stop_reading()
276        if self.connection_handle:
277            self.connection_handle.close()
278        if not self.set_log:
279            logging.flush_log()
280        self.flush_log()
281        logging.info('Connection Closed')
282
283    def flush(self):
284        """Will flush any input from the serial connection."""
285        self.write('\n')
286        self.connection_handle.flushInput()
287        self.connection_handle.flush()
288        flushed = 0
289        while True:
290            ready_r, _, ready_x = (select.select([self.connection_handle], [],
291                                                 [self.connection_handle], 0))
292            if ready_x:
293                logging.exception('exception from serial port')
294                return
295            elif ready_r:
296                flushed += 1
297                # This may cause underlying buffering.
298                self.connection_handle.read(1)
299                # Flush the underlying buffer too.
300                self.connection_handle.flush()
301            else:
302                break
303            if flushed > 0:
304                logging.debug('dropped >{} bytes'.format(flushed))
305
306    def write(self, command, wait_time=0.2):
307        """Will write into the serial connection.
308
309        Args:
310            command: String object with the text to write.
311            wait_time: Float object with the seconds to wait after the
312                       command was issued.
313        """
314        if command:
315            if self.terminator:
316                command += self.terminator
317            self.command_ini_index = len(self.log)
318            self.connection_handle.write(command.encode())
319            if wait_time:
320                time.sleep(wait_time)
321            logging.info('cmd [{}] sent.'.format(command.strip()))
322
323    def flush_log(self):
324        """Will output the log into a CSV file."""
325        if len(self.log) > 0:
326            path = ''
327            if not self.output_path:
328                self.output_path = os.getcwd()
329            elif not os.path.exists(self.output_path):
330                self.output_path = os.getcwd()
331            path = os.path.join(self.output_path,
332                                str(uuid.uuid4()) + '_serial.log')
333            with open(path, 'a') as log_file:
334                for info in self.log:
335                    log_file.write('{}, {}\n'.format(info[0], info[1]))
336
337    def read(self):
338        """Will read from the log the output from the serial connection
339        after a write command was issued. It will take the initial time
340        of the command as a reference.
341
342        Returns:
343            Array object with the log lines.
344        """
345        buf_read = []
346        command_end_index = len(self.log)
347        info = self.query_serial_log(self.command_ini_index, command_end_index)
348        for line in info:
349            buf_read.append(line[1])
350        self.command_ini_index = command_end_index
351        return buf_read
352
353    def get_all_log(self):
354        """Gets the log object that collects the logs.
355
356        Returns:
357            DataFrame object with all the logs.
358        """
359        return self.log
360
361    def query_serial_log(self, from_index, to_index):
362        """Will query the session log from a given time in EPOC format.
363
364        Args:
365            from_timestamp: Double value with the EPOC timestamp to start
366                            the search.
367            to_timestamp: Double value with the EPOC timestamp to finish the
368                          rearch.
369
370        Returns:
371            DataFrame with the result query.
372        """
373        if from_index < to_index:
374            info = self.log[from_index:to_index]
375            return info
376
377    def _start_reading_thread(self):
378        if self.connection_handle.isOpen():
379            self.reading = True
380            while self.reading:
381                try:
382                    data = self.connection_handle.readline().decode('utf-8')
383                    if data:
384                        self.is_logging = True
385                        data.replace('/n', '')
386                        data.replace('/r', '')
387                        data = data.strip()
388                        self.log.append([time.time(), data])
389                    else:
390                        self.is_logging = False
391                except Exception:
392                    time.sleep(1)
393            logging.info('Read thread closed')
394
395    def start_reading(self):
396        """Method to start the log collection."""
397        if not self.log_thread.isAlive():
398            self.log_thread = Thread(target=self._start_reading_thread, args=())
399            self.log_thread.daemon = True
400            try:
401                self.log_thread.start()
402            except(KeyboardInterrupt, SystemExit):
403                self.close()
404        else:
405            logging.warning('Not running log thread, is already alive')
406
407    def stop_reading(self):
408        """Method to stop the log collection."""
409        self.reading = False
410        self.log_thread.join(timeout=600)
411