1#!/usr/bin/env python3
2
3#
4# Copyright (C) 2018 The Android Open Source Project
5#
6# Licensed under the Apache License, Version 2.0 (the "License");
7# you may not use this file except in compliance with the License.
8# You may obtain a copy of the License at
9#
10#      http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing, software
13# distributed under the License is distributed on an "AS IS" BASIS,
14# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15# See the License for the specific language governing permissions and
16# limitations under the License.
17#
18
19"""A command line utility to pull multiple change lists from Gerrit."""
20
21from __future__ import print_function
22
23import argparse
24import collections
25import itertools
26import json
27import multiprocessing
28import os
29import os.path
30import re
31import sys
32import xml.dom.minidom
33
34from gerrit import create_url_opener_from_args, query_change_lists
35
36try:
37    # pylint: disable=redefined-builtin
38    from __builtin__ import raw_input as input  # PY2
39except ImportError:
40    pass
41
42try:
43    from shlex import quote as _sh_quote  # PY3.3
44except ImportError:
45    # Shell language simple string pattern.  If a string matches this pattern,
46    # it doesn't have to be quoted.
47    _SHELL_SIMPLE_PATTERN = re.compile('^[a-zA-Z90-9_./-]+$')
48
49    def _sh_quote(txt):
50        """Quote a string if it contains special characters."""
51        return txt if _SHELL_SIMPLE_PATTERN.match(txt) else json.dumps(txt)
52
53try:
54    from subprocess import PIPE, run  # PY3.5
55except ImportError:
56    from subprocess import CalledProcessError, PIPE, Popen
57
58    class CompletedProcess(object):
59        """Process execution result returned by subprocess.run()."""
60        # pylint: disable=too-few-public-methods
61
62        def __init__(self, args, returncode, stdout, stderr):
63            self.args = args
64            self.returncode = returncode
65            self.stdout = stdout
66            self.stderr = stderr
67
68    def run(*args, **kwargs):
69        """Run a command with subprocess.Popen() and redirect input/output."""
70
71        check = kwargs.pop('check', False)
72
73        try:
74            stdin = kwargs.pop('input')
75            assert 'stdin' not in kwargs
76            kwargs['stdin'] = PIPE
77        except KeyError:
78            stdin = None
79
80        proc = Popen(*args, **kwargs)
81        try:
82            stdout, stderr = proc.communicate(stdin)
83        except:
84            proc.kill()
85            proc.wait()
86            raise
87        returncode = proc.wait()
88
89        if check and returncode:
90            raise CalledProcessError(returncode, args, stdout)
91        return CompletedProcess(args, returncode, stdout, stderr)
92
93
94if bytes is str:
95    def write_bytes(data, file):  # PY2
96        """Write bytes to a file."""
97        # pylint: disable=redefined-builtin
98        file.write(data)
99else:
100    def write_bytes(data, file):  # PY3
101        """Write bytes to a file."""
102        # pylint: disable=redefined-builtin
103        file.buffer.write(data)
104
105
106def _confirm(question, default, file=sys.stderr):
107    """Prompt a yes/no question and convert the answer to a boolean value."""
108    # pylint: disable=redefined-builtin
109    answers = {'': default, 'y': True, 'yes': True, 'n': False, 'no': False}
110    suffix = '[Y/n] ' if default else ' [y/N] '
111    while True:
112        file.write(question + suffix)
113        file.flush()
114        ans = answers.get(input().lower())
115        if ans is not None:
116            return ans
117
118
119class ChangeList(object):
120    """A ChangeList to be checked out."""
121    # pylint: disable=too-few-public-methods,too-many-instance-attributes
122
123    def __init__(self, project, fetch, commit_sha1, commit, change_list):
124        """Initialize a ChangeList instance."""
125        # pylint: disable=too-many-arguments
126
127        self.project = project
128        self.number = change_list['_number']
129
130        self.fetch = fetch
131
132        fetch_git = None
133        for protocol in ('http', 'sso', 'rpc'):
134            fetch_git = fetch.get(protocol)
135            if fetch_git:
136                break
137
138        if not fetch_git:
139            raise ValueError(
140                'unknown fetch protocols: ' + str(list(fetch.keys())))
141
142        self.fetch_url = fetch_git['url']
143        self.fetch_ref = fetch_git['ref']
144
145        self.commit_sha1 = commit_sha1
146        self.commit = commit
147        self.parents = commit['parents']
148
149        self.change_list = change_list
150
151
152    def is_merge(self):
153        """Check whether this change list a merge commit."""
154        return len(self.parents) > 1
155
156
157def find_manifest_xml(dir_path):
158    """Find the path to manifest.xml for this Android source tree."""
159    dir_path_prev = None
160    while dir_path != dir_path_prev:
161        path = os.path.join(dir_path, '.repo', 'manifest.xml')
162        if os.path.exists(path):
163            return path
164        dir_path_prev = dir_path
165        dir_path = os.path.dirname(dir_path)
166    raise ValueError('.repo dir not found')
167
168
169def build_project_name_dir_dict(manifest_path):
170    """Build the mapping from Gerrit project name to source tree project
171    directory path."""
172    project_dirs = {}
173    parsed_xml = xml.dom.minidom.parse(manifest_path)
174
175    includes = parsed_xml.getElementsByTagName('include')
176    for include in includes:
177        include_path = include.getAttribute('name')
178        if not os.path.isabs(include_path):
179            manifest_dir = os.path.dirname(os.path.realpath(manifest_path))
180            include_path = os.path.join(manifest_dir, include_path)
181        project_dirs.update(build_project_name_dir_dict(include_path))
182
183    projects = parsed_xml.getElementsByTagName('project')
184    for project in projects:
185        name = project.getAttribute('name')
186        path = project.getAttribute('path')
187        if path:
188            project_dirs[name] = path
189        else:
190            project_dirs[name] = name
191
192    return project_dirs
193
194
195def group_and_sort_change_lists(change_lists):
196    """Build a dict that maps projects to a list of topologically sorted change
197    lists."""
198
199    # Build a dict that map projects to dicts that map commits to changes.
200    projects = collections.defaultdict(dict)
201    for change_list in change_lists:
202        commit_sha1 = None
203        for commit_sha1, value in change_list['revisions'].items():
204            fetch = value['fetch']
205            commit = value['commit']
206
207        if not commit_sha1:
208            raise ValueError('bad revision')
209
210        project = change_list['project']
211
212        project_changes = projects[project]
213        if commit_sha1 in project_changes:
214            raise KeyError('repeated commit sha1 "{}" in project "{}"'.format(
215                commit_sha1, project))
216
217        project_changes[commit_sha1] = ChangeList(
218            project, fetch, commit_sha1, commit, change_list)
219
220    # Sort all change lists in a project in post ordering.
221    def _sort_project_change_lists(changes):
222        visited_changes = set()
223        sorted_changes = []
224
225        def _post_order_traverse(change):
226            visited_changes.add(change)
227            for parent in change.parents:
228                parent_change = changes.get(parent['commit'])
229                if parent_change and parent_change not in visited_changes:
230                    _post_order_traverse(parent_change)
231            sorted_changes.append(change)
232
233        for change in sorted(changes.values(), key=lambda x: x.number):
234            if change not in visited_changes:
235                _post_order_traverse(change)
236
237        return sorted_changes
238
239    # Sort changes in each projects
240    sorted_changes = []
241    for project in sorted(projects.keys()):
242        sorted_changes.append(_sort_project_change_lists(projects[project]))
243
244    return sorted_changes
245
246
247def _main_json(args):
248    """Print the change lists in JSON format."""
249    change_lists = _get_change_lists_from_args(args)
250    json.dump(change_lists, sys.stdout, indent=4, separators=(', ', ': '))
251    print()  # Print the end-of-line
252
253
254# Git commands for merge commits
255_MERGE_COMMANDS = {
256    'merge': ['git', 'merge', '--no-edit'],
257    'merge-ff-only': ['git', 'merge', '--no-edit', '--ff-only'],
258    'merge-no-ff': ['git', 'merge', '--no-edit', '--no-ff'],
259    'reset': ['git', 'reset', '--hard'],
260    'checkout': ['git', 'checkout'],
261}
262
263
264# Git commands for non-merge commits
265_PICK_COMMANDS = {
266    'pick': ['git', 'cherry-pick', '--allow-empty'],
267    'merge': ['git', 'merge', '--no-edit'],
268    'merge-ff-only': ['git', 'merge', '--no-edit', '--ff-only'],
269    'merge-no-ff': ['git', 'merge', '--no-edit', '--no-ff'],
270    'reset': ['git', 'reset', '--hard'],
271    'checkout': ['git', 'checkout'],
272}
273
274
275def build_pull_commands(change, branch_name, merge_opt, pick_opt):
276    """Build command lines for each change.  The command lines will be passed
277    to subprocess.run()."""
278
279    cmds = []
280    if branch_name is not None:
281        cmds.append(['repo', 'start', branch_name])
282    cmds.append(['git', 'fetch', change.fetch_url, change.fetch_ref])
283    if change.is_merge():
284        cmds.append(_MERGE_COMMANDS[merge_opt] + ['FETCH_HEAD'])
285    else:
286        cmds.append(_PICK_COMMANDS[pick_opt] + ['FETCH_HEAD'])
287    return cmds
288
289
290def _sh_quote_command(cmd):
291    """Convert a command (an argument to subprocess.run()) to a shell command
292    string."""
293    return ' '.join(_sh_quote(x) for x in cmd)
294
295
296def _sh_quote_commands(cmds):
297    """Convert multiple commands (arguments to subprocess.run()) to shell
298    command strings."""
299    return ' && '.join(_sh_quote_command(cmd) for cmd in cmds)
300
301
302def _main_bash(args):
303    """Print the bash command to pull the change lists."""
304
305    branch_name = _get_local_branch_name_from_args(args)
306
307    manifest_path = _get_manifest_xml_from_args(args)
308    project_dirs = build_project_name_dir_dict(manifest_path)
309
310    change_lists = _get_change_lists_from_args(args)
311    change_list_groups = group_and_sort_change_lists(change_lists)
312
313    for changes in change_list_groups:
314        for change in changes:
315            project_dir = project_dirs.get(change.project, change.project)
316            cmds = []
317            cmds.append(['pushd', project_dir])
318            cmds.extend(build_pull_commands(
319                change, branch_name, args.merge, args.pick))
320            cmds.append(['popd'])
321            print(_sh_quote_commands(cmds))
322
323
324def _do_pull_change_lists_for_project(task):
325    """Pick a list of changes (usually under a project directory)."""
326    changes, task_opts = task
327
328    branch_name = task_opts['branch_name']
329    merge_opt = task_opts['merge_opt']
330    pick_opt = task_opts['pick_opt']
331    project_dirs = task_opts['project_dirs']
332
333    for i, change in enumerate(changes):
334        try:
335            cwd = project_dirs[change.project]
336        except KeyError:
337            err_msg = 'error: project "{}" cannot be found in manifest.xml\n'
338            err_msg = err_msg.format(change.project).encode('utf-8')
339            return (change, changes[i + 1:], [], err_msg)
340
341        print(change.commit_sha1[0:10], i + 1, cwd)
342        cmds = build_pull_commands(change, branch_name, merge_opt, pick_opt)
343        for cmd in cmds:
344            proc = run(cmd, cwd=cwd, stderr=PIPE)
345            if proc.returncode != 0:
346                return (change, changes[i + 1:], cmd, proc.stderr)
347    return None
348
349
350def _print_pull_failures(failures, file=sys.stderr):
351    """Print pull failures and tracebacks."""
352    # pylint: disable=redefined-builtin
353
354    separator = '=' * 78
355    separator_sub = '-' * 78
356
357    print(separator, file=file)
358    for failed_change, skipped_changes, cmd, errors in failures:
359        print('PROJECT:', failed_change.project, file=file)
360        print('FAILED COMMIT:', failed_change.commit_sha1, file=file)
361        for change in skipped_changes:
362            print('PENDING COMMIT:', change.commit_sha1, file=file)
363        print(separator_sub, file=sys.stderr)
364        print('FAILED COMMAND:', _sh_quote_command(cmd), file=file)
365        write_bytes(errors, file=sys.stderr)
366        print(separator, file=sys.stderr)
367
368
369def _main_pull(args):
370    """Pull the change lists."""
371
372    branch_name = _get_local_branch_name_from_args(args)
373
374    manifest_path = _get_manifest_xml_from_args(args)
375    project_dirs = build_project_name_dir_dict(manifest_path)
376
377    # Collect change lists
378    change_lists = _get_change_lists_from_args(args)
379    change_list_groups = group_and_sort_change_lists(change_lists)
380
381    # Build the options list for tasks
382    task_opts = {
383        'branch_name': branch_name,
384        'merge_opt': args.merge,
385        'pick_opt': args.pick,
386        'project_dirs': project_dirs,
387    }
388
389    # Run the commands to pull the change lists
390    if args.parallel <= 1:
391        results = [_do_pull_change_lists_for_project((changes, task_opts))
392                   for changes in change_list_groups]
393    else:
394        pool = multiprocessing.Pool(processes=args.parallel)
395        results = pool.map(_do_pull_change_lists_for_project,
396                           zip(change_list_groups, itertools.repeat(task_opts)))
397
398    # Print failures and tracebacks
399    failures = [result for result in results if result]
400    if failures:
401        _print_pull_failures(failures)
402        sys.exit(1)
403
404
405def _parse_args():
406    """Parse command line options."""
407    parser = argparse.ArgumentParser()
408
409    parser.add_argument('command', choices=['pull', 'bash', 'json'],
410                        help='Commands')
411
412    parser.add_argument('query', help='Change list query string')
413    parser.add_argument('-g', '--gerrit', required=True,
414                        help='Gerrit review URL')
415
416    parser.add_argument('--gitcookies',
417                        default=os.path.expanduser('~/.gitcookies'),
418                        help='Gerrit cookie file')
419    parser.add_argument('--manifest', help='Manifest')
420    parser.add_argument('--limits', default=1000,
421                        help='Max number of change lists')
422
423    parser.add_argument('-m', '--merge',
424                        choices=sorted(_MERGE_COMMANDS.keys()),
425                        default='merge-ff-only',
426                        help='Method to pull merge commits')
427
428    parser.add_argument('-p', '--pick',
429                        choices=sorted(_PICK_COMMANDS.keys()),
430                        default='pick',
431                        help='Method to pull merge commits')
432
433    parser.add_argument('-b', '--branch',
434                        help='Local branch name for `repo start`')
435
436    parser.add_argument('-j', '--parallel', default=1, type=int,
437                        help='Number of parallel running commands')
438
439    return parser.parse_args()
440
441
442def _get_manifest_xml_from_args(args):
443    """Get the path to manifest.xml from args."""
444    manifest_path = args.manifest
445    if not args.manifest:
446        manifest_path = find_manifest_xml(os.getcwd())
447    return manifest_path
448
449
450def _get_change_lists_from_args(args):
451    """Query the change lists by args."""
452    url_opener = create_url_opener_from_args(args)
453    return query_change_lists(url_opener, args.gerrit, args.query, args.limits)
454
455
456def _get_local_branch_name_from_args(args):
457    """Get the local branch name from args."""
458    if not args.branch and not _confirm(
459            'Do you want to continue without local branch name?', False):
460        print('error: `-b` or `--branch` must be specified', file=sys.stderr)
461        sys.exit(1)
462    return args.branch
463
464
465def main():
466    """Main function"""
467    args = _parse_args()
468    if args.command == 'json':
469        _main_json(args)
470    elif args.command == 'bash':
471        _main_bash(args)
472    elif args.command == 'pull':
473        _main_pull(args)
474    else:
475        raise KeyError('unknown command')
476
477if __name__ == '__main__':
478    main()
479