1# Copyright (C) 2019 The Android Open Source Project
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific
13
14import os
15import threading
16from hashlib import sha1
17
18from rangelib import RangeSet
19
20__all__ = ["EmptyImage", "DataImage", "FileImage"]
21
22
23class Image(object):
24  def RangeSha1(self, ranges):
25    raise NotImplementedError
26
27  def ReadRangeSet(self, ranges):
28    raise NotImplementedError
29
30  def TotalSha1(self, include_clobbered_blocks=False):
31    raise NotImplementedError
32
33  def WriteRangeDataToFd(self, ranges, fd):
34    raise NotImplementedError
35
36
37class EmptyImage(Image):
38  """A zero-length image."""
39
40  def __init__(self):
41    self.blocksize = 4096
42    self.care_map = RangeSet()
43    self.clobbered_blocks = RangeSet()
44    self.extended = RangeSet()
45    self.total_blocks = 0
46    self.file_map = {}
47    self.hashtree_info = None
48
49  def RangeSha1(self, ranges):
50    return sha1().hexdigest()
51
52  def ReadRangeSet(self, ranges):
53    return ()
54
55  def TotalSha1(self, include_clobbered_blocks=False):
56    # EmptyImage always carries empty clobbered_blocks, so
57    # include_clobbered_blocks can be ignored.
58    assert self.clobbered_blocks.size() == 0
59    return sha1().hexdigest()
60
61  def WriteRangeDataToFd(self, ranges, fd):
62    raise ValueError("Can't write data from EmptyImage to file")
63
64
65class DataImage(Image):
66  """An image wrapped around a single string of data."""
67
68  def __init__(self, data, trim=False, pad=False):
69    self.data = data
70    self.blocksize = 4096
71
72    assert not (trim and pad)
73
74    partial = len(self.data) % self.blocksize
75    padded = False
76    if partial > 0:
77      if trim:
78        self.data = self.data[:-partial]
79      elif pad:
80        self.data += '\0' * (self.blocksize - partial)
81        padded = True
82      else:
83        raise ValueError(("data for DataImage must be multiple of %d bytes "
84                          "unless trim or pad is specified") %
85                         (self.blocksize,))
86
87    assert len(self.data) % self.blocksize == 0
88
89    self.total_blocks = len(self.data) // self.blocksize
90    self.care_map = RangeSet(data=(0, self.total_blocks))
91    # When the last block is padded, we always write the whole block even for
92    # incremental OTAs. Because otherwise the last block may get skipped if
93    # unchanged for an incremental, but would fail the post-install
94    # verification if it has non-zero contents in the padding bytes.
95    # Bug: 23828506
96    if padded:
97      clobbered_blocks = [self.total_blocks-1, self.total_blocks]
98    else:
99      clobbered_blocks = []
100    self.clobbered_blocks = clobbered_blocks
101    self.extended = RangeSet()
102
103    zero_blocks = []
104    nonzero_blocks = []
105    reference = '\0' * self.blocksize
106
107    for i in range(self.total_blocks-1 if padded else self.total_blocks):
108      d = self.data[i*self.blocksize : (i+1)*self.blocksize]
109      if d == reference:
110        zero_blocks.append(i)
111        zero_blocks.append(i+1)
112      else:
113        nonzero_blocks.append(i)
114        nonzero_blocks.append(i+1)
115
116    assert zero_blocks or nonzero_blocks or clobbered_blocks
117
118    self.file_map = dict()
119    if zero_blocks:
120      self.file_map["__ZERO"] = RangeSet(data=zero_blocks)
121    if nonzero_blocks:
122      self.file_map["__NONZERO"] = RangeSet(data=nonzero_blocks)
123    if clobbered_blocks:
124      self.file_map["__COPY"] = RangeSet(data=clobbered_blocks)
125
126  def _GetRangeData(self, ranges):
127    for s, e in ranges:
128      yield self.data[s*self.blocksize:e*self.blocksize]
129
130  def RangeSha1(self, ranges):
131    h = sha1()
132    for data in self._GetRangeData(ranges): # pylint: disable=not-an-iterable
133      h.update(data)
134    return h.hexdigest()
135
136  def ReadRangeSet(self, ranges):
137    return list(self._GetRangeData(ranges))
138
139  def TotalSha1(self, include_clobbered_blocks=False):
140    if not include_clobbered_blocks:
141      return self.RangeSha1(self.care_map.subtract(self.clobbered_blocks))
142    return sha1(self.data).hexdigest()
143
144  def WriteRangeDataToFd(self, ranges, fd):
145    for data in self._GetRangeData(ranges): # pylint: disable=not-an-iterable
146      fd.write(data)
147
148
149class FileImage(Image):
150  """An image wrapped around a raw image file."""
151
152  def __init__(self, path, hashtree_info_generator=None):
153    self.path = path
154    self.blocksize = 4096
155    self._file_size = os.path.getsize(self.path)
156    self._file = open(self.path, 'rb')
157
158    if self._file_size % self.blocksize != 0:
159      raise ValueError("Size of file %s must be multiple of %d bytes, but is %d"
160                       % self.path, self.blocksize, self._file_size)
161
162    self.total_blocks = self._file_size // self.blocksize
163    self.care_map = RangeSet(data=(0, self.total_blocks))
164    self.clobbered_blocks = RangeSet()
165    self.extended = RangeSet()
166
167    self.generator_lock = threading.Lock()
168
169    self.hashtree_info = None
170    if hashtree_info_generator:
171      self.hashtree_info = hashtree_info_generator.Generate(self)
172
173    zero_blocks = []
174    nonzero_blocks = []
175    reference = '\0' * self.blocksize
176
177    for i in range(self.total_blocks):
178      d = self._file.read(self.blocksize)
179      if d == reference:
180        zero_blocks.append(i)
181        zero_blocks.append(i+1)
182      else:
183        nonzero_blocks.append(i)
184        nonzero_blocks.append(i+1)
185
186    assert zero_blocks or nonzero_blocks
187
188    self.file_map = {}
189    if zero_blocks:
190      self.file_map["__ZERO"] = RangeSet(data=zero_blocks)
191    if nonzero_blocks:
192      self.file_map["__NONZERO"] = RangeSet(data=nonzero_blocks)
193    if self.hashtree_info:
194      self.file_map["__HASHTREE"] = self.hashtree_info.hashtree_range
195
196  def __del__(self):
197    self._file.close()
198
199  def _GetRangeData(self, ranges):
200    # Use a lock to protect the generator so that we will not run two
201    # instances of this generator on the same object simultaneously.
202    with self.generator_lock:
203      for s, e in ranges:
204        self._file.seek(s * self.blocksize)
205        for _ in range(s, e):
206          yield self._file.read(self.blocksize)
207
208  def RangeSha1(self, ranges):
209    h = sha1()
210    for data in self._GetRangeData(ranges): # pylint: disable=not-an-iterable
211      h.update(data)
212    return h.hexdigest()
213
214  def ReadRangeSet(self, ranges):
215    return list(self._GetRangeData(ranges))
216
217  def TotalSha1(self, include_clobbered_blocks=False):
218    assert not self.clobbered_blocks
219    return self.RangeSha1(self.care_map)
220
221  def WriteRangeDataToFd(self, ranges, fd):
222    for data in self._GetRangeData(ranges): # pylint: disable=not-an-iterable
223      fd.write(data)
224