1#!/usr/bin/env python3
2#
3#   Copyright 2018 - The Android Open Source Project
4#
5#   Licensed under the Apache License, Version 2.0 (the "License");
6#   you may not use this file except in compliance with the License.
7#   You may obtain a copy of the License at
8#
9#       http://www.apache.org/licenses/LICENSE-2.0
10#
11#   Unless required by applicable law or agreed to in writing, software
12#   distributed under the License is distributed on an "AS IS" BASIS,
13#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14#   See the License for the specific language governing permissions and
15#   limitations under the License.
16import bisect
17from collections import namedtuple
18import inspect
19import numbers
20
21
22def _fully_qualified_name(func):
23    """Returns the fully-qualified name of a function.
24
25    Note: __qualname__ is not the fully qualified name. It is the the fully
26          qualified name without the module name.
27
28    See: https://www.python.org/dev/peps/pep-3155/#naming-choice
29    """
30    return '%s:%s' % (func.__module__, func.__qualname__)
31
32
33_FrameInfo = namedtuple('_FrameInfo', ['frame', 'filename', 'lineno',
34                                       'function', 'code_context', 'index'])
35
36
37def _inspect_stack():
38    """Returns named tuple for each tuple returned by inspect.stack().
39
40    For Python3.4 and earlier, which returns unnamed tuples for inspect.stack().
41
42    Returns:
43        list of _FrameInfo named tuples representing stack frame info.
44    """
45    return [_FrameInfo(*info) for info in inspect.stack()]
46
47
48def set_version(get_version_func, min_version, max_version):
49    """Returns a decorator returning a VersionSelector containing all versions
50    of the decorated func.
51
52    Args:
53        get_version_func: The lambda that returns the version level based on the
54                          arguments sent to versioned_func
55        min_version: The minimum API level for calling versioned_func.
56        max_version: The maximum API level for calling versioned_func.
57
58    Raises:
59        SyntaxError if get_version_func is different between versioned funcs.
60
61    Returns:
62        A VersionSelector containing all versioned calls to the decorated func.
63    """
64    func_owner_variables = None
65    for frame_info in _inspect_stack():
66        if frame_info.function == '<module>':
67            # We've reached the end of the most recently imported module in our
68            # stack without finding a class first. This indicates that the
69            # decorator is on a module-level function.
70            func_owner_variables = frame_info.frame.f_locals
71            break
72        elif '__qualname__' in frame_info.frame.f_locals:
73            # __qualname__ appears in stack frames of objects that have
74            # yet to be interpreted. Here we can guarantee that the object in
75            # question is the innermost class that contains the function.
76            func_owner_variables = frame_info.frame.f_locals
77            break
78
79    def decorator(func):
80        if isinstance(func, (staticmethod, classmethod)):
81            raise SyntaxError('@staticmethod and @classmethod decorators must '
82                              'be placed before the versioning decorator.')
83        func_name = func.__name__
84
85        if func_name in func_owner_variables:
86            # If the function already exists within the class/module, get it.
87            version_selector = func_owner_variables[func_name]
88            if isinstance(version_selector, (staticmethod, classmethod)):
89                # If the function was also decorated with @staticmethod or
90                # @classmethod, the version_selector will be stored in __func__.
91                version_selector = version_selector.__func__
92            if not isinstance(version_selector, _VersionSelector):
93                raise SyntaxError('The previously defined function "%s" is not '
94                                  'decorated with a versioning decorator.' %
95                                  version_selector.__qualname__)
96            if (version_selector.comparison_func_name !=
97                    _fully_qualified_name(get_version_func)):
98                raise SyntaxError('Functions of the same name must be decorated'
99                                  ' with the same versioning decorator.')
100        else:
101            version_selector = _VersionSelector(get_version_func)
102
103        version_selector.add_fn(func, min_version, max_version)
104        return version_selector
105
106    return decorator
107
108
109class _VersionSelector(object):
110    """A class that maps API levels to versioned functions for that API level.
111
112    Attributes:
113        entry_list: A sorted list of Entries that define which functions to call
114                    for a given API level.
115    """
116
117    class ListWrap(object):
118        """This class wraps a list of VersionSelector.Entry objects.
119
120        This is required to make the bisect functions work, since the underlying
121        implementation of those functions do not use __cmp__, __lt__, __gt__,
122        etc. because they are not implemented in Python.
123
124        See: https://docs.python.org/3/library/bisect.html#other-examples
125        """
126
127        def __init__(self, entry_list):
128            self.list = entry_list
129
130        def __len__(self):
131            return len(self.list)
132
133        def __getitem__(self, index):
134            return self.list[index].level
135
136    class Entry(object):
137        def __init__(self, level, func, direction):
138            """Creates an Entry object.
139
140            Args:
141                level: The API level for this point.
142                func: The function to call.
143                direction: (-1, 0 or 1) the  direction the ray from this level
144                           points towards.
145            """
146            self.level = level
147            self.func = func
148            self.direction = direction
149
150    def __init__(self, version_func):
151        """Creates a VersionSelector object.
152
153        Args:
154            version_func: The function that converts the arguments into an
155                          integer that represents the API level.
156        """
157        self.entry_list = list()
158        self.get_version = version_func
159        self.instance = None
160        self.comparison_func_name = _fully_qualified_name(version_func)
161
162    def __name__(self):
163        if len(self.entry_list) > 0:
164            return self.entry_list[0].func.__name__
165        return '%s<%s>' % (self.__class__.__name__, self.get_version.__name__)
166
167    def print_ranges(self):
168        """Returns all ranges as a string.
169
170        The string is formatted as '[min_a, max_a], [min_b, max_b], ...'
171        """
172        ranges = []
173        min_boundary = None
174        for entry in self.entry_list:
175            if entry.direction == 1:
176                min_boundary = entry.level
177            elif entry.direction == 0:
178                ranges.append(str([entry.level, entry.level]))
179            else:
180                ranges.append(str([min_boundary, entry.level]))
181        return ', '.join(ranges)
182
183    def add_fn(self, fn, min_version, max_version):
184        """Adds a function to the VersionSelector for the given API range.
185
186        Args:
187            fn: The function to call when the API level is met.
188            min_version: The minimum version level for calling this function.
189            max_version: The maximum version level for calling this function.
190
191        Raises:
192            ValueError if min_version > max_version or another versioned
193                       function overlaps this new range.
194        """
195        if min_version > max_version:
196            raise ValueError('The minimum API level must be greater than the'
197                             'maximum API level.')
198        insertion_index = bisect.bisect_left(
199            _VersionSelector.ListWrap(self.entry_list), min_version)
200        if insertion_index != len(self.entry_list):
201            right_neighbor = self.entry_list[insertion_index]
202            if not (min_version <= max_version < right_neighbor.level and
203                    right_neighbor.direction != -1):
204                raise ValueError('New range overlaps another API level. '
205                                 'New range: %s, Existing ranges: %s' %
206                                 ([min_version, max_version],
207                                  self.print_ranges()))
208        if min_version == max_version:
209            new_entry = _VersionSelector.Entry(min_version, fn, direction=0)
210            self.entry_list.insert(insertion_index, new_entry)
211        else:
212            # Inserts the 2 entries into the entry list at insertion_index.
213            self.entry_list[insertion_index:insertion_index] = [
214                _VersionSelector.Entry(min_version, fn, direction=1),
215                _VersionSelector.Entry(max_version, fn, direction=-1)]
216
217    def __call__(self, *args, **kwargs):
218        """Calls the proper versioned function for the given API level.
219
220        This is a magic python function that gets called whenever parentheses
221        immediately follow the attribute access (e.g. obj.version_selector()).
222
223        Args:
224            *args, **kwargs: The arguments passed into this call. These
225                             arguments are intended for the decorated function.
226
227        Returns:
228            The result of the called function.
229        """
230        if self.instance is not None:
231            # When the versioned function is a classmethod, the class is passed
232            # into __call__ as the first argument.
233            level = self.get_version(self.instance, *args, **kwargs)
234        else:
235            level = self.get_version(*args, **kwargs)
236        if not isinstance(level, numbers.Number):
237            kwargs_out = []
238            for key, value in kwargs.items():
239                kwargs_out.append('%s=%s' % (key, str(value)))
240            args_out = str(list(args))[1:-1]
241            kwargs_out = ', '.join(kwargs_out)
242            raise ValueError(
243                'The API level the function %s returned %s for the arguments '
244                '(%s). This function must return a number.' %
245                (self.get_version.__qualname__, repr(level),
246                 ', '.join(i for i in [args_out, kwargs_out] if i)))
247
248        index = bisect.bisect_left(_VersionSelector.ListWrap(self.entry_list),
249                                   level)
250
251        # Check to make sure the function being called is within the API range
252        if index == len(self.entry_list):
253            raise NotImplementedError('No function %s exists for API level %s'
254                                      % (self.entry_list[0].func.__qualname__,
255                                         level))
256        closest_entry = self.entry_list[index]
257        if (closest_entry.direction == 0 and closest_entry.level != level or
258                closest_entry.direction == 1 and closest_entry.level > level or
259                closest_entry.direction == -1 and closest_entry.level < level):
260            raise NotImplementedError('No function %s exists for API level %s'
261                                      % (self.entry_list[0].func.__qualname__,
262                                         level))
263
264        func = self.entry_list[index].func
265        if self.instance is None:
266            # __get__ was not called, so the function is module-level.
267            return func(*args, **kwargs)
268
269        return func(self.instance, *args, **kwargs)
270
271    def __get__(self, instance, owner):
272        """Gets the instance and owner whenever this function is obtained.
273
274        These arguments will be used to pass in the self to instance methods.
275        If the function is marked with @staticmethod or @classmethod, those
276        decorators will handle removing self or getting the class, respectively.
277
278        Note that this function will NOT be called on module-level functions.
279
280        Args:
281            instance: The instance of the object this function is being called
282                      from. If this function is static or a classmethod,
283                      instance will be None.
284            owner: The object that owns this function. This is the class object
285                   that defines the function.
286
287        Returns:
288            self, this VersionSelector instance.
289        """
290        self.instance = instance
291        return self
292