1# Copyright 2016 - The Android Open Source Project
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Common Utilities."""
15# pylint: disable=too-many-lines
16from __future__ import print_function
17
18from distutils.spawn import find_executable
19import base64
20import binascii
21import collections
22import errno
23import getpass
24import grp
25import logging
26import os
27import platform
28import shlex
29import shutil
30import signal
31import struct
32import socket
33import subprocess
34import sys
35import tarfile
36import tempfile
37import time
38import uuid
39import webbrowser
40import zipfile
41
42import six
43
44from acloud import errors
45from acloud.internal import constants
46
47
48logger = logging.getLogger(__name__)
49
50SSH_KEYGEN_CMD = ["ssh-keygen", "-t", "rsa", "-b", "4096"]
51SSH_KEYGEN_PUB_CMD = ["ssh-keygen", "-y"]
52SSH_ARGS = ["-o", "UserKnownHostsFile=/dev/null",
53            "-o", "StrictHostKeyChecking=no"]
54SSH_CMD = ["ssh"] + SSH_ARGS
55SCP_CMD = ["scp"] + SSH_ARGS
56GET_BUILD_VAR_CMD = ["build/soong/soong_ui.bash", "--dumpvar-mode"]
57DEFAULT_RETRY_BACKOFF_FACTOR = 1
58DEFAULT_SLEEP_MULTIPLIER = 0
59
60_SSH_TUNNEL_ARGS = ("-i %(rsa_key_file)s -o UserKnownHostsFile=/dev/null "
61                    "-o StrictHostKeyChecking=no "
62                    "-L %(vnc_port)d:127.0.0.1:%(target_vnc_port)d "
63                    "-L %(adb_port)d:127.0.0.1:%(target_adb_port)d "
64                    "-N -f -l %(ssh_user)s %(ip_addr)s")
65_ADB_CONNECT_ARGS = "connect 127.0.0.1:%(adb_port)d"
66# Store the ports that vnc/adb are forwarded to, both are integers.
67ForwardedPorts = collections.namedtuple("ForwardedPorts", [constants.VNC_PORT,
68                                                           constants.ADB_PORT])
69AVD_PORT_DICT = {
70    constants.TYPE_GCE: ForwardedPorts(constants.GCE_VNC_PORT,
71                                       constants.GCE_ADB_PORT),
72    constants.TYPE_CF: ForwardedPorts(constants.CF_VNC_PORT,
73                                      constants.CF_ADB_PORT),
74    constants.TYPE_GF: ForwardedPorts(constants.GF_VNC_PORT,
75                                      constants.GF_ADB_PORT),
76    constants.TYPE_CHEEPS: ForwardedPorts(constants.CHEEPS_VNC_PORT,
77                                          constants.CHEEPS_ADB_PORT)
78}
79
80_VNC_BIN = "ssvnc"
81_CMD_KILL = ["pkill", "-9", "-f"]
82_CMD_SG = "sg "
83_CMD_START_VNC = "%(bin)s vnc://127.0.0.1:%(port)d"
84_CMD_INSTALL_SSVNC = "sudo apt-get --assume-yes install ssvnc"
85_ENV_DISPLAY = "DISPLAY"
86_SSVNC_ENV_VARS = {"SSVNC_NO_ENC_WARN": "1", "SSVNC_SCALE": "auto", "VNCVIEWER_X11CURSOR": "1"}
87_DEFAULT_DISPLAY_SCALE = 1.0
88_DIST_DIR = "DIST_DIR"
89
90# For webrtc
91_WEBRTC_URL = "https://"
92_WEBRTC_PORT = "8443"
93
94_CONFIRM_CONTINUE = ("In order to display the screen to the AVD, we'll need to "
95                     "install a vnc client (ssvnc). \nWould you like acloud to "
96                     "install it for you? (%s) \nPress 'y' to continue or "
97                     "anything else to abort it[y/N]: ") % _CMD_INSTALL_SSVNC
98_EvaluatedResult = collections.namedtuple("EvaluatedResult",
99                                          ["is_result_ok", "result_message"])
100# dict of supported system and their distributions.
101_SUPPORTED_SYSTEMS_AND_DISTS = {"Linux": ["Ubuntu", "Debian"]}
102_DEFAULT_TIMEOUT_ERR = "Function did not complete within %d secs."
103_SSVNC_VIEWER_PATTERN = "vnc://127.0.0.1:%(vnc_port)d"
104
105
106class TempDir(object):
107    """A context manager that ceates a temporary directory.
108
109    Attributes:
110        path: The path of the temporary directory.
111    """
112
113    def __init__(self):
114        self.path = tempfile.mkdtemp()
115        os.chmod(self.path, 0o700)
116        logger.debug("Created temporary dir %s", self.path)
117
118    def __enter__(self):
119        """Enter."""
120        return self.path
121
122    def __exit__(self, exc_type, exc_value, traceback):
123        """Exit.
124
125        Args:
126            exc_type: Exception type raised within the context manager.
127                      None if no execption is raised.
128            exc_value: Exception instance raised within the context manager.
129                       None if no execption is raised.
130            traceback: Traceback for exeception that is raised within
131                       the context manager.
132                       None if no execption is raised.
133        Raises:
134            EnvironmentError or OSError when failed to delete temp directory.
135        """
136        try:
137            if self.path:
138                shutil.rmtree(self.path)
139                logger.debug("Deleted temporary dir %s", self.path)
140        except EnvironmentError as e:
141            # Ignore error if there is no exception raised
142            # within the with-clause and the EnvironementError is
143            # about problem that directory or file does not exist.
144            if not exc_type and e.errno != errno.ENOENT:
145                raise
146        except Exception as e:  # pylint: disable=W0703
147            if exc_type:
148                logger.error(
149                    "Encountered error while deleting %s: %s",
150                    self.path,
151                    str(e),
152                    exc_info=True)
153            else:
154                raise
155
156
157def RetryOnException(retry_checker,
158                     max_retries,
159                     sleep_multiplier=0,
160                     retry_backoff_factor=1):
161    """Decorater which retries the function call if |retry_checker| returns true.
162
163    Args:
164        retry_checker: A callback function which should take an exception instance
165                       and return True if functor(*args, **kwargs) should be retried
166                       when such exception is raised, and return False if it should
167                       not be retried.
168        max_retries: Maximum number of retries allowed.
169        sleep_multiplier: Will sleep sleep_multiplier * attempt_count seconds if
170                          retry_backoff_factor is 1.  Will sleep
171                          sleep_multiplier * (
172                              retry_backoff_factor ** (attempt_count -  1))
173                          if retry_backoff_factor != 1.
174        retry_backoff_factor: See explanation of sleep_multiplier.
175
176    Returns:
177        The function wrapper.
178    """
179
180    def _Wrapper(func):
181        def _FunctionWrapper(*args, **kwargs):
182            return Retry(retry_checker, max_retries, func, sleep_multiplier,
183                         retry_backoff_factor, *args, **kwargs)
184
185        return _FunctionWrapper
186
187    return _Wrapper
188
189
190def Retry(retry_checker, max_retries, functor, sleep_multiplier,
191          retry_backoff_factor, *args, **kwargs):
192    """Conditionally retry a function.
193
194    Args:
195        retry_checker: A callback function which should take an exception instance
196                       and return True if functor(*args, **kwargs) should be retried
197                       when such exception is raised, and return False if it should
198                       not be retried.
199        max_retries: Maximum number of retries allowed.
200        functor: The function to call, will call functor(*args, **kwargs).
201        sleep_multiplier: Will sleep sleep_multiplier * attempt_count seconds if
202                          retry_backoff_factor is 1.  Will sleep
203                          sleep_multiplier * (
204                              retry_backoff_factor ** (attempt_count -  1))
205                          if retry_backoff_factor != 1.
206        retry_backoff_factor: See explanation of sleep_multiplier.
207        *args: Arguments to pass to the functor.
208        **kwargs: Key-val based arguments to pass to the functor.
209
210    Returns:
211        The return value of the functor.
212
213    Raises:
214        Exception: The exception that functor(*args, **kwargs) throws.
215    """
216    attempt_count = 0
217    while attempt_count <= max_retries:
218        try:
219            attempt_count += 1
220            return_value = functor(*args, **kwargs)
221            return return_value
222        except Exception as e:  # pylint: disable=W0703
223            if retry_checker(e) and attempt_count <= max_retries:
224                if retry_backoff_factor != 1:
225                    sleep = sleep_multiplier * (retry_backoff_factor**
226                                                (attempt_count - 1))
227                else:
228                    sleep = sleep_multiplier * attempt_count
229                time.sleep(sleep)
230            else:
231                raise
232
233
234def RetryExceptionType(exception_types, max_retries, functor, *args, **kwargs):
235    """Retry exception if it is one of the given types.
236
237    Args:
238        exception_types: A tuple of exception types, e.g. (ValueError, KeyError)
239        max_retries: Max number of retries allowed.
240        functor: The function to call. Will be retried if exception is raised and
241                 the exception is one of the exception_types.
242        *args: Arguments to pass to Retry function.
243        **kwargs: Key-val based arguments to pass to Retry functions.
244
245    Returns:
246        The value returned by calling functor.
247    """
248    return Retry(lambda e: isinstance(e, exception_types), max_retries,
249                 functor, *args, **kwargs)
250
251
252def PollAndWait(func, expected_return, timeout_exception, timeout_secs,
253                sleep_interval_secs, *args, **kwargs):
254    """Call a function until the function returns expected value or times out.
255
256    Args:
257        func: Function to call.
258        expected_return: The expected return value.
259        timeout_exception: Exception to raise when it hits timeout.
260        timeout_secs: Timeout seconds.
261                      If 0 or less than zero, the function will run once and
262                      we will not wait on it.
263        sleep_interval_secs: Time to sleep between two attemps.
264        *args: list of args to pass to func.
265        **kwargs: dictionary of keyword based args to pass to func.
266
267    Raises:
268        timeout_exception: if the run of function times out.
269    """
270    # TODO(fdeng): Currently this method does not kill
271    # |func|, if |func| takes longer than |timeout_secs|.
272    # We can use a more robust version from chromite.
273    start = time.time()
274    while True:
275        return_value = func(*args, **kwargs)
276        if return_value == expected_return:
277            return
278        elif time.time() - start > timeout_secs:
279            raise timeout_exception
280        else:
281            if sleep_interval_secs > 0:
282                time.sleep(sleep_interval_secs)
283
284
285def GenerateUniqueName(prefix=None, suffix=None):
286    """Generate a random unique name using uuid4.
287
288    Args:
289        prefix: String, desired prefix to prepend to the generated name.
290        suffix: String, desired suffix to append to the generated name.
291
292    Returns:
293        String, a random name.
294    """
295    name = uuid.uuid4().hex
296    if prefix:
297        name = "-".join([prefix, name])
298    if suffix:
299        name = "-".join([name, suffix])
300    return name
301
302
303def MakeTarFile(src_dict, dest):
304    """Archive files in tar.gz format to a file named as |dest|.
305
306    Args:
307        src_dict: A dictionary that maps a path to be archived
308                  to the corresponding name that appears in the archive.
309        dest: String, path to output file, e.g. /tmp/myfile.tar.gz
310    """
311    logger.info("Compressing %s into %s.", src_dict.keys(), dest)
312    with tarfile.open(dest, "w:gz") as tar:
313        for src, arcname in six.iteritems(src_dict):
314            tar.add(src, arcname=arcname)
315
316def CreateSshKeyPairIfNotExist(private_key_path, public_key_path):
317    """Create the ssh key pair if they don't exist.
318
319    Case1. If the private key doesn't exist, we will create both the public key
320           and the private key.
321    Case2. If the private key exists but public key doesn't, we will create the
322           public key by using the private key.
323    Case3. If the public key exists but the private key doesn't, we will create
324           a new private key and overwrite the public key.
325
326    Args:
327        private_key_path: Path to the private key file.
328                          e.g. ~/.ssh/acloud_rsa
329        public_key_path: Path to the public key file.
330                         e.g. ~/.ssh/acloud_rsa.pub
331
332    Raises:
333        error.DriverError: If failed to create the key pair.
334    """
335    public_key_path = os.path.expanduser(public_key_path)
336    private_key_path = os.path.expanduser(private_key_path)
337    public_key_exist = os.path.exists(public_key_path)
338    private_key_exist = os.path.exists(private_key_path)
339    if public_key_exist and private_key_exist:
340        logger.debug(
341            "The ssh private key (%s) and public key (%s) already exist,"
342            "will not automatically create the key pairs.", private_key_path,
343            public_key_path)
344        return
345    key_folder = os.path.dirname(private_key_path)
346    if not os.path.exists(key_folder):
347        os.makedirs(key_folder)
348    try:
349        if private_key_exist:
350            cmd = SSH_KEYGEN_PUB_CMD + ["-f", private_key_path]
351            with open(public_key_path, 'w') as outfile:
352                stream_content = subprocess.check_output(cmd)
353                outfile.write(
354                    stream_content.rstrip('\n') + " " + getpass.getuser())
355            logger.info(
356                "The ssh public key (%s) do not exist, "
357                "automatically creating public key, calling: %s",
358                public_key_path, " ".join(cmd))
359        else:
360            cmd = SSH_KEYGEN_CMD + [
361                "-C", getpass.getuser(), "-f", private_key_path
362            ]
363            logger.info(
364                "Creating public key from private key (%s) via cmd: %s",
365                private_key_path, " ".join(cmd))
366            subprocess.check_call(cmd, stdout=sys.stderr, stderr=sys.stdout)
367    except subprocess.CalledProcessError as e:
368        raise errors.DriverError("Failed to create ssh key pair: %s" % str(e))
369    except OSError as e:
370        raise errors.DriverError(
371            "Failed to create ssh key pair, please make sure "
372            "'ssh-keygen' is installed: %s" % str(e))
373
374    # By default ssh-keygen will create a public key file
375    # by append .pub to the private key file name. Rename it
376    # to what's requested by public_key_path.
377    default_pub_key_path = "%s.pub" % private_key_path
378    try:
379        if default_pub_key_path != public_key_path:
380            os.rename(default_pub_key_path, public_key_path)
381    except OSError as e:
382        raise errors.DriverError(
383            "Failed to rename %s to %s: %s" % (default_pub_key_path,
384                                               public_key_path, str(e)))
385
386    logger.info("Created ssh private key (%s) and public key (%s)",
387                private_key_path, public_key_path)
388
389
390def VerifyRsaPubKey(rsa):
391    """Verify the format of rsa public key.
392
393    Args:
394        rsa: content of rsa public key. It should follow the format of
395             ssh-rsa AAAAB3NzaC1yc2EA.... [email protected]
396
397    Raises:
398        DriverError if the format is not correct.
399    """
400    if not rsa or not all(ord(c) < 128 for c in rsa):
401        raise errors.DriverError(
402            "rsa key is empty or contains non-ascii character: %s" % rsa)
403
404    elements = rsa.split()
405    if len(elements) != 3:
406        raise errors.DriverError("rsa key is invalid, wrong format: %s" % rsa)
407
408    key_type, data, _ = elements
409    try:
410        binary_data = base64.decodestring(data)
411        # number of bytes of int type
412        int_length = 4
413        # binary_data is like "7ssh-key..." in a binary format.
414        # The first 4 bytes should represent 7, which should be
415        # the length of the following string "ssh-key".
416        # And the next 7 bytes should be string "ssh-key".
417        # We will verify that the rsa conforms to this format.
418        # ">I" in the following line means "big-endian unsigned integer".
419        type_length = struct.unpack(">I", binary_data[:int_length])[0]
420        if binary_data[int_length:int_length + type_length] != key_type:
421            raise errors.DriverError("rsa key is invalid: %s" % rsa)
422    except (struct.error, binascii.Error) as e:
423        raise errors.DriverError(
424            "rsa key is invalid: %s, error: %s" % (rsa, str(e)))
425
426
427def Decompress(sourcefile, dest=None):
428    """Decompress .zip or .tar.gz.
429
430    Args:
431        sourcefile: A string, a source file path to decompress.
432        dest: A string, a folder path as decompress destination.
433
434    Raises:
435        errors.UnsupportedCompressionFileType: Not supported extension.
436    """
437    logger.info("Start to decompress %s!", sourcefile)
438    dest_path = dest if dest else "."
439    if sourcefile.endswith(".tar.gz"):
440        with tarfile.open(sourcefile, "r:gz") as compressor:
441            compressor.extractall(dest_path)
442    elif sourcefile.endswith(".zip"):
443        with zipfile.ZipFile(sourcefile, 'r') as compressor:
444            compressor.extractall(dest_path)
445    else:
446        raise errors.UnsupportedCompressionFileType(
447            "Sorry, we could only support compression file type "
448            "for zip or tar.gz.")
449
450
451# pylint: disable=old-style-class,no-init
452class TextColors:
453    """A class that defines common color ANSI code."""
454
455    HEADER = "\033[95m"
456    OKBLUE = "\033[94m"
457    OKGREEN = "\033[92m"
458    WARNING = "\033[33m"
459    FAIL = "\033[91m"
460    ENDC = "\033[0m"
461    BOLD = "\033[1m"
462    UNDERLINE = "\033[4m"
463
464
465def PrintColorString(message, colors=TextColors.OKBLUE, **kwargs):
466    """A helper function to print out colored text.
467
468    Use print function "print(message, end="")" to show message in one line.
469    Example code:
470        DisplayMessages("Creating GCE instance...", end="")
471        # Job execute 20s
472        DisplayMessages("Done! (20s)")
473    Display:
474        Creating GCE instance...
475        # After job finished, messages update as following:
476        Creating GCE instance...Done! (20s)
477
478    Args:
479        message: String, the message text.
480        colors: String, color code.
481        **kwargs: dictionary of keyword based args to pass to func.
482    """
483    print(colors + message + TextColors.ENDC, **kwargs)
484    sys.stdout.flush()
485
486
487def InteractWithQuestion(question, colors=TextColors.WARNING):
488    """A helper function to define the common way to run interactive cmd.
489
490    Args:
491        question: String, the question to ask user.
492        colors: String, color code.
493
494    Returns:
495        String, input from user.
496    """
497    return str(six.moves.input(colors + question + TextColors.ENDC).strip())
498
499
500def GetUserAnswerYes(question):
501    """Ask user about acloud setup question.
502
503    Args:
504        question: String of question for user. Enter is equivalent to pressing
505                  n. We should hint user with upper case N surrounded in square
506                  brackets.
507                  Ex: "Are you sure to change bucket name[y/N]:"
508
509    Returns:
510        Boolean, True if answer is "Yes", False otherwise.
511    """
512    answer = InteractWithQuestion(question)
513    return answer.lower() in constants.USER_ANSWER_YES
514
515
516class BatchHttpRequestExecutor(object):
517    """A helper class that executes requests in batch with retry.
518
519    This executor executes http requests in a batch and retry
520    those that have failed. It iteratively updates the dictionary
521    self._final_results with latest results, which can be retrieved
522    via GetResults.
523    """
524
525    def __init__(self,
526                 execute_once_functor,
527                 requests,
528                 retry_http_codes=None,
529                 max_retry=None,
530                 sleep=None,
531                 backoff_factor=None,
532                 other_retriable_errors=None):
533        """Initializes the executor.
534
535        Args:
536            execute_once_functor: A function that execute requests in batch once.
537                                  It should return a dictionary like
538                                  {request_id: (response, exception)}
539            requests: A dictionary where key is request id picked by caller,
540                      and value is a apiclient.http.HttpRequest.
541            retry_http_codes: A list of http codes to retry.
542            max_retry: See utils.Retry.
543            sleep: See utils.Retry.
544            backoff_factor: See utils.Retry.
545            other_retriable_errors: A tuple of error types that should be retried
546                                    other than errors.HttpError.
547        """
548        self._execute_once_functor = execute_once_functor
549        self._requests = requests
550        # A dictionary that maps request id to pending request.
551        self._pending_requests = {}
552        # A dictionary that maps request id to a tuple (response, exception).
553        self._final_results = {}
554        self._retry_http_codes = retry_http_codes
555        self._max_retry = max_retry
556        self._sleep = sleep
557        self._backoff_factor = backoff_factor
558        self._other_retriable_errors = other_retriable_errors
559
560    def _ShoudRetry(self, exception):
561        """Check if an exception is retriable.
562
563        Args:
564            exception: An exception instance.
565        """
566        if isinstance(exception, self._other_retriable_errors):
567            return True
568
569        if (isinstance(exception, errors.HttpError)
570                and exception.code in self._retry_http_codes):
571            return True
572        return False
573
574    def _ExecuteOnce(self):
575        """Executes pending requests and update it with failed, retriable ones.
576
577        Raises:
578            HasRetriableRequestsError: if some requests fail and are retriable.
579        """
580        results = self._execute_once_functor(self._pending_requests)
581        # Update final_results with latest results.
582        self._final_results.update(results)
583        # Clear pending_requests
584        self._pending_requests.clear()
585        for request_id, result in six.iteritems(results):
586            exception = result[1]
587            if exception is not None and self._ShoudRetry(exception):
588                # If this is a retriable exception, put it in pending_requests
589                self._pending_requests[request_id] = self._requests[request_id]
590        if self._pending_requests:
591            # If there is still retriable requests pending, raise an error
592            # so that Retry will retry this function with pending_requests.
593            raise errors.HasRetriableRequestsError(
594                "Retriable errors: %s" %
595                [str(results[rid][1]) for rid in self._pending_requests])
596
597    def Execute(self):
598        """Executes the requests and retry if necessary.
599
600        Will populate self._final_results.
601        """
602
603        def _ShouldRetryHandler(exc):
604            """Check if |exc| is a retriable exception.
605
606            Args:
607                exc: An exception.
608
609            Returns:
610                True if exception is of type HasRetriableRequestsError; False otherwise.
611            """
612            should_retry = isinstance(exc, errors.HasRetriableRequestsError)
613            if should_retry:
614                logger.info("Will retry failed requests.", exc_info=True)
615                logger.info("%s", exc)
616            return should_retry
617
618        try:
619            self._pending_requests = self._requests.copy()
620            Retry(
621                _ShouldRetryHandler,
622                max_retries=self._max_retry,
623                functor=self._ExecuteOnce,
624                sleep_multiplier=self._sleep,
625                retry_backoff_factor=self._backoff_factor)
626        except errors.HasRetriableRequestsError:
627            logger.debug("Some requests did not succeed after retry.")
628
629    def GetResults(self):
630        """Returns final results.
631
632        Returns:
633            results, a dictionary in the following format
634            {request_id: (response, exception)}
635            request_ids are those from requests; response
636            is the http response for the request or None on error;
637            exception is an instance of DriverError or None if no error.
638        """
639        return self._final_results
640
641
642def DefaultEvaluator(result):
643    """Default Evaluator always return result is ok.
644
645    Args:
646        result:the return value of the target function.
647
648    Returns:
649        _EvaluatedResults namedtuple.
650    """
651    return _EvaluatedResult(is_result_ok=True, result_message=result)
652
653
654def ReportEvaluator(report):
655    """Evalute the acloud operation by the report.
656
657    Args:
658        report: acloud.public.report() object.
659
660    Returns:
661        _EvaluatedResults namedtuple.
662    """
663    if report is None or report.errors:
664        return _EvaluatedResult(is_result_ok=False,
665                                result_message=report.errors)
666
667    return _EvaluatedResult(is_result_ok=True, result_message=None)
668
669
670def BootEvaluator(boot_dict):
671    """Evaluate if the device booted successfully.
672
673    Args:
674        boot_dict: Dict of instance_name:boot error.
675
676    Returns:
677        _EvaluatedResults namedtuple.
678    """
679    if boot_dict:
680        return _EvaluatedResult(is_result_ok=False, result_message=boot_dict)
681    return _EvaluatedResult(is_result_ok=True, result_message=None)
682
683
684class TimeExecute(object):
685    """Count the function execute time."""
686
687    def __init__(self, function_description=None, print_before_call=True,
688                 print_status=True, result_evaluator=DefaultEvaluator,
689                 display_waiting_dots=True):
690        """Initializes the class.
691
692        Args:
693            function_description: String that describes function (e.g."Creating
694                                  Instance...")
695            print_before_call: Boolean, print the function description before
696                               calling the function, default True.
697            print_status: Boolean, print the status of the function after the
698                          function has completed, default True ("OK" or "Fail").
699            result_evaluator: Func object. Pass func to evaluate result.
700                              Default evaluator always report result is ok and
701                              failed result will be identified only in exception
702                              case.
703            display_waiting_dots: Boolean, if true print the function_description
704                                  followed by waiting dot.
705        """
706        self._function_description = function_description
707        self._print_before_call = print_before_call
708        self._print_status = print_status
709        self._result_evaluator = result_evaluator
710        self._display_waiting_dots = display_waiting_dots
711
712    def __call__(self, func):
713        def DecoratorFunction(*args, **kargs):
714            """Decorator function.
715
716            Args:
717                *args: Arguments to pass to the functor.
718                **kwargs: Key-val based arguments to pass to the functor.
719
720            Raises:
721                Exception: The exception that functor(*args, **kwargs) throws.
722            """
723            timestart = time.time()
724            if self._print_before_call:
725                waiting_dots = "..." if self._display_waiting_dots else ""
726                PrintColorString("%s %s"% (self._function_description,
727                                           waiting_dots), end="")
728            try:
729                result = func(*args, **kargs)
730                result_time = time.time() - timestart
731                if not self._print_before_call:
732                    PrintColorString("%s (%ds)" % (self._function_description,
733                                                   result_time),
734                                     TextColors.OKGREEN)
735                if self._print_status:
736                    evaluated_result = self._result_evaluator(result)
737                    if evaluated_result.is_result_ok:
738                        PrintColorString("OK! (%ds)" % (result_time),
739                                         TextColors.OKGREEN)
740                    else:
741                        PrintColorString("Fail! (%ds)" % (result_time),
742                                         TextColors.FAIL)
743                        PrintColorString("Error: %s" %
744                                         evaluated_result.result_message,
745                                         TextColors.FAIL)
746                return result
747            except:
748                if self._print_status:
749                    PrintColorString("Fail! (%ds)" % (time.time() - timestart),
750                                     TextColors.FAIL)
751                raise
752        return DecoratorFunction
753
754
755def PickFreePort():
756    """Helper to pick a free port.
757
758    Returns:
759        Integer, a free port number.
760    """
761    tcp_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
762    tcp_socket.bind(("", 0))
763    port = tcp_socket.getsockname()[1]
764    tcp_socket.close()
765    return port
766
767
768def CheckPortFree(port):
769    """Check the availablity of the tcp port.
770
771    Args:
772        Integer, a port number.
773
774    Raises:
775        PortOccupied: This port is not available.
776    """
777    tcp_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
778    try:
779        tcp_socket.bind(("", port))
780    except socket.error:
781        raise errors.PortOccupied("Port (%d) is taken, please choose another "
782                                  "port." % port)
783    tcp_socket.close()
784
785
786def _ExecuteCommand(cmd, args):
787    """Execute command.
788
789    Args:
790        cmd: Strings of execute binary name.
791        args: List of args to pass in with cmd.
792
793    Raises:
794        errors.NoExecuteBin: Can't find the execute bin file.
795    """
796    bin_path = FindExecutable(cmd)
797    if not bin_path:
798        raise errors.NoExecuteCmd("unable to locate %s" % cmd)
799    command = [bin_path] + args
800    logger.debug("Running '%s'", ' '.join(command))
801    with open(os.devnull, "w") as dev_null:
802        subprocess.check_call(command, stderr=dev_null, stdout=dev_null)
803
804
805# TODO(147337696): create ssh tunnels tear down as adb and vnc.
806# pylint: disable=too-many-locals
807def AutoConnect(ip_addr, rsa_key_file, target_vnc_port, target_adb_port,
808                ssh_user, client_adb_port=None, extra_args_ssh_tunnel=None):
809    """Autoconnect to an AVD instance.
810
811    Args:
812        ip_addr: String, use to build the adb & vnc tunnel between local
813                 and remote instance.
814        rsa_key_file: String, Private key file path to use when creating
815                      the ssh tunnels.
816        target_vnc_port: Integer of target vnc port number.
817        target_adb_port: Integer of target adb port number.
818        ssh_user: String of user login into the instance.
819        client_adb_port: Integer, Specified adb port to establish connection.
820        extra_args_ssh_tunnel: String, extra args for ssh tunnel connection.
821
822    Returns:
823        NamedTuple of (vnc_port, adb_port) SSHTUNNEL of the connect, both are
824        integers.
825    """
826    local_free_vnc_port = PickFreePort()
827    local_adb_port = client_adb_port or PickFreePort()
828    try:
829        ssh_tunnel_args = _SSH_TUNNEL_ARGS % {
830            "rsa_key_file": rsa_key_file,
831            "vnc_port": local_free_vnc_port,
832            "adb_port": local_adb_port,
833            "target_vnc_port": target_vnc_port,
834            "target_adb_port": target_adb_port,
835            "ssh_user": ssh_user,
836            "ip_addr": ip_addr}
837        ssh_tunnel_args_list = shlex.split(ssh_tunnel_args)
838        if extra_args_ssh_tunnel:
839            ssh_tunnel_args_list.extend(shlex.split(extra_args_ssh_tunnel))
840        _ExecuteCommand(constants.SSH_BIN, ssh_tunnel_args_list)
841    except subprocess.CalledProcessError as e:
842        PrintColorString("\n%s\nFailed to create ssh tunnels, retry with '#acloud "
843                         "reconnect'." % e, TextColors.FAIL)
844        return ForwardedPorts(vnc_port=None, adb_port=None)
845
846    try:
847        adb_connect_args = _ADB_CONNECT_ARGS % {"adb_port": local_adb_port}
848        _ExecuteCommand(constants.ADB_BIN, adb_connect_args.split())
849    except subprocess.CalledProcessError:
850        PrintColorString("Failed to adb connect, retry with "
851                         "'#acloud reconnect'", TextColors.FAIL)
852
853    return ForwardedPorts(vnc_port=local_free_vnc_port,
854                          adb_port=local_adb_port)
855
856
857def GetAnswerFromList(answer_list, enable_choose_all=False):
858    """Get answer from a list.
859
860    Args:
861        answer_list: list of the answers to choose from.
862        enable_choose_all: True to choose all items from answer list.
863
864    Return:
865        List holding the answer(s).
866    """
867    print("[0] to exit.")
868    start_index = 1
869    max_choice = len(answer_list)
870
871    for num, item in enumerate(answer_list, start_index):
872        print("[%d] %s" % (num, item))
873    if enable_choose_all:
874        max_choice += 1
875        print("[%d] for all." % max_choice)
876
877    choice = -1
878
879    while True:
880        try:
881            choice = six.moves.input("Enter your choice[0-%d]: " % max_choice)
882            choice = int(choice)
883        except ValueError:
884            print("'%s' is not a valid integer.", choice)
885            continue
886        # Filter out choices
887        if choice == 0:
888            sys.exit(constants.EXIT_BY_USER)
889        if enable_choose_all and choice == max_choice:
890            return answer_list
891        if choice < 0 or choice > max_choice:
892            print("please choose between 0 and %d" % max_choice)
893        else:
894            return [answer_list[choice-start_index]]
895
896
897def LaunchVNCFromReport(report, avd_spec, no_prompts=False):
898    """Launch vnc client according to the instances report.
899
900    Args:
901        report: Report object, that stores and generates report.
902        avd_spec: AVDSpec object that tells us what we're going to create.
903        no_prompts: Boolean, True to skip all prompts.
904    """
905    for device in report.data.get("devices", []):
906        if device.get(constants.VNC_PORT):
907            LaunchVncClient(device.get(constants.VNC_PORT),
908                            avd_width=avd_spec.hw_property["x_res"],
909                            avd_height=avd_spec.hw_property["y_res"],
910                            no_prompts=no_prompts)
911        else:
912            PrintColorString("No VNC port specified, skipping VNC startup.",
913                             TextColors.FAIL)
914
915def LaunchBrowserFromReport(report):
916    """Open browser when autoconnect to webrtc according to the instances report.
917
918    Args:
919        report: Report object, that stores and generates report.
920    """
921    PrintColorString("(This is an experimental project for webrtc, and since "
922                     "the certificate is self-signed, Chrome will mark it as "
923                     "an insecure website. keep going.)",
924                     TextColors.WARNING)
925
926    for device in report.data.get("devices", []):
927        if device.get("ip"):
928            webrtc_link = "%s%s:%s" % (_WEBRTC_URL, device.get("ip"),
929                                       _WEBRTC_PORT)
930            if os.environ.get(_ENV_DISPLAY, None):
931                webbrowser.open_new_tab(webrtc_link)
932            else:
933                PrintColorString("Remote terminal can't support launch webbrowser.",
934                                 TextColors.FAIL)
935                PrintColorString("Open %s to remotely control AVD on the "
936                                 "browser." % webrtc_link)
937        else:
938            PrintColorString("Auto-launch devices webrtc in browser failed!",
939                             TextColors.FAIL)
940
941def LaunchVncClient(port, avd_width=None, avd_height=None, no_prompts=False):
942    """Launch ssvnc.
943
944    Args:
945        port: Integer, port number.
946        avd_width: String, the width of avd.
947        avd_height: String, the height of avd.
948        no_prompts: Boolean, True to skip all prompts.
949    """
950    try:
951        os.environ[_ENV_DISPLAY]
952    except KeyError:
953        PrintColorString("Remote terminal can't support VNC. "
954                         "Skipping VNC startup.", TextColors.FAIL)
955        return
956
957    if IsSupportedPlatform() and not FindExecutable(_VNC_BIN):
958        if no_prompts or GetUserAnswerYes(_CONFIRM_CONTINUE):
959            try:
960                PrintColorString("Installing ssvnc vnc client... ", end="")
961                sys.stdout.flush()
962                subprocess.check_output(_CMD_INSTALL_SSVNC, shell=True)
963                PrintColorString("Done", TextColors.OKGREEN)
964            except subprocess.CalledProcessError as cpe:
965                PrintColorString("Failed to install ssvnc: %s" %
966                                 cpe.output, TextColors.FAIL)
967                return
968        else:
969            return
970    ssvnc_env = os.environ.copy()
971    ssvnc_env.update(_SSVNC_ENV_VARS)
972    # Override SSVNC_SCALE
973    if avd_width or avd_height:
974        scale_ratio = CalculateVNCScreenRatio(avd_width, avd_height)
975        ssvnc_env["SSVNC_SCALE"] = str(scale_ratio)
976        logger.debug("SSVNC_SCALE:%s", scale_ratio)
977
978    ssvnc_args = _CMD_START_VNC % {"bin": FindExecutable(_VNC_BIN),
979                                   "port": port}
980    subprocess.Popen(ssvnc_args.split(), env=ssvnc_env)
981
982
983def PrintDeviceSummary(report):
984    """Display summary of devices.
985
986    -Display device details from the report instance.
987        report example:
988            'data': [{'devices':[{'instance_name': 'ins-f6a397-none-53363',
989                                  'ip': u'35.234.10.162'}]}]
990    -Display error message from report.error.
991
992    Args:
993        report: A Report instance.
994    """
995    PrintColorString("\n")
996    PrintColorString("Device summary:")
997    for device in report.data.get("devices", []):
998        adb_serial = "(None)"
999        adb_port = device.get("adb_port")
1000        if adb_port:
1001            adb_serial = constants.LOCALHOST_ADB_SERIAL % adb_port
1002        instance_name = device.get("instance_name")
1003        instance_ip = device.get("ip")
1004        instance_details = "" if not instance_name else "(%s[%s])" % (
1005            instance_name, instance_ip)
1006        PrintColorString(" - device serial: %s %s" % (adb_serial,
1007                                                      instance_details))
1008        PrintColorString("   export ANDROID_SERIAL=%s" % adb_serial)
1009
1010    # TODO(b/117245508): Help user to delete instance if it got created.
1011    if report.errors:
1012        error_msg = "\n".join(report.errors)
1013        PrintColorString("Fail in:\n%s\n" % error_msg, TextColors.FAIL)
1014
1015
1016def CalculateVNCScreenRatio(avd_width, avd_height):
1017    """calculate the vnc screen scale ratio to fit into user's monitor.
1018
1019    Args:
1020        avd_width: String, the width of avd.
1021        avd_height: String, the height of avd.
1022    Return:
1023        Float, scale ratio for vnc client.
1024    """
1025    try:
1026        import Tkinter
1027    # Some python interpreters may not be configured for Tk, just return default scale ratio.
1028    except ImportError:
1029        return _DEFAULT_DISPLAY_SCALE
1030    root = Tkinter.Tk()
1031    margin = 100 # leave some space on user's monitor.
1032    screen_height = root.winfo_screenheight() - margin
1033    screen_width = root.winfo_screenwidth() - margin
1034
1035    scale_h = _DEFAULT_DISPLAY_SCALE
1036    scale_w = _DEFAULT_DISPLAY_SCALE
1037    if float(screen_height) < float(avd_height):
1038        scale_h = round(float(screen_height) / float(avd_height), 1)
1039
1040    if float(screen_width) < float(avd_width):
1041        scale_w = round(float(screen_width) / float(avd_width), 1)
1042
1043    logger.debug("scale_h: %s (screen_h: %s/avd_h: %s),"
1044                 " scale_w: %s (screen_w: %s/avd_w: %s)",
1045                 scale_h, screen_height, avd_height,
1046                 scale_w, screen_width, avd_width)
1047
1048    # Return the larger scale-down ratio.
1049    return scale_h if scale_h < scale_w else scale_w
1050
1051
1052def IsCommandRunning(command):
1053    """Check if command is running.
1054
1055    Args:
1056        command: String of command name.
1057
1058    Returns:
1059        Boolean, True if command is running. False otherwise.
1060    """
1061    try:
1062        with open(os.devnull, "w") as dev_null:
1063            subprocess.check_call([constants.CMD_PGREP, "-af", command],
1064                                  stderr=dev_null, stdout=dev_null)
1065        return True
1066    except subprocess.CalledProcessError:
1067        return False
1068
1069
1070def AddUserGroupsToCmd(cmd, user_groups):
1071    """Add the user groups to the command if necessary.
1072
1073    As part of local host setup to enable local instance support, the user is
1074    added to certain groups. For those settings to take effect systemwide
1075    requires the user to log out and log back in. In the scenario where the
1076    user has run setup and hasn't logged out, we still want them to be able to
1077    launch a local instance so add the user to the groups as part of the
1078    command to ensure success.
1079
1080    The reason using here-doc instead of '&' is all operations need to be ran in
1081    ths same pid.  Here's an example cmd:
1082    $ sg kvm  << EOF
1083    sg libvirt
1084    sg cvdnetwork
1085    launch_cvd --cpus 2 --x_res 1280 --y_res 720 --dpi 160 --memory_mb 4096
1086    EOF
1087
1088    Args:
1089        cmd: String of the command to prepend the user groups to.
1090        user_groups: List of user groups name.(String)
1091
1092    Returns:
1093        String of the command with the user groups prepended to it if necessary,
1094        otherwise the same existing command.
1095    """
1096    user_group_cmd = ""
1097    if not CheckUserInGroups(user_groups):
1098        logger.debug("Need to add user groups to the command")
1099        for idx, group in enumerate(user_groups):
1100            user_group_cmd += _CMD_SG + group
1101            if idx == 0:
1102                user_group_cmd += " <<EOF\n"
1103            else:
1104                user_group_cmd += "\n"
1105        cmd += "\nEOF"
1106    user_group_cmd += cmd
1107    logger.debug("user group cmd: %s", user_group_cmd)
1108    return user_group_cmd
1109
1110
1111def CheckUserInGroups(group_name_list):
1112    """Check if the current user is in the group.
1113
1114    Args:
1115        group_name_list: The list of group name.
1116    Returns:
1117        True if current user is in all the groups.
1118    """
1119    logger.info("Checking if user is in following groups: %s", group_name_list)
1120    current_groups = [grp.getgrgid(g).gr_name for g in os.getgroups()]
1121    all_groups_present = True
1122    for group in group_name_list:
1123        if group not in current_groups:
1124            all_groups_present = False
1125            logger.info("missing group: %s", group)
1126    return all_groups_present
1127
1128
1129def IsSupportedPlatform(print_warning=False):
1130    """Check if user's os is the supported platform.
1131
1132    Args:
1133        print_warning: Boolean, print the unsupported warning
1134                       if True.
1135    Returns:
1136        Boolean, True if user is using supported platform.
1137    """
1138    system = platform.system()
1139    # TODO(b/143197659): linux_distribution() deprecated in python 3. To fix it
1140    # try to use another package "import distro".
1141    dist = platform.linux_distribution()[0]
1142    platform_supported = (system in _SUPPORTED_SYSTEMS_AND_DISTS and
1143                          dist in _SUPPORTED_SYSTEMS_AND_DISTS[system])
1144
1145    logger.info("supported system and dists: %s",
1146                _SUPPORTED_SYSTEMS_AND_DISTS)
1147    platform_supported_msg = ("%s[%s] %s supported platform" %
1148                              (system,
1149                               dist,
1150                               "is a" if platform_supported else "is not a"))
1151    if print_warning and not platform_supported:
1152        PrintColorString(platform_supported_msg, TextColors.WARNING)
1153    else:
1154        logger.info(platform_supported_msg)
1155
1156    return platform_supported
1157
1158
1159def GetDistDir():
1160    """Return the absolute path to the dist dir."""
1161    android_build_top = os.environ.get(constants.ENV_ANDROID_BUILD_TOP)
1162    if not android_build_top:
1163        return None
1164    dist_cmd = GET_BUILD_VAR_CMD[:]
1165    dist_cmd.append(_DIST_DIR)
1166    try:
1167        dist_dir = subprocess.check_output(dist_cmd, cwd=android_build_top)
1168    except subprocess.CalledProcessError:
1169        return None
1170    return os.path.join(android_build_top, dist_dir.strip())
1171
1172
1173def CleanupProcess(pattern):
1174    """Cleanup process with pattern.
1175
1176    Args:
1177        pattern: String, string of process pattern.
1178    """
1179    if IsCommandRunning(pattern):
1180        command_kill = _CMD_KILL + [pattern]
1181        subprocess.check_call(command_kill)
1182
1183
1184def TimeoutException(timeout_secs, timeout_error=_DEFAULT_TIMEOUT_ERR):
1185    """Decorater which function timeout setup and raise custom exception.
1186
1187    Args:
1188        timeout_secs: Number of maximum seconds of waiting time.
1189        timeout_error: String to describe timeout exception.
1190
1191    Returns:
1192        The function wrapper.
1193    """
1194    if timeout_error == _DEFAULT_TIMEOUT_ERR:
1195        timeout_error = timeout_error % timeout_secs
1196
1197    def _Wrapper(func):
1198        # pylint: disable=unused-argument
1199        def _HandleTimeout(signum, frame):
1200            raise errors.FunctionTimeoutError(timeout_error)
1201
1202        def _FunctionWrapper(*args, **kwargs):
1203            signal.signal(signal.SIGALRM, _HandleTimeout)
1204            signal.alarm(timeout_secs)
1205            try:
1206                result = func(*args, **kwargs)
1207            finally:
1208                signal.alarm(0)
1209            return result
1210
1211        return _FunctionWrapper
1212
1213    return _Wrapper
1214
1215
1216def GetBuildEnvironmentVariable(variable_name):
1217    """Get build environment variable.
1218
1219    Args:
1220        variable_name: String of variable name.
1221
1222    Returns:
1223        String, the value of the variable.
1224
1225    Raises:
1226        errors.GetAndroidBuildEnvVarError: No environment variable found.
1227    """
1228    try:
1229        return os.environ[variable_name]
1230    except KeyError:
1231        raise errors.GetAndroidBuildEnvVarError(
1232            "Could not get environment var: %s\n"
1233            "Try to run 'source build/envsetup.sh && lunch <target>'"
1234            % variable_name
1235        )
1236
1237
1238# pylint: disable=no-member
1239def FindExecutable(filename):
1240    """A compatibility function to find execution file path.
1241
1242    Args:
1243        filename: String of execution filename.
1244
1245    Returns:
1246        String: execution file path.
1247    """
1248    return find_executable(filename) if six.PY2 else shutil.which(filename)
1249
1250
1251def GetDictItems(namedtuple_object):
1252    """A compatibility function to access the OrdereDict object from the given namedtuple object.
1253
1254    Args:
1255        namedtuple_object: namedtuple object.
1256
1257    Returns:
1258        collections.namedtuple.__dict__.items() when using python2.
1259        collections.namedtuple._asdict().items() when using python3.
1260    """
1261    return (namedtuple_object.__dict__.items() if six.PY2
1262            else namedtuple_object._asdict().items())
1263
1264
1265def CleanupSSVncviewer(vnc_port):
1266    """Cleanup the old disconnected ssvnc viewer.
1267
1268    Args:
1269        vnc_port: Integer, port number of vnc.
1270    """
1271    ssvnc_viewer_pattern = _SSVNC_VIEWER_PATTERN % {"vnc_port":vnc_port}
1272    CleanupProcess(ssvnc_viewer_pattern)
1273