1"""Experimentally determines a camera's rolling shutter skew.
2
3See the accompanying PDF for instructions on how to use this test.
4"""
5from __future__ import division
6from __future__ import print_function
7
8import argparse
9import glob
10import math
11import os
12import sys
13import tempfile
14
15import cv2
16import its.caps
17import its.device
18import its.image
19import its.objects
20import numpy as np
21
22DEBUG = False
23
24# Constants for which direction the camera is facing.
25FACING_FRONT = 0
26FACING_BACK = 1
27FACING_EXTERNAL = 2
28
29# Camera capture defaults.
30FPS = 30
31WIDTH = 640
32HEIGHT = 480
33TEST_LENGTH = 1
34
35# Each circle in a cluster must be within this many pixels of some other circle
36# in the cluster.
37CLUSTER_DISTANCE = 50.0 / HEIGHT
38# A cluster must consist of at least this percentage of the total contours for
39# it to be allowed into the computation.
40MAJORITY_THRESHOLD = 0.7
41
42# Constants to make sure the slope of the fitted line is reasonable.
43SLOPE_MIN_THRESHOLD = 0.5
44SLOPE_MAX_THRESHOLD = 1.5
45
46# To improve readability of unit conversions.
47SEC_TO_NSEC = float(10**9)
48MSEC_TO_NSEC = float(10**6)
49NSEC_TO_MSEC = 1.0 / float(10**6)
50
51
52class RollingShutterArgumentParser(object):
53    """Parses command line arguments for the rolling shutter test."""
54
55    def __init__(self):
56        self.__parser = argparse.ArgumentParser(
57                description='Run rolling shutter test')
58        self.__parser.add_argument(
59                '-d', '--debug',
60                action='store_true',
61                help='print and write data useful for debugging')
62        self.__parser.add_argument(
63                '-f', '--fps',
64                type=int,
65                help='FPS to capture with during the test (defaults to 30)')
66        self.__parser.add_argument(
67                '-i', '--img_size',
68                help=('comma-separated dimensions of captured images (defaults '
69                      'to 640x480). Example: --img_size=<width>,<height>'))
70        self.__parser.add_argument(
71                '-l', '--led_time',
72                type=float,
73                required=True,
74                help=('how many milliseconds each column of the LED array is '
75                      'lit for'))
76        self.__parser.add_argument(
77                '-p', '--panel_distance',
78                type=float,
79                help='how far the LED panel is from the camera (in meters)')
80        self.__parser.add_argument(
81                '-r', '--read_dir',
82                help=('read existing test data from specified directory.  If '
83                      'not specified, new test data is collected from the '
84                      'device\'s camera)'))
85        self.__parser.add_argument(
86                '--device_id',
87                help=('device ID for device being tested (can also use '
88                      '\'device=<DEVICE ID>\')'))
89        self.__parser.add_argument(
90                '-t', '--test_length',
91                type=int,
92                help=('how many seconds the test should run for (defaults to 1 '
93                      'second)'))
94        self.__parser.add_argument(
95                '-o', '--debug_dir',
96                help=('write debugging information in a folder in the '
97                      'specified directory.  Otherwise, the system\'s default '
98                      'location for temporary folders is used.  --debug must '
99                      'be specified along with this argument.'))
100
101    def parse_args(self):
102        """Returns object containing parsed values from the command line."""
103        # Don't show argparse the 'device' flag, since it's in a different
104        # format than the others (to maintain CameraITS conventions) and it will
105        # complain.
106        filtered_args = [arg for arg in sys.argv[1:] if 'device=' not in arg]
107        args = self.__parser.parse_args(filtered_args)
108        if args.device_id:
109            # If argparse format is used, convert it to a format its.device can
110            # use later on.
111            sys.argv.append('device=%s' % args.device_id)
112        return args
113
114
115def main():
116    global DEBUG
117    global CLUSTER_DISTANCE
118
119    parser = RollingShutterArgumentParser()
120    args = parser.parse_args()
121
122    DEBUG = args.debug
123    if not DEBUG and args.debug_dir:
124        print('argument --debug_dir requires --debug', file=sys.stderr)
125        sys.exit()
126
127    if args.read_dir is None:
128        # Collect new data.
129        raw_caps, reported_skew = collect_data(args)
130        frames = [its.image.convert_capture_to_rgb_image(c) for c in raw_caps]
131    else:
132        # Load existing data.
133        frames, reported_skew = load_data(args.read_dir)
134
135    # Make the cluster distance relative to the height of the image.
136    (frame_h, _, _) = frames[0].shape
137    CLUSTER_DISTANCE = frame_h * CLUSTER_DISTANCE
138    debug_print('Setting cluster distance to %spx.' % CLUSTER_DISTANCE)
139
140    if DEBUG:
141        debug_dir = setup_debug_dir(args.debug_dir)
142        # Write raw frames.
143        for i, img in enumerate(frames):
144            its.image.write_image(img, '%s/raw/%03d.png' % (debug_dir, i))
145    else:
146        debug_dir = None
147
148    avg_shutter_skew, num_frames_used = find_average_shutter_skew(
149            frames, args.led_time, debug_dir)
150    if debug_dir:
151        # Write the reported skew with the raw images, so the directory can also
152        # be used to read from.
153        with open(debug_dir + '/raw/reported_skew.txt', 'w') as f:
154            f.write('%sms\n' % reported_skew)
155
156    if avg_shutter_skew is None:
157        print('Could not find usable frames.')
158    else:
159        print('Device reported shutter skew of %sms.' % reported_skew)
160        print('Measured shutter skew is %sms (averaged over %s frames).' %
161              (avg_shutter_skew, num_frames_used))
162
163
164def collect_data(args):
165    """Capture a new set of frames from the device's camera.
166
167    Args:
168        args: Parsed command line arguments.
169
170    Returns:
171        A list of RGB images as numpy arrays.
172    """
173    fps = args.fps if args.fps else FPS
174    if args.img_size:
175        w, h = map(int, args.img_size.split(','))
176    else:
177        w, h = WIDTH, HEIGHT
178    test_length = args.test_length if args.test_length else TEST_LENGTH
179
180    with its.device.ItsSession() as cam:
181        props = cam.get_camera_properties()
182        its.caps.skip_unless(its.caps.manual_sensor(props))
183        facing = props['android.lens.facing']
184        if facing != FACING_FRONT and facing != FACING_BACK:
185            print('Unknown lens facing %s' % facing)
186            assert 0
187
188        fmt = {'format': 'yuv', 'width': w, 'height': h}
189        s, e, _, _, _ = cam.do_3a(get_results=True, do_af=False)
190        req = its.objects.manual_capture_request(s, e)
191        req['android.control.aeTargetFpsRange'] = [fps, fps]
192
193        # Convert from milliseconds to nanoseconds.  We only want enough
194        # exposure time to saturate approximately one column.
195        exposure_time = (args.led_time / 2.0) * MSEC_TO_NSEC
196        print('Using exposure time of %sns.' % exposure_time)
197        req['android.sensor.exposureTime'] = exposure_time
198        req["android.sensor.frameDuration"] = int(SEC_TO_NSEC / fps);
199
200        if args.panel_distance is not None:
201            # Convert meters to diopters and use that for the focus distance.
202            req['android.lens.focusDistance'] = 1 / args.panel_distance
203        print('Starting capture')
204        raw_caps = cam.do_capture([req]*fps*test_length, fmt)
205        print('Finished capture')
206
207        # Convert from nanoseconds to milliseconds.
208        shutter_skews = {c['metadata']['android.sensor.rollingShutterSkew'] *
209                          NSEC_TO_MSEC for c in raw_caps}
210        # All frames should have same rolling shutter skew.
211        assert len(shutter_skews) == 1
212        shutter_skew = list(shutter_skews)[0]
213
214        return raw_caps, shutter_skew
215
216
217def load_data(dir_name):
218    """Reads camera frame data from an existing directory.
219
220    Args:
221        dir_name: Name of the directory to read data from.
222
223    Returns:
224        A list of RGB images as numpy arrays.
225    """
226    frame_files = glob.glob('%s/*.png' % dir_name)
227    frames = []
228    for frame_file in sorted(frame_files):
229        frames.append(its.image.load_rgb_image(frame_file))
230    with open('%s/reported_skew.txt' % dir_name, 'r') as f:
231        reported_skew = f.readline()[:-2]  # Strip off 'ms' suffix
232    return frames, reported_skew
233
234
235def find_average_shutter_skew(frames, led_time, debug_dir=None):
236    """Finds the average shutter skew using the given frames.
237
238    Frames without enough information will be discarded from the average to
239    improve overall accuracy.
240
241    Args:
242        frames:    List of RGB images from the camera being tested.
243        led_time:  How long a single LED column is lit for (in milliseconds).
244        debug_dir: (optional) Directory to write debugging information to.
245
246    Returns:
247        The average calculated shutter skew and the number of frames used to
248        calculate the average.
249    """
250    avg_shutter_skew = 0.0
251    avg_slope = 0.0
252    weight = 0.0
253    num_frames_used = 0
254
255    for i, frame in enumerate(frames):
256        debug_print('------------------------')
257        debug_print('| PROCESSING FRAME %03d |' % i)
258        debug_print('------------------------')
259        shutter_skew, confidence, slope = calculate_shutter_skew(
260                frame, led_time, i, debug_dir=debug_dir)
261        if shutter_skew is None:
262            debug_print('Skipped frame.')
263        else:
264            debug_print('Shutter skew is %sms (confidence: %s).' %
265                        (shutter_skew, confidence))
266            # Use the confidence to weight the average.
267            avg_shutter_skew += shutter_skew * confidence
268            avg_slope += slope * confidence
269            weight += confidence
270            num_frames_used += 1
271
272    debug_print('\n')
273    if num_frames_used == 0:
274        return None, None
275    else:
276        avg_shutter_skew /= weight
277        avg_slope /= weight
278        slope_err_str = ('The average slope of the fitted line was too %s '
279                         'to get an accurate measurement (slope was %s).  '
280                         'Try making the LED panel %s.')
281        if avg_slope < SLOPE_MIN_THRESHOLD:
282            print(slope_err_str % ('flat', avg_slope, 'slower'),
283                  file=sys.stderr)
284        elif avg_slope > SLOPE_MAX_THRESHOLD:
285            print(slope_err_str % ('steep', avg_slope, 'faster'),
286                  file=sys.stderr)
287        return avg_shutter_skew, num_frames_used
288
289
290def calculate_shutter_skew(frame, led_time, frame_num=None, debug_dir=None):
291    """Calculates the shutter skew of the camera being used for this test.
292
293    Args:
294        frame:     A single RGB image captured by the camera being tested.
295        led_time:  How long a single LED column is lit for (in milliseconds).
296        frame_num: (optional) Number of the given frame.
297        debug_dir: (optional) Directory to write debugging information to.
298
299    Returns:
300        The shutter skew (in milliseconds), the confidence in the accuracy of
301        the measurement (useful for weighting averages), and the slope of the
302        fitted line.
303    """
304    contours, scratch_img, contour_img, mono_img = find_contours(frame.copy())
305    if debug_dir is not None:
306        cv2.imwrite('%s/contour/%03d.png' % (debug_dir, frame_num), contour_img)
307        cv2.imwrite('%s/mono/%03d.png' % (debug_dir, frame_num), mono_img)
308
309    largest_cluster, cluster_percentage = find_largest_cluster(contours,
310                                                               scratch_img)
311    if largest_cluster is None:
312        debug_print('No majority cluster found.')
313        return None, None, None
314    elif len(largest_cluster) <= 1:
315        debug_print('Majority cluster was too small.')
316        return None, None, None
317    debug_print('%s points in the largest cluster.' % len(largest_cluster))
318
319    np_cluster = np.array([[c.x, c.y] for c in largest_cluster])
320    [vx], [vy], [x0], [y0] = cv2.fitLine(np_cluster, cv2.cv.CV_DIST_L2,
321                                         0, 0.01, 0.01)
322    slope = vy / vx
323    debug_print('Slope is %s.' % slope)
324    (frame_h, frame_w, _) = frame.shape
325    # Draw line onto scratch frame.
326    pt1 = tuple(map(int, (x0 - vx * 1000, y0 - vy * 1000)))
327    pt2 = tuple(map(int, (x0 + vx * 1000, y0 + vy * 1000)))
328    cv2.line(scratch_img, pt1, pt2, (0, 255, 255), thickness=3)
329
330    # We only need the width of the cluster.
331    _, _, cluster_w, _ = find_cluster_bounding_rect(largest_cluster,
332                                                    scratch_img)
333
334    num_columns = find_num_columns_spanned(largest_cluster)
335    debug_print('%s columns spanned by cluster.' % num_columns)
336    # How long it takes for a column to move from the left of the bounding
337    # rectangle to the right.
338    left_to_right_time = led_time * num_columns
339    milliseconds_per_x_pixel = left_to_right_time / cluster_w
340    # The distance between the line's intersection at the top of the frame and
341    # the intersection at the bottom.
342    x_range = frame_h / slope
343    shutter_skew = milliseconds_per_x_pixel * x_range
344    # If the aspect ratio is different from 4:3 (the aspect ratio of the actual
345    # sensor), we need to correct, because it will be cropped.
346    shutter_skew *= (float(frame_w) / float(frame_h)) / (4.0 / 3.0)
347
348    if debug_dir is not None:
349        cv2.imwrite('%s/scratch/%03d.png' % (debug_dir, frame_num),
350                    scratch_img)
351
352    return shutter_skew, cluster_percentage, slope
353
354
355def find_contours(img):
356    """Finds contours in the given image.
357
358    Args:
359        img: Image in Android camera RGB format.
360
361    Returns:
362        OpenCV-formatted contours, the original image in OpenCV format, a
363        thresholded image with the contours drawn on, and a grayscale version of
364        the image.
365    """
366    # Convert to format OpenCV can work with (BGR ordering with byte-ranged
367    # values).
368    img *= 255
369    img = img.astype(np.uint8)
370    img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
371
372    # Since the LED colors for the panel we're using are red, we can get better
373    # contours for the LEDs if we ignore the green and blue channels.  This also
374    # makes it so we don't pick up the blue control screen of the LED panel.
375    red_img = img[:, :, 2]
376    _, thresh = cv2.threshold(red_img, 0, 255, cv2.THRESH_BINARY +
377                              cv2.THRESH_OTSU)
378
379    # Remove noise before finding contours by eroding the thresholded image and
380    # then re-dilating it.  The size of the kernel represents how many
381    # neighboring pixels to consider for the result of a single pixel.
382    kernel = np.ones((3, 3), np.uint8)
383    opening = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, kernel, iterations=2)
384
385    if DEBUG:
386        # Need to convert it back to BGR if we want to draw colored contours.
387        contour_img = cv2.cvtColor(opening, cv2.COLOR_GRAY2BGR)
388    else:
389        contour_img = None
390    contours, _ = cv2.findContours(opening,
391                                   cv2.cv.CV_RETR_EXTERNAL,
392                                   cv2.cv.CV_CHAIN_APPROX_NONE)
393    if DEBUG:
394        cv2.drawContours(contour_img, contours, -1, (0, 0, 255), thickness=2)
395    return contours, img, contour_img, red_img
396
397
398def convert_to_circles(contours):
399    """Converts given contours into circle objects.
400
401    Args:
402        contours: Contours generated by OpenCV.
403
404    Returns:
405        A list of circles.
406    """
407
408    class Circle(object):
409        """Holds data to uniquely define a circle."""
410
411        def __init__(self, contour):
412            self.x = int(np.mean(contour[:, 0, 0]))
413            self.y = int(np.mean(contour[:, 0, 1]))
414            # Get diameters of each axis then half it.
415            x_r = (np.max(contour[:, 0, 0]) - np.min(contour[:, 0, 0])) / 2.0
416            y_r = (np.max(contour[:, 0, 1]) - np.min(contour[:, 0, 1])) / 2.0
417            # Average x radius and y radius to get the approximate radius for
418            # the given contour.
419            self.r = (x_r + y_r) / 2.0
420            assert self.r > 0.0
421
422        def distance_to(self, other):
423            return (math.sqrt((other.x - self.x)**2 + (other.y - self.y)**2) -
424                    self.r - other.r)
425
426        def intersects(self, other):
427            return self.distance_to(other) <= 0.0
428
429    return list(map(Circle, contours))
430
431
432def find_largest_cluster(contours, frame):
433    """Finds the largest cluster in the given contours.
434
435    Args:
436        contours: Contours generated by OpenCV.
437        frame:    For drawing debugging information onto.
438
439    Returns:
440        The cluster with the most contours in it and the percentage of all
441        contours that the cluster contains.
442    """
443    clusters = proximity_clusters(contours)
444
445    if not clusters:
446        return None, None  # No clusters found.
447
448    largest_cluster = max(clusters, key=len)
449    cluster_percentage = len(largest_cluster) / len(contours)
450
451    if cluster_percentage < MAJORITY_THRESHOLD:
452        return None, None
453
454    if DEBUG:
455        # Draw largest cluster on scratch frame.
456        for circle in largest_cluster:
457            cv2.circle(frame, (int(circle.x), int(circle.y)), int(circle.r),
458                       (0, 255, 0), thickness=2)
459
460    return largest_cluster, cluster_percentage
461
462
463def proximity_clusters(contours):
464    """Sorts the given contours into groups by distance.
465
466    Converts every given contour to a circle and clusters by adding a circle to
467    a cluster only if it is close to at least one other circle in the cluster.
468
469    TODO: Make algorithm faster (currently O(n**2)).
470
471    Args:
472        contours: Contours generated by OpenCV.
473
474    Returns:
475        A list of clusters, where each cluster is a list of the circles
476        contained in the cluster.
477    """
478    circles = convert_to_circles(contours)
479
480    # Use disjoint-set data structure to store assignments.  Start every point
481    # in their own cluster.
482    cluster_assignments = [-1 for i in range(len(circles))]
483
484    def get_canonical_index(i):
485        if cluster_assignments[i] >= 0:
486            index = get_canonical_index(cluster_assignments[i])
487            # Collapse tree for better runtime.
488            cluster_assignments[i] = index
489            return index
490        else:
491            return i
492
493    def get_cluster_size(i):
494        return -cluster_assignments[get_canonical_index(i)]
495
496    for i, curr in enumerate(circles):
497        close_circles = [j for j, p in enumerate(circles) if i != j and
498                         curr.distance_to(p) < CLUSTER_DISTANCE]
499        if close_circles:
500            # Note: largest_cluster is an index into cluster_assignments.
501            largest_cluster = min(close_circles, key=get_cluster_size)
502            largest_size = get_cluster_size(largest_cluster)
503            curr_index = get_canonical_index(i)
504            curr_size = get_cluster_size(i)
505            if largest_size > curr_size:
506                # largest_cluster is larger than us.
507                target_index = get_canonical_index(largest_cluster)
508                # Add our cluster size to the bigger one.
509                cluster_assignments[target_index] -= curr_size
510                # Reroute our group to the bigger one.
511                cluster_assignments[curr_index] = target_index
512            else:
513                # We're the largest (or equal to the largest) cluster.  Reroute
514                # all groups to us.
515                for j in close_circles:
516                    smaller_size = get_cluster_size(j)
517                    smaller_index = get_canonical_index(j)
518                    if smaller_index != curr_index:
519                        # We only want to modify clusters that aren't already in
520                        # the current one.
521
522                        # Add the smaller cluster's size to ours.
523                        cluster_assignments[curr_index] -= smaller_size
524                        # Reroute their group to us.
525                        cluster_assignments[smaller_index] = curr_index
526
527    # Convert assignments list into list of clusters.
528    clusters_dict = {}
529    for i in range(len(cluster_assignments)):
530        canonical_index = get_canonical_index(i)
531        if canonical_index not in clusters_dict:
532            clusters_dict[canonical_index] = []
533        clusters_dict[canonical_index].append(circles[i])
534    return clusters_dict.values()
535
536
537def find_cluster_bounding_rect(cluster, scratch_frame):
538    """Finds the minimum rectangle that bounds the given cluster.
539
540    The bounding rectangle will always be axis-aligned.
541
542    Args:
543        cluster:       Cluster being used to find the bounding rectangle.
544        scratch_frame: Image that rectangle is drawn onto for debugging
545                       purposes.
546
547    Returns:
548        The leftmost and topmost x and y coordinates, respectively, along with
549        the width and height of the rectangle.
550    """
551    avg_distance = find_average_neighbor_distance(cluster)
552    debug_print('Average distance between points in largest cluster is %s '
553                'pixels.' % avg_distance)
554
555    c_x = min(cluster, key=lambda c: c.x - c.r)
556    c_y = min(cluster, key=lambda c: c.y - c.r)
557    c_w = max(cluster, key=lambda c: c.x + c.r)
558    c_h = max(cluster, key=lambda c: c.y + c.r)
559
560    x = c_x.x - c_x.r - avg_distance
561    y = c_y.y - c_y.r - avg_distance
562    w = (c_w.x + c_w.r + avg_distance) - x
563    h = (c_h.y + c_h.r + avg_distance) - y
564
565    if DEBUG:
566        points = np.array([[x, y], [x + w, y], [x + w, y + h], [x, y + h]],
567                          np.int32)
568        cv2.polylines(scratch_frame, [points], True, (255, 0, 0), thickness=2)
569
570    return x, y, w, h
571
572
573def find_average_neighbor_distance(cluster):
574    """Finds the average distance between every circle and its closest neighbor.
575
576    Args:
577        cluster: List of circles
578
579    Returns:
580        The average distance.
581    """
582    avg_distance = 0.0
583    for a in cluster:
584        closest_point = None
585        closest_dist = None
586        for b in cluster:
587            if a is b:
588                continue
589            curr_dist = a.distance_to(b)
590            if closest_point is None or curr_dist < closest_dist:
591                closest_point = b
592                closest_dist = curr_dist
593        avg_distance += closest_dist
594    avg_distance /= len(cluster)
595    return avg_distance
596
597
598def find_num_columns_spanned(circles):
599    """Finds how many columns of the LED panel are spanned by the given circles.
600
601    Args:
602        circles: List of circles (assumed to be from the LED panel).
603
604    Returns:
605        The number of columns spanned.
606    """
607    if not circles:
608        return 0
609
610    def x_intersects(c_a, c_b):
611        return abs(c_a.x - c_b.x) < (c_a.r + c_b.r)
612
613    circles = sorted(circles, key=lambda c: c.x)
614    last_circle = circles[0]
615    num_columns = 1
616    for circle in circles[1:]:
617        if not x_intersects(circle, last_circle):
618            last_circle = circle
619            num_columns += 1
620
621    return num_columns
622
623
624def setup_debug_dir(dir_name=None):
625    """Creates a debug directory and required subdirectories.
626
627    Each subdirectory contains images from a different step in the process.
628
629    Args:
630        dir_name: The directory to create.  If none is specified, a temp
631        directory is created.
632
633    Returns:
634        The name of the directory that is used.
635    """
636    if dir_name is None:
637        dir_name = tempfile.mkdtemp()
638    else:
639        force_mkdir(dir_name)
640    print('Saving debugging files to "%s"' % dir_name)
641    # For original captured images.
642    force_mkdir(dir_name + '/raw', clean=True)
643    # For monochrome images.
644    force_mkdir(dir_name + '/mono', clean=True)
645    # For contours generated from monochrome images.
646    force_mkdir(dir_name + '/contour', clean=True)
647    # For post-contour debugging information.
648    force_mkdir(dir_name + '/scratch', clean=True)
649    return dir_name
650
651
652def force_mkdir(dir_name, clean=False):
653    """Creates a directory if it doesn't already exist.
654
655    Args:
656        dir_name: Name of the directory to create.
657        clean:    (optional) If set to true, cleans image files from the
658                  directory (if it already exists).
659    """
660    if os.path.exists(dir_name):
661        if clean:
662            for image in glob.glob('%s/*.png' % dir_name):
663                os.remove(image)
664    else:
665        os.makedirs(dir_name)
666
667
668def debug_print(s, *args, **kwargs):
669    """Only prints if the test is running in debug mode."""
670    if DEBUG:
671        print(s, *args, **kwargs)
672
673
674if __name__ == '__main__':
675    main()
676