1#
2# Copyright (C) 2013 The Android Open Source Project
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#      http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15#
16
17"""Utilities for unit testing."""
18
19from __future__ import absolute_import
20from __future__ import print_function
21
22import io
23import hashlib
24import os
25import struct
26import subprocess
27
28from update_payload import common
29from update_payload import payload
30from update_payload import update_metadata_pb2
31
32
33class TestError(Exception):
34  """An error during testing of update payload code."""
35
36
37# Private/public RSA keys used for testing.
38_PRIVKEY_FILE_NAME = os.path.join(os.path.dirname(__file__),
39                                  'payload-test-key.pem')
40_PUBKEY_FILE_NAME = os.path.join(os.path.dirname(__file__),
41                                 'payload-test-key.pub')
42
43
44def KiB(count):
45  return count << 10
46
47
48def MiB(count):
49  return count << 20
50
51
52def GiB(count):
53  return count << 30
54
55
56def _WriteInt(file_obj, size, is_unsigned, val):
57  """Writes a binary-encoded integer to a file.
58
59  It will do the correct conversion based on the reported size and whether or
60  not a signed number is expected. Assumes a network (big-endian) byte
61  ordering.
62
63  Args:
64    file_obj: a file object
65    size: the integer size in bytes (2, 4 or 8)
66    is_unsigned: whether it is signed or not
67    val: integer value to encode
68
69  Raises:
70    PayloadError if a write error occurred.
71  """
72  try:
73    file_obj.write(struct.pack(common.IntPackingFmtStr(size, is_unsigned), val))
74  except IOError as e:
75    raise payload.PayloadError('error writing to file (%s): %s' %
76                               (file_obj.name, e))
77
78
79def _SetMsgField(msg, field_name, val):
80  """Sets or clears a field in a protobuf message."""
81  if val is None:
82    msg.ClearField(field_name)
83  else:
84    setattr(msg, field_name, val)
85
86
87def SignSha256(data, privkey_file_name):
88  """Signs the data's SHA256 hash with an RSA private key.
89
90  Args:
91    data: the data whose SHA256 hash we want to sign
92    privkey_file_name: private key used for signing data
93
94  Returns:
95    The signature string, prepended with an ASN1 header.
96
97  Raises:
98    TestError if something goes wrong.
99  """
100  data_sha256_hash = common.SIG_ASN1_HEADER + hashlib.sha256(data).digest()
101  sign_cmd = ['openssl', 'rsautl', '-sign', '-inkey', privkey_file_name]
102  try:
103    sign_process = subprocess.Popen(sign_cmd, stdin=subprocess.PIPE,
104                                    stdout=subprocess.PIPE)
105    sig, _ = sign_process.communicate(input=data_sha256_hash)
106  except Exception as e:
107    raise TestError('signing subprocess failed: %s' % e)
108
109  return sig
110
111
112class SignaturesGenerator(object):
113  """Generates a payload signatures data block."""
114
115  def __init__(self):
116    self.sigs = update_metadata_pb2.Signatures()
117
118  def AddSig(self, version, data):
119    """Adds a signature to the signature sequence.
120
121    Args:
122      version: signature version (None means do not assign)
123      data: signature binary data (None means do not assign)
124    """
125    sig = self.sigs.signatures.add()
126    if version is not None:
127      sig.version = version
128    if data is not None:
129      sig.data = data
130
131  def ToBinary(self):
132    """Returns the binary representation of the signature block."""
133    return self.sigs.SerializeToString()
134
135
136class PayloadGenerator(object):
137  """Generates an update payload allowing low-level control.
138
139  Attributes:
140    manifest: the protobuf containing the payload manifest
141    version: the payload version identifier
142    block_size: the block size pertaining to update operations
143
144  """
145
146  def __init__(self, version=1):
147    self.manifest = update_metadata_pb2.DeltaArchiveManifest()
148    self.version = version
149    self.block_size = 0
150
151  @staticmethod
152  def _WriteExtent(ex, val):
153    """Returns an Extent message."""
154    start_block, num_blocks = val
155    _SetMsgField(ex, 'start_block', start_block)
156    _SetMsgField(ex, 'num_blocks', num_blocks)
157
158  @staticmethod
159  def _AddValuesToRepeatedField(repeated_field, values, write_func):
160    """Adds values to a repeated message field."""
161    if values:
162      for val in values:
163        new_item = repeated_field.add()
164        write_func(new_item, val)
165
166  @staticmethod
167  def _AddExtents(extents_field, values):
168    """Adds extents to an extents field."""
169    PayloadGenerator._AddValuesToRepeatedField(
170        extents_field, values, PayloadGenerator._WriteExtent)
171
172  def SetBlockSize(self, block_size):
173    """Sets the payload's block size."""
174    self.block_size = block_size
175    _SetMsgField(self.manifest, 'block_size', block_size)
176
177  def SetPartInfo(self, part_name, is_new, part_size, part_hash):
178    """Set the partition info entry.
179
180    Args:
181      part_name: The name of the partition.
182      is_new: Whether to set old (False) or new (True) info.
183      part_size: The partition size (in fact, filesystem size).
184      part_hash: The partition hash.
185    """
186    partition = next((x for x in self.manifest.partitions
187                      if x.partition_name == part_name), None)
188    if partition is None:
189      partition = self.manifest.partitions.add()
190      partition.partition_name = part_name
191
192    part_info = (partition.new_partition_info if is_new
193                 else partition.old_partition_info)
194    _SetMsgField(part_info, 'size', part_size)
195    _SetMsgField(part_info, 'hash', part_hash)
196
197  def AddOperation(self, part_name, op_type, data_offset=None,
198                   data_length=None, src_extents=None, src_length=None,
199                   dst_extents=None, dst_length=None, data_sha256_hash=None):
200    """Adds an InstallOperation entry."""
201    partition = next((x for x in self.manifest.partitions
202                      if x.partition_name == part_name), None)
203    if partition is None:
204      partition = self.manifest.partitions.add()
205      partition.partition_name = part_name
206
207    operations = partition.operations
208    op = operations.add()
209    op.type = op_type
210
211    _SetMsgField(op, 'data_offset', data_offset)
212    _SetMsgField(op, 'data_length', data_length)
213
214    self._AddExtents(op.src_extents, src_extents)
215    _SetMsgField(op, 'src_length', src_length)
216
217    self._AddExtents(op.dst_extents, dst_extents)
218    _SetMsgField(op, 'dst_length', dst_length)
219
220    _SetMsgField(op, 'data_sha256_hash', data_sha256_hash)
221
222  def SetSignatures(self, sigs_offset, sigs_size):
223    """Set the payload's signature block descriptors."""
224    _SetMsgField(self.manifest, 'signatures_offset', sigs_offset)
225    _SetMsgField(self.manifest, 'signatures_size', sigs_size)
226
227  def SetMinorVersion(self, minor_version):
228    """Set the payload's minor version field."""
229    _SetMsgField(self.manifest, 'minor_version', minor_version)
230
231  def _WriteHeaderToFile(self, file_obj, manifest_len):
232    """Writes a payload heaer to a file."""
233    # We need to access protected members in Payload for writing the header.
234    # pylint: disable=W0212
235    file_obj.write(payload.Payload._PayloadHeader._MAGIC)
236    _WriteInt(file_obj, payload.Payload._PayloadHeader._VERSION_SIZE, True,
237              self.version)
238    _WriteInt(file_obj, payload.Payload._PayloadHeader._MANIFEST_LEN_SIZE, True,
239              manifest_len)
240
241  def WriteToFile(self, file_obj, manifest_len=-1, data_blobs=None,
242                  sigs_data=None, padding=None):
243    """Writes the payload content to a file.
244
245    Args:
246      file_obj: a file object open for writing
247      manifest_len: manifest len to dump (otherwise computed automatically)
248      data_blobs: a list of data blobs to be concatenated to the payload
249      sigs_data: a binary Signatures message to be concatenated to the payload
250      padding: stuff to dump past the normal data blobs provided (optional)
251    """
252    manifest = self.manifest.SerializeToString()
253    if manifest_len < 0:
254      manifest_len = len(manifest)
255    self._WriteHeaderToFile(file_obj, manifest_len)
256    file_obj.write(manifest)
257    if data_blobs:
258      for data_blob in data_blobs:
259        file_obj.write(data_blob)
260    if sigs_data:
261      file_obj.write(sigs_data)
262    if padding:
263      file_obj.write(padding)
264
265
266class EnhancedPayloadGenerator(PayloadGenerator):
267  """Payload generator with automatic handling of data blobs.
268
269  Attributes:
270    data_blobs: a list of blobs, in the order they were added
271    curr_offset: the currently consumed offset of blobs added to the payload
272  """
273
274  def __init__(self):
275    super(EnhancedPayloadGenerator, self).__init__()
276    self.data_blobs = []
277    self.curr_offset = 0
278
279  def AddData(self, data_blob):
280    """Adds a (possibly orphan) data blob."""
281    data_length = len(data_blob)
282    data_offset = self.curr_offset
283    self.curr_offset += data_length
284    self.data_blobs.append(data_blob)
285    return data_length, data_offset
286
287  def AddOperationWithData(self, part_name, op_type, src_extents=None,
288                           src_length=None, dst_extents=None, dst_length=None,
289                           data_blob=None, do_hash_data_blob=True):
290    """Adds an install operation and associated data blob.
291
292    This takes care of obtaining a hash of the data blob (if so instructed)
293    and appending it to the internally maintained list of blobs, including the
294    necessary offset/length accounting.
295
296    Args:
297      part_name: The name of the partition (e.g. kernel or root).
298      op_type: one of REPLACE, REPLACE_BZ, REPLACE_XZ.
299      src_extents: list of (start, length) pairs indicating src block ranges
300      src_length: size of the src data in bytes (needed for diff operations)
301      dst_extents: list of (start, length) pairs indicating dst block ranges
302      dst_length: size of the dst data in bytes (needed for diff operations)
303      data_blob: a data blob associated with this operation
304      do_hash_data_blob: whether or not to compute and add a data blob hash
305    """
306    data_offset = data_length = data_sha256_hash = None
307    if data_blob is not None:
308      if do_hash_data_blob:
309        data_sha256_hash = hashlib.sha256(data_blob).digest()
310      data_length, data_offset = self.AddData(data_blob)
311
312    self.AddOperation(part_name, op_type, data_offset=data_offset,
313                      data_length=data_length, src_extents=src_extents,
314                      src_length=src_length, dst_extents=dst_extents,
315                      dst_length=dst_length, data_sha256_hash=data_sha256_hash)
316
317  def WriteToFileWithData(self, file_obj, sigs_data=None,
318                          privkey_file_name=None, padding=None):
319    """Writes the payload content to a file, optionally signing the content.
320
321    Args:
322      file_obj: a file object open for writing
323      sigs_data: signatures blob to be appended to the payload (optional;
324                 payload signature fields assumed to be preset by the caller)
325      privkey_file_name: key used for signing the payload (optional; used only
326                         if explicit signatures blob not provided)
327      padding: stuff to dump past the normal data blobs provided (optional)
328
329    Raises:
330      TestError: if arguments are inconsistent or something goes wrong.
331    """
332    sigs_len = len(sigs_data) if sigs_data else 0
333
334    # Do we need to generate a genuine signatures blob?
335    do_generate_sigs_data = sigs_data is None and privkey_file_name
336
337    if do_generate_sigs_data:
338      # First, sign some arbitrary data to obtain the size of a signature blob.
339      fake_sig = SignSha256(b'fake-payload-data', privkey_file_name)
340      fake_sigs_gen = SignaturesGenerator()
341      fake_sigs_gen.AddSig(1, fake_sig)
342      sigs_len = len(fake_sigs_gen.ToBinary())
343
344      # Update the payload with proper signature attributes.
345      self.SetSignatures(self.curr_offset, sigs_len)
346
347    if do_generate_sigs_data:
348      # Once all payload fields are updated, dump and sign it.
349      temp_payload_file = io.BytesIO()
350      self.WriteToFile(temp_payload_file, data_blobs=self.data_blobs)
351      sig = SignSha256(temp_payload_file.getvalue(), privkey_file_name)
352      sigs_gen = SignaturesGenerator()
353      sigs_gen.AddSig(1, sig)
354      sigs_data = sigs_gen.ToBinary()
355      assert len(sigs_data) == sigs_len, 'signature blob lengths mismatch'
356
357    # Dump the whole thing, complete with data and signature blob, to a file.
358    self.WriteToFile(file_obj, data_blobs=self.data_blobs, sigs_data=sigs_data,
359                     padding=padding)
360