1# -*- coding:utf-8 -*-
2# Copyright 2016 The Android Open Source Project
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#      http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Git helper functions."""
17
18from __future__ import print_function
19
20import os
21import re
22import sys
23
24_path = os.path.realpath(__file__ + '/../..')
25if sys.path[0] != _path:
26    sys.path.insert(0, _path)
27del _path
28
29# pylint: disable=wrong-import-position
30import rh.utils
31
32
33def get_upstream_remote():
34    """Returns the current upstream remote name."""
35    # First get the current branch name.
36    cmd = ['git', 'rev-parse', '--abbrev-ref', 'HEAD']
37    result = rh.utils.run(cmd, capture_output=True)
38    branch = result.stdout.strip()
39
40    # Then get the remote associated with this branch.
41    cmd = ['git', 'config', 'branch.%s.remote' % branch]
42    result = rh.utils.run(cmd, capture_output=True)
43    return result.stdout.strip()
44
45
46def get_upstream_branch():
47    """Returns the upstream tracking branch of the current branch.
48
49    Raises:
50      Error if there is no tracking branch
51    """
52    cmd = ['git', 'symbolic-ref', 'HEAD']
53    result = rh.utils.run(cmd, capture_output=True)
54    current_branch = result.stdout.strip().replace('refs/heads/', '')
55    if not current_branch:
56        raise ValueError('Need to be on a tracking branch')
57
58    cfg_option = 'branch.' + current_branch + '.%s'
59    cmd = ['git', 'config', cfg_option % 'merge']
60    result = rh.utils.run(cmd, capture_output=True)
61    full_upstream = result.stdout.strip()
62    # If remote is not fully qualified, add an implicit namespace.
63    if '/' not in full_upstream:
64        full_upstream = 'refs/heads/%s' % full_upstream
65    cmd = ['git', 'config', cfg_option % 'remote']
66    result = rh.utils.run(cmd, capture_output=True)
67    remote = result.stdout.strip()
68    if not remote or not full_upstream:
69        raise ValueError('Need to be on a tracking branch')
70
71    return full_upstream.replace('heads', 'remotes/' + remote)
72
73
74def get_commit_for_ref(ref):
75    """Returns the latest commit for this ref."""
76    cmd = ['git', 'rev-parse', ref]
77    result = rh.utils.run(cmd, capture_output=True)
78    return result.stdout.strip()
79
80
81def get_remote_revision(ref, remote):
82    """Returns the remote revision for this ref."""
83    prefix = 'refs/remotes/%s/' % remote
84    if ref.startswith(prefix):
85        return ref[len(prefix):]
86    return ref
87
88
89def get_patch(commit):
90    """Returns the patch for this commit."""
91    cmd = ['git', 'format-patch', '--stdout', '-1', commit]
92    return rh.utils.run(cmd, capture_output=True).stdout
93
94
95def get_file_content(commit, path):
96    """Returns the content of a file at a specific commit.
97
98    We can't rely on the file as it exists in the filesystem as people might be
99    uploading a series of changes which modifies the file multiple times.
100
101    Note: The "content" of a symlink is just the target.  So if you're expecting
102    a full file, you should check that first.  One way to detect is that the
103    content will not have any newlines.
104    """
105    cmd = ['git', 'show', '%s:%s' % (commit, path)]
106    return rh.utils.run(cmd, capture_output=True).stdout
107
108
109class RawDiffEntry(object):
110    """Representation of a line from raw formatted git diff output."""
111
112    # pylint: disable=redefined-builtin
113    def __init__(self, src_mode=0, dst_mode=0, src_sha=None, dst_sha=None,
114                 status=None, score=None, src_file=None, dst_file=None,
115                 file=None):
116        self.src_mode = src_mode
117        self.dst_mode = dst_mode
118        self.src_sha = src_sha
119        self.dst_sha = dst_sha
120        self.status = status
121        self.score = score
122        self.src_file = src_file
123        self.dst_file = dst_file
124        self.file = file
125
126
127# This regular expression pulls apart a line of raw formatted git diff output.
128DIFF_RE = re.compile(
129    r':(?P<src_mode>[0-7]*) (?P<dst_mode>[0-7]*) '
130    r'(?P<src_sha>[0-9a-f]*)(\.)* (?P<dst_sha>[0-9a-f]*)(\.)* '
131    r'(?P<status>[ACDMRTUX])(?P<score>[0-9]+)?\t'
132    r'(?P<src_file>[^\t]+)\t?(?P<dst_file>[^\t]+)?')
133
134
135def raw_diff(path, target):
136    """Return the parsed raw format diff of target
137
138    Args:
139      path: Path to the git repository to diff in.
140      target: The target to diff.
141
142    Returns:
143      A list of RawDiffEntry's.
144    """
145    entries = []
146
147    cmd = ['git', 'diff', '--no-ext-diff', '-M', '--raw', target]
148    diff = rh.utils.run(cmd, cwd=path, capture_output=True).stdout
149    diff_lines = diff.strip().splitlines()
150    for line in diff_lines:
151        match = DIFF_RE.match(line)
152        if not match:
153            raise ValueError('Failed to parse diff output: %s' % line)
154        rawdiff = RawDiffEntry(**match.groupdict())
155        rawdiff.src_mode = int(rawdiff.src_mode)
156        rawdiff.dst_mode = int(rawdiff.dst_mode)
157        rawdiff.file = (rawdiff.dst_file
158                        if rawdiff.dst_file else rawdiff.src_file)
159        entries.append(rawdiff)
160
161    return entries
162
163
164def get_affected_files(commit):
165    """Returns list of file paths that were modified/added.
166
167    Returns:
168      A list of modified/added (and perhaps deleted) files
169    """
170    return raw_diff(os.getcwd(), '%s^!' % commit)
171
172
173def get_commits(ignore_merged_commits=False):
174    """Returns a list of commits for this review."""
175    cmd = ['git', 'rev-list', '%s..' % get_upstream_branch()]
176    if ignore_merged_commits:
177        cmd.append('--first-parent')
178    return rh.utils.run(cmd, capture_output=True).stdout.split()
179
180
181def get_commit_desc(commit):
182    """Returns the full commit message of a commit."""
183    cmd = ['git', 'log', '--format=%B', commit + '^!']
184    return rh.utils.run(cmd, capture_output=True).stdout
185
186
187def find_repo_root(path=None):
188    """Locate the top level of this repo checkout starting at |path|."""
189    if path is None:
190        path = os.getcwd()
191    orig_path = path
192
193    path = os.path.abspath(path)
194    while not os.path.exists(os.path.join(path, '.repo')):
195        path = os.path.dirname(path)
196        if path == '/':
197            raise ValueError('Could not locate .repo in %s' % orig_path)
198
199    return path
200
201
202def is_git_repository(path):
203    """Returns True if the path is a valid git repository."""
204    cmd = ['git', 'rev-parse', '--resolve-git-dir', os.path.join(path, '.git')]
205    result = rh.utils.run(cmd, capture_output=True, check=False)
206    return result.returncode == 0
207