1# Copyright 2020 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15"""Common functions for the stress tester.""" 16 17import logging 18import os 19 20from absl import flags 21import stress_test_pb2 22from google.protobuf import text_format 23 24FLAGS = flags.FLAGS 25 26flags.DEFINE_string("resource_path", None, 27 "Optional override path where to grab resources from. By " 28 "default, resources are grabbed from " 29 "stress_test_common.RESOURCE_DIR, specifying this flag " 30 "will instead result in first looking in this path before " 31 "the module defined resource directory.") 32 33RESOURCE_DIR = "resources/" 34 35 36def MakeDirsIfNeeded(path): 37 """Helper function to create all the directories on a path.""" 38 if not os.path.isdir(path): 39 os.makedirs(path) 40 41 42def GetResourceContents(resource_name): 43 """Gets a string containing the named resource.""" 44 # Look in the resource override folder first (just go with the basename to 45 # find the file, rather than the full path). 46 if FLAGS.resource_path: 47 path = os.path.join(FLAGS.resource_path, os.path.basename(resource_name)) 48 if os.path.exists(path): 49 return open(path, "rb").read() 50 51 # If the full path exists, grab that, otherwise fall back to the basename. 52 if os.path.exists(resource_name): 53 return open(resource_name, "rb").read() 54 return open(os.path.join(RESOURCE_DIR, os.path.basename(resource_name)), 55 "rb").read() 56 57 58def LoadDeviceConfig(device_type, serial_number): 59 """Assembles a DeviceConfig proto following all includes, or the default.""" 60 61 config = stress_test_pb2.DeviceConfig() 62 text_format.Merge(GetResourceContents( 63 os.path.join(RESOURCE_DIR, "device_config.common.ascii_proto")), config) 64 def RecursiveIncludeToConfig(resource_prefix, print_error): 65 """Load configurations recursively.""" 66 try: 67 new_config = stress_test_pb2.DeviceConfig() 68 text_format.Merge(GetResourceContents( 69 os.path.join(RESOURCE_DIR, 70 "device_config.%s.ascii_proto" % resource_prefix)), 71 new_config) 72 for include_name in new_config.include: 73 # If we've managed to import this level properly, then we should print 74 # out any errors if we hit them on the included files. 75 RecursiveIncludeToConfig(include_name, print_error=True) 76 config.MergeFrom(new_config) 77 except IOError as err: 78 if print_error: 79 logging.error(str(err)) 80 81 RecursiveIncludeToConfig(device_type, print_error=True) 82 RecursiveIncludeToConfig(serial_number, print_error=False) 83 84 def TakeOnlyLatestFromRepeatedField(message, field, key): 85 """Take only the latest version.""" 86 old_list = list(getattr(message, field)) 87 message.ClearField(field) 88 new_list = [] 89 for i in range(len(old_list) - 1, -1, -1): 90 element = old_list[i] 91 if not any([getattr(x, key) == getattr(element, key) 92 for x in old_list[i + 1:]]): 93 new_list.append(element) 94 getattr(message, field).extend(reversed(new_list)) 95 96 # We actually need to do a bit of post-processing on the proto - we only want 97 # to take the latest version for each (that way people can override stuff if 98 # they want) 99 TakeOnlyLatestFromRepeatedField(config, "file_to_watch", "source") 100 TakeOnlyLatestFromRepeatedField(config, "file_to_move", "source") 101 TakeOnlyLatestFromRepeatedField(config, "event", "name") 102 TakeOnlyLatestFromRepeatedField(config, "daemon_process", "name") 103 104 return config 105