1#!/usr/bin/env python3.4
2#
3#   Copyright 2019 - The Android Open Source Project
4#
5#   Licensed under the Apache License, Version 2.0 (the 'License');
6#   you may not use this file except in compliance with the License.
7#   You may obtain a copy of the License at
8#
9#       http://www.apache.org/licenses/LICENSE-2.0
10#
11#   Unless required by applicable law or agreed to in writing, software
12#   distributed under the License is distributed on an 'AS IS' BASIS,
13#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14#   See the License for the specific language governing permissions and
15#   limitations under the License.
16
17import bokeh, bokeh.plotting
18import collections
19import hashlib
20import itertools
21import json
22import logging
23import math
24import os
25import re
26import statistics
27import time
28from acts.controllers.android_device import AndroidDevice
29from acts.controllers.utils_lib import ssh
30from acts import utils
31from acts.test_utils.wifi import wifi_test_utils as wutils
32from concurrent.futures import ThreadPoolExecutor
33
34SHORT_SLEEP = 1
35MED_SLEEP = 6
36TEST_TIMEOUT = 10
37STATION_DUMP = 'iw wlan0 station dump'
38SCAN = 'wpa_cli scan'
39SCAN_RESULTS = 'wpa_cli scan_results'
40SIGNAL_POLL = 'wpa_cli signal_poll'
41WPA_CLI_STATUS = 'wpa_cli status'
42CONST_3dB = 3.01029995664
43RSSI_ERROR_VAL = float('nan')
44RTT_REGEX = re.compile(r'^\[(?P<timestamp>\S+)\] .*? time=(?P<rtt>\S+)')
45LOSS_REGEX = re.compile(r'(?P<loss>\S+)% packet loss')
46FW_REGEX = re.compile(r'FW:(?P<firmware>\S+) HW:')
47
48
49# Threading decorator
50def nonblocking(f):
51    """Creates a decorator transforming function calls to non-blocking"""
52    def wrap(*args, **kwargs):
53        executor = ThreadPoolExecutor(max_workers=1)
54        thread_future = executor.submit(f, *args, **kwargs)
55        # Ensure resources are freed up when executor ruturns or raises
56        executor.shutdown(wait=False)
57        return thread_future
58
59    return wrap
60
61
62# Link layer stats utilities
63class LinkLayerStats():
64
65    LLSTATS_CMD = 'cat /d/wlan0/ll_stats'
66    PEER_REGEX = 'LL_STATS_PEER_ALL'
67    MCS_REGEX = re.compile(
68        r'preamble: (?P<mode>\S+), nss: (?P<num_streams>\S+), bw: (?P<bw>\S+), '
69        'mcs: (?P<mcs>\S+), bitrate: (?P<rate>\S+), txmpdu: (?P<txmpdu>\S+), '
70        'rxmpdu: (?P<rxmpdu>\S+), mpdu_lost: (?P<mpdu_lost>\S+), '
71        'retries: (?P<retries>\S+), retries_short: (?P<retries_short>\S+), '
72        'retries_long: (?P<retries_long>\S+)')
73    MCS_ID = collections.namedtuple(
74        'mcs_id', ['mode', 'num_streams', 'bandwidth', 'mcs', 'rate'])
75    MODE_MAP = {'0': '11a/g', '1': '11b', '2': '11n', '3': '11ac'}
76    BW_MAP = {'0': 20, '1': 40, '2': 80}
77
78    def __init__(self, dut, llstats_enabled=True):
79        self.dut = dut
80        self.llstats_enabled = llstats_enabled
81        self.llstats_cumulative = self._empty_llstats()
82        self.llstats_incremental = self._empty_llstats()
83
84    def update_stats(self):
85        if self.llstats_enabled:
86            try:
87                llstats_output = self.dut.adb.shell(self.LLSTATS_CMD, timeout=0.1)
88            except:
89                llstats_output = ''
90        else:
91            llstats_output = ''
92        self._update_stats(llstats_output)
93
94    def reset_stats(self):
95        self.llstats_cumulative = self._empty_llstats()
96        self.llstats_incremental = self._empty_llstats()
97
98    def _empty_llstats(self):
99        return collections.OrderedDict(mcs_stats=collections.OrderedDict(),
100                                       summary=collections.OrderedDict())
101
102    def _empty_mcs_stat(self):
103        return collections.OrderedDict(txmpdu=0,
104                                       rxmpdu=0,
105                                       mpdu_lost=0,
106                                       retries=0,
107                                       retries_short=0,
108                                       retries_long=0)
109
110    def _mcs_id_to_string(self, mcs_id):
111        mcs_string = '{} {}MHz Nss{} MCS{} {}Mbps'.format(
112            mcs_id.mode, mcs_id.bandwidth, mcs_id.num_streams, mcs_id.mcs,
113            mcs_id.rate)
114        return mcs_string
115
116    def _parse_mcs_stats(self, llstats_output):
117        llstats_dict = {}
118        # Look for per-peer stats
119        match = re.search(self.PEER_REGEX, llstats_output)
120        if not match:
121            self.reset_stats()
122            return collections.OrderedDict()
123        # Find and process all matches for per stream stats
124        match_iter = re.finditer(self.MCS_REGEX, llstats_output)
125        for match in match_iter:
126            current_mcs = self.MCS_ID(self.MODE_MAP[match.group('mode')],
127                                      int(match.group('num_streams')) + 1,
128                                      self.BW_MAP[match.group('bw')],
129                                      int(match.group('mcs')),
130                                      int(match.group('rate'), 16) / 1000)
131            current_stats = collections.OrderedDict(
132                txmpdu=int(match.group('txmpdu')),
133                rxmpdu=int(match.group('rxmpdu')),
134                mpdu_lost=int(match.group('mpdu_lost')),
135                retries=int(match.group('retries')),
136                retries_short=int(match.group('retries_short')),
137                retries_long=int(match.group('retries_long')))
138            llstats_dict[self._mcs_id_to_string(current_mcs)] = current_stats
139        return llstats_dict
140
141    def _diff_mcs_stats(self, new_stats, old_stats):
142        stats_diff = collections.OrderedDict()
143        for stat_key in new_stats.keys():
144            stats_diff[stat_key] = new_stats[stat_key] - old_stats[stat_key]
145        return stats_diff
146
147    def _generate_stats_summary(self, llstats_dict):
148        llstats_summary = collections.OrderedDict(common_tx_mcs=None,
149                                                  common_tx_mcs_count=0,
150                                                  common_tx_mcs_freq=0,
151                                                  common_rx_mcs=None,
152                                                  common_rx_mcs_count=0,
153                                                  common_rx_mcs_freq=0)
154        txmpdu_count = 0
155        rxmpdu_count = 0
156        for mcs_id, mcs_stats in llstats_dict['mcs_stats'].items():
157            if mcs_stats['txmpdu'] > llstats_summary['common_tx_mcs_count']:
158                llstats_summary['common_tx_mcs'] = mcs_id
159                llstats_summary['common_tx_mcs_count'] = mcs_stats['txmpdu']
160            if mcs_stats['rxmpdu'] > llstats_summary['common_rx_mcs_count']:
161                llstats_summary['common_rx_mcs'] = mcs_id
162                llstats_summary['common_rx_mcs_count'] = mcs_stats['rxmpdu']
163            txmpdu_count += mcs_stats['txmpdu']
164            rxmpdu_count += mcs_stats['rxmpdu']
165        if txmpdu_count:
166            llstats_summary['common_tx_mcs_freq'] = (
167                llstats_summary['common_tx_mcs_count'] / txmpdu_count)
168        if rxmpdu_count:
169            llstats_summary['common_rx_mcs_freq'] = (
170                llstats_summary['common_rx_mcs_count'] / rxmpdu_count)
171        return llstats_summary
172
173    def _update_stats(self, llstats_output):
174        # Parse stats
175        new_llstats = self._empty_llstats()
176        new_llstats['mcs_stats'] = self._parse_mcs_stats(llstats_output)
177        # Save old stats and set new cumulative stats
178        old_llstats = self.llstats_cumulative.copy()
179        self.llstats_cumulative = new_llstats.copy()
180        # Compute difference between new and old stats
181        self.llstats_incremental = self._empty_llstats()
182        for mcs_id, new_mcs_stats in new_llstats['mcs_stats'].items():
183            old_mcs_stats = old_llstats['mcs_stats'].get(
184                mcs_id, self._empty_mcs_stat())
185            self.llstats_incremental['mcs_stats'][
186                mcs_id] = self._diff_mcs_stats(new_mcs_stats, old_mcs_stats)
187        # Generate llstats summary
188        self.llstats_incremental['summary'] = self._generate_stats_summary(
189            self.llstats_incremental)
190        self.llstats_cumulative['summary'] = self._generate_stats_summary(
191            self.llstats_cumulative)
192
193
194# JSON serializer
195def serialize_dict(input_dict):
196    """Function to serialize dicts to enable JSON output"""
197    output_dict = collections.OrderedDict()
198    for key, value in input_dict.items():
199        output_dict[_serialize_value(key)] = _serialize_value(value)
200    return output_dict
201
202
203def _serialize_value(value):
204    """Function to recursively serialize dict entries to enable JSON output"""
205    if isinstance(value, tuple):
206        return str(value)
207    if isinstance(value, list):
208        return [_serialize_value(x) for x in value]
209    elif isinstance(value, dict):
210        return serialize_dict(value)
211    else:
212        return value
213
214
215# Plotting Utilities
216class BokehFigure():
217    """Class enabling  simplified Bokeh plotting."""
218
219    COLORS = [
220        'black',
221        'blue',
222        'blueviolet',
223        'brown',
224        'burlywood',
225        'cadetblue',
226        'cornflowerblue',
227        'crimson',
228        'cyan',
229        'darkblue',
230        'darkgreen',
231        'darkmagenta',
232        'darkorange',
233        'darkred',
234        'deepskyblue',
235        'goldenrod',
236        'green',
237        'grey',
238        'indigo',
239        'navy',
240        'olive',
241        'orange',
242        'red',
243        'salmon',
244        'teal',
245        'yellow',
246    ]
247    MARKERS = [
248        'asterisk', 'circle', 'circle_cross', 'circle_x', 'cross', 'diamond',
249        'diamond_cross', 'hex', 'inverted_triangle', 'square', 'square_x',
250        'square_cross', 'triangle', 'x'
251    ]
252
253    TOOLS = ('box_zoom,box_select,pan,crosshair,redo,undo,reset,hover,save')
254    TOOLTIPS = [
255        ('index', '$index'),
256        ('(x,y)', '($x, $y)'),
257        ('info', '@hover_text'),
258    ]
259
260    def __init__(self,
261                 title=None,
262                 x_label=None,
263                 primary_y_label=None,
264                 secondary_y_label=None,
265                 height=700,
266                 width=1100,
267                 title_size='15pt',
268                 axis_label_size='12pt',
269                 json_file=None):
270        if json_file:
271            self.load_from_json(json_file)
272        else:
273            self.figure_data = []
274            self.fig_property = {
275                'title': title,
276                'x_label': x_label,
277                'primary_y_label': primary_y_label,
278                'secondary_y_label': secondary_y_label,
279                'num_lines': 0,
280                'height': height,
281                'width': width,
282                'title_size': title_size,
283                'axis_label_size': axis_label_size
284            }
285
286    def init_plot(self):
287        self.plot = bokeh.plotting.figure(
288            sizing_mode='scale_both',
289            plot_width=self.fig_property['width'],
290            plot_height=self.fig_property['height'],
291            title=self.fig_property['title'],
292            tools=self.TOOLS,
293            output_backend='webgl')
294        self.plot.hover.tooltips = self.TOOLTIPS
295        self.plot.add_tools(
296            bokeh.models.tools.WheelZoomTool(dimensions='width'))
297        self.plot.add_tools(
298            bokeh.models.tools.WheelZoomTool(dimensions='height'))
299
300    def _filter_line(self, x_data, y_data, hover_text=None):
301        """Function to remove NaN points from bokeh plots."""
302        x_data_filtered = []
303        y_data_filtered = []
304        hover_text_filtered = []
305        for x, y, hover in itertools.zip_longest(x_data, y_data, hover_text):
306            if not math.isnan(y):
307                x_data_filtered.append(x)
308                y_data_filtered.append(y)
309                hover_text_filtered.append(hover)
310        return x_data_filtered, y_data_filtered, hover_text_filtered
311
312    def add_line(self,
313                 x_data,
314                 y_data,
315                 legend,
316                 hover_text=None,
317                 color=None,
318                 width=3,
319                 style='solid',
320                 marker=None,
321                 marker_size=10,
322                 shaded_region=None,
323                 y_axis='default'):
324        """Function to add line to existing BokehFigure.
325
326        Args:
327            x_data: list containing x-axis values for line
328            y_data: list containing y_axis values for line
329            legend: string containing line title
330            hover_text: text to display when hovering over lines
331            color: string describing line color
332            width: integer line width
333            style: string describing line style, e.g, solid or dashed
334            marker: string specifying line marker, e.g., cross
335            shaded region: data describing shaded region to plot
336            y_axis: identifier for y-axis to plot line against
337        """
338        if y_axis not in ['default', 'secondary']:
339            raise ValueError('y_axis must be default or secondary')
340        if color == None:
341            color = self.COLORS[self.fig_property['num_lines'] %
342                                len(self.COLORS)]
343        if style == 'dashed':
344            style = [5, 5]
345        if not hover_text:
346            hover_text = ['y={}'.format(y) for y in y_data]
347        x_data_filter, y_data_filter, hover_text_filter = self._filter_line(
348            x_data, y_data, hover_text)
349        self.figure_data.append({
350            'x_data': x_data_filter,
351            'y_data': y_data_filter,
352            'legend': legend,
353            'hover_text': hover_text_filter,
354            'color': color,
355            'width': width,
356            'style': style,
357            'marker': marker,
358            'marker_size': marker_size,
359            'shaded_region': shaded_region,
360            'y_axis': y_axis
361        })
362        self.fig_property['num_lines'] += 1
363
364    def add_scatter(self,
365                    x_data,
366                    y_data,
367                    legend,
368                    hover_text=None,
369                    color=None,
370                    marker=None,
371                    marker_size=10,
372                    y_axis='default'):
373        """Function to add line to existing BokehFigure.
374
375        Args:
376            x_data: list containing x-axis values for line
377            y_data: list containing y_axis values for line
378            legend: string containing line title
379            hover_text: text to display when hovering over lines
380            color: string describing line color
381            marker: string specifying marker, e.g., cross
382            y_axis: identifier for y-axis to plot line against
383        """
384        if y_axis not in ['default', 'secondary']:
385            raise ValueError('y_axis must be default or secondary')
386        if color == None:
387            color = self.COLORS[self.fig_property['num_lines'] %
388                                len(self.COLORS)]
389        if marker == None:
390            marker = self.MARKERS[self.fig_property['num_lines'] %
391                                  len(self.MARKERS)]
392        if not hover_text:
393            hover_text = ['y={}'.format(y) for y in y_data]
394        self.figure_data.append({
395            'x_data': x_data,
396            'y_data': y_data,
397            'legend': legend,
398            'hover_text': hover_text,
399            'color': color,
400            'width': 0,
401            'style': 'solid',
402            'marker': marker,
403            'marker_size': marker_size,
404            'shaded_region': None,
405            'y_axis': y_axis
406        })
407        self.fig_property['num_lines'] += 1
408
409    def generate_figure(self, output_file=None, save_json=True):
410        """Function to generate and save BokehFigure.
411
412        Args:
413            output_file: string specifying output file path
414        """
415        self.init_plot()
416        two_axes = False
417        for line in self.figure_data:
418            source = bokeh.models.ColumnDataSource(
419                data=dict(x=line['x_data'],
420                          y=line['y_data'],
421                          hover_text=line['hover_text']))
422            if line['width'] > 0:
423                self.plot.line(x='x',
424                               y='y',
425                               legend_label=line['legend'],
426                               line_width=line['width'],
427                               color=line['color'],
428                               line_dash=line['style'],
429                               name=line['y_axis'],
430                               y_range_name=line['y_axis'],
431                               source=source)
432            if line['shaded_region']:
433                band_x = line['shaded_region']['x_vector']
434                band_x.extend(line['shaded_region']['x_vector'][::-1])
435                band_y = line['shaded_region']['lower_limit']
436                band_y.extend(line['shaded_region']['upper_limit'][::-1])
437                self.plot.patch(band_x,
438                                band_y,
439                                color='#7570B3',
440                                line_alpha=0.1,
441                                fill_alpha=0.1)
442            if line['marker'] in self.MARKERS:
443                marker_func = getattr(self.plot, line['marker'])
444                marker_func(x='x',
445                            y='y',
446                            size=line['marker_size'],
447                            legend_label=line['legend'],
448                            line_color=line['color'],
449                            fill_color=line['color'],
450                            name=line['y_axis'],
451                            y_range_name=line['y_axis'],
452                            source=source)
453            if line['y_axis'] == 'secondary':
454                two_axes = True
455
456        #x-axis formatting
457        self.plot.xaxis.axis_label = self.fig_property['x_label']
458        self.plot.x_range.range_padding = 0
459        self.plot.xaxis[0].axis_label_text_font_size = self.fig_property[
460            'axis_label_size']
461        #y-axis formatting
462        self.plot.yaxis[0].axis_label = self.fig_property['primary_y_label']
463        self.plot.yaxis[0].axis_label_text_font_size = self.fig_property[
464            'axis_label_size']
465        self.plot.y_range = bokeh.models.DataRange1d(names=['default'])
466        if two_axes and 'secondary' not in self.plot.extra_y_ranges:
467            self.plot.extra_y_ranges = {
468                'secondary': bokeh.models.DataRange1d(names=['secondary'])
469            }
470            self.plot.add_layout(
471                bokeh.models.LinearAxis(
472                    y_range_name='secondary',
473                    axis_label=self.fig_property['secondary_y_label'],
474                    axis_label_text_font_size=self.
475                    fig_property['axis_label_size']), 'right')
476        # plot formatting
477        self.plot.legend.location = 'top_right'
478        self.plot.legend.click_policy = 'hide'
479        self.plot.title.text_font_size = self.fig_property['title_size']
480
481        if output_file is not None:
482            self.save_figure(output_file, save_json)
483        return self.plot
484
485    def load_from_json(self, file_path):
486        with open(file_path, 'r') as json_file:
487            fig_dict = json.load(json_file)
488        self.fig_property = fig_dict['fig_property']
489        self.figure_data = fig_dict['figure_data']
490
491    def _save_figure_json(self, output_file):
492        """Function to save a json format of a figure"""
493        figure_dict = collections.OrderedDict(fig_property=self.fig_property,
494                                              figure_data=self.figure_data)
495        output_file = output_file.replace('.html', '_plot_data.json')
496        with open(output_file, 'w') as outfile:
497            json.dump(figure_dict, outfile, indent=4)
498
499    def save_figure(self, output_file, save_json=True):
500        """Function to save BokehFigure.
501
502        Args:
503            output_file: string specifying output file path
504            save_json: flag controlling json outputs
505        """
506        bokeh.plotting.output_file(output_file)
507        bokeh.plotting.save(self.plot)
508        if save_json:
509            self._save_figure_json(output_file)
510
511    @staticmethod
512    def save_figures(figure_array, output_file_path, save_json=True):
513        """Function to save list of BokehFigures in one file.
514
515        Args:
516            figure_array: list of BokehFigure object to be plotted
517            output_file: string specifying output file path
518        """
519        for idx, figure in enumerate(figure_array):
520            figure.generate_figure()
521            if save_json:
522                json_file_path = output_file_path.replace(
523                    '.html', '{}-plot_data.json'.format(idx))
524                figure._save_figure_json(json_file_path)
525        plot_array = [figure.plot for figure in figure_array]
526        all_plots = bokeh.layouts.column(children=plot_array,
527                                         sizing_mode='scale_width')
528        bokeh.plotting.output_file(output_file_path)
529        bokeh.plotting.save(all_plots)
530
531
532# Ping utilities
533class PingResult(object):
534    """An object that contains the results of running ping command.
535
536    Attributes:
537        connected: True if a connection was made. False otherwise.
538        packet_loss_percentage: The total percentage of packets lost.
539        transmission_times: The list of PingTransmissionTimes containing the
540            timestamps gathered for transmitted packets.
541        rtts: An list-like object enumerating all round-trip-times of
542            transmitted packets.
543        timestamps: A list-like object enumerating the beginning timestamps of
544            each packet transmission.
545        ping_interarrivals: A list-like object enumerating the amount of time
546            between the beginning of each subsequent transmission.
547    """
548    def __init__(self, ping_output):
549        self.packet_loss_percentage = 100
550        self.transmission_times = []
551
552        self.rtts = _ListWrap(self.transmission_times, lambda entry: entry.rtt)
553        self.timestamps = _ListWrap(self.transmission_times,
554                                    lambda entry: entry.timestamp)
555        self.ping_interarrivals = _PingInterarrivals(self.transmission_times)
556
557        self.start_time = 0
558        for line in ping_output:
559            if 'loss' in line:
560                match = re.search(LOSS_REGEX, line)
561                self.packet_loss_percentage = float(match.group('loss'))
562            if 'time=' in line:
563                match = re.search(RTT_REGEX, line)
564                if self.start_time == 0:
565                    self.start_time = float(match.group('timestamp'))
566                self.transmission_times.append(
567                    PingTransmissionTimes(
568                        float(match.group('timestamp')) - self.start_time,
569                        float(match.group('rtt'))))
570        self.connected = len(
571            ping_output) > 1 and self.packet_loss_percentage < 100
572
573    def __getitem__(self, item):
574        if item == 'rtt':
575            return self.rtts
576        if item == 'connected':
577            return self.connected
578        if item == 'packet_loss_percentage':
579            return self.packet_loss_percentage
580        raise ValueError('Invalid key. Please use an attribute instead.')
581
582    def as_dict(self):
583        return {
584            'connected': 1 if self.connected else 0,
585            'rtt': list(self.rtts),
586            'time_stamp': list(self.timestamps),
587            'ping_interarrivals': list(self.ping_interarrivals),
588            'packet_loss_percentage': self.packet_loss_percentage
589        }
590
591
592class PingTransmissionTimes(object):
593    """A class that holds the timestamps for a packet sent via the ping command.
594
595    Attributes:
596        rtt: The round trip time for the packet sent.
597        timestamp: The timestamp the packet started its trip.
598    """
599    def __init__(self, timestamp, rtt):
600        self.rtt = rtt
601        self.timestamp = timestamp
602
603
604class _ListWrap(object):
605    """A convenient helper class for treating list iterators as native lists."""
606    def __init__(self, wrapped_list, func):
607        self.__wrapped_list = wrapped_list
608        self.__func = func
609
610    def __getitem__(self, key):
611        return self.__func(self.__wrapped_list[key])
612
613    def __iter__(self):
614        for item in self.__wrapped_list:
615            yield self.__func(item)
616
617    def __len__(self):
618        return len(self.__wrapped_list)
619
620
621class _PingInterarrivals(object):
622    """A helper class for treating ping interarrivals as a native list."""
623    def __init__(self, ping_entries):
624        self.__ping_entries = ping_entries
625
626    def __getitem__(self, key):
627        return (self.__ping_entries[key + 1].timestamp -
628                self.__ping_entries[key].timestamp)
629
630    def __iter__(self):
631        for index in range(len(self.__ping_entries) - 1):
632            yield self[index]
633
634    def __len__(self):
635        return max(0, len(self.__ping_entries) - 1)
636
637
638def get_ping_stats(src_device, dest_address, ping_duration, ping_interval,
639                   ping_size):
640    """Run ping to or from the DUT.
641
642    The function computes either pings the DUT or pings a remote ip from
643    DUT.
644
645    Args:
646        src_device: object representing device to ping from
647        dest_address: ip address to ping
648        ping_duration: timeout to set on the the ping process (in seconds)
649        ping_interval: time between pings (in seconds)
650        ping_size: size of ping packet payload
651    Returns:
652        ping_result: dict containing ping results and other meta data
653    """
654    ping_count = int(ping_duration / ping_interval)
655    ping_deadline = int(ping_count * ping_interval) + 1
656    ping_cmd = 'ping -c {} -w {} -i {} -s {} -D'.format(
657        ping_count,
658        ping_deadline,
659        ping_interval,
660        ping_size,
661    )
662    if isinstance(src_device, AndroidDevice):
663        ping_cmd = '{} {}'.format(ping_cmd, dest_address)
664        ping_output = src_device.adb.shell(ping_cmd,
665                                           timeout=ping_deadline + SHORT_SLEEP,
666                                           ignore_status=True)
667    elif isinstance(src_device, ssh.connection.SshConnection):
668        ping_cmd = 'sudo {} {}'.format(ping_cmd, dest_address)
669        ping_output = src_device.run(ping_cmd,
670                                     timeout=ping_deadline + SHORT_SLEEP,
671                                     ignore_status=True).stdout
672    else:
673        raise TypeError('Unable to ping using src_device of type %s.' %
674                        type(src_device))
675    return PingResult(ping_output.splitlines())
676
677
678@nonblocking
679def get_ping_stats_nb(src_device, dest_address, ping_duration, ping_interval,
680                      ping_size):
681    return get_ping_stats(src_device, dest_address, ping_duration,
682                          ping_interval, ping_size)
683
684
685@nonblocking
686def start_iperf_client_nb(iperf_client, iperf_server_address, iperf_args, tag,
687                          timeout):
688    return iperf_client.start(iperf_server_address, iperf_args, tag, timeout)
689
690
691# Rssi Utilities
692def empty_rssi_result():
693    return collections.OrderedDict([('data', []), ('mean', None),
694                                    ('stdev', None)])
695
696
697def get_connected_rssi(dut,
698                       num_measurements=1,
699                       polling_frequency=SHORT_SLEEP,
700                       first_measurement_delay=0,
701                       disconnect_warning=True,
702                       ignore_samples=0):
703    """Gets all RSSI values reported for the connected access point/BSSID.
704
705    Args:
706        dut: android device object from which to get RSSI
707        num_measurements: number of scans done, and RSSIs collected
708        polling_frequency: time to wait between RSSI measurements
709        disconnect_warning: boolean controlling disconnection logging messages
710        ignore_samples: number of leading samples to ignore
711    Returns:
712        connected_rssi: dict containing the measurements results for
713        all reported RSSI values (signal_poll, per chain, etc.) and their
714        statistics
715    """
716    # yapf: disable
717    connected_rssi = collections.OrderedDict(
718        [('time_stamp', []),
719         ('bssid', []), ('frequency', []),
720         ('signal_poll_rssi', empty_rssi_result()),
721         ('signal_poll_avg_rssi', empty_rssi_result()),
722         ('chain_0_rssi', empty_rssi_result()),
723         ('chain_1_rssi', empty_rssi_result())])
724    # yapf: enable
725    previous_bssid = 'disconnected'
726    t0 = time.time()
727    time.sleep(first_measurement_delay)
728    for idx in range(num_measurements):
729        measurement_start_time = time.time()
730        connected_rssi['time_stamp'].append(measurement_start_time - t0)
731        # Get signal poll RSSI
732        status_output = dut.adb.shell(WPA_CLI_STATUS)
733        match = re.search('bssid=.*', status_output)
734        if match:
735            current_bssid = match.group(0).split('=')[1]
736            connected_rssi['bssid'].append(current_bssid)
737        else:
738            current_bssid = 'disconnected'
739            connected_rssi['bssid'].append(current_bssid)
740            if disconnect_warning and previous_bssid != 'disconnected':
741                logging.warning('WIFI DISCONNECT DETECTED!')
742        previous_bssid = current_bssid
743        signal_poll_output = dut.adb.shell(SIGNAL_POLL)
744        match = re.search('FREQUENCY=.*', signal_poll_output)
745        if match:
746            frequency = int(match.group(0).split('=')[1])
747            connected_rssi['frequency'].append(frequency)
748        else:
749            connected_rssi['frequency'].append(RSSI_ERROR_VAL)
750        match = re.search('RSSI=.*', signal_poll_output)
751        if match:
752            temp_rssi = int(match.group(0).split('=')[1])
753            if temp_rssi == -9999 or temp_rssi == 0:
754                connected_rssi['signal_poll_rssi']['data'].append(
755                    RSSI_ERROR_VAL)
756            else:
757                connected_rssi['signal_poll_rssi']['data'].append(temp_rssi)
758        else:
759            connected_rssi['signal_poll_rssi']['data'].append(RSSI_ERROR_VAL)
760        match = re.search('AVG_RSSI=.*', signal_poll_output)
761        if match:
762            connected_rssi['signal_poll_avg_rssi']['data'].append(
763                int(match.group(0).split('=')[1]))
764        else:
765            connected_rssi['signal_poll_avg_rssi']['data'].append(
766                RSSI_ERROR_VAL)
767        # Get per chain RSSI
768        per_chain_rssi = dut.adb.shell(STATION_DUMP)
769        match = re.search('.*signal avg:.*', per_chain_rssi)
770        if match:
771            per_chain_rssi = per_chain_rssi[per_chain_rssi.find('[') +
772                                            1:per_chain_rssi.find(']')]
773            per_chain_rssi = per_chain_rssi.split(', ')
774            connected_rssi['chain_0_rssi']['data'].append(
775                int(per_chain_rssi[0]))
776            connected_rssi['chain_1_rssi']['data'].append(
777                int(per_chain_rssi[1]))
778        else:
779            connected_rssi['chain_0_rssi']['data'].append(RSSI_ERROR_VAL)
780            connected_rssi['chain_1_rssi']['data'].append(RSSI_ERROR_VAL)
781        measurement_elapsed_time = time.time() - measurement_start_time
782        time.sleep(max(0, polling_frequency - measurement_elapsed_time))
783
784    # Compute mean RSSIs. Only average valid readings.
785    # Output RSSI_ERROR_VAL if no valid connected readings found.
786    for key, val in connected_rssi.copy().items():
787        if 'data' not in val:
788            continue
789        filtered_rssi_values = [x for x in val['data'] if not math.isnan(x)]
790        if len(filtered_rssi_values) > ignore_samples:
791            filtered_rssi_values = filtered_rssi_values[ignore_samples:]
792        if filtered_rssi_values:
793            connected_rssi[key]['mean'] = statistics.mean(filtered_rssi_values)
794            if len(filtered_rssi_values) > 1:
795                connected_rssi[key]['stdev'] = statistics.stdev(
796                    filtered_rssi_values)
797            else:
798                connected_rssi[key]['stdev'] = 0
799        else:
800            connected_rssi[key]['mean'] = RSSI_ERROR_VAL
801            connected_rssi[key]['stdev'] = RSSI_ERROR_VAL
802    return connected_rssi
803
804
805@nonblocking
806def get_connected_rssi_nb(dut,
807                          num_measurements=1,
808                          polling_frequency=SHORT_SLEEP,
809                          first_measurement_delay=0,
810                          disconnect_warning=True,
811                          ignore_samples=0):
812    return get_connected_rssi(dut, num_measurements, polling_frequency,
813                              first_measurement_delay, disconnect_warning,
814                              ignore_samples)
815
816
817def get_scan_rssi(dut, tracked_bssids, num_measurements=1):
818    """Gets scan RSSI for specified BSSIDs.
819
820    Args:
821        dut: android device object from which to get RSSI
822        tracked_bssids: array of BSSIDs to gather RSSI data for
823        num_measurements: number of scans done, and RSSIs collected
824    Returns:
825        scan_rssi: dict containing the measurement results as well as the
826        statistics of the scan RSSI for all BSSIDs in tracked_bssids
827    """
828    scan_rssi = collections.OrderedDict()
829    for bssid in tracked_bssids:
830        scan_rssi[bssid] = empty_rssi_result()
831    for idx in range(num_measurements):
832        scan_output = dut.adb.shell(SCAN)
833        time.sleep(MED_SLEEP)
834        scan_output = dut.adb.shell(SCAN_RESULTS)
835        for bssid in tracked_bssids:
836            bssid_result = re.search(bssid + '.*',
837                                     scan_output,
838                                     flags=re.IGNORECASE)
839            if bssid_result:
840                bssid_result = bssid_result.group(0).split('\t')
841                scan_rssi[bssid]['data'].append(int(bssid_result[2]))
842            else:
843                scan_rssi[bssid]['data'].append(RSSI_ERROR_VAL)
844    # Compute mean RSSIs. Only average valid readings.
845    # Output RSSI_ERROR_VAL if no readings found.
846    for key, val in scan_rssi.items():
847        filtered_rssi_values = [x for x in val['data'] if not math.isnan(x)]
848        if filtered_rssi_values:
849            scan_rssi[key]['mean'] = statistics.mean(filtered_rssi_values)
850            if len(filtered_rssi_values) > 1:
851                scan_rssi[key]['stdev'] = statistics.stdev(
852                    filtered_rssi_values)
853            else:
854                scan_rssi[key]['stdev'] = 0
855        else:
856            scan_rssi[key]['mean'] = RSSI_ERROR_VAL
857            scan_rssi[key]['stdev'] = RSSI_ERROR_VAL
858    return scan_rssi
859
860
861@nonblocking
862def get_scan_rssi_nb(dut, tracked_bssids, num_measurements=1):
863    return get_scan_rssi(dut, tracked_bssids, num_measurements)
864
865
866# Attenuator Utilities
867def atten_by_label(atten_list, path_label, atten_level):
868    """Attenuate signals according to their path label.
869
870    Args:
871        atten_list: list of attenuators to iterate over
872        path_label: path label on which to set desired attenuation
873        atten_level: attenuation desired on path
874    """
875    for atten in atten_list:
876        if path_label in atten.path:
877            atten.set_atten(atten_level)
878
879
880def get_atten_for_target_rssi(target_rssi, attenuators, dut, ping_server):
881    """Function to estimate attenuation to hit a target RSSI.
882
883    This function estimates a constant attenuation setting on all atennuation
884    ports to hit a target RSSI. The estimate is not meant to be exact or
885    guaranteed.
886
887    Args:
888        target_rssi: rssi of interest
889        attenuators: list of attenuator ports
890        dut: android device object assumed connected to a wifi network.
891        ping_server: ssh connection object to ping server
892    Returns:
893        target_atten: attenuation setting to achieve target_rssi
894    """
895    logging.info('Searching attenuation for RSSI = {}dB'.format(target_rssi))
896    # Set attenuator to 0 dB
897    for atten in attenuators:
898        atten.set_atten(0, strict=False)
899    # Start ping traffic
900    dut_ip = dut.droid.connectivityGetIPv4Addresses('wlan0')[0]
901    # Measure starting RSSI
902    ping_future = get_ping_stats_nb(src_device=ping_server,
903                                    dest_address=dut_ip,
904                                    ping_duration=1.5,
905                                    ping_interval=0.02,
906                                    ping_size=64)
907    current_rssi = get_connected_rssi(dut,
908                                      num_measurements=4,
909                                      polling_frequency=0.25,
910                                      first_measurement_delay=0.5,
911                                      disconnect_warning=1,
912                                      ignore_samples=1)
913    current_rssi = current_rssi['signal_poll_rssi']['mean']
914    ping_future.result()
915    target_atten = 0
916    logging.debug("RSSI @ {0:.2f}dB attenuation = {1:.2f}".format(
917        target_atten, current_rssi))
918    within_range = 0
919    for idx in range(20):
920        atten_delta = max(min(current_rssi - target_rssi, 20), -20)
921        target_atten = int((target_atten + atten_delta) * 4) / 4
922        if target_atten < 0:
923            return 0
924        if target_atten > attenuators[0].get_max_atten():
925            return attenuators[0].get_max_atten()
926        for atten in attenuators:
927            atten.set_atten(target_atten, strict=False)
928        ping_future = get_ping_stats_nb(src_device=ping_server,
929                                        dest_address=dut_ip,
930                                        ping_duration=1.5,
931                                        ping_interval=0.02,
932                                        ping_size=64)
933        current_rssi = get_connected_rssi(dut,
934                                          num_measurements=4,
935                                          polling_frequency=0.25,
936                                          first_measurement_delay=0.5,
937                                          disconnect_warning=1,
938                                          ignore_samples=1)
939        current_rssi = current_rssi['signal_poll_rssi']['mean']
940        ping_future.result()
941        logging.info("RSSI @ {0:.2f}dB attenuation = {1:.2f}".format(
942            target_atten, current_rssi))
943        if abs(current_rssi - target_rssi) < 1:
944            if within_range:
945                logging.info(
946                    'Reached RSSI: {0:.2f}. Target RSSI: {1:.2f}.'
947                    'Attenuation: {2:.2f}, Iterations = {3:.2f}'.format(
948                        current_rssi, target_rssi, target_atten, idx))
949                return target_atten
950            else:
951                within_range = True
952        else:
953            within_range = False
954    return target_atten
955
956
957def get_current_atten_dut_chain_map(attenuators, dut, ping_server):
958    """Function to detect mapping between attenuator ports and DUT chains.
959
960    This function detects the mapping between attenuator ports and DUT chains
961    in cases where DUT chains are connected to only one attenuator port. The
962    function assumes the DUT is already connected to a wifi network. The
963    function starts by measuring per chain RSSI at 0 attenuation, then
964    attenuates one port at a time looking for the chain that reports a lower
965    RSSI.
966
967    Args:
968        attenuators: list of attenuator ports
969        dut: android device object assumed connected to a wifi network.
970        ping_server: ssh connection object to ping server
971    Returns:
972        chain_map: list of dut chains, one entry per attenuator port
973    """
974    # Set attenuator to 0 dB
975    for atten in attenuators:
976        atten.set_atten(0, strict=False)
977    # Start ping traffic
978    dut_ip = dut.droid.connectivityGetIPv4Addresses('wlan0')[0]
979    ping_future = get_ping_stats_nb(ping_server, dut_ip, 11, 0.02, 64)
980    # Measure starting RSSI
981    base_rssi = get_connected_rssi(dut, 4, 0.25, 1)
982    chain0_base_rssi = base_rssi['chain_0_rssi']['mean']
983    chain1_base_rssi = base_rssi['chain_1_rssi']['mean']
984    if chain0_base_rssi < -70 or chain1_base_rssi < -70:
985        logging.warning('RSSI might be too low to get reliable chain map.')
986    # Compile chain map by attenuating one path at a time and seeing which
987    # chain's RSSI degrades
988    chain_map = []
989    for test_atten in attenuators:
990        # Set one attenuator to 30 dB down
991        test_atten.set_atten(30, strict=False)
992        # Get new RSSI
993        test_rssi = get_connected_rssi(dut, 4, 0.25, 1)
994        # Assign attenuator to path that has lower RSSI
995        if chain0_base_rssi > -70 and chain0_base_rssi - test_rssi[
996                'chain_0_rssi']['mean'] > 10:
997            chain_map.append('DUT-Chain-0')
998        elif chain1_base_rssi > -70 and chain1_base_rssi - test_rssi[
999                'chain_1_rssi']['mean'] > 10:
1000            chain_map.append('DUT-Chain-1')
1001        else:
1002            chain_map.append(None)
1003        # Reset attenuator to 0
1004        test_atten.set_atten(0, strict=False)
1005    ping_future.result()
1006    logging.debug('Chain Map: {}'.format(chain_map))
1007    return chain_map
1008
1009
1010def get_full_rf_connection_map(attenuators, dut, ping_server, networks):
1011    """Function to detect per-network connections between attenuator and DUT.
1012
1013    This function detects the mapping between attenuator ports and DUT chains
1014    on all networks in its arguments. The function connects the DUT to each
1015    network then calls get_current_atten_dut_chain_map to get the connection
1016    map on the current network. The function outputs the results in two formats
1017    to enable easy access when users are interested in indexing by network or
1018    attenuator port.
1019
1020    Args:
1021        attenuators: list of attenuator ports
1022        dut: android device object assumed connected to a wifi network.
1023        ping_server: ssh connection object to ping server
1024        networks: dict of network IDs and configs
1025    Returns:
1026        rf_map_by_network: dict of RF connections indexed by network.
1027        rf_map_by_atten: list of RF connections indexed by attenuator
1028    """
1029    for atten in attenuators:
1030        atten.set_atten(0, strict=False)
1031
1032    rf_map_by_network = collections.OrderedDict()
1033    rf_map_by_atten = [[] for atten in attenuators]
1034    for net_id, net_config in networks.items():
1035        wutils.reset_wifi(dut)
1036        wutils.wifi_connect(dut,
1037                            net_config,
1038                            num_of_tries=1,
1039                            assert_on_fail=False,
1040                            check_connectivity=False)
1041        rf_map_by_network[net_id] = get_current_atten_dut_chain_map(
1042            attenuators, dut, ping_server)
1043        for idx, chain in enumerate(rf_map_by_network[net_id]):
1044            if chain:
1045                rf_map_by_atten[idx].append({
1046                    "network": net_id,
1047                    "dut_chain": chain
1048                })
1049    logging.debug("RF Map (by Network): {}".format(rf_map_by_network))
1050    logging.debug("RF Map (by Atten): {}".format(rf_map_by_atten))
1051
1052    return rf_map_by_network, rf_map_by_atten
1053
1054
1055# Miscellaneous Wifi Utilities
1056def extract_sub_dict(full_dict, fields):
1057    sub_dict = collections.OrderedDict(
1058        (field, full_dict[field]) for field in fields)
1059    return sub_dict
1060
1061
1062def validate_network(dut, ssid):
1063    """Check that DUT has a valid internet connection through expected SSID
1064
1065    Args:
1066        dut: android device of interest
1067        ssid: expected ssid
1068    """
1069    current_network = dut.droid.wifiGetConnectionInfo()
1070    try:
1071        connected = wutils.validate_connection(dut) is not None
1072    except:
1073        connected = False
1074    if connected and current_network['SSID'] == ssid:
1075        return True
1076    else:
1077        return False
1078
1079
1080def get_server_address(ssh_connection, dut_ip, subnet_mask):
1081    """Get server address on a specific subnet,
1082
1083    This function retrieves the LAN IP of a remote machine used in testing,
1084    i.e., it returns the server's IP belonging to the same LAN as the DUT.
1085
1086    Args:
1087        ssh_connection: object representing server for which we want an ip
1088        dut_ip: string in ip address format, i.e., xxx.xxx.xxx.xxx, specifying
1089        the DUT LAN IP we wish to connect to
1090        subnet_mask: string representing subnet mask
1091    """
1092    subnet_mask = subnet_mask.split('.')
1093    dut_subnet = [
1094        int(dut) & int(subnet)
1095        for dut, subnet in zip(dut_ip.split('.'), subnet_mask)
1096    ]
1097    ifconfig_out = ssh_connection.run('ifconfig').stdout
1098    ip_list = re.findall('inet (?:addr:)?(\d+.\d+.\d+.\d+)', ifconfig_out)
1099    for current_ip in ip_list:
1100        current_subnet = [
1101            int(ip) & int(subnet)
1102            for ip, subnet in zip(current_ip.split('.'), subnet_mask)
1103        ]
1104        if current_subnet == dut_subnet:
1105            return current_ip
1106    logging.error('No IP address found in requested subnet')
1107
1108
1109def get_iperf_arg_string(duration,
1110                         reverse_direction,
1111                         interval=1,
1112                         traffic_type='TCP',
1113                         tcp_window=None,
1114                         tcp_processes=1,
1115                         udp_throughput='1000M'):
1116    """Function to format iperf client arguments.
1117
1118    This function takes in iperf client parameters and returns a properly
1119    formatter iperf arg string to be used in throughput tests.
1120
1121    Args:
1122        duration: iperf duration in seconds
1123        reverse_direction: boolean controlling the -R flag for iperf clients
1124        interval: iperf print interval
1125        traffic_type: string specifying TCP or UDP traffic
1126        tcp_window: string specifying TCP window, e.g., 2M
1127        tcp_processes: int specifying number of tcp processes
1128        udp_throughput: string specifying TX throughput in UDP tests, e.g. 100M
1129    Returns:
1130        iperf_args: string of formatted iperf args
1131    """
1132    iperf_args = '-i {} -t {} -J '.format(interval, duration)
1133    if traffic_type.upper() == 'UDP':
1134        iperf_args = iperf_args + '-u -b {} -l 1400'.format(udp_throughput)
1135    elif traffic_type.upper() == 'TCP':
1136        iperf_args = iperf_args + '-P {}'.format(tcp_processes)
1137        if tcp_window:
1138            iperf_args = iperf_args + '-w {}'.format(tcp_window)
1139    if reverse_direction:
1140        iperf_args = iperf_args + ' -R'
1141    return iperf_args
1142
1143
1144def get_dut_temperature(dut):
1145    """Function to get dut temperature.
1146
1147    The function fetches and returns the reading from the temperature sensor
1148    used for skin temperature and thermal throttling.
1149
1150    Args:
1151        dut: AndroidDevice of interest
1152    Returns:
1153        temperature: device temperature. 0 if temperature could not be read
1154    """
1155    candidate_zones = [
1156        'skin-therm', 'sdm-therm-monitor', 'sdm-therm-adc', 'back_therm'
1157    ]
1158    for zone in candidate_zones:
1159        try:
1160            temperature = int(
1161                dut.adb.shell(
1162                    'cat /sys/class/thermal/tz-by-name/{}/temp'.format(zone)))
1163            break
1164        except ValueError:
1165            temperature = 0
1166    if temperature == 0:
1167        logging.debug('Could not check DUT temperature.')
1168    elif temperature > 100:
1169        temperature = temperature / 1000
1170    return temperature
1171
1172
1173def wait_for_dut_cooldown(dut, target_temp=50, timeout=300):
1174    """Function to wait for a DUT to cool down.
1175
1176    Args:
1177        dut: AndroidDevice of interest
1178        target_temp: target cooldown temperature
1179        timeout: maxt time to wait for cooldown
1180    """
1181    start_time = time.time()
1182    while time.time() - start_time < timeout:
1183        temperature = get_dut_temperature(dut)
1184        if temperature < target_temp:
1185            break
1186        time.sleep(SHORT_SLEEP)
1187    elapsed_time = time.time() - start_time
1188    logging.debug("DUT Final Temperature: {}C. Cooldown duration: {}".format(
1189        temperature, elapsed_time))
1190
1191
1192def health_check(dut, batt_thresh=5, temp_threshold=53, cooldown=1):
1193    """Function to check health status of a DUT.
1194
1195    The function checks both battery levels and temperature to avoid DUT
1196    powering off during the test.
1197
1198    Args:
1199        dut: AndroidDevice of interest
1200        batt_thresh: battery level threshold
1201        temp_threshold: temperature threshold
1202        cooldown: flag to wait for DUT to cool down when overheating
1203    Returns:
1204        health_check: boolean confirming device is healthy
1205    """
1206    health_check = True
1207    battery_level = utils.get_battery_level(dut)
1208    if battery_level < batt_thresh:
1209        logging.warning("Battery level low ({}%)".format(battery_level))
1210        health_check = False
1211    else:
1212        logging.debug("Battery level = {}%".format(battery_level))
1213
1214    temperature = get_dut_temperature(dut)
1215    if temperature > temp_threshold:
1216        if cooldown:
1217            logging.warning(
1218                "Waiting for DUT to cooldown. ({} C)".format(temperature))
1219            wait_for_dut_cooldown(dut, target_temp=temp_threshold - 5)
1220        else:
1221            logging.warning("DUT Overheating ({} C)".format(temperature))
1222            health_check = False
1223    else:
1224        logging.debug("DUT Temperature = {} C".format(temperature))
1225    return health_check
1226
1227
1228def get_sw_signature(dut):
1229    """Function that checks the signature for wifi firmware and config files.
1230
1231    Returns:
1232        bdf_signature: signature consisting of last three digits of bdf cksums
1233        fw_signature: floating point firmware version, i.e., major.minor
1234    """
1235    bdf_output = dut.adb.shell('cksum /vendor/firmware/bdwlan*')
1236    logging.debug('BDF Checksum output: {}'.format(bdf_output))
1237    bdf_signature = sum(
1238        [int(line.split(' ')[0]) for line in bdf_output.splitlines()]) % 1000
1239
1240    fw_output = dut.adb.shell('halutil -logger -get fw')
1241    logging.debug('Firmware version output: {}'.format(fw_output))
1242    fw_version = re.search(FW_REGEX, fw_output).group('firmware')
1243    fw_signature = fw_version.split('.')[-3:-1]
1244    fw_signature = float('.'.join(fw_signature))
1245    serial_hash = int(hashlib.md5(dut.serial.encode()).hexdigest(), 16) % 1000
1246    return {
1247        'bdf_signature': bdf_signature,
1248        'fw_signature': fw_signature,
1249        'serial_hash': serial_hash
1250    }
1251
1252
1253def push_bdf(dut, bdf_file):
1254    """Function to push Wifi BDF files
1255
1256    This function checks for existing wifi bdf files and over writes them all,
1257    for simplicity, with the bdf file provided in the arguments. The dut is
1258    rebooted for the bdf file to take effect
1259
1260    Args:
1261        dut: dut to push bdf file to
1262        bdf_file: path to bdf_file to push
1263    """
1264    bdf_files_list = dut.adb.shell('ls /vendor/firmware/bdwlan*').splitlines()
1265    for dst_file in bdf_files_list:
1266        dut.push_system_file(bdf_file, dst_file)
1267    dut.reboot()
1268
1269
1270def push_firmware(dut, wlanmdsp_file, datamsc_file):
1271    """Function to push Wifi firmware files
1272
1273    Args:
1274        dut: dut to push bdf file to
1275        wlanmdsp_file: path to wlanmdsp.mbn file
1276        datamsc_file: path to Data.msc file
1277    """
1278    dut.push_system_file(wlanmdsp_file, '/vendor/firmware/wlanmdsp.mbn')
1279    dut.push_system_file(datamsc_file, '/vendor/firmware/Data.msc')
1280    dut.reboot()
1281
1282
1283def _set_ini_fields(ini_file_path, ini_field_dict):
1284    template_regex = r'^{}=[0-9,.x-]+'
1285    with open(ini_file_path, 'r') as f:
1286        ini_lines = f.read().splitlines()
1287        for idx, line in enumerate(ini_lines):
1288            for field_name, field_value in ini_field_dict.items():
1289                line_regex = re.compile(template_regex.format(field_name))
1290                if re.match(line_regex, line):
1291                    ini_lines[idx] = "{}={}".format(field_name, field_value)
1292                    print(ini_lines[idx])
1293    with open(ini_file_path, 'w') as f:
1294        f.write("\n".join(ini_lines) + "\n")
1295
1296
1297def _edit_dut_ini(dut, ini_fields):
1298    """Function to edit Wifi ini files."""
1299    dut_ini_path = '/vendor/firmware/wlan/qca_cld/WCNSS_qcom_cfg.ini'
1300    local_ini_path = os.path.expanduser('~/WCNSS_qcom_cfg.ini')
1301    dut.pull_files(dut_ini_path, local_ini_path)
1302
1303    _set_ini_fields(local_ini_path, ini_fields)
1304
1305    dut.push_system_file(local_ini_path, dut_ini_path)
1306    dut.reboot()
1307
1308
1309def set_ini_single_chain_mode(dut, chain):
1310    ini_fields = {
1311        'gEnable2x2': 0,
1312        'gSetTxChainmask1x1': chain + 1,
1313        'gSetRxChainmask1x1': chain + 1,
1314        'gDualMacFeatureDisable': 1,
1315        'gDot11Mode': 0
1316    }
1317    _edit_dut_ini(dut, ini_fields)
1318
1319
1320def set_ini_two_chain_mode(dut):
1321    ini_fields = {
1322        'gEnable2x2': 2,
1323        'gSetTxChainmask1x1': 1,
1324        'gSetRxChainmask1x1': 1,
1325        'gDualMacFeatureDisable': 6,
1326        'gDot11Mode': 0
1327    }
1328    _edit_dut_ini(dut, ini_fields)
1329
1330
1331def set_ini_tx_mode(dut, mode):
1332    TX_MODE_DICT = {
1333        "Auto": 0,
1334        "11n": 4,
1335        "11ac": 9,
1336        "11abg": 1,
1337        "11b": 2,
1338        "11g": 3,
1339        "11g only": 5,
1340        "11n only": 6,
1341        "11b only": 7,
1342        "11ac only": 8
1343    }
1344
1345    ini_fields = {
1346        'gEnable2x2': 2,
1347        'gSetTxChainmask1x1': 1,
1348        'gSetRxChainmask1x1': 1,
1349        'gDualMacFeatureDisable': 6,
1350        'gDot11Mode': TX_MODE_DICT[mode]
1351    }
1352    _edit_dut_ini(dut, ini_fields)
1353