1#!/usr/bin/env python
2#
3# Copyright (C) 2017 The Android Open Source Project
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9#      http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17
18"""Send an A/B update to an Android device over adb."""
19
20from __future__ import absolute_import
21
22import argparse
23import hashlib
24import logging
25import os
26import socket
27import subprocess
28import sys
29import threading
30import xml.etree.ElementTree
31import zipfile
32
33from six.moves import BaseHTTPServer
34
35import update_payload.payload
36
37
38# The path used to store the OTA package when applying the package from a file.
39OTA_PACKAGE_PATH = '/data/ota_package'
40
41# The path to the payload public key on the device.
42PAYLOAD_KEY_PATH = '/etc/update_engine/update-payload-key.pub.pem'
43
44# The port on the device that update_engine should connect to.
45DEVICE_PORT = 1234
46
47
48def CopyFileObjLength(fsrc, fdst, buffer_size=128 * 1024, copy_length=None):
49  """Copy from a file object to another.
50
51  This function is similar to shutil.copyfileobj except that it allows to copy
52  less than the full source file.
53
54  Args:
55    fsrc: source file object where to read from.
56    fdst: destination file object where to write to.
57    buffer_size: size of the copy buffer in memory.
58    copy_length: maximum number of bytes to copy, or None to copy everything.
59
60  Returns:
61    the number of bytes copied.
62  """
63  copied = 0
64  while True:
65    chunk_size = buffer_size
66    if copy_length is not None:
67      chunk_size = min(chunk_size, copy_length - copied)
68      if not chunk_size:
69        break
70    buf = fsrc.read(chunk_size)
71    if not buf:
72      break
73    fdst.write(buf)
74    copied += len(buf)
75  return copied
76
77
78class AndroidOTAPackage(object):
79  """Android update payload using the .zip format.
80
81  Android OTA packages traditionally used a .zip file to store the payload. When
82  applying A/B updates over the network, a payload binary is stored RAW inside
83  this .zip file which is used by update_engine to apply the payload. To do
84  this, an offset and size inside the .zip file are provided.
85  """
86
87  # Android OTA package file paths.
88  OTA_PAYLOAD_BIN = 'payload.bin'
89  OTA_PAYLOAD_PROPERTIES_TXT = 'payload_properties.txt'
90  SECONDARY_OTA_PAYLOAD_BIN = 'secondary/payload.bin'
91  SECONDARY_OTA_PAYLOAD_PROPERTIES_TXT = 'secondary/payload_properties.txt'
92
93  def __init__(self, otafilename, secondary_payload=False):
94    self.otafilename = otafilename
95
96    otazip = zipfile.ZipFile(otafilename, 'r')
97    payload_entry = (self.SECONDARY_OTA_PAYLOAD_BIN if secondary_payload else
98                     self.OTA_PAYLOAD_BIN)
99    payload_info = otazip.getinfo(payload_entry)
100    self.offset = payload_info.header_offset
101    self.offset += zipfile.sizeFileHeader
102    self.offset += len(payload_info.extra) + len(payload_info.filename)
103    self.size = payload_info.file_size
104
105    property_entry = (self.SECONDARY_OTA_PAYLOAD_PROPERTIES_TXT if
106                      secondary_payload else self.OTA_PAYLOAD_PROPERTIES_TXT)
107    self.properties = otazip.read(property_entry)
108
109
110class UpdateHandler(BaseHTTPServer.BaseHTTPRequestHandler):
111  """A HTTPServer that supports single-range requests.
112
113  Attributes:
114    serving_payload: path to the only payload file we are serving.
115    serving_range: the start offset and size tuple of the payload.
116  """
117
118  @staticmethod
119  def _parse_range(range_str, file_size):
120    """Parse an HTTP range string.
121
122    Args:
123      range_str: HTTP Range header in the request, not including "Header:".
124      file_size: total size of the serving file.
125
126    Returns:
127      A tuple (start_range, end_range) with the range of bytes requested.
128    """
129    start_range = 0
130    end_range = file_size
131
132    if range_str:
133      range_str = range_str.split('=', 1)[1]
134      s, e = range_str.split('-', 1)
135      if s:
136        start_range = int(s)
137        if e:
138          end_range = int(e) + 1
139      elif e:
140        if int(e) < file_size:
141          start_range = file_size - int(e)
142    return start_range, end_range
143
144  def do_GET(self):  # pylint: disable=invalid-name
145    """Reply with the requested payload file."""
146    if self.path != '/payload':
147      self.send_error(404, 'Unknown request')
148      return
149
150    if not self.serving_payload:
151      self.send_error(500, 'No serving payload set')
152      return
153
154    try:
155      f = open(self.serving_payload, 'rb')
156    except IOError:
157      self.send_error(404, 'File not found')
158      return
159    # Handle the range request.
160    if 'Range' in self.headers:
161      self.send_response(206)
162    else:
163      self.send_response(200)
164
165    serving_start, serving_size = self.serving_range
166    start_range, end_range = self._parse_range(self.headers.get('range'),
167                                               serving_size)
168    logging.info('Serving request for %s from %s [%d, %d) length: %d',
169                 self.path, self.serving_payload, serving_start + start_range,
170                 serving_start + end_range, end_range - start_range)
171
172    self.send_header('Accept-Ranges', 'bytes')
173    self.send_header('Content-Range',
174                     'bytes ' + str(start_range) + '-' + str(end_range - 1) +
175                     '/' + str(end_range - start_range))
176    self.send_header('Content-Length', end_range - start_range)
177
178    stat = os.fstat(f.fileno())
179    self.send_header('Last-Modified', self.date_time_string(stat.st_mtime))
180    self.send_header('Content-type', 'application/octet-stream')
181    self.end_headers()
182
183    f.seek(serving_start + start_range)
184    CopyFileObjLength(f, self.wfile, copy_length=end_range - start_range)
185
186  def do_POST(self):  # pylint: disable=invalid-name
187    """Reply with the omaha response xml."""
188    if self.path != '/update':
189      self.send_error(404, 'Unknown request')
190      return
191
192    if not self.serving_payload:
193      self.send_error(500, 'No serving payload set')
194      return
195
196    try:
197      f = open(self.serving_payload, 'rb')
198    except IOError:
199      self.send_error(404, 'File not found')
200      return
201
202    content_length = int(self.headers.getheader('Content-Length'))
203    request_xml = self.rfile.read(content_length)
204    xml_root = xml.etree.ElementTree.fromstring(request_xml)
205    appid = None
206    for app in xml_root.iter('app'):
207      if 'appid' in app.attrib:
208        appid = app.attrib['appid']
209        break
210    if not appid:
211      self.send_error(400, 'No appid in Omaha request')
212      return
213
214    self.send_response(200)
215    self.send_header("Content-type", "text/xml")
216    self.end_headers()
217
218    serving_start, serving_size = self.serving_range
219    sha256 = hashlib.sha256()
220    f.seek(serving_start)
221    bytes_to_hash = serving_size
222    while bytes_to_hash:
223      buf = f.read(min(bytes_to_hash, 1024 * 1024))
224      if not buf:
225        self.send_error(500, 'Payload too small')
226        return
227      sha256.update(buf)
228      bytes_to_hash -= len(buf)
229
230    payload = update_payload.Payload(f, payload_file_offset=serving_start)
231    payload.Init()
232
233    response_xml = '''
234        <?xml version="1.0" encoding="UTF-8"?>
235        <response protocol="3.0">
236          <app appid="{appid}">
237            <updatecheck status="ok">
238              <urls>
239                <url codebase="http://127.0.0.1:{port}/"/>
240              </urls>
241              <manifest version="0.0.0.1">
242                <actions>
243                  <action event="install" run="payload"/>
244                  <action event="postinstall" MetadataSize="{metadata_size}"/>
245                </actions>
246                <packages>
247                  <package hash_sha256="{payload_hash}" name="payload" size="{payload_size}"/>
248                </packages>
249              </manifest>
250            </updatecheck>
251          </app>
252        </response>
253    '''.format(appid=appid, port=DEVICE_PORT,
254               metadata_size=payload.metadata_size,
255               payload_hash=sha256.hexdigest(),
256               payload_size=serving_size)
257    self.wfile.write(response_xml.strip())
258    return
259
260
261class ServerThread(threading.Thread):
262  """A thread for serving HTTP requests."""
263
264  def __init__(self, ota_filename, serving_range):
265    threading.Thread.__init__(self)
266    # serving_payload and serving_range are class attributes and the
267    # UpdateHandler class is instantiated with every request.
268    UpdateHandler.serving_payload = ota_filename
269    UpdateHandler.serving_range = serving_range
270    self._httpd = BaseHTTPServer.HTTPServer(('127.0.0.1', 0), UpdateHandler)
271    self.port = self._httpd.server_port
272
273  def run(self):
274    try:
275      self._httpd.serve_forever()
276    except (KeyboardInterrupt, socket.error):
277      pass
278    logging.info('Server Terminated')
279
280  def StopServer(self):
281    self._httpd.socket.close()
282
283
284def StartServer(ota_filename, serving_range):
285  t = ServerThread(ota_filename, serving_range)
286  t.start()
287  return t
288
289
290def AndroidUpdateCommand(ota_filename, secondary, payload_url, extra_headers):
291  """Return the command to run to start the update in the Android device."""
292  ota = AndroidOTAPackage(ota_filename, secondary)
293  headers = ota.properties
294  headers += 'USER_AGENT=Dalvik (something, something)\n'
295  headers += 'NETWORK_ID=0\n'
296  headers += extra_headers
297
298  return ['update_engine_client', '--update', '--follow',
299          '--payload=%s' % payload_url, '--offset=%d' % ota.offset,
300          '--size=%d' % ota.size, '--headers="%s"' % headers]
301
302
303def OmahaUpdateCommand(omaha_url):
304  """Return the command to run to start the update in a device using Omaha."""
305  return ['update_engine_client', '--update', '--follow',
306          '--omaha_url=%s' % omaha_url]
307
308
309class AdbHost(object):
310  """Represents a device connected via ADB."""
311
312  def __init__(self, device_serial=None):
313    """Construct an instance.
314
315    Args:
316        device_serial: options string serial number of attached device.
317    """
318    self._device_serial = device_serial
319    self._command_prefix = ['adb']
320    if self._device_serial:
321      self._command_prefix += ['-s', self._device_serial]
322
323  def adb(self, command):
324    """Run an ADB command like "adb push".
325
326    Args:
327      command: list of strings containing command and arguments to run
328
329    Returns:
330      the program's return code.
331
332    Raises:
333      subprocess.CalledProcessError on command exit != 0.
334    """
335    command = self._command_prefix + command
336    logging.info('Running: %s', ' '.join(str(x) for x in command))
337    p = subprocess.Popen(command, universal_newlines=True)
338    p.wait()
339    return p.returncode
340
341  def adb_output(self, command):
342    """Run an ADB command like "adb push" and return the output.
343
344    Args:
345      command: list of strings containing command and arguments to run
346
347    Returns:
348      the program's output as a string.
349
350    Raises:
351      subprocess.CalledProcessError on command exit != 0.
352    """
353    command = self._command_prefix + command
354    logging.info('Running: %s', ' '.join(str(x) for x in command))
355    return subprocess.check_output(command, universal_newlines=True)
356
357
358def main():
359  parser = argparse.ArgumentParser(description='Android A/B OTA helper.')
360  parser.add_argument('otafile', metavar='PAYLOAD', type=str,
361                      help='the OTA package file (a .zip file) or raw payload \
362                      if device uses Omaha.')
363  parser.add_argument('--file', action='store_true',
364                      help='Push the file to the device before updating.')
365  parser.add_argument('--no-push', action='store_true',
366                      help='Skip the "push" command when using --file')
367  parser.add_argument('-s', type=str, default='', metavar='DEVICE',
368                      help='The specific device to use.')
369  parser.add_argument('--no-verbose', action='store_true',
370                      help='Less verbose output')
371  parser.add_argument('--public-key', type=str, default='',
372                      help='Override the public key used to verify payload.')
373  parser.add_argument('--extra-headers', type=str, default='',
374                      help='Extra headers to pass to the device.')
375  parser.add_argument('--secondary', action='store_true',
376                      help='Update with the secondary payload in the package.')
377  args = parser.parse_args()
378  logging.basicConfig(
379      level=logging.WARNING if args.no_verbose else logging.INFO)
380
381  dut = AdbHost(args.s)
382
383  server_thread = None
384  # List of commands to execute on exit.
385  finalize_cmds = []
386  # Commands to execute when canceling an update.
387  cancel_cmd = ['shell', 'su', '0', 'update_engine_client', '--cancel']
388  # List of commands to perform the update.
389  cmds = []
390
391  help_cmd = ['shell', 'su', '0', 'update_engine_client', '--help']
392  use_omaha = 'omaha' in dut.adb_output(help_cmd)
393
394  if args.file:
395    # Update via pushing a file to /data.
396    device_ota_file = os.path.join(OTA_PACKAGE_PATH, 'debug.zip')
397    payload_url = 'file://' + device_ota_file
398    if not args.no_push:
399      data_local_tmp_file = '/data/local/tmp/debug.zip'
400      cmds.append(['push', args.otafile, data_local_tmp_file])
401      cmds.append(['shell', 'su', '0', 'mv', data_local_tmp_file,
402                   device_ota_file])
403      cmds.append(['shell', 'su', '0', 'chcon',
404                   'u:object_r:ota_package_file:s0', device_ota_file])
405    cmds.append(['shell', 'su', '0', 'chown', 'system:cache', device_ota_file])
406    cmds.append(['shell', 'su', '0', 'chmod', '0660', device_ota_file])
407  else:
408    # Update via sending the payload over the network with an "adb reverse"
409    # command.
410    payload_url = 'http://127.0.0.1:%d/payload' % DEVICE_PORT
411    if use_omaha and zipfile.is_zipfile(args.otafile):
412      ota = AndroidOTAPackage(args.otafile, args.secondary)
413      serving_range = (ota.offset, ota.size)
414    else:
415      serving_range = (0, os.stat(args.otafile).st_size)
416    server_thread = StartServer(args.otafile, serving_range)
417    cmds.append(
418        ['reverse', 'tcp:%d' % DEVICE_PORT, 'tcp:%d' % server_thread.port])
419    finalize_cmds.append(['reverse', '--remove', 'tcp:%d' % DEVICE_PORT])
420
421  if args.public_key:
422    payload_key_dir = os.path.dirname(PAYLOAD_KEY_PATH)
423    cmds.append(
424        ['shell', 'su', '0', 'mount', '-t', 'tmpfs', 'tmpfs', payload_key_dir])
425    # Allow adb push to payload_key_dir
426    cmds.append(['shell', 'su', '0', 'chcon', 'u:object_r:shell_data_file:s0',
427                 payload_key_dir])
428    cmds.append(['push', args.public_key, PAYLOAD_KEY_PATH])
429    # Allow update_engine to read it.
430    cmds.append(['shell', 'su', '0', 'chcon', '-R', 'u:object_r:system_file:s0',
431                 payload_key_dir])
432    finalize_cmds.append(['shell', 'su', '0', 'umount', payload_key_dir])
433
434  try:
435    # The main update command using the configured payload_url.
436    if use_omaha:
437      update_cmd = \
438          OmahaUpdateCommand('http://127.0.0.1:%d/update' % DEVICE_PORT)
439    else:
440      update_cmd = AndroidUpdateCommand(args.otafile, args.secondary,
441                                        payload_url, args.extra_headers)
442    cmds.append(['shell', 'su', '0'] + update_cmd)
443
444    for cmd in cmds:
445      dut.adb(cmd)
446  except KeyboardInterrupt:
447    dut.adb(cancel_cmd)
448  finally:
449    if server_thread:
450      server_thread.StopServer()
451    for cmd in finalize_cmds:
452      dut.adb(cmd)
453
454  return 0
455
456
457if __name__ == '__main__':
458  sys.exit(main())
459