1#!/usr/bin/env python3.4 2# 3# Copyright 2019 - The Android Open Source Project 4# 5# Licensed under the Apache License, Version 2.0 (the 'License'); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an 'AS IS' BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16 17import bokeh, bokeh.plotting 18import collections 19import hashlib 20import itertools 21import json 22import logging 23import math 24import os 25import re 26import statistics 27import time 28from acts.controllers.android_device import AndroidDevice 29from acts.controllers.utils_lib import ssh 30from acts import utils 31from acts.test_utils.wifi import wifi_test_utils as wutils 32from concurrent.futures import ThreadPoolExecutor 33 34SHORT_SLEEP = 1 35MED_SLEEP = 6 36TEST_TIMEOUT = 10 37STATION_DUMP = 'iw wlan0 station dump' 38SCAN = 'wpa_cli scan' 39SCAN_RESULTS = 'wpa_cli scan_results' 40SIGNAL_POLL = 'wpa_cli signal_poll' 41WPA_CLI_STATUS = 'wpa_cli status' 42CONST_3dB = 3.01029995664 43RSSI_ERROR_VAL = float('nan') 44RTT_REGEX = re.compile(r'^\[(?P<timestamp>\S+)\] .*? time=(?P<rtt>\S+)') 45LOSS_REGEX = re.compile(r'(?P<loss>\S+)% packet loss') 46FW_REGEX = re.compile(r'FW:(?P<firmware>\S+) HW:') 47 48 49# Threading decorator 50def nonblocking(f): 51 """Creates a decorator transforming function calls to non-blocking""" 52 def wrap(*args, **kwargs): 53 executor = ThreadPoolExecutor(max_workers=1) 54 thread_future = executor.submit(f, *args, **kwargs) 55 # Ensure resources are freed up when executor ruturns or raises 56 executor.shutdown(wait=False) 57 return thread_future 58 59 return wrap 60 61 62# Link layer stats utilities 63class LinkLayerStats(): 64 65 LLSTATS_CMD = 'cat /d/wlan0/ll_stats' 66 PEER_REGEX = 'LL_STATS_PEER_ALL' 67 MCS_REGEX = re.compile( 68 r'preamble: (?P<mode>\S+), nss: (?P<num_streams>\S+), bw: (?P<bw>\S+), ' 69 'mcs: (?P<mcs>\S+), bitrate: (?P<rate>\S+), txmpdu: (?P<txmpdu>\S+), ' 70 'rxmpdu: (?P<rxmpdu>\S+), mpdu_lost: (?P<mpdu_lost>\S+), ' 71 'retries: (?P<retries>\S+), retries_short: (?P<retries_short>\S+), ' 72 'retries_long: (?P<retries_long>\S+)') 73 MCS_ID = collections.namedtuple( 74 'mcs_id', ['mode', 'num_streams', 'bandwidth', 'mcs', 'rate']) 75 MODE_MAP = {'0': '11a/g', '1': '11b', '2': '11n', '3': '11ac'} 76 BW_MAP = {'0': 20, '1': 40, '2': 80} 77 78 def __init__(self, dut, llstats_enabled=True): 79 self.dut = dut 80 self.llstats_enabled = llstats_enabled 81 self.llstats_cumulative = self._empty_llstats() 82 self.llstats_incremental = self._empty_llstats() 83 84 def update_stats(self): 85 if self.llstats_enabled: 86 try: 87 llstats_output = self.dut.adb.shell(self.LLSTATS_CMD, timeout=0.1) 88 except: 89 llstats_output = '' 90 else: 91 llstats_output = '' 92 self._update_stats(llstats_output) 93 94 def reset_stats(self): 95 self.llstats_cumulative = self._empty_llstats() 96 self.llstats_incremental = self._empty_llstats() 97 98 def _empty_llstats(self): 99 return collections.OrderedDict(mcs_stats=collections.OrderedDict(), 100 summary=collections.OrderedDict()) 101 102 def _empty_mcs_stat(self): 103 return collections.OrderedDict(txmpdu=0, 104 rxmpdu=0, 105 mpdu_lost=0, 106 retries=0, 107 retries_short=0, 108 retries_long=0) 109 110 def _mcs_id_to_string(self, mcs_id): 111 mcs_string = '{} {}MHz Nss{} MCS{} {}Mbps'.format( 112 mcs_id.mode, mcs_id.bandwidth, mcs_id.num_streams, mcs_id.mcs, 113 mcs_id.rate) 114 return mcs_string 115 116 def _parse_mcs_stats(self, llstats_output): 117 llstats_dict = {} 118 # Look for per-peer stats 119 match = re.search(self.PEER_REGEX, llstats_output) 120 if not match: 121 self.reset_stats() 122 return collections.OrderedDict() 123 # Find and process all matches for per stream stats 124 match_iter = re.finditer(self.MCS_REGEX, llstats_output) 125 for match in match_iter: 126 current_mcs = self.MCS_ID(self.MODE_MAP[match.group('mode')], 127 int(match.group('num_streams')) + 1, 128 self.BW_MAP[match.group('bw')], 129 int(match.group('mcs')), 130 int(match.group('rate'), 16) / 1000) 131 current_stats = collections.OrderedDict( 132 txmpdu=int(match.group('txmpdu')), 133 rxmpdu=int(match.group('rxmpdu')), 134 mpdu_lost=int(match.group('mpdu_lost')), 135 retries=int(match.group('retries')), 136 retries_short=int(match.group('retries_short')), 137 retries_long=int(match.group('retries_long'))) 138 llstats_dict[self._mcs_id_to_string(current_mcs)] = current_stats 139 return llstats_dict 140 141 def _diff_mcs_stats(self, new_stats, old_stats): 142 stats_diff = collections.OrderedDict() 143 for stat_key in new_stats.keys(): 144 stats_diff[stat_key] = new_stats[stat_key] - old_stats[stat_key] 145 return stats_diff 146 147 def _generate_stats_summary(self, llstats_dict): 148 llstats_summary = collections.OrderedDict(common_tx_mcs=None, 149 common_tx_mcs_count=0, 150 common_tx_mcs_freq=0, 151 common_rx_mcs=None, 152 common_rx_mcs_count=0, 153 common_rx_mcs_freq=0) 154 txmpdu_count = 0 155 rxmpdu_count = 0 156 for mcs_id, mcs_stats in llstats_dict['mcs_stats'].items(): 157 if mcs_stats['txmpdu'] > llstats_summary['common_tx_mcs_count']: 158 llstats_summary['common_tx_mcs'] = mcs_id 159 llstats_summary['common_tx_mcs_count'] = mcs_stats['txmpdu'] 160 if mcs_stats['rxmpdu'] > llstats_summary['common_rx_mcs_count']: 161 llstats_summary['common_rx_mcs'] = mcs_id 162 llstats_summary['common_rx_mcs_count'] = mcs_stats['rxmpdu'] 163 txmpdu_count += mcs_stats['txmpdu'] 164 rxmpdu_count += mcs_stats['rxmpdu'] 165 if txmpdu_count: 166 llstats_summary['common_tx_mcs_freq'] = ( 167 llstats_summary['common_tx_mcs_count'] / txmpdu_count) 168 if rxmpdu_count: 169 llstats_summary['common_rx_mcs_freq'] = ( 170 llstats_summary['common_rx_mcs_count'] / rxmpdu_count) 171 return llstats_summary 172 173 def _update_stats(self, llstats_output): 174 # Parse stats 175 new_llstats = self._empty_llstats() 176 new_llstats['mcs_stats'] = self._parse_mcs_stats(llstats_output) 177 # Save old stats and set new cumulative stats 178 old_llstats = self.llstats_cumulative.copy() 179 self.llstats_cumulative = new_llstats.copy() 180 # Compute difference between new and old stats 181 self.llstats_incremental = self._empty_llstats() 182 for mcs_id, new_mcs_stats in new_llstats['mcs_stats'].items(): 183 old_mcs_stats = old_llstats['mcs_stats'].get( 184 mcs_id, self._empty_mcs_stat()) 185 self.llstats_incremental['mcs_stats'][ 186 mcs_id] = self._diff_mcs_stats(new_mcs_stats, old_mcs_stats) 187 # Generate llstats summary 188 self.llstats_incremental['summary'] = self._generate_stats_summary( 189 self.llstats_incremental) 190 self.llstats_cumulative['summary'] = self._generate_stats_summary( 191 self.llstats_cumulative) 192 193 194# JSON serializer 195def serialize_dict(input_dict): 196 """Function to serialize dicts to enable JSON output""" 197 output_dict = collections.OrderedDict() 198 for key, value in input_dict.items(): 199 output_dict[_serialize_value(key)] = _serialize_value(value) 200 return output_dict 201 202 203def _serialize_value(value): 204 """Function to recursively serialize dict entries to enable JSON output""" 205 if isinstance(value, tuple): 206 return str(value) 207 if isinstance(value, list): 208 return [_serialize_value(x) for x in value] 209 elif isinstance(value, dict): 210 return serialize_dict(value) 211 else: 212 return value 213 214 215# Plotting Utilities 216class BokehFigure(): 217 """Class enabling simplified Bokeh plotting.""" 218 219 COLORS = [ 220 'black', 221 'blue', 222 'blueviolet', 223 'brown', 224 'burlywood', 225 'cadetblue', 226 'cornflowerblue', 227 'crimson', 228 'cyan', 229 'darkblue', 230 'darkgreen', 231 'darkmagenta', 232 'darkorange', 233 'darkred', 234 'deepskyblue', 235 'goldenrod', 236 'green', 237 'grey', 238 'indigo', 239 'navy', 240 'olive', 241 'orange', 242 'red', 243 'salmon', 244 'teal', 245 'yellow', 246 ] 247 MARKERS = [ 248 'asterisk', 'circle', 'circle_cross', 'circle_x', 'cross', 'diamond', 249 'diamond_cross', 'hex', 'inverted_triangle', 'square', 'square_x', 250 'square_cross', 'triangle', 'x' 251 ] 252 253 TOOLS = ('box_zoom,box_select,pan,crosshair,redo,undo,reset,hover,save') 254 TOOLTIPS = [ 255 ('index', '$index'), 256 ('(x,y)', '($x, $y)'), 257 ('info', '@hover_text'), 258 ] 259 260 def __init__(self, 261 title=None, 262 x_label=None, 263 primary_y_label=None, 264 secondary_y_label=None, 265 height=700, 266 width=1100, 267 title_size='15pt', 268 axis_label_size='12pt', 269 json_file=None): 270 if json_file: 271 self.load_from_json(json_file) 272 else: 273 self.figure_data = [] 274 self.fig_property = { 275 'title': title, 276 'x_label': x_label, 277 'primary_y_label': primary_y_label, 278 'secondary_y_label': secondary_y_label, 279 'num_lines': 0, 280 'height': height, 281 'width': width, 282 'title_size': title_size, 283 'axis_label_size': axis_label_size 284 } 285 286 def init_plot(self): 287 self.plot = bokeh.plotting.figure( 288 sizing_mode='scale_both', 289 plot_width=self.fig_property['width'], 290 plot_height=self.fig_property['height'], 291 title=self.fig_property['title'], 292 tools=self.TOOLS, 293 output_backend='webgl') 294 self.plot.hover.tooltips = self.TOOLTIPS 295 self.plot.add_tools( 296 bokeh.models.tools.WheelZoomTool(dimensions='width')) 297 self.plot.add_tools( 298 bokeh.models.tools.WheelZoomTool(dimensions='height')) 299 300 def _filter_line(self, x_data, y_data, hover_text=None): 301 """Function to remove NaN points from bokeh plots.""" 302 x_data_filtered = [] 303 y_data_filtered = [] 304 hover_text_filtered = [] 305 for x, y, hover in itertools.zip_longest(x_data, y_data, hover_text): 306 if not math.isnan(y): 307 x_data_filtered.append(x) 308 y_data_filtered.append(y) 309 hover_text_filtered.append(hover) 310 return x_data_filtered, y_data_filtered, hover_text_filtered 311 312 def add_line(self, 313 x_data, 314 y_data, 315 legend, 316 hover_text=None, 317 color=None, 318 width=3, 319 style='solid', 320 marker=None, 321 marker_size=10, 322 shaded_region=None, 323 y_axis='default'): 324 """Function to add line to existing BokehFigure. 325 326 Args: 327 x_data: list containing x-axis values for line 328 y_data: list containing y_axis values for line 329 legend: string containing line title 330 hover_text: text to display when hovering over lines 331 color: string describing line color 332 width: integer line width 333 style: string describing line style, e.g, solid or dashed 334 marker: string specifying line marker, e.g., cross 335 shaded region: data describing shaded region to plot 336 y_axis: identifier for y-axis to plot line against 337 """ 338 if y_axis not in ['default', 'secondary']: 339 raise ValueError('y_axis must be default or secondary') 340 if color == None: 341 color = self.COLORS[self.fig_property['num_lines'] % 342 len(self.COLORS)] 343 if style == 'dashed': 344 style = [5, 5] 345 if not hover_text: 346 hover_text = ['y={}'.format(y) for y in y_data] 347 x_data_filter, y_data_filter, hover_text_filter = self._filter_line( 348 x_data, y_data, hover_text) 349 self.figure_data.append({ 350 'x_data': x_data_filter, 351 'y_data': y_data_filter, 352 'legend': legend, 353 'hover_text': hover_text_filter, 354 'color': color, 355 'width': width, 356 'style': style, 357 'marker': marker, 358 'marker_size': marker_size, 359 'shaded_region': shaded_region, 360 'y_axis': y_axis 361 }) 362 self.fig_property['num_lines'] += 1 363 364 def add_scatter(self, 365 x_data, 366 y_data, 367 legend, 368 hover_text=None, 369 color=None, 370 marker=None, 371 marker_size=10, 372 y_axis='default'): 373 """Function to add line to existing BokehFigure. 374 375 Args: 376 x_data: list containing x-axis values for line 377 y_data: list containing y_axis values for line 378 legend: string containing line title 379 hover_text: text to display when hovering over lines 380 color: string describing line color 381 marker: string specifying marker, e.g., cross 382 y_axis: identifier for y-axis to plot line against 383 """ 384 if y_axis not in ['default', 'secondary']: 385 raise ValueError('y_axis must be default or secondary') 386 if color == None: 387 color = self.COLORS[self.fig_property['num_lines'] % 388 len(self.COLORS)] 389 if marker == None: 390 marker = self.MARKERS[self.fig_property['num_lines'] % 391 len(self.MARKERS)] 392 if not hover_text: 393 hover_text = ['y={}'.format(y) for y in y_data] 394 self.figure_data.append({ 395 'x_data': x_data, 396 'y_data': y_data, 397 'legend': legend, 398 'hover_text': hover_text, 399 'color': color, 400 'width': 0, 401 'style': 'solid', 402 'marker': marker, 403 'marker_size': marker_size, 404 'shaded_region': None, 405 'y_axis': y_axis 406 }) 407 self.fig_property['num_lines'] += 1 408 409 def generate_figure(self, output_file=None, save_json=True): 410 """Function to generate and save BokehFigure. 411 412 Args: 413 output_file: string specifying output file path 414 """ 415 self.init_plot() 416 two_axes = False 417 for line in self.figure_data: 418 source = bokeh.models.ColumnDataSource( 419 data=dict(x=line['x_data'], 420 y=line['y_data'], 421 hover_text=line['hover_text'])) 422 if line['width'] > 0: 423 self.plot.line(x='x', 424 y='y', 425 legend_label=line['legend'], 426 line_width=line['width'], 427 color=line['color'], 428 line_dash=line['style'], 429 name=line['y_axis'], 430 y_range_name=line['y_axis'], 431 source=source) 432 if line['shaded_region']: 433 band_x = line['shaded_region']['x_vector'] 434 band_x.extend(line['shaded_region']['x_vector'][::-1]) 435 band_y = line['shaded_region']['lower_limit'] 436 band_y.extend(line['shaded_region']['upper_limit'][::-1]) 437 self.plot.patch(band_x, 438 band_y, 439 color='#7570B3', 440 line_alpha=0.1, 441 fill_alpha=0.1) 442 if line['marker'] in self.MARKERS: 443 marker_func = getattr(self.plot, line['marker']) 444 marker_func(x='x', 445 y='y', 446 size=line['marker_size'], 447 legend_label=line['legend'], 448 line_color=line['color'], 449 fill_color=line['color'], 450 name=line['y_axis'], 451 y_range_name=line['y_axis'], 452 source=source) 453 if line['y_axis'] == 'secondary': 454 two_axes = True 455 456 #x-axis formatting 457 self.plot.xaxis.axis_label = self.fig_property['x_label'] 458 self.plot.x_range.range_padding = 0 459 self.plot.xaxis[0].axis_label_text_font_size = self.fig_property[ 460 'axis_label_size'] 461 #y-axis formatting 462 self.plot.yaxis[0].axis_label = self.fig_property['primary_y_label'] 463 self.plot.yaxis[0].axis_label_text_font_size = self.fig_property[ 464 'axis_label_size'] 465 self.plot.y_range = bokeh.models.DataRange1d(names=['default']) 466 if two_axes and 'secondary' not in self.plot.extra_y_ranges: 467 self.plot.extra_y_ranges = { 468 'secondary': bokeh.models.DataRange1d(names=['secondary']) 469 } 470 self.plot.add_layout( 471 bokeh.models.LinearAxis( 472 y_range_name='secondary', 473 axis_label=self.fig_property['secondary_y_label'], 474 axis_label_text_font_size=self. 475 fig_property['axis_label_size']), 'right') 476 # plot formatting 477 self.plot.legend.location = 'top_right' 478 self.plot.legend.click_policy = 'hide' 479 self.plot.title.text_font_size = self.fig_property['title_size'] 480 481 if output_file is not None: 482 self.save_figure(output_file, save_json) 483 return self.plot 484 485 def load_from_json(self, file_path): 486 with open(file_path, 'r') as json_file: 487 fig_dict = json.load(json_file) 488 self.fig_property = fig_dict['fig_property'] 489 self.figure_data = fig_dict['figure_data'] 490 491 def _save_figure_json(self, output_file): 492 """Function to save a json format of a figure""" 493 figure_dict = collections.OrderedDict(fig_property=self.fig_property, 494 figure_data=self.figure_data) 495 output_file = output_file.replace('.html', '_plot_data.json') 496 with open(output_file, 'w') as outfile: 497 json.dump(figure_dict, outfile, indent=4) 498 499 def save_figure(self, output_file, save_json=True): 500 """Function to save BokehFigure. 501 502 Args: 503 output_file: string specifying output file path 504 save_json: flag controlling json outputs 505 """ 506 bokeh.plotting.output_file(output_file) 507 bokeh.plotting.save(self.plot) 508 if save_json: 509 self._save_figure_json(output_file) 510 511 @staticmethod 512 def save_figures(figure_array, output_file_path, save_json=True): 513 """Function to save list of BokehFigures in one file. 514 515 Args: 516 figure_array: list of BokehFigure object to be plotted 517 output_file: string specifying output file path 518 """ 519 for idx, figure in enumerate(figure_array): 520 figure.generate_figure() 521 if save_json: 522 json_file_path = output_file_path.replace( 523 '.html', '{}-plot_data.json'.format(idx)) 524 figure._save_figure_json(json_file_path) 525 plot_array = [figure.plot for figure in figure_array] 526 all_plots = bokeh.layouts.column(children=plot_array, 527 sizing_mode='scale_width') 528 bokeh.plotting.output_file(output_file_path) 529 bokeh.plotting.save(all_plots) 530 531 532# Ping utilities 533class PingResult(object): 534 """An object that contains the results of running ping command. 535 536 Attributes: 537 connected: True if a connection was made. False otherwise. 538 packet_loss_percentage: The total percentage of packets lost. 539 transmission_times: The list of PingTransmissionTimes containing the 540 timestamps gathered for transmitted packets. 541 rtts: An list-like object enumerating all round-trip-times of 542 transmitted packets. 543 timestamps: A list-like object enumerating the beginning timestamps of 544 each packet transmission. 545 ping_interarrivals: A list-like object enumerating the amount of time 546 between the beginning of each subsequent transmission. 547 """ 548 def __init__(self, ping_output): 549 self.packet_loss_percentage = 100 550 self.transmission_times = [] 551 552 self.rtts = _ListWrap(self.transmission_times, lambda entry: entry.rtt) 553 self.timestamps = _ListWrap(self.transmission_times, 554 lambda entry: entry.timestamp) 555 self.ping_interarrivals = _PingInterarrivals(self.transmission_times) 556 557 self.start_time = 0 558 for line in ping_output: 559 if 'loss' in line: 560 match = re.search(LOSS_REGEX, line) 561 self.packet_loss_percentage = float(match.group('loss')) 562 if 'time=' in line: 563 match = re.search(RTT_REGEX, line) 564 if self.start_time == 0: 565 self.start_time = float(match.group('timestamp')) 566 self.transmission_times.append( 567 PingTransmissionTimes( 568 float(match.group('timestamp')) - self.start_time, 569 float(match.group('rtt')))) 570 self.connected = len( 571 ping_output) > 1 and self.packet_loss_percentage < 100 572 573 def __getitem__(self, item): 574 if item == 'rtt': 575 return self.rtts 576 if item == 'connected': 577 return self.connected 578 if item == 'packet_loss_percentage': 579 return self.packet_loss_percentage 580 raise ValueError('Invalid key. Please use an attribute instead.') 581 582 def as_dict(self): 583 return { 584 'connected': 1 if self.connected else 0, 585 'rtt': list(self.rtts), 586 'time_stamp': list(self.timestamps), 587 'ping_interarrivals': list(self.ping_interarrivals), 588 'packet_loss_percentage': self.packet_loss_percentage 589 } 590 591 592class PingTransmissionTimes(object): 593 """A class that holds the timestamps for a packet sent via the ping command. 594 595 Attributes: 596 rtt: The round trip time for the packet sent. 597 timestamp: The timestamp the packet started its trip. 598 """ 599 def __init__(self, timestamp, rtt): 600 self.rtt = rtt 601 self.timestamp = timestamp 602 603 604class _ListWrap(object): 605 """A convenient helper class for treating list iterators as native lists.""" 606 def __init__(self, wrapped_list, func): 607 self.__wrapped_list = wrapped_list 608 self.__func = func 609 610 def __getitem__(self, key): 611 return self.__func(self.__wrapped_list[key]) 612 613 def __iter__(self): 614 for item in self.__wrapped_list: 615 yield self.__func(item) 616 617 def __len__(self): 618 return len(self.__wrapped_list) 619 620 621class _PingInterarrivals(object): 622 """A helper class for treating ping interarrivals as a native list.""" 623 def __init__(self, ping_entries): 624 self.__ping_entries = ping_entries 625 626 def __getitem__(self, key): 627 return (self.__ping_entries[key + 1].timestamp - 628 self.__ping_entries[key].timestamp) 629 630 def __iter__(self): 631 for index in range(len(self.__ping_entries) - 1): 632 yield self[index] 633 634 def __len__(self): 635 return max(0, len(self.__ping_entries) - 1) 636 637 638def get_ping_stats(src_device, dest_address, ping_duration, ping_interval, 639 ping_size): 640 """Run ping to or from the DUT. 641 642 The function computes either pings the DUT or pings a remote ip from 643 DUT. 644 645 Args: 646 src_device: object representing device to ping from 647 dest_address: ip address to ping 648 ping_duration: timeout to set on the the ping process (in seconds) 649 ping_interval: time between pings (in seconds) 650 ping_size: size of ping packet payload 651 Returns: 652 ping_result: dict containing ping results and other meta data 653 """ 654 ping_count = int(ping_duration / ping_interval) 655 ping_deadline = int(ping_count * ping_interval) + 1 656 ping_cmd = 'ping -c {} -w {} -i {} -s {} -D'.format( 657 ping_count, 658 ping_deadline, 659 ping_interval, 660 ping_size, 661 ) 662 if isinstance(src_device, AndroidDevice): 663 ping_cmd = '{} {}'.format(ping_cmd, dest_address) 664 ping_output = src_device.adb.shell(ping_cmd, 665 timeout=ping_deadline + SHORT_SLEEP, 666 ignore_status=True) 667 elif isinstance(src_device, ssh.connection.SshConnection): 668 ping_cmd = 'sudo {} {}'.format(ping_cmd, dest_address) 669 ping_output = src_device.run(ping_cmd, 670 timeout=ping_deadline + SHORT_SLEEP, 671 ignore_status=True).stdout 672 else: 673 raise TypeError('Unable to ping using src_device of type %s.' % 674 type(src_device)) 675 return PingResult(ping_output.splitlines()) 676 677 678@nonblocking 679def get_ping_stats_nb(src_device, dest_address, ping_duration, ping_interval, 680 ping_size): 681 return get_ping_stats(src_device, dest_address, ping_duration, 682 ping_interval, ping_size) 683 684 685@nonblocking 686def start_iperf_client_nb(iperf_client, iperf_server_address, iperf_args, tag, 687 timeout): 688 return iperf_client.start(iperf_server_address, iperf_args, tag, timeout) 689 690 691# Rssi Utilities 692def empty_rssi_result(): 693 return collections.OrderedDict([('data', []), ('mean', None), 694 ('stdev', None)]) 695 696 697def get_connected_rssi(dut, 698 num_measurements=1, 699 polling_frequency=SHORT_SLEEP, 700 first_measurement_delay=0, 701 disconnect_warning=True, 702 ignore_samples=0): 703 """Gets all RSSI values reported for the connected access point/BSSID. 704 705 Args: 706 dut: android device object from which to get RSSI 707 num_measurements: number of scans done, and RSSIs collected 708 polling_frequency: time to wait between RSSI measurements 709 disconnect_warning: boolean controlling disconnection logging messages 710 ignore_samples: number of leading samples to ignore 711 Returns: 712 connected_rssi: dict containing the measurements results for 713 all reported RSSI values (signal_poll, per chain, etc.) and their 714 statistics 715 """ 716 # yapf: disable 717 connected_rssi = collections.OrderedDict( 718 [('time_stamp', []), 719 ('bssid', []), ('frequency', []), 720 ('signal_poll_rssi', empty_rssi_result()), 721 ('signal_poll_avg_rssi', empty_rssi_result()), 722 ('chain_0_rssi', empty_rssi_result()), 723 ('chain_1_rssi', empty_rssi_result())]) 724 # yapf: enable 725 previous_bssid = 'disconnected' 726 t0 = time.time() 727 time.sleep(first_measurement_delay) 728 for idx in range(num_measurements): 729 measurement_start_time = time.time() 730 connected_rssi['time_stamp'].append(measurement_start_time - t0) 731 # Get signal poll RSSI 732 status_output = dut.adb.shell(WPA_CLI_STATUS) 733 match = re.search('bssid=.*', status_output) 734 if match: 735 current_bssid = match.group(0).split('=')[1] 736 connected_rssi['bssid'].append(current_bssid) 737 else: 738 current_bssid = 'disconnected' 739 connected_rssi['bssid'].append(current_bssid) 740 if disconnect_warning and previous_bssid != 'disconnected': 741 logging.warning('WIFI DISCONNECT DETECTED!') 742 previous_bssid = current_bssid 743 signal_poll_output = dut.adb.shell(SIGNAL_POLL) 744 match = re.search('FREQUENCY=.*', signal_poll_output) 745 if match: 746 frequency = int(match.group(0).split('=')[1]) 747 connected_rssi['frequency'].append(frequency) 748 else: 749 connected_rssi['frequency'].append(RSSI_ERROR_VAL) 750 match = re.search('RSSI=.*', signal_poll_output) 751 if match: 752 temp_rssi = int(match.group(0).split('=')[1]) 753 if temp_rssi == -9999 or temp_rssi == 0: 754 connected_rssi['signal_poll_rssi']['data'].append( 755 RSSI_ERROR_VAL) 756 else: 757 connected_rssi['signal_poll_rssi']['data'].append(temp_rssi) 758 else: 759 connected_rssi['signal_poll_rssi']['data'].append(RSSI_ERROR_VAL) 760 match = re.search('AVG_RSSI=.*', signal_poll_output) 761 if match: 762 connected_rssi['signal_poll_avg_rssi']['data'].append( 763 int(match.group(0).split('=')[1])) 764 else: 765 connected_rssi['signal_poll_avg_rssi']['data'].append( 766 RSSI_ERROR_VAL) 767 # Get per chain RSSI 768 per_chain_rssi = dut.adb.shell(STATION_DUMP) 769 match = re.search('.*signal avg:.*', per_chain_rssi) 770 if match: 771 per_chain_rssi = per_chain_rssi[per_chain_rssi.find('[') + 772 1:per_chain_rssi.find(']')] 773 per_chain_rssi = per_chain_rssi.split(', ') 774 connected_rssi['chain_0_rssi']['data'].append( 775 int(per_chain_rssi[0])) 776 connected_rssi['chain_1_rssi']['data'].append( 777 int(per_chain_rssi[1])) 778 else: 779 connected_rssi['chain_0_rssi']['data'].append(RSSI_ERROR_VAL) 780 connected_rssi['chain_1_rssi']['data'].append(RSSI_ERROR_VAL) 781 measurement_elapsed_time = time.time() - measurement_start_time 782 time.sleep(max(0, polling_frequency - measurement_elapsed_time)) 783 784 # Compute mean RSSIs. Only average valid readings. 785 # Output RSSI_ERROR_VAL if no valid connected readings found. 786 for key, val in connected_rssi.copy().items(): 787 if 'data' not in val: 788 continue 789 filtered_rssi_values = [x for x in val['data'] if not math.isnan(x)] 790 if len(filtered_rssi_values) > ignore_samples: 791 filtered_rssi_values = filtered_rssi_values[ignore_samples:] 792 if filtered_rssi_values: 793 connected_rssi[key]['mean'] = statistics.mean(filtered_rssi_values) 794 if len(filtered_rssi_values) > 1: 795 connected_rssi[key]['stdev'] = statistics.stdev( 796 filtered_rssi_values) 797 else: 798 connected_rssi[key]['stdev'] = 0 799 else: 800 connected_rssi[key]['mean'] = RSSI_ERROR_VAL 801 connected_rssi[key]['stdev'] = RSSI_ERROR_VAL 802 return connected_rssi 803 804 805@nonblocking 806def get_connected_rssi_nb(dut, 807 num_measurements=1, 808 polling_frequency=SHORT_SLEEP, 809 first_measurement_delay=0, 810 disconnect_warning=True, 811 ignore_samples=0): 812 return get_connected_rssi(dut, num_measurements, polling_frequency, 813 first_measurement_delay, disconnect_warning, 814 ignore_samples) 815 816 817def get_scan_rssi(dut, tracked_bssids, num_measurements=1): 818 """Gets scan RSSI for specified BSSIDs. 819 820 Args: 821 dut: android device object from which to get RSSI 822 tracked_bssids: array of BSSIDs to gather RSSI data for 823 num_measurements: number of scans done, and RSSIs collected 824 Returns: 825 scan_rssi: dict containing the measurement results as well as the 826 statistics of the scan RSSI for all BSSIDs in tracked_bssids 827 """ 828 scan_rssi = collections.OrderedDict() 829 for bssid in tracked_bssids: 830 scan_rssi[bssid] = empty_rssi_result() 831 for idx in range(num_measurements): 832 scan_output = dut.adb.shell(SCAN) 833 time.sleep(MED_SLEEP) 834 scan_output = dut.adb.shell(SCAN_RESULTS) 835 for bssid in tracked_bssids: 836 bssid_result = re.search(bssid + '.*', 837 scan_output, 838 flags=re.IGNORECASE) 839 if bssid_result: 840 bssid_result = bssid_result.group(0).split('\t') 841 scan_rssi[bssid]['data'].append(int(bssid_result[2])) 842 else: 843 scan_rssi[bssid]['data'].append(RSSI_ERROR_VAL) 844 # Compute mean RSSIs. Only average valid readings. 845 # Output RSSI_ERROR_VAL if no readings found. 846 for key, val in scan_rssi.items(): 847 filtered_rssi_values = [x for x in val['data'] if not math.isnan(x)] 848 if filtered_rssi_values: 849 scan_rssi[key]['mean'] = statistics.mean(filtered_rssi_values) 850 if len(filtered_rssi_values) > 1: 851 scan_rssi[key]['stdev'] = statistics.stdev( 852 filtered_rssi_values) 853 else: 854 scan_rssi[key]['stdev'] = 0 855 else: 856 scan_rssi[key]['mean'] = RSSI_ERROR_VAL 857 scan_rssi[key]['stdev'] = RSSI_ERROR_VAL 858 return scan_rssi 859 860 861@nonblocking 862def get_scan_rssi_nb(dut, tracked_bssids, num_measurements=1): 863 return get_scan_rssi(dut, tracked_bssids, num_measurements) 864 865 866# Attenuator Utilities 867def atten_by_label(atten_list, path_label, atten_level): 868 """Attenuate signals according to their path label. 869 870 Args: 871 atten_list: list of attenuators to iterate over 872 path_label: path label on which to set desired attenuation 873 atten_level: attenuation desired on path 874 """ 875 for atten in atten_list: 876 if path_label in atten.path: 877 atten.set_atten(atten_level) 878 879 880def get_atten_for_target_rssi(target_rssi, attenuators, dut, ping_server): 881 """Function to estimate attenuation to hit a target RSSI. 882 883 This function estimates a constant attenuation setting on all atennuation 884 ports to hit a target RSSI. The estimate is not meant to be exact or 885 guaranteed. 886 887 Args: 888 target_rssi: rssi of interest 889 attenuators: list of attenuator ports 890 dut: android device object assumed connected to a wifi network. 891 ping_server: ssh connection object to ping server 892 Returns: 893 target_atten: attenuation setting to achieve target_rssi 894 """ 895 logging.info('Searching attenuation for RSSI = {}dB'.format(target_rssi)) 896 # Set attenuator to 0 dB 897 for atten in attenuators: 898 atten.set_atten(0, strict=False) 899 # Start ping traffic 900 dut_ip = dut.droid.connectivityGetIPv4Addresses('wlan0')[0] 901 # Measure starting RSSI 902 ping_future = get_ping_stats_nb(src_device=ping_server, 903 dest_address=dut_ip, 904 ping_duration=1.5, 905 ping_interval=0.02, 906 ping_size=64) 907 current_rssi = get_connected_rssi(dut, 908 num_measurements=4, 909 polling_frequency=0.25, 910 first_measurement_delay=0.5, 911 disconnect_warning=1, 912 ignore_samples=1) 913 current_rssi = current_rssi['signal_poll_rssi']['mean'] 914 ping_future.result() 915 target_atten = 0 916 logging.debug("RSSI @ {0:.2f}dB attenuation = {1:.2f}".format( 917 target_atten, current_rssi)) 918 within_range = 0 919 for idx in range(20): 920 atten_delta = max(min(current_rssi - target_rssi, 20), -20) 921 target_atten = int((target_atten + atten_delta) * 4) / 4 922 if target_atten < 0: 923 return 0 924 if target_atten > attenuators[0].get_max_atten(): 925 return attenuators[0].get_max_atten() 926 for atten in attenuators: 927 atten.set_atten(target_atten, strict=False) 928 ping_future = get_ping_stats_nb(src_device=ping_server, 929 dest_address=dut_ip, 930 ping_duration=1.5, 931 ping_interval=0.02, 932 ping_size=64) 933 current_rssi = get_connected_rssi(dut, 934 num_measurements=4, 935 polling_frequency=0.25, 936 first_measurement_delay=0.5, 937 disconnect_warning=1, 938 ignore_samples=1) 939 current_rssi = current_rssi['signal_poll_rssi']['mean'] 940 ping_future.result() 941 logging.info("RSSI @ {0:.2f}dB attenuation = {1:.2f}".format( 942 target_atten, current_rssi)) 943 if abs(current_rssi - target_rssi) < 1: 944 if within_range: 945 logging.info( 946 'Reached RSSI: {0:.2f}. Target RSSI: {1:.2f}.' 947 'Attenuation: {2:.2f}, Iterations = {3:.2f}'.format( 948 current_rssi, target_rssi, target_atten, idx)) 949 return target_atten 950 else: 951 within_range = True 952 else: 953 within_range = False 954 return target_atten 955 956 957def get_current_atten_dut_chain_map(attenuators, dut, ping_server): 958 """Function to detect mapping between attenuator ports and DUT chains. 959 960 This function detects the mapping between attenuator ports and DUT chains 961 in cases where DUT chains are connected to only one attenuator port. The 962 function assumes the DUT is already connected to a wifi network. The 963 function starts by measuring per chain RSSI at 0 attenuation, then 964 attenuates one port at a time looking for the chain that reports a lower 965 RSSI. 966 967 Args: 968 attenuators: list of attenuator ports 969 dut: android device object assumed connected to a wifi network. 970 ping_server: ssh connection object to ping server 971 Returns: 972 chain_map: list of dut chains, one entry per attenuator port 973 """ 974 # Set attenuator to 0 dB 975 for atten in attenuators: 976 atten.set_atten(0, strict=False) 977 # Start ping traffic 978 dut_ip = dut.droid.connectivityGetIPv4Addresses('wlan0')[0] 979 ping_future = get_ping_stats_nb(ping_server, dut_ip, 11, 0.02, 64) 980 # Measure starting RSSI 981 base_rssi = get_connected_rssi(dut, 4, 0.25, 1) 982 chain0_base_rssi = base_rssi['chain_0_rssi']['mean'] 983 chain1_base_rssi = base_rssi['chain_1_rssi']['mean'] 984 if chain0_base_rssi < -70 or chain1_base_rssi < -70: 985 logging.warning('RSSI might be too low to get reliable chain map.') 986 # Compile chain map by attenuating one path at a time and seeing which 987 # chain's RSSI degrades 988 chain_map = [] 989 for test_atten in attenuators: 990 # Set one attenuator to 30 dB down 991 test_atten.set_atten(30, strict=False) 992 # Get new RSSI 993 test_rssi = get_connected_rssi(dut, 4, 0.25, 1) 994 # Assign attenuator to path that has lower RSSI 995 if chain0_base_rssi > -70 and chain0_base_rssi - test_rssi[ 996 'chain_0_rssi']['mean'] > 10: 997 chain_map.append('DUT-Chain-0') 998 elif chain1_base_rssi > -70 and chain1_base_rssi - test_rssi[ 999 'chain_1_rssi']['mean'] > 10: 1000 chain_map.append('DUT-Chain-1') 1001 else: 1002 chain_map.append(None) 1003 # Reset attenuator to 0 1004 test_atten.set_atten(0, strict=False) 1005 ping_future.result() 1006 logging.debug('Chain Map: {}'.format(chain_map)) 1007 return chain_map 1008 1009 1010def get_full_rf_connection_map(attenuators, dut, ping_server, networks): 1011 """Function to detect per-network connections between attenuator and DUT. 1012 1013 This function detects the mapping between attenuator ports and DUT chains 1014 on all networks in its arguments. The function connects the DUT to each 1015 network then calls get_current_atten_dut_chain_map to get the connection 1016 map on the current network. The function outputs the results in two formats 1017 to enable easy access when users are interested in indexing by network or 1018 attenuator port. 1019 1020 Args: 1021 attenuators: list of attenuator ports 1022 dut: android device object assumed connected to a wifi network. 1023 ping_server: ssh connection object to ping server 1024 networks: dict of network IDs and configs 1025 Returns: 1026 rf_map_by_network: dict of RF connections indexed by network. 1027 rf_map_by_atten: list of RF connections indexed by attenuator 1028 """ 1029 for atten in attenuators: 1030 atten.set_atten(0, strict=False) 1031 1032 rf_map_by_network = collections.OrderedDict() 1033 rf_map_by_atten = [[] for atten in attenuators] 1034 for net_id, net_config in networks.items(): 1035 wutils.reset_wifi(dut) 1036 wutils.wifi_connect(dut, 1037 net_config, 1038 num_of_tries=1, 1039 assert_on_fail=False, 1040 check_connectivity=False) 1041 rf_map_by_network[net_id] = get_current_atten_dut_chain_map( 1042 attenuators, dut, ping_server) 1043 for idx, chain in enumerate(rf_map_by_network[net_id]): 1044 if chain: 1045 rf_map_by_atten[idx].append({ 1046 "network": net_id, 1047 "dut_chain": chain 1048 }) 1049 logging.debug("RF Map (by Network): {}".format(rf_map_by_network)) 1050 logging.debug("RF Map (by Atten): {}".format(rf_map_by_atten)) 1051 1052 return rf_map_by_network, rf_map_by_atten 1053 1054 1055# Miscellaneous Wifi Utilities 1056def extract_sub_dict(full_dict, fields): 1057 sub_dict = collections.OrderedDict( 1058 (field, full_dict[field]) for field in fields) 1059 return sub_dict 1060 1061 1062def validate_network(dut, ssid): 1063 """Check that DUT has a valid internet connection through expected SSID 1064 1065 Args: 1066 dut: android device of interest 1067 ssid: expected ssid 1068 """ 1069 current_network = dut.droid.wifiGetConnectionInfo() 1070 try: 1071 connected = wutils.validate_connection(dut) is not None 1072 except: 1073 connected = False 1074 if connected and current_network['SSID'] == ssid: 1075 return True 1076 else: 1077 return False 1078 1079 1080def get_server_address(ssh_connection, dut_ip, subnet_mask): 1081 """Get server address on a specific subnet, 1082 1083 This function retrieves the LAN IP of a remote machine used in testing, 1084 i.e., it returns the server's IP belonging to the same LAN as the DUT. 1085 1086 Args: 1087 ssh_connection: object representing server for which we want an ip 1088 dut_ip: string in ip address format, i.e., xxx.xxx.xxx.xxx, specifying 1089 the DUT LAN IP we wish to connect to 1090 subnet_mask: string representing subnet mask 1091 """ 1092 subnet_mask = subnet_mask.split('.') 1093 dut_subnet = [ 1094 int(dut) & int(subnet) 1095 for dut, subnet in zip(dut_ip.split('.'), subnet_mask) 1096 ] 1097 ifconfig_out = ssh_connection.run('ifconfig').stdout 1098 ip_list = re.findall('inet (?:addr:)?(\d+.\d+.\d+.\d+)', ifconfig_out) 1099 for current_ip in ip_list: 1100 current_subnet = [ 1101 int(ip) & int(subnet) 1102 for ip, subnet in zip(current_ip.split('.'), subnet_mask) 1103 ] 1104 if current_subnet == dut_subnet: 1105 return current_ip 1106 logging.error('No IP address found in requested subnet') 1107 1108 1109def get_iperf_arg_string(duration, 1110 reverse_direction, 1111 interval=1, 1112 traffic_type='TCP', 1113 tcp_window=None, 1114 tcp_processes=1, 1115 udp_throughput='1000M'): 1116 """Function to format iperf client arguments. 1117 1118 This function takes in iperf client parameters and returns a properly 1119 formatter iperf arg string to be used in throughput tests. 1120 1121 Args: 1122 duration: iperf duration in seconds 1123 reverse_direction: boolean controlling the -R flag for iperf clients 1124 interval: iperf print interval 1125 traffic_type: string specifying TCP or UDP traffic 1126 tcp_window: string specifying TCP window, e.g., 2M 1127 tcp_processes: int specifying number of tcp processes 1128 udp_throughput: string specifying TX throughput in UDP tests, e.g. 100M 1129 Returns: 1130 iperf_args: string of formatted iperf args 1131 """ 1132 iperf_args = '-i {} -t {} -J '.format(interval, duration) 1133 if traffic_type.upper() == 'UDP': 1134 iperf_args = iperf_args + '-u -b {} -l 1400'.format(udp_throughput) 1135 elif traffic_type.upper() == 'TCP': 1136 iperf_args = iperf_args + '-P {}'.format(tcp_processes) 1137 if tcp_window: 1138 iperf_args = iperf_args + '-w {}'.format(tcp_window) 1139 if reverse_direction: 1140 iperf_args = iperf_args + ' -R' 1141 return iperf_args 1142 1143 1144def get_dut_temperature(dut): 1145 """Function to get dut temperature. 1146 1147 The function fetches and returns the reading from the temperature sensor 1148 used for skin temperature and thermal throttling. 1149 1150 Args: 1151 dut: AndroidDevice of interest 1152 Returns: 1153 temperature: device temperature. 0 if temperature could not be read 1154 """ 1155 candidate_zones = [ 1156 'skin-therm', 'sdm-therm-monitor', 'sdm-therm-adc', 'back_therm' 1157 ] 1158 for zone in candidate_zones: 1159 try: 1160 temperature = int( 1161 dut.adb.shell( 1162 'cat /sys/class/thermal/tz-by-name/{}/temp'.format(zone))) 1163 break 1164 except ValueError: 1165 temperature = 0 1166 if temperature == 0: 1167 logging.debug('Could not check DUT temperature.') 1168 elif temperature > 100: 1169 temperature = temperature / 1000 1170 return temperature 1171 1172 1173def wait_for_dut_cooldown(dut, target_temp=50, timeout=300): 1174 """Function to wait for a DUT to cool down. 1175 1176 Args: 1177 dut: AndroidDevice of interest 1178 target_temp: target cooldown temperature 1179 timeout: maxt time to wait for cooldown 1180 """ 1181 start_time = time.time() 1182 while time.time() - start_time < timeout: 1183 temperature = get_dut_temperature(dut) 1184 if temperature < target_temp: 1185 break 1186 time.sleep(SHORT_SLEEP) 1187 elapsed_time = time.time() - start_time 1188 logging.debug("DUT Final Temperature: {}C. Cooldown duration: {}".format( 1189 temperature, elapsed_time)) 1190 1191 1192def health_check(dut, batt_thresh=5, temp_threshold=53, cooldown=1): 1193 """Function to check health status of a DUT. 1194 1195 The function checks both battery levels and temperature to avoid DUT 1196 powering off during the test. 1197 1198 Args: 1199 dut: AndroidDevice of interest 1200 batt_thresh: battery level threshold 1201 temp_threshold: temperature threshold 1202 cooldown: flag to wait for DUT to cool down when overheating 1203 Returns: 1204 health_check: boolean confirming device is healthy 1205 """ 1206 health_check = True 1207 battery_level = utils.get_battery_level(dut) 1208 if battery_level < batt_thresh: 1209 logging.warning("Battery level low ({}%)".format(battery_level)) 1210 health_check = False 1211 else: 1212 logging.debug("Battery level = {}%".format(battery_level)) 1213 1214 temperature = get_dut_temperature(dut) 1215 if temperature > temp_threshold: 1216 if cooldown: 1217 logging.warning( 1218 "Waiting for DUT to cooldown. ({} C)".format(temperature)) 1219 wait_for_dut_cooldown(dut, target_temp=temp_threshold - 5) 1220 else: 1221 logging.warning("DUT Overheating ({} C)".format(temperature)) 1222 health_check = False 1223 else: 1224 logging.debug("DUT Temperature = {} C".format(temperature)) 1225 return health_check 1226 1227 1228def get_sw_signature(dut): 1229 """Function that checks the signature for wifi firmware and config files. 1230 1231 Returns: 1232 bdf_signature: signature consisting of last three digits of bdf cksums 1233 fw_signature: floating point firmware version, i.e., major.minor 1234 """ 1235 bdf_output = dut.adb.shell('cksum /vendor/firmware/bdwlan*') 1236 logging.debug('BDF Checksum output: {}'.format(bdf_output)) 1237 bdf_signature = sum( 1238 [int(line.split(' ')[0]) for line in bdf_output.splitlines()]) % 1000 1239 1240 fw_output = dut.adb.shell('halutil -logger -get fw') 1241 logging.debug('Firmware version output: {}'.format(fw_output)) 1242 fw_version = re.search(FW_REGEX, fw_output).group('firmware') 1243 fw_signature = fw_version.split('.')[-3:-1] 1244 fw_signature = float('.'.join(fw_signature)) 1245 serial_hash = int(hashlib.md5(dut.serial.encode()).hexdigest(), 16) % 1000 1246 return { 1247 'bdf_signature': bdf_signature, 1248 'fw_signature': fw_signature, 1249 'serial_hash': serial_hash 1250 } 1251 1252 1253def push_bdf(dut, bdf_file): 1254 """Function to push Wifi BDF files 1255 1256 This function checks for existing wifi bdf files and over writes them all, 1257 for simplicity, with the bdf file provided in the arguments. The dut is 1258 rebooted for the bdf file to take effect 1259 1260 Args: 1261 dut: dut to push bdf file to 1262 bdf_file: path to bdf_file to push 1263 """ 1264 bdf_files_list = dut.adb.shell('ls /vendor/firmware/bdwlan*').splitlines() 1265 for dst_file in bdf_files_list: 1266 dut.push_system_file(bdf_file, dst_file) 1267 dut.reboot() 1268 1269 1270def push_firmware(dut, wlanmdsp_file, datamsc_file): 1271 """Function to push Wifi firmware files 1272 1273 Args: 1274 dut: dut to push bdf file to 1275 wlanmdsp_file: path to wlanmdsp.mbn file 1276 datamsc_file: path to Data.msc file 1277 """ 1278 dut.push_system_file(wlanmdsp_file, '/vendor/firmware/wlanmdsp.mbn') 1279 dut.push_system_file(datamsc_file, '/vendor/firmware/Data.msc') 1280 dut.reboot() 1281 1282 1283def _set_ini_fields(ini_file_path, ini_field_dict): 1284 template_regex = r'^{}=[0-9,.x-]+' 1285 with open(ini_file_path, 'r') as f: 1286 ini_lines = f.read().splitlines() 1287 for idx, line in enumerate(ini_lines): 1288 for field_name, field_value in ini_field_dict.items(): 1289 line_regex = re.compile(template_regex.format(field_name)) 1290 if re.match(line_regex, line): 1291 ini_lines[idx] = "{}={}".format(field_name, field_value) 1292 print(ini_lines[idx]) 1293 with open(ini_file_path, 'w') as f: 1294 f.write("\n".join(ini_lines) + "\n") 1295 1296 1297def _edit_dut_ini(dut, ini_fields): 1298 """Function to edit Wifi ini files.""" 1299 dut_ini_path = '/vendor/firmware/wlan/qca_cld/WCNSS_qcom_cfg.ini' 1300 local_ini_path = os.path.expanduser('~/WCNSS_qcom_cfg.ini') 1301 dut.pull_files(dut_ini_path, local_ini_path) 1302 1303 _set_ini_fields(local_ini_path, ini_fields) 1304 1305 dut.push_system_file(local_ini_path, dut_ini_path) 1306 dut.reboot() 1307 1308 1309def set_ini_single_chain_mode(dut, chain): 1310 ini_fields = { 1311 'gEnable2x2': 0, 1312 'gSetTxChainmask1x1': chain + 1, 1313 'gSetRxChainmask1x1': chain + 1, 1314 'gDualMacFeatureDisable': 1, 1315 'gDot11Mode': 0 1316 } 1317 _edit_dut_ini(dut, ini_fields) 1318 1319 1320def set_ini_two_chain_mode(dut): 1321 ini_fields = { 1322 'gEnable2x2': 2, 1323 'gSetTxChainmask1x1': 1, 1324 'gSetRxChainmask1x1': 1, 1325 'gDualMacFeatureDisable': 6, 1326 'gDot11Mode': 0 1327 } 1328 _edit_dut_ini(dut, ini_fields) 1329 1330 1331def set_ini_tx_mode(dut, mode): 1332 TX_MODE_DICT = { 1333 "Auto": 0, 1334 "11n": 4, 1335 "11ac": 9, 1336 "11abg": 1, 1337 "11b": 2, 1338 "11g": 3, 1339 "11g only": 5, 1340 "11n only": 6, 1341 "11b only": 7, 1342 "11ac only": 8 1343 } 1344 1345 ini_fields = { 1346 'gEnable2x2': 2, 1347 'gSetTxChainmask1x1': 1, 1348 'gSetRxChainmask1x1': 1, 1349 'gDualMacFeatureDisable': 6, 1350 'gDot11Mode': TX_MODE_DICT[mode] 1351 } 1352 _edit_dut_ini(dut, ini_fields) 1353