1# Copyright (C) 2014 The Android Open Source Project
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15from __future__ import print_function
16
17import array
18import copy
19import functools
20import heapq
21import itertools
22import logging
23import multiprocessing
24import os
25import os.path
26import re
27import sys
28import threading
29import zlib
30from collections import deque, namedtuple, OrderedDict
31
32import common
33from images import EmptyImage
34from rangelib import RangeSet
35
36__all__ = ["BlockImageDiff"]
37
38logger = logging.getLogger(__name__)
39
40# The tuple contains the style and bytes of a bsdiff|imgdiff patch.
41PatchInfo = namedtuple("PatchInfo", ["imgdiff", "content"])
42
43
44def compute_patch(srcfile, tgtfile, imgdiff=False):
45  """Calls bsdiff|imgdiff to compute the patch data, returns a PatchInfo."""
46  patchfile = common.MakeTempFile(prefix='patch-')
47
48  cmd = ['imgdiff', '-z'] if imgdiff else ['bsdiff']
49  cmd.extend([srcfile, tgtfile, patchfile])
50
51  # Don't dump the bsdiff/imgdiff commands, which are not useful for the case
52  # here, since they contain temp filenames only.
53  proc = common.Run(cmd, verbose=False)
54  output, _ = proc.communicate()
55
56  if proc.returncode != 0:
57    raise ValueError(output)
58
59  with open(patchfile, 'rb') as f:
60    return PatchInfo(imgdiff, f.read())
61
62
63class Transfer(object):
64  def __init__(self, tgt_name, src_name, tgt_ranges, src_ranges, tgt_sha1,
65               src_sha1, style, by_id):
66    self.tgt_name = tgt_name
67    self.src_name = src_name
68    self.tgt_ranges = tgt_ranges
69    self.src_ranges = src_ranges
70    self.tgt_sha1 = tgt_sha1
71    self.src_sha1 = src_sha1
72    self.style = style
73
74    # We use OrderedDict rather than dict so that the output is repeatable;
75    # otherwise it would depend on the hash values of the Transfer objects.
76    self.goes_before = OrderedDict()
77    self.goes_after = OrderedDict()
78
79    self.stash_before = []
80    self.use_stash = []
81
82    self.id = len(by_id)
83    by_id.append(self)
84
85    self._patch_info = None
86
87  @property
88  def patch_info(self):
89    return self._patch_info
90
91  @patch_info.setter
92  def patch_info(self, info):
93    if info:
94      assert self.style == "diff"
95    self._patch_info = info
96
97  def NetStashChange(self):
98    return (sum(sr.size() for (_, sr) in self.stash_before) -
99            sum(sr.size() for (_, sr) in self.use_stash))
100
101  def ConvertToNew(self):
102    assert self.style != "new"
103    self.use_stash = []
104    self.style = "new"
105    self.src_ranges = RangeSet()
106    self.patch_info = None
107
108  def __str__(self):
109    return (str(self.id) + ": <" + str(self.src_ranges) + " " + self.style +
110            " to " + str(self.tgt_ranges) + ">")
111
112
113@functools.total_ordering
114class HeapItem(object):
115  def __init__(self, item):
116    self.item = item
117    # Negate the score since python's heap is a min-heap and we want the
118    # maximum score.
119    self.score = -item.score
120
121  def clear(self):
122    self.item = None
123
124  def __bool__(self):
125    return self.item is not None
126
127  # Python 2 uses __nonzero__, while Python 3 uses __bool__.
128  __nonzero__ = __bool__
129
130  # The rest operations are generated by functools.total_ordering decorator.
131  def __eq__(self, other):
132    return self.score == other.score
133
134  def __le__(self, other):
135    return self.score <= other.score
136
137
138class ImgdiffStats(object):
139  """A class that collects imgdiff stats.
140
141  It keeps track of the files that will be applied imgdiff while generating
142  BlockImageDiff. It also logs the ones that cannot use imgdiff, with specific
143  reasons. The stats is only meaningful when imgdiff not being disabled by the
144  caller of BlockImageDiff. In addition, only files with supported types
145  (BlockImageDiff.FileTypeSupportedByImgdiff()) are allowed to be logged.
146  """
147
148  USED_IMGDIFF = "APK files diff'd with imgdiff"
149  USED_IMGDIFF_LARGE_APK = "Large APK files split and diff'd with imgdiff"
150
151  # Reasons for not applying imgdiff on APKs.
152  SKIPPED_NONMONOTONIC = "Not used imgdiff due to having non-monotonic ranges"
153  SKIPPED_SHARED_BLOCKS = "Not used imgdiff due to using shared blocks"
154  SKIPPED_INCOMPLETE = "Not used imgdiff due to incomplete RangeSet"
155
156  # The list of valid reasons, which will also be the dumped order in a report.
157  REASONS = (
158      USED_IMGDIFF,
159      USED_IMGDIFF_LARGE_APK,
160      SKIPPED_NONMONOTONIC,
161      SKIPPED_SHARED_BLOCKS,
162      SKIPPED_INCOMPLETE,
163  )
164
165  def  __init__(self):
166    self.stats = {}
167
168  def Log(self, filename, reason):
169    """Logs why imgdiff can or cannot be applied to the given filename.
170
171    Args:
172      filename: The filename string.
173      reason: One of the reason constants listed in REASONS.
174
175    Raises:
176      AssertionError: On unsupported filetypes or invalid reason.
177    """
178    assert BlockImageDiff.FileTypeSupportedByImgdiff(filename)
179    assert reason in self.REASONS
180
181    if reason not in self.stats:
182      self.stats[reason] = set()
183    self.stats[reason].add(filename)
184
185  def Report(self):
186    """Prints a report of the collected imgdiff stats."""
187
188    def print_header(header, separator):
189      logger.info(header)
190      logger.info('%s\n', separator * len(header))
191
192    print_header('  Imgdiff Stats Report  ', '=')
193    for key in self.REASONS:
194      if key not in self.stats:
195        continue
196      values = self.stats[key]
197      section_header = ' {} (count: {}) '.format(key, len(values))
198      print_header(section_header, '-')
199      logger.info(''.join(['  {}\n'.format(name) for name in values]))
200
201
202class BlockImageDiff(object):
203  """Generates the diff of two block image objects.
204
205  BlockImageDiff works on two image objects. An image object is anything that
206  provides the following attributes:
207
208     blocksize: the size in bytes of a block, currently must be 4096.
209
210     total_blocks: the total size of the partition/image, in blocks.
211
212     care_map: a RangeSet containing which blocks (in the range [0,
213       total_blocks) we actually care about; i.e. which blocks contain data.
214
215     file_map: a dict that partitions the blocks contained in care_map into
216         smaller domains that are useful for doing diffs on. (Typically a domain
217         is a file, and the key in file_map is the pathname.)
218
219     clobbered_blocks: a RangeSet containing which blocks contain data but may
220         be altered by the FS. They need to be excluded when verifying the
221         partition integrity.
222
223     ReadRangeSet(): a function that takes a RangeSet and returns the data
224         contained in the image blocks of that RangeSet. The data is returned as
225         a list or tuple of strings; concatenating the elements together should
226         produce the requested data. Implementations are free to break up the
227         data into list/tuple elements in any way that is convenient.
228
229     RangeSha1(): a function that returns (as a hex string) the SHA-1 hash of
230         all the data in the specified range.
231
232     TotalSha1(): a function that returns (as a hex string) the SHA-1 hash of
233         all the data in the image (ie, all the blocks in the care_map minus
234         clobbered_blocks, or including the clobbered blocks if
235         include_clobbered_blocks is True).
236
237  When creating a BlockImageDiff, the src image may be None, in which case the
238  list of transfers produced will never read from the original image.
239  """
240
241  def __init__(self, tgt, src=None, threads=None, version=4,
242               disable_imgdiff=False):
243    if threads is None:
244      threads = multiprocessing.cpu_count() // 2
245      if threads == 0:
246        threads = 1
247    self.threads = threads
248    self.version = version
249    self.transfers = []
250    self.src_basenames = {}
251    self.src_numpatterns = {}
252    self._max_stashed_size = 0
253    self.touched_src_ranges = RangeSet()
254    self.touched_src_sha1 = None
255    self.disable_imgdiff = disable_imgdiff
256    self.imgdiff_stats = ImgdiffStats() if not disable_imgdiff else None
257
258    assert version in (3, 4)
259
260    self.tgt = tgt
261    if src is None:
262      src = EmptyImage()
263    self.src = src
264
265    # The updater code that installs the patch always uses 4k blocks.
266    assert tgt.blocksize == 4096
267    assert src.blocksize == 4096
268
269    # The range sets in each filemap should comprise a partition of
270    # the care map.
271    self.AssertPartition(src.care_map, src.file_map.values())
272    self.AssertPartition(tgt.care_map, tgt.file_map.values())
273
274  @property
275  def max_stashed_size(self):
276    return self._max_stashed_size
277
278  @staticmethod
279  def FileTypeSupportedByImgdiff(filename):
280    """Returns whether the file type is supported by imgdiff."""
281    return filename.lower().endswith(('.apk', '.jar', '.zip'))
282
283  def CanUseImgdiff(self, name, tgt_ranges, src_ranges, large_apk=False):
284    """Checks whether we can apply imgdiff for the given RangeSets.
285
286    For files in ZIP format (e.g., APKs, JARs, etc.) we would like to use
287    'imgdiff -z' if possible. Because it usually produces significantly smaller
288    patches than bsdiff.
289
290    This is permissible if all of the following conditions hold.
291      - The imgdiff hasn't been disabled by the caller (e.g. squashfs);
292      - The file type is supported by imgdiff;
293      - The source and target blocks are monotonic (i.e. the data is stored with
294        blocks in increasing order);
295      - Both files don't contain shared blocks;
296      - Both files have complete lists of blocks;
297      - We haven't removed any blocks from the source set.
298
299    If all these conditions are satisfied, concatenating all the blocks in the
300    RangeSet in order will produce a valid ZIP file (plus possibly extra zeros
301    in the last block). imgdiff is fine with extra zeros at the end of the file.
302
303    Args:
304      name: The filename to be diff'd.
305      tgt_ranges: The target RangeSet.
306      src_ranges: The source RangeSet.
307      large_apk: Whether this is to split a large APK.
308
309    Returns:
310      A boolean result.
311    """
312    if self.disable_imgdiff or not self.FileTypeSupportedByImgdiff(name):
313      return False
314
315    if not tgt_ranges.monotonic or not src_ranges.monotonic:
316      self.imgdiff_stats.Log(name, ImgdiffStats.SKIPPED_NONMONOTONIC)
317      return False
318
319    if (tgt_ranges.extra.get('uses_shared_blocks') or
320        src_ranges.extra.get('uses_shared_blocks')):
321      self.imgdiff_stats.Log(name, ImgdiffStats.SKIPPED_SHARED_BLOCKS)
322      return False
323
324    if tgt_ranges.extra.get('incomplete') or src_ranges.extra.get('incomplete'):
325      self.imgdiff_stats.Log(name, ImgdiffStats.SKIPPED_INCOMPLETE)
326      return False
327
328    reason = (ImgdiffStats.USED_IMGDIFF_LARGE_APK if large_apk
329              else ImgdiffStats.USED_IMGDIFF)
330    self.imgdiff_stats.Log(name, reason)
331    return True
332
333  def Compute(self, prefix):
334    # When looking for a source file to use as the diff input for a
335    # target file, we try:
336    #   1) an exact path match if available, otherwise
337    #   2) a exact basename match if available, otherwise
338    #   3) a basename match after all runs of digits are replaced by
339    #      "#" if available, otherwise
340    #   4) we have no source for this target.
341    self.AbbreviateSourceNames()
342    self.FindTransfers()
343
344    self.FindSequenceForTransfers()
345
346    # Ensure the runtime stash size is under the limit.
347    if common.OPTIONS.cache_size is not None:
348      stash_limit = (common.OPTIONS.cache_size *
349                     common.OPTIONS.stash_threshold / self.tgt.blocksize)
350      # Ignore the stash limit and calculate the maximum simultaneously stashed
351      # blocks needed.
352      _, max_stashed_blocks = self.ReviseStashSize(ignore_stash_limit=True)
353
354      # We cannot stash more blocks than the stash limit simultaneously. As a
355      # result, some 'diff' commands will be converted to new; leading to an
356      # unintended large package. To mitigate this issue, we can carefully
357      # choose the transfers for conversion. The number '1024' can be further
358      # tweaked here to balance the package size and build time.
359      if max_stashed_blocks > stash_limit + 1024:
360        self.SelectAndConvertDiffTransfersToNew(
361            max_stashed_blocks - stash_limit)
362        # Regenerate the sequence as the graph has changed.
363        self.FindSequenceForTransfers()
364
365      # Revise the stash size again to keep the size under limit.
366      self.ReviseStashSize()
367
368    # Double-check our work.
369    self.AssertSequenceGood()
370    self.AssertSha1Good()
371
372    self.ComputePatches(prefix)
373    self.WriteTransfers(prefix)
374
375    # Report the imgdiff stats.
376    if not self.disable_imgdiff:
377      self.imgdiff_stats.Report()
378
379  def WriteTransfers(self, prefix):
380    def WriteSplitTransfers(out, style, target_blocks):
381      """Limit the size of operand in command 'new' and 'zero' to 1024 blocks.
382
383      This prevents the target size of one command from being too large; and
384      might help to avoid fsync errors on some devices."""
385
386      assert style == "new" or style == "zero"
387      blocks_limit = 1024
388      total = 0
389      while target_blocks:
390        blocks_to_write = target_blocks.first(blocks_limit)
391        out.append("%s %s\n" % (style, blocks_to_write.to_string_raw()))
392        total += blocks_to_write.size()
393        target_blocks = target_blocks.subtract(blocks_to_write)
394      return total
395
396    out = []
397    total = 0
398
399    # In BBOTA v3+, it uses the hash of the stashed blocks as the stash slot
400    # id. 'stashes' records the map from 'hash' to the ref count. The stash
401    # will be freed only if the count decrements to zero.
402    stashes = {}
403    stashed_blocks = 0
404    max_stashed_blocks = 0
405
406    for xf in self.transfers:
407
408      for _, sr in xf.stash_before:
409        sh = self.src.RangeSha1(sr)
410        if sh in stashes:
411          stashes[sh] += 1
412        else:
413          stashes[sh] = 1
414          stashed_blocks += sr.size()
415          self.touched_src_ranges = self.touched_src_ranges.union(sr)
416          out.append("stash %s %s\n" % (sh, sr.to_string_raw()))
417
418      if stashed_blocks > max_stashed_blocks:
419        max_stashed_blocks = stashed_blocks
420
421      free_string = []
422      free_size = 0
423
424      #   <# blocks> <src ranges>
425      #     OR
426      #   <# blocks> <src ranges> <src locs> <stash refs...>
427      #     OR
428      #   <# blocks> - <stash refs...>
429
430      size = xf.src_ranges.size()
431      src_str_buffer = [str(size)]
432
433      unstashed_src_ranges = xf.src_ranges
434      mapped_stashes = []
435      for _, sr in xf.use_stash:
436        unstashed_src_ranges = unstashed_src_ranges.subtract(sr)
437        sh = self.src.RangeSha1(sr)
438        sr = xf.src_ranges.map_within(sr)
439        mapped_stashes.append(sr)
440        assert sh in stashes
441        src_str_buffer.append("%s:%s" % (sh, sr.to_string_raw()))
442        stashes[sh] -= 1
443        if stashes[sh] == 0:
444          free_string.append("free %s\n" % (sh,))
445          free_size += sr.size()
446          stashes.pop(sh)
447
448      if unstashed_src_ranges:
449        src_str_buffer.insert(1, unstashed_src_ranges.to_string_raw())
450        if xf.use_stash:
451          mapped_unstashed = xf.src_ranges.map_within(unstashed_src_ranges)
452          src_str_buffer.insert(2, mapped_unstashed.to_string_raw())
453          mapped_stashes.append(mapped_unstashed)
454          self.AssertPartition(RangeSet(data=(0, size)), mapped_stashes)
455      else:
456        src_str_buffer.insert(1, "-")
457        self.AssertPartition(RangeSet(data=(0, size)), mapped_stashes)
458
459      src_str = " ".join(src_str_buffer)
460
461      # version 3+:
462      #   zero <rangeset>
463      #   new <rangeset>
464      #   erase <rangeset>
465      #   bsdiff patchstart patchlen srchash tgthash <tgt rangeset> <src_str>
466      #   imgdiff patchstart patchlen srchash tgthash <tgt rangeset> <src_str>
467      #   move hash <tgt rangeset> <src_str>
468
469      tgt_size = xf.tgt_ranges.size()
470
471      if xf.style == "new":
472        assert xf.tgt_ranges
473        assert tgt_size == WriteSplitTransfers(out, xf.style, xf.tgt_ranges)
474        total += tgt_size
475      elif xf.style == "move":
476        assert xf.tgt_ranges
477        assert xf.src_ranges.size() == tgt_size
478        if xf.src_ranges != xf.tgt_ranges:
479          # take into account automatic stashing of overlapping blocks
480          if xf.src_ranges.overlaps(xf.tgt_ranges):
481            temp_stash_usage = stashed_blocks + xf.src_ranges.size()
482            if temp_stash_usage > max_stashed_blocks:
483              max_stashed_blocks = temp_stash_usage
484
485          self.touched_src_ranges = self.touched_src_ranges.union(
486              xf.src_ranges)
487
488          out.append("%s %s %s %s\n" % (
489              xf.style,
490              xf.tgt_sha1,
491              xf.tgt_ranges.to_string_raw(), src_str))
492          total += tgt_size
493      elif xf.style in ("bsdiff", "imgdiff"):
494        assert xf.tgt_ranges
495        assert xf.src_ranges
496        # take into account automatic stashing of overlapping blocks
497        if xf.src_ranges.overlaps(xf.tgt_ranges):
498          temp_stash_usage = stashed_blocks + xf.src_ranges.size()
499          if temp_stash_usage > max_stashed_blocks:
500            max_stashed_blocks = temp_stash_usage
501
502        self.touched_src_ranges = self.touched_src_ranges.union(xf.src_ranges)
503
504        out.append("%s %d %d %s %s %s %s\n" % (
505            xf.style,
506            xf.patch_start, xf.patch_len,
507            xf.src_sha1,
508            xf.tgt_sha1,
509            xf.tgt_ranges.to_string_raw(), src_str))
510        total += tgt_size
511      elif xf.style == "zero":
512        assert xf.tgt_ranges
513        to_zero = xf.tgt_ranges.subtract(xf.src_ranges)
514        assert WriteSplitTransfers(out, xf.style, to_zero) == to_zero.size()
515        total += to_zero.size()
516      else:
517        raise ValueError("unknown transfer style '%s'\n" % xf.style)
518
519      if free_string:
520        out.append("".join(free_string))
521        stashed_blocks -= free_size
522
523      if common.OPTIONS.cache_size is not None:
524        # Validation check: abort if we're going to need more stash space than
525        # the allowed size (cache_size * threshold). There are two purposes
526        # of having a threshold here. a) Part of the cache may have been
527        # occupied by some recovery logs. b) It will buy us some time to deal
528        # with the oversize issue.
529        cache_size = common.OPTIONS.cache_size
530        stash_threshold = common.OPTIONS.stash_threshold
531        max_allowed = cache_size * stash_threshold
532        assert max_stashed_blocks * self.tgt.blocksize <= max_allowed, \
533               'Stash size %d (%d * %d) exceeds the limit %d (%d * %.2f)' % (
534                   max_stashed_blocks * self.tgt.blocksize, max_stashed_blocks,
535                   self.tgt.blocksize, max_allowed, cache_size,
536                   stash_threshold)
537
538    self.touched_src_sha1 = self.src.RangeSha1(self.touched_src_ranges)
539
540    if self.tgt.hashtree_info:
541      out.append("compute_hash_tree {} {} {} {} {}\n".format(
542          self.tgt.hashtree_info.hashtree_range.to_string_raw(),
543          self.tgt.hashtree_info.filesystem_range.to_string_raw(),
544          self.tgt.hashtree_info.hash_algorithm,
545          self.tgt.hashtree_info.salt,
546          self.tgt.hashtree_info.root_hash))
547
548    # Zero out extended blocks as a workaround for bug 20881595.
549    if self.tgt.extended:
550      assert (WriteSplitTransfers(out, "zero", self.tgt.extended) ==
551              self.tgt.extended.size())
552      total += self.tgt.extended.size()
553
554    # We erase all the blocks on the partition that a) don't contain useful
555    # data in the new image; b) will not be touched by dm-verity. Out of those
556    # blocks, we erase the ones that won't be used in this update at the
557    # beginning of an update. The rest would be erased at the end. This is to
558    # work around the eMMC issue observed on some devices, which may otherwise
559    # get starving for clean blocks and thus fail the update. (b/28347095)
560    all_tgt = RangeSet(data=(0, self.tgt.total_blocks))
561    all_tgt_minus_extended = all_tgt.subtract(self.tgt.extended)
562    new_dontcare = all_tgt_minus_extended.subtract(self.tgt.care_map)
563
564    erase_first = new_dontcare.subtract(self.touched_src_ranges)
565    if erase_first:
566      out.insert(0, "erase %s\n" % (erase_first.to_string_raw(),))
567
568    erase_last = new_dontcare.subtract(erase_first)
569    if erase_last:
570      out.append("erase %s\n" % (erase_last.to_string_raw(),))
571
572    out.insert(0, "%d\n" % (self.version,))   # format version number
573    out.insert(1, "%d\n" % (total,))
574    # v3+: the number of stash slots is unused.
575    out.insert(2, "0\n")
576    out.insert(3, str(max_stashed_blocks) + "\n")
577
578    with open(prefix + ".transfer.list", "w") as f:
579      for i in out:
580        f.write(i)
581
582    self._max_stashed_size = max_stashed_blocks * self.tgt.blocksize
583    OPTIONS = common.OPTIONS
584    if OPTIONS.cache_size is not None:
585      max_allowed = OPTIONS.cache_size * OPTIONS.stash_threshold
586      logger.info(
587          "max stashed blocks: %d  (%d bytes), limit: %d bytes (%.2f%%)\n",
588          max_stashed_blocks, self._max_stashed_size, max_allowed,
589          self._max_stashed_size * 100.0 / max_allowed)
590    else:
591      logger.info(
592          "max stashed blocks: %d  (%d bytes), limit: <unknown>\n",
593          max_stashed_blocks, self._max_stashed_size)
594
595  def ReviseStashSize(self, ignore_stash_limit=False):
596    """ Revises the transfers to keep the stash size within the size limit.
597
598    Iterates through the transfer list and calculates the stash size each
599    transfer generates. Converts the affected transfers to new if we reach the
600    stash limit.
601
602    Args:
603      ignore_stash_limit: Ignores the stash limit and calculates the max
604      simultaneous stashed blocks instead. No change will be made to the
605      transfer list with this flag.
606
607    Return:
608      A tuple of (tgt blocks converted to new, max stashed blocks)
609    """
610    logger.info("Revising stash size...")
611    stash_map = {}
612
613    # Create the map between a stash and its def/use points. For example, for a
614    # given stash of (raw_id, sr), stash_map[raw_id] = (sr, def_cmd, use_cmd).
615    for xf in self.transfers:
616      # Command xf defines (stores) all the stashes in stash_before.
617      for stash_raw_id, sr in xf.stash_before:
618        stash_map[stash_raw_id] = (sr, xf)
619
620      # Record all the stashes command xf uses.
621      for stash_raw_id, _ in xf.use_stash:
622        stash_map[stash_raw_id] += (xf,)
623
624    max_allowed_blocks = None
625    if not ignore_stash_limit:
626      # Compute the maximum blocks available for stash based on /cache size and
627      # the threshold.
628      cache_size = common.OPTIONS.cache_size
629      stash_threshold = common.OPTIONS.stash_threshold
630      max_allowed_blocks = cache_size * stash_threshold / self.tgt.blocksize
631
632    # See the comments for 'stashes' in WriteTransfers().
633    stashes = {}
634    stashed_blocks = 0
635    new_blocks = 0
636    max_stashed_blocks = 0
637
638    # Now go through all the commands. Compute the required stash size on the
639    # fly. If a command requires excess stash than available, it deletes the
640    # stash by replacing the command that uses the stash with a "new" command
641    # instead.
642    for xf in self.transfers:
643      replaced_cmds = []
644
645      # xf.stash_before generates explicit stash commands.
646      for stash_raw_id, sr in xf.stash_before:
647        # Check the post-command stashed_blocks.
648        stashed_blocks_after = stashed_blocks
649        sh = self.src.RangeSha1(sr)
650        if sh not in stashes:
651          stashed_blocks_after += sr.size()
652
653        if max_allowed_blocks and stashed_blocks_after > max_allowed_blocks:
654          # We cannot stash this one for a later command. Find out the command
655          # that will use this stash and replace the command with "new".
656          use_cmd = stash_map[stash_raw_id][2]
657          replaced_cmds.append(use_cmd)
658          logger.info("%10d  %9s  %s", sr.size(), "explicit", use_cmd)
659        else:
660          # Update the stashes map.
661          if sh in stashes:
662            stashes[sh] += 1
663          else:
664            stashes[sh] = 1
665          stashed_blocks = stashed_blocks_after
666          max_stashed_blocks = max(max_stashed_blocks, stashed_blocks)
667
668      # "move" and "diff" may introduce implicit stashes in BBOTA v3. Prior to
669      # ComputePatches(), they both have the style of "diff".
670      if xf.style == "diff":
671        assert xf.tgt_ranges and xf.src_ranges
672        if xf.src_ranges.overlaps(xf.tgt_ranges):
673          if (max_allowed_blocks and
674              stashed_blocks + xf.src_ranges.size() > max_allowed_blocks):
675            replaced_cmds.append(xf)
676            logger.info("%10d  %9s  %s", xf.src_ranges.size(), "implicit", xf)
677          else:
678            # The whole source ranges will be stashed for implicit stashes.
679            max_stashed_blocks = max(max_stashed_blocks,
680                                     stashed_blocks + xf.src_ranges.size())
681
682      # Replace the commands in replaced_cmds with "new"s.
683      for cmd in replaced_cmds:
684        # It no longer uses any commands in "use_stash". Remove the def points
685        # for all those stashes.
686        for stash_raw_id, sr in cmd.use_stash:
687          def_cmd = stash_map[stash_raw_id][1]
688          assert (stash_raw_id, sr) in def_cmd.stash_before
689          def_cmd.stash_before.remove((stash_raw_id, sr))
690
691        # Add up blocks that violates space limit and print total number to
692        # screen later.
693        new_blocks += cmd.tgt_ranges.size()
694        cmd.ConvertToNew()
695
696      # xf.use_stash may generate free commands.
697      for _, sr in xf.use_stash:
698        sh = self.src.RangeSha1(sr)
699        assert sh in stashes
700        stashes[sh] -= 1
701        if stashes[sh] == 0:
702          stashed_blocks -= sr.size()
703          stashes.pop(sh)
704
705    num_of_bytes = new_blocks * self.tgt.blocksize
706    logger.info(
707        "  Total %d blocks (%d bytes) are packed as new blocks due to "
708        "insufficient cache size. Maximum blocks stashed simultaneously: %d",
709        new_blocks, num_of_bytes, max_stashed_blocks)
710    return new_blocks, max_stashed_blocks
711
712  def ComputePatches(self, prefix):
713    logger.info("Reticulating splines...")
714    diff_queue = []
715    patch_num = 0
716    with open(prefix + ".new.dat", "wb") as new_f:
717      for index, xf in enumerate(self.transfers):
718        if xf.style == "zero":
719          tgt_size = xf.tgt_ranges.size() * self.tgt.blocksize
720          logger.info(
721              "%10d %10d (%6.2f%%) %7s %s %s", tgt_size, tgt_size, 100.0,
722              xf.style, xf.tgt_name, str(xf.tgt_ranges))
723
724        elif xf.style == "new":
725          self.tgt.WriteRangeDataToFd(xf.tgt_ranges, new_f)
726          tgt_size = xf.tgt_ranges.size() * self.tgt.blocksize
727          logger.info(
728              "%10d %10d (%6.2f%%) %7s %s %s", tgt_size, tgt_size, 100.0,
729              xf.style, xf.tgt_name, str(xf.tgt_ranges))
730
731        elif xf.style == "diff":
732          # We can't compare src and tgt directly because they may have
733          # the same content but be broken up into blocks differently, eg:
734          #
735          #    ["he", "llo"]  vs  ["h", "ello"]
736          #
737          # We want those to compare equal, ideally without having to
738          # actually concatenate the strings (these may be tens of
739          # megabytes).
740          if xf.src_sha1 == xf.tgt_sha1:
741            # These are identical; we don't need to generate a patch,
742            # just issue copy commands on the device.
743            xf.style = "move"
744            xf.patch_info = None
745            tgt_size = xf.tgt_ranges.size() * self.tgt.blocksize
746            if xf.src_ranges != xf.tgt_ranges:
747              logger.info(
748                  "%10d %10d (%6.2f%%) %7s %s %s (from %s)", tgt_size, tgt_size,
749                  100.0, xf.style,
750                  xf.tgt_name if xf.tgt_name == xf.src_name else (
751                      xf.tgt_name + " (from " + xf.src_name + ")"),
752                  str(xf.tgt_ranges), str(xf.src_ranges))
753          else:
754            if xf.patch_info:
755              # We have already generated the patch (e.g. during split of large
756              # APKs or reduction of stash size)
757              imgdiff = xf.patch_info.imgdiff
758            else:
759              imgdiff = self.CanUseImgdiff(
760                  xf.tgt_name, xf.tgt_ranges, xf.src_ranges)
761            xf.style = "imgdiff" if imgdiff else "bsdiff"
762            diff_queue.append((index, imgdiff, patch_num))
763            patch_num += 1
764
765        else:
766          assert False, "unknown style " + xf.style
767
768    patches = self.ComputePatchesForInputList(diff_queue, False)
769
770    offset = 0
771    with open(prefix + ".patch.dat", "wb") as patch_fd:
772      for index, patch_info, _ in patches:
773        xf = self.transfers[index]
774        xf.patch_len = len(patch_info.content)
775        xf.patch_start = offset
776        offset += xf.patch_len
777        patch_fd.write(patch_info.content)
778
779        tgt_size = xf.tgt_ranges.size() * self.tgt.blocksize
780        logger.info(
781            "%10d %10d (%6.2f%%) %7s %s %s %s", xf.patch_len, tgt_size,
782            xf.patch_len * 100.0 / tgt_size, xf.style,
783            xf.tgt_name if xf.tgt_name == xf.src_name else (
784                xf.tgt_name + " (from " + xf.src_name + ")"),
785            xf.tgt_ranges, xf.src_ranges)
786
787  def AssertSha1Good(self):
788    """Check the SHA-1 of the src & tgt blocks in the transfer list.
789
790    Double check the SHA-1 value to avoid the issue in b/71908713, where
791    SparseImage.RangeSha1() messed up with the hash calculation in multi-thread
792    environment. That specific problem has been fixed by protecting the
793    underlying generator function 'SparseImage._GetRangeData()' with lock.
794    """
795    for xf in self.transfers:
796      tgt_sha1 = self.tgt.RangeSha1(xf.tgt_ranges)
797      assert xf.tgt_sha1 == tgt_sha1
798      if xf.style == "diff":
799        src_sha1 = self.src.RangeSha1(xf.src_ranges)
800        assert xf.src_sha1 == src_sha1
801
802  def AssertSequenceGood(self):
803    # Simulate the sequences of transfers we will output, and check that:
804    # - we never read a block after writing it, and
805    # - we write every block we care about exactly once.
806
807    # Start with no blocks having been touched yet.
808    touched = array.array("B", b"\0" * self.tgt.total_blocks)
809
810    # Imagine processing the transfers in order.
811    for xf in self.transfers:
812      # Check that the input blocks for this transfer haven't yet been touched.
813
814      x = xf.src_ranges
815      for _, sr in xf.use_stash:
816        x = x.subtract(sr)
817
818      for s, e in x:
819        # Source image could be larger. Don't check the blocks that are in the
820        # source image only. Since they are not in 'touched', and won't ever
821        # be touched.
822        for i in range(s, min(e, self.tgt.total_blocks)):
823          assert touched[i] == 0
824
825      # Check that the output blocks for this transfer haven't yet
826      # been touched, and touch all the blocks written by this
827      # transfer.
828      for s, e in xf.tgt_ranges:
829        for i in range(s, e):
830          assert touched[i] == 0
831          touched[i] = 1
832
833    if self.tgt.hashtree_info:
834      for s, e in self.tgt.hashtree_info.hashtree_range:
835        for i in range(s, e):
836          assert touched[i] == 0
837          touched[i] = 1
838
839    # Check that we've written every target block.
840    for s, e in self.tgt.care_map:
841      for i in range(s, e):
842        assert touched[i] == 1
843
844  def FindSequenceForTransfers(self):
845    """Finds a sequence for the given transfers.
846
847     The goal is to minimize the violation of order dependencies between these
848     transfers, so that fewer blocks are stashed when applying the update.
849    """
850
851    # Clear the existing dependency between transfers
852    for xf in self.transfers:
853      xf.goes_before = OrderedDict()
854      xf.goes_after = OrderedDict()
855
856      xf.stash_before = []
857      xf.use_stash = []
858
859    # Find the ordering dependencies among transfers (this is O(n^2)
860    # in the number of transfers).
861    self.GenerateDigraph()
862    # Find a sequence of transfers that satisfies as many ordering
863    # dependencies as possible (heuristically).
864    self.FindVertexSequence()
865    # Fix up the ordering dependencies that the sequence didn't
866    # satisfy.
867    self.ReverseBackwardEdges()
868    self.ImproveVertexSequence()
869
870  def ImproveVertexSequence(self):
871    logger.info("Improving vertex order...")
872
873    # At this point our digraph is acyclic; we reversed any edges that
874    # were backwards in the heuristically-generated sequence.  The
875    # previously-generated order is still acceptable, but we hope to
876    # find a better order that needs less memory for stashed data.
877    # Now we do a topological sort to generate a new vertex order,
878    # using a greedy algorithm to choose which vertex goes next
879    # whenever we have a choice.
880
881    # Make a copy of the edge set; this copy will get destroyed by the
882    # algorithm.
883    for xf in self.transfers:
884      xf.incoming = xf.goes_after.copy()
885      xf.outgoing = xf.goes_before.copy()
886
887    L = []   # the new vertex order
888
889    # S is the set of sources in the remaining graph; we always choose
890    # the one that leaves the least amount of stashed data after it's
891    # executed.
892    S = [(u.NetStashChange(), u.order, u) for u in self.transfers
893         if not u.incoming]
894    heapq.heapify(S)
895
896    while S:
897      _, _, xf = heapq.heappop(S)
898      L.append(xf)
899      for u in xf.outgoing:
900        del u.incoming[xf]
901        if not u.incoming:
902          heapq.heappush(S, (u.NetStashChange(), u.order, u))
903
904    # if this fails then our graph had a cycle.
905    assert len(L) == len(self.transfers)
906
907    self.transfers = L
908    for i, xf in enumerate(L):
909      xf.order = i
910
911  def ReverseBackwardEdges(self):
912    """Reverse unsatisfying edges and compute pairs of stashed blocks.
913
914    For each transfer, make sure it properly stashes the blocks it touches and
915    will be used by later transfers. It uses pairs of (stash_raw_id, range) to
916    record the blocks to be stashed. 'stash_raw_id' is an id that uniquely
917    identifies each pair. Note that for the same range (e.g. RangeSet("1-5")),
918    it is possible to have multiple pairs with different 'stash_raw_id's. Each
919    'stash_raw_id' will be consumed by one transfer. In BBOTA v3+, identical
920    blocks will be written to the same stash slot in WriteTransfers().
921    """
922
923    logger.info("Reversing backward edges...")
924    in_order = 0
925    out_of_order = 0
926    stash_raw_id = 0
927    stash_size = 0
928
929    for xf in self.transfers:
930      for u in xf.goes_before.copy():
931        # xf should go before u
932        if xf.order < u.order:
933          # it does, hurray!
934          in_order += 1
935        else:
936          # it doesn't, boo.  modify u to stash the blocks that it
937          # writes that xf wants to read, and then require u to go
938          # before xf.
939          out_of_order += 1
940
941          overlap = xf.src_ranges.intersect(u.tgt_ranges)
942          assert overlap
943
944          u.stash_before.append((stash_raw_id, overlap))
945          xf.use_stash.append((stash_raw_id, overlap))
946          stash_raw_id += 1
947          stash_size += overlap.size()
948
949          # reverse the edge direction; now xf must go after u
950          del xf.goes_before[u]
951          del u.goes_after[xf]
952          xf.goes_after[u] = None    # value doesn't matter
953          u.goes_before[xf] = None
954
955    logger.info(
956        "  %d/%d dependencies (%.2f%%) were violated; %d source blocks "
957        "stashed.", out_of_order, in_order + out_of_order,
958        (out_of_order * 100.0 / (in_order + out_of_order)) if (
959            in_order + out_of_order) else 0.0,
960        stash_size)
961
962  def FindVertexSequence(self):
963    logger.info("Finding vertex sequence...")
964
965    # This is based on "A Fast & Effective Heuristic for the Feedback
966    # Arc Set Problem" by P. Eades, X. Lin, and W.F. Smyth.  Think of
967    # it as starting with the digraph G and moving all the vertices to
968    # be on a horizontal line in some order, trying to minimize the
969    # number of edges that end up pointing to the left.  Left-pointing
970    # edges will get removed to turn the digraph into a DAG.  In this
971    # case each edge has a weight which is the number of source blocks
972    # we'll lose if that edge is removed; we try to minimize the total
973    # weight rather than just the number of edges.
974
975    # Make a copy of the edge set; this copy will get destroyed by the
976    # algorithm.
977    for xf in self.transfers:
978      xf.incoming = xf.goes_after.copy()
979      xf.outgoing = xf.goes_before.copy()
980      xf.score = sum(xf.outgoing.values()) - sum(xf.incoming.values())
981
982    # We use an OrderedDict instead of just a set so that the output
983    # is repeatable; otherwise it would depend on the hash values of
984    # the transfer objects.
985    G = OrderedDict()
986    for xf in self.transfers:
987      G[xf] = None
988    s1 = deque()  # the left side of the sequence, built from left to right
989    s2 = deque()  # the right side of the sequence, built from right to left
990
991    heap = []
992    for xf in self.transfers:
993      xf.heap_item = HeapItem(xf)
994      heap.append(xf.heap_item)
995    heapq.heapify(heap)
996
997    # Use OrderedDict() instead of set() to preserve the insertion order. Need
998    # to use 'sinks[key] = None' to add key into the set. sinks will look like
999    # { key1: None, key2: None, ... }.
1000    sinks = OrderedDict.fromkeys(u for u in G if not u.outgoing)
1001    sources = OrderedDict.fromkeys(u for u in G if not u.incoming)
1002
1003    def adjust_score(iu, delta):
1004      iu.score += delta
1005      iu.heap_item.clear()
1006      iu.heap_item = HeapItem(iu)
1007      heapq.heappush(heap, iu.heap_item)
1008
1009    while G:
1010      # Put all sinks at the end of the sequence.
1011      while sinks:
1012        new_sinks = OrderedDict()
1013        for u in sinks:
1014          if u not in G:
1015            continue
1016          s2.appendleft(u)
1017          del G[u]
1018          for iu in u.incoming:
1019            adjust_score(iu, -iu.outgoing.pop(u))
1020            if not iu.outgoing:
1021              new_sinks[iu] = None
1022        sinks = new_sinks
1023
1024      # Put all the sources at the beginning of the sequence.
1025      while sources:
1026        new_sources = OrderedDict()
1027        for u in sources:
1028          if u not in G:
1029            continue
1030          s1.append(u)
1031          del G[u]
1032          for iu in u.outgoing:
1033            adjust_score(iu, +iu.incoming.pop(u))
1034            if not iu.incoming:
1035              new_sources[iu] = None
1036        sources = new_sources
1037
1038      if not G:
1039        break
1040
1041      # Find the "best" vertex to put next.  "Best" is the one that
1042      # maximizes the net difference in source blocks saved we get by
1043      # pretending it's a source rather than a sink.
1044
1045      while True:
1046        u = heapq.heappop(heap)
1047        if u and u.item in G:
1048          u = u.item
1049          break
1050
1051      s1.append(u)
1052      del G[u]
1053      for iu in u.outgoing:
1054        adjust_score(iu, +iu.incoming.pop(u))
1055        if not iu.incoming:
1056          sources[iu] = None
1057
1058      for iu in u.incoming:
1059        adjust_score(iu, -iu.outgoing.pop(u))
1060        if not iu.outgoing:
1061          sinks[iu] = None
1062
1063    # Now record the sequence in the 'order' field of each transfer,
1064    # and by rearranging self.transfers to be in the chosen sequence.
1065
1066    new_transfers = []
1067    for x in itertools.chain(s1, s2):
1068      x.order = len(new_transfers)
1069      new_transfers.append(x)
1070      del x.incoming
1071      del x.outgoing
1072
1073    self.transfers = new_transfers
1074
1075  def GenerateDigraph(self):
1076    logger.info("Generating digraph...")
1077
1078    # Each item of source_ranges will be:
1079    #   - None, if that block is not used as a source,
1080    #   - an ordered set of transfers.
1081    source_ranges = []
1082    for b in self.transfers:
1083      for s, e in b.src_ranges:
1084        if e > len(source_ranges):
1085          source_ranges.extend([None] * (e-len(source_ranges)))
1086        for i in range(s, e):
1087          if source_ranges[i] is None:
1088            source_ranges[i] = OrderedDict.fromkeys([b])
1089          else:
1090            source_ranges[i][b] = None
1091
1092    for a in self.transfers:
1093      intersections = OrderedDict()
1094      for s, e in a.tgt_ranges:
1095        for i in range(s, e):
1096          if i >= len(source_ranges):
1097            break
1098          # Add all the Transfers in source_ranges[i] to the (ordered) set.
1099          if source_ranges[i] is not None:
1100            for j in source_ranges[i]:
1101              intersections[j] = None
1102
1103      for b in intersections:
1104        if a is b:
1105          continue
1106
1107        # If the blocks written by A are read by B, then B needs to go before A.
1108        i = a.tgt_ranges.intersect(b.src_ranges)
1109        if i:
1110          if b.src_name == "__ZERO":
1111            # the cost of removing source blocks for the __ZERO domain
1112            # is (nearly) zero.
1113            size = 0
1114          else:
1115            size = i.size()
1116          b.goes_before[a] = size
1117          a.goes_after[b] = size
1118
1119  def ComputePatchesForInputList(self, diff_queue, compress_target):
1120    """Returns a list of patch information for the input list of transfers.
1121
1122      Args:
1123        diff_queue: a list of transfers with style 'diff'
1124        compress_target: If True, compresses the target ranges of each
1125            transfers; and save the size.
1126
1127      Returns:
1128        A list of (transfer order, patch_info, compressed_size) tuples.
1129    """
1130
1131    if not diff_queue:
1132      return []
1133
1134    if self.threads > 1:
1135      logger.info("Computing patches (using %d threads)...", self.threads)
1136    else:
1137      logger.info("Computing patches...")
1138
1139    diff_total = len(diff_queue)
1140    patches = [None] * diff_total
1141    error_messages = []
1142
1143    # Using multiprocessing doesn't give additional benefits, due to the
1144    # pattern of the code. The diffing work is done by subprocess.call, which
1145    # already runs in a separate process (not affected much by the GIL -
1146    # Global Interpreter Lock). Using multiprocess also requires either a)
1147    # writing the diff input files in the main process before forking, or b)
1148    # reopening the image file (SparseImage) in the worker processes. Doing
1149    # neither of them further improves the performance.
1150    lock = threading.Lock()
1151
1152    def diff_worker():
1153      while True:
1154        with lock:
1155          if not diff_queue:
1156            return
1157          xf_index, imgdiff, patch_index = diff_queue.pop()
1158          xf = self.transfers[xf_index]
1159
1160        message = []
1161        compressed_size = None
1162
1163        patch_info = xf.patch_info
1164        if not patch_info:
1165          src_file = common.MakeTempFile(prefix="src-")
1166          with open(src_file, "wb") as fd:
1167            self.src.WriteRangeDataToFd(xf.src_ranges, fd)
1168
1169          tgt_file = common.MakeTempFile(prefix="tgt-")
1170          with open(tgt_file, "wb") as fd:
1171            self.tgt.WriteRangeDataToFd(xf.tgt_ranges, fd)
1172
1173          try:
1174            patch_info = compute_patch(src_file, tgt_file, imgdiff)
1175          except ValueError as e:
1176            message.append(
1177                "Failed to generate %s for %s: tgt=%s, src=%s:\n%s" % (
1178                    "imgdiff" if imgdiff else "bsdiff",
1179                    xf.tgt_name if xf.tgt_name == xf.src_name else
1180                    xf.tgt_name + " (from " + xf.src_name + ")",
1181                    xf.tgt_ranges, xf.src_ranges, e.message))
1182
1183        if compress_target:
1184          tgt_data = self.tgt.ReadRangeSet(xf.tgt_ranges)
1185          try:
1186            # Compresses with the default level
1187            compress_obj = zlib.compressobj(6, zlib.DEFLATED, -zlib.MAX_WBITS)
1188            compressed_data = (compress_obj.compress("".join(tgt_data))
1189                               + compress_obj.flush())
1190            compressed_size = len(compressed_data)
1191          except zlib.error as e:
1192            message.append(
1193                "Failed to compress the data in target range {} for {}:\n"
1194                "{}".format(xf.tgt_ranges, xf.tgt_name, e.message))
1195
1196        if message:
1197          with lock:
1198            error_messages.extend(message)
1199
1200        with lock:
1201          patches[patch_index] = (xf_index, patch_info, compressed_size)
1202
1203    threads = [threading.Thread(target=diff_worker)
1204               for _ in range(self.threads)]
1205    for th in threads:
1206      th.start()
1207    while threads:
1208      threads.pop().join()
1209
1210    if error_messages:
1211      logger.error('ERROR:')
1212      logger.error('\n'.join(error_messages))
1213      logger.error('\n\n\n')
1214      sys.exit(1)
1215
1216    return patches
1217
1218  def SelectAndConvertDiffTransfersToNew(self, violated_stash_blocks):
1219    """Converts the diff transfers to reduce the max simultaneous stash.
1220
1221    Since the 'new' data is compressed with deflate, we can select the 'diff'
1222    transfers for conversion by comparing its patch size with the size of the
1223    compressed data. Ideally, we want to convert the transfers with a small
1224    size increase, but using a large number of stashed blocks.
1225    """
1226    TransferSizeScore = namedtuple("TransferSizeScore",
1227                                   "xf, used_stash_blocks, score")
1228
1229    logger.info("Selecting diff commands to convert to new.")
1230    diff_queue = []
1231    for xf in self.transfers:
1232      if xf.style == "diff" and xf.src_sha1 != xf.tgt_sha1:
1233        use_imgdiff = self.CanUseImgdiff(xf.tgt_name, xf.tgt_ranges,
1234                                         xf.src_ranges)
1235        diff_queue.append((xf.order, use_imgdiff, len(diff_queue)))
1236
1237    # Remove the 'move' transfers, and compute the patch & compressed size
1238    # for the remaining.
1239    result = self.ComputePatchesForInputList(diff_queue, True)
1240
1241    conversion_candidates = []
1242    for xf_index, patch_info, compressed_size in result:
1243      xf = self.transfers[xf_index]
1244      if not xf.patch_info:
1245        xf.patch_info = patch_info
1246
1247      size_ratio = len(xf.patch_info.content) * 100.0 / compressed_size
1248      diff_style = "imgdiff" if xf.patch_info.imgdiff else "bsdiff"
1249      logger.info("%s, target size: %d blocks, style: %s, patch size: %d,"
1250                  " compression_size: %d, ratio %.2f%%", xf.tgt_name,
1251                  xf.tgt_ranges.size(), diff_style,
1252                  len(xf.patch_info.content), compressed_size, size_ratio)
1253
1254      used_stash_blocks = sum(sr.size() for _, sr in xf.use_stash)
1255      # Convert the transfer to new if the compressed size is smaller or equal.
1256      # We don't need to maintain the stash_before lists here because the
1257      # graph will be regenerated later.
1258      if len(xf.patch_info.content) >= compressed_size:
1259        # Add the transfer to the candidate list with negative score. And it
1260        # will be converted later.
1261        conversion_candidates.append(TransferSizeScore(xf, used_stash_blocks,
1262                                                       -1))
1263      elif used_stash_blocks > 0:
1264        # This heuristic represents the size increase in the final package to
1265        # remove per unit of stashed data.
1266        score = ((compressed_size - len(xf.patch_info.content)) * 100.0
1267                 / used_stash_blocks)
1268        conversion_candidates.append(TransferSizeScore(xf, used_stash_blocks,
1269                                                       score))
1270    # Transfers with lower score (i.e. less expensive to convert) will be
1271    # converted first.
1272    conversion_candidates.sort(key=lambda x: x.score)
1273
1274    # TODO(xunchang), improve the logic to find the transfers to convert, e.g.
1275    # convert the ones that contribute to the max stash, run ReviseStashSize
1276    # multiple times etc.
1277    removed_stashed_blocks = 0
1278    for xf, used_stash_blocks, _ in conversion_candidates:
1279      logger.info("Converting %s to new", xf.tgt_name)
1280      xf.ConvertToNew()
1281      removed_stashed_blocks += used_stash_blocks
1282      # Experiments show that we will get a smaller package size if we remove
1283      # slightly more stashed blocks than the violated stash blocks.
1284      if removed_stashed_blocks >= violated_stash_blocks:
1285        break
1286
1287    logger.info("Removed %d stashed blocks", removed_stashed_blocks)
1288
1289  def FindTransfers(self):
1290    """Parse the file_map to generate all the transfers."""
1291
1292    def AddSplitTransfersWithFixedSizeChunks(tgt_name, src_name, tgt_ranges,
1293                                             src_ranges, style, by_id):
1294      """Add one or multiple Transfer()s by splitting large files.
1295
1296      For BBOTA v3, we need to stash source blocks for resumable feature.
1297      However, with the growth of file size and the shrink of the cache
1298      partition source blocks are too large to be stashed. If a file occupies
1299      too many blocks, we split it into smaller pieces by getting multiple
1300      Transfer()s.
1301
1302      The downside is that after splitting, we may increase the package size
1303      since the split pieces don't align well. According to our experiments,
1304      1/8 of the cache size as the per-piece limit appears to be optimal.
1305      Compared to the fixed 1024-block limit, it reduces the overall package
1306      size by 30% for volantis, and 20% for angler and bullhead."""
1307
1308      pieces = 0
1309      while (tgt_ranges.size() > max_blocks_per_transfer and
1310             src_ranges.size() > max_blocks_per_transfer):
1311        tgt_split_name = "%s-%d" % (tgt_name, pieces)
1312        src_split_name = "%s-%d" % (src_name, pieces)
1313        tgt_first = tgt_ranges.first(max_blocks_per_transfer)
1314        src_first = src_ranges.first(max_blocks_per_transfer)
1315
1316        Transfer(tgt_split_name, src_split_name, tgt_first, src_first,
1317                 self.tgt.RangeSha1(tgt_first), self.src.RangeSha1(src_first),
1318                 style, by_id)
1319
1320        tgt_ranges = tgt_ranges.subtract(tgt_first)
1321        src_ranges = src_ranges.subtract(src_first)
1322        pieces += 1
1323
1324      # Handle remaining blocks.
1325      if tgt_ranges.size() or src_ranges.size():
1326        # Must be both non-empty.
1327        assert tgt_ranges.size() and src_ranges.size()
1328        tgt_split_name = "%s-%d" % (tgt_name, pieces)
1329        src_split_name = "%s-%d" % (src_name, pieces)
1330        Transfer(tgt_split_name, src_split_name, tgt_ranges, src_ranges,
1331                 self.tgt.RangeSha1(tgt_ranges), self.src.RangeSha1(src_ranges),
1332                 style, by_id)
1333
1334    def AddSplitTransfers(tgt_name, src_name, tgt_ranges, src_ranges, style,
1335                          by_id):
1336      """Find all the zip files and split the others with a fixed chunk size.
1337
1338      This function will construct a list of zip archives, which will later be
1339      split by imgdiff to reduce the final patch size. For the other files,
1340      we will plainly split them based on a fixed chunk size with the potential
1341      patch size penalty.
1342      """
1343
1344      assert style == "diff"
1345
1346      # Change nothing for small files.
1347      if (tgt_ranges.size() <= max_blocks_per_transfer and
1348          src_ranges.size() <= max_blocks_per_transfer):
1349        Transfer(tgt_name, src_name, tgt_ranges, src_ranges,
1350                 self.tgt.RangeSha1(tgt_ranges), self.src.RangeSha1(src_ranges),
1351                 style, by_id)
1352        return
1353
1354      # Split large APKs with imgdiff, if possible. We're intentionally checking
1355      # file types one more time (CanUseImgdiff() checks that as well), before
1356      # calling the costly RangeSha1()s.
1357      if (self.FileTypeSupportedByImgdiff(tgt_name) and
1358          self.tgt.RangeSha1(tgt_ranges) != self.src.RangeSha1(src_ranges)):
1359        if self.CanUseImgdiff(tgt_name, tgt_ranges, src_ranges, True):
1360          large_apks.append((tgt_name, src_name, tgt_ranges, src_ranges))
1361          return
1362
1363      AddSplitTransfersWithFixedSizeChunks(tgt_name, src_name, tgt_ranges,
1364                                           src_ranges, style, by_id)
1365
1366    def AddTransfer(tgt_name, src_name, tgt_ranges, src_ranges, style, by_id,
1367                    split=False):
1368      """Wrapper function for adding a Transfer()."""
1369
1370      # We specialize diff transfers only (which covers bsdiff/imgdiff/move);
1371      # otherwise add the Transfer() as is.
1372      if style != "diff" or not split:
1373        Transfer(tgt_name, src_name, tgt_ranges, src_ranges,
1374                 self.tgt.RangeSha1(tgt_ranges), self.src.RangeSha1(src_ranges),
1375                 style, by_id)
1376        return
1377
1378      # Handle .odex files specially to analyze the block-wise difference. If
1379      # most of the blocks are identical with only few changes (e.g. header),
1380      # we will patch the changed blocks only. This avoids stashing unchanged
1381      # blocks while patching. We limit the analysis to files without size
1382      # changes only. This is to avoid sacrificing the OTA generation cost too
1383      # much.
1384      if (tgt_name.split(".")[-1].lower() == 'odex' and
1385          tgt_ranges.size() == src_ranges.size()):
1386
1387        # 0.5 threshold can be further tuned. The tradeoff is: if only very
1388        # few blocks remain identical, we lose the opportunity to use imgdiff
1389        # that may have better compression ratio than bsdiff.
1390        crop_threshold = 0.5
1391
1392        tgt_skipped = RangeSet()
1393        src_skipped = RangeSet()
1394        tgt_size = tgt_ranges.size()
1395        tgt_changed = 0
1396        for src_block, tgt_block in zip(src_ranges.next_item(),
1397                                        tgt_ranges.next_item()):
1398          src_rs = RangeSet(str(src_block))
1399          tgt_rs = RangeSet(str(tgt_block))
1400          if self.src.ReadRangeSet(src_rs) == self.tgt.ReadRangeSet(tgt_rs):
1401            tgt_skipped = tgt_skipped.union(tgt_rs)
1402            src_skipped = src_skipped.union(src_rs)
1403          else:
1404            tgt_changed += tgt_rs.size()
1405
1406          # Terminate early if no clear sign of benefits.
1407          if tgt_changed > tgt_size * crop_threshold:
1408            break
1409
1410        if tgt_changed < tgt_size * crop_threshold:
1411          assert tgt_changed + tgt_skipped.size() == tgt_size
1412          logger.info(
1413              '%10d %10d (%6.2f%%) %s', tgt_skipped.size(), tgt_size,
1414              tgt_skipped.size() * 100.0 / tgt_size, tgt_name)
1415          AddSplitTransfers(
1416              "%s-skipped" % (tgt_name,),
1417              "%s-skipped" % (src_name,),
1418              tgt_skipped, src_skipped, style, by_id)
1419
1420          # Intentionally change the file extension to avoid being imgdiff'd as
1421          # the files are no longer in their original format.
1422          tgt_name = "%s-cropped" % (tgt_name,)
1423          src_name = "%s-cropped" % (src_name,)
1424          tgt_ranges = tgt_ranges.subtract(tgt_skipped)
1425          src_ranges = src_ranges.subtract(src_skipped)
1426
1427          # Possibly having no changed blocks.
1428          if not tgt_ranges:
1429            return
1430
1431      # Add the transfer(s).
1432      AddSplitTransfers(
1433          tgt_name, src_name, tgt_ranges, src_ranges, style, by_id)
1434
1435    def ParseAndValidateSplitInfo(patch_size, tgt_ranges, src_ranges,
1436                                  split_info):
1437      """Parse the split_info and return a list of info tuples.
1438
1439      Args:
1440        patch_size: total size of the patch file.
1441        tgt_ranges: Ranges of the target file within the original image.
1442        src_ranges: Ranges of the source file within the original image.
1443        split_info format:
1444          imgdiff version#
1445          count of pieces
1446          <patch_size_1> <tgt_size_1> <src_ranges_1>
1447          ...
1448          <patch_size_n> <tgt_size_n> <src_ranges_n>
1449
1450      Returns:
1451        [patch_start, patch_len, split_tgt_ranges, split_src_ranges]
1452      """
1453
1454      version = int(split_info[0])
1455      assert version == 2
1456      count = int(split_info[1])
1457      assert len(split_info) - 2 == count
1458
1459      split_info_list = []
1460      patch_start = 0
1461      tgt_remain = copy.deepcopy(tgt_ranges)
1462      # each line has the format <patch_size>, <tgt_size>, <src_ranges>
1463      for line in split_info[2:]:
1464        info = line.split()
1465        assert len(info) == 3
1466        patch_length = int(info[0])
1467
1468        split_tgt_size = int(info[1])
1469        assert split_tgt_size % 4096 == 0
1470        assert split_tgt_size // 4096 <= tgt_remain.size()
1471        split_tgt_ranges = tgt_remain.first(split_tgt_size // 4096)
1472        tgt_remain = tgt_remain.subtract(split_tgt_ranges)
1473
1474        # Find the split_src_ranges within the image file from its relative
1475        # position in file.
1476        split_src_indices = RangeSet.parse_raw(info[2])
1477        split_src_ranges = RangeSet()
1478        for r in split_src_indices:
1479          curr_range = src_ranges.first(r[1]).subtract(src_ranges.first(r[0]))
1480          assert not split_src_ranges.overlaps(curr_range)
1481          split_src_ranges = split_src_ranges.union(curr_range)
1482
1483        split_info_list.append((patch_start, patch_length,
1484                                split_tgt_ranges, split_src_ranges))
1485        patch_start += patch_length
1486
1487      # Check that the sizes of all the split pieces add up to the final file
1488      # size for patch and target.
1489      assert tgt_remain.size() == 0
1490      assert patch_start == patch_size
1491      return split_info_list
1492
1493    def SplitLargeApks():
1494      """Split the large apks files.
1495
1496      Example: Chrome.apk will be split into
1497        src-0: Chrome.apk-0, tgt-0: Chrome.apk-0
1498        src-1: Chrome.apk-1, tgt-1: Chrome.apk-1
1499        ...
1500
1501      After the split, the target pieces are continuous and block aligned; and
1502      the source pieces are mutually exclusive. During the split, we also
1503      generate and save the image patch between src-X & tgt-X. This patch will
1504      be valid because the block ranges of src-X & tgt-X will always stay the
1505      same afterwards; but there's a chance we don't use the patch if we
1506      convert the "diff" command into "new" or "move" later.
1507      """
1508
1509      while True:
1510        with transfer_lock:
1511          if not large_apks:
1512            return
1513          tgt_name, src_name, tgt_ranges, src_ranges = large_apks.pop(0)
1514
1515        src_file = common.MakeTempFile(prefix="src-")
1516        tgt_file = common.MakeTempFile(prefix="tgt-")
1517        with open(src_file, "wb") as src_fd:
1518          self.src.WriteRangeDataToFd(src_ranges, src_fd)
1519        with open(tgt_file, "wb") as tgt_fd:
1520          self.tgt.WriteRangeDataToFd(tgt_ranges, tgt_fd)
1521
1522        patch_file = common.MakeTempFile(prefix="patch-")
1523        patch_info_file = common.MakeTempFile(prefix="split_info-")
1524        cmd = ["imgdiff", "-z",
1525               "--block-limit={}".format(max_blocks_per_transfer),
1526               "--split-info=" + patch_info_file,
1527               src_file, tgt_file, patch_file]
1528        proc = common.Run(cmd)
1529        imgdiff_output, _ = proc.communicate()
1530        assert proc.returncode == 0, \
1531            "Failed to create imgdiff patch between {} and {}:\n{}".format(
1532                src_name, tgt_name, imgdiff_output)
1533
1534        with open(patch_info_file) as patch_info:
1535          lines = patch_info.readlines()
1536
1537        patch_size_total = os.path.getsize(patch_file)
1538        split_info_list = ParseAndValidateSplitInfo(patch_size_total,
1539                                                    tgt_ranges, src_ranges,
1540                                                    lines)
1541        for index, (patch_start, patch_length, split_tgt_ranges,
1542                    split_src_ranges) in enumerate(split_info_list):
1543          with open(patch_file, 'rb') as f:
1544            f.seek(patch_start)
1545            patch_content = f.read(patch_length)
1546
1547          split_src_name = "{}-{}".format(src_name, index)
1548          split_tgt_name = "{}-{}".format(tgt_name, index)
1549          split_large_apks.append((split_tgt_name,
1550                                   split_src_name,
1551                                   split_tgt_ranges,
1552                                   split_src_ranges,
1553                                   patch_content))
1554
1555    logger.info("Finding transfers...")
1556
1557    large_apks = []
1558    split_large_apks = []
1559    cache_size = common.OPTIONS.cache_size
1560    split_threshold = 0.125
1561    assert cache_size is not None
1562    max_blocks_per_transfer = int(cache_size * split_threshold /
1563                                  self.tgt.blocksize)
1564    empty = RangeSet()
1565    for tgt_fn, tgt_ranges in sorted(self.tgt.file_map.items()):
1566      if tgt_fn == "__ZERO":
1567        # the special "__ZERO" domain is all the blocks not contained
1568        # in any file and that are filled with zeros.  We have a
1569        # special transfer style for zero blocks.
1570        src_ranges = self.src.file_map.get("__ZERO", empty)
1571        AddTransfer(tgt_fn, "__ZERO", tgt_ranges, src_ranges,
1572                    "zero", self.transfers)
1573        continue
1574
1575      elif tgt_fn == "__COPY":
1576        # "__COPY" domain includes all the blocks not contained in any
1577        # file and that need to be copied unconditionally to the target.
1578        AddTransfer(tgt_fn, None, tgt_ranges, empty, "new", self.transfers)
1579        continue
1580
1581      elif tgt_fn == "__HASHTREE":
1582        continue
1583
1584      elif tgt_fn in self.src.file_map:
1585        # Look for an exact pathname match in the source.
1586        AddTransfer(tgt_fn, tgt_fn, tgt_ranges, self.src.file_map[tgt_fn],
1587                    "diff", self.transfers, True)
1588        continue
1589
1590      b = os.path.basename(tgt_fn)
1591      if b in self.src_basenames:
1592        # Look for an exact basename match in the source.
1593        src_fn = self.src_basenames[b]
1594        AddTransfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn],
1595                    "diff", self.transfers, True)
1596        continue
1597
1598      b = re.sub("[0-9]+", "#", b)
1599      if b in self.src_numpatterns:
1600        # Look for a 'number pattern' match (a basename match after
1601        # all runs of digits are replaced by "#").  (This is useful
1602        # for .so files that contain version numbers in the filename
1603        # that get bumped.)
1604        src_fn = self.src_numpatterns[b]
1605        AddTransfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn],
1606                    "diff", self.transfers, True)
1607        continue
1608
1609      AddTransfer(tgt_fn, None, tgt_ranges, empty, "new", self.transfers)
1610
1611    transfer_lock = threading.Lock()
1612    threads = [threading.Thread(target=SplitLargeApks)
1613               for _ in range(self.threads)]
1614    for th in threads:
1615      th.start()
1616    while threads:
1617      threads.pop().join()
1618
1619    # Sort the split transfers for large apks to generate a determinate package.
1620    split_large_apks.sort()
1621    for (tgt_name, src_name, tgt_ranges, src_ranges,
1622         patch) in split_large_apks:
1623      transfer_split = Transfer(tgt_name, src_name, tgt_ranges, src_ranges,
1624                                self.tgt.RangeSha1(tgt_ranges),
1625                                self.src.RangeSha1(src_ranges),
1626                                "diff", self.transfers)
1627      transfer_split.patch_info = PatchInfo(True, patch)
1628
1629  def AbbreviateSourceNames(self):
1630    for k in self.src.file_map.keys():
1631      b = os.path.basename(k)
1632      self.src_basenames[b] = k
1633      b = re.sub("[0-9]+", "#", b)
1634      self.src_numpatterns[b] = k
1635
1636  @staticmethod
1637  def AssertPartition(total, seq):
1638    """Assert that all the RangeSets in 'seq' form a partition of the
1639    'total' RangeSet (ie, they are nonintersecting and their union
1640    equals 'total')."""
1641
1642    so_far = RangeSet()
1643    for i in seq:
1644      assert not so_far.overlaps(i)
1645      so_far = so_far.union(i)
1646    assert so_far == total
1647