#!/usr/bin/env python3 #-*- coding: iso-8859-1 -*- ################################################################################ # # Function guards for Python 3. # # (c) 2016, Dmitry Dvoinikov <dmitry@targeted.org> # Distributed under MIT license. # # Samples: # # from funcguard import guard # # @guard # def abs(a, _when = "a >= 0"): # return a # # @guard # def abs(a, _when = "a < 0"): # return -a # # assert abs(1) == abs(-1) == 1 # # @guard # def factorial(n): # no _when expression => default # return 1 # # @guard # def factorial(n, _when = "n > 1"): # return n * factorial(n - 1) # # assert factorial(10) == 3628800 # # class TypeTeller: # @staticmethod # @guard # def typeof(value, _when = "isinstance(value, int)"): # return int # @staticmethod # @guard # def typeof(value, _when = "isinstance(value, str)"): # return str # # assert TypeTeller.typeof(0) is int # TypeTeller.typeof(0.0) # throws # # class AllowedProcessor: # def __init__(self, allowed): # self._allowed = allowed # @guard # def process(self, value, _when = "value in self._allowed"): # return "ok" # @guard # def process(self, value): # no _when expression => default # return "fail" # # ap = AllowedProcessor({1, 2, 3}) # assert ap.process(1) == "ok" # assert ap.process(0) == "fail" # # guard.default_eval_args( # values to insert to all guards scopes # office_hours = lambda: 9 <= datetime.now().hour < 18) # # @guard # def at_work(*args, _when = "office_hours()", **kwargs): # print("welcome") # # @guard # def at_work(*args, **kwargs): # print("come back tomorrow") # # at_work() # either "welcome" or "come back tomorrow" # # The complete source code with self-tests is available from: # https://github.com/targeted/funcguard # ################################################################################ __all__ = [ "guard", "GuardException", "IncompatibleFunctionsException", "FunctionArgumentsMatchException", "GuardExpressionException", "DuplicateDefaultGuardException", "GuardEvalException", "NoMatchingFunctionException" ] ################################################################################ import inspect; from inspect import getfullargspec import functools; from functools import wraps import sys; from sys import modules try: (lambda: None).__qualname__ except AttributeError: import qualname; from qualname import qualname # prior to Python 3.3 workaround else: qualname = lambda f: f.__qualname__ ################################################################################ class GuardException(Exception): pass class IncompatibleFunctionsException(GuardException): pass class FunctionArgumentsMatchException(GuardException): pass class GuardExpressionException(GuardException): pass class DuplicateDefaultGuardException(GuardException): pass class GuardEvalException(GuardException): pass class NoMatchingFunctionException(GuardException): pass ################################################################################ # takes an argument specification for a function and a set of actual call # positional and keyword arguments, returns a flat namespace-like dict # mapping parameter names to their actual values def _eval_args(argspec, args, kwargs): # match positional arguments matched_args = {} expected_args = argspec.args default_args = argspec.defaults or () _many = lambda t: "argument" + ("s" if len(t) != 1 else "") # copy provided args to expected, append defaults if necessary for i, name in enumerate(expected_args): if i < len(args): value = args[i] elif i >= len(expected_args) - len(default_args): value = argspec.defaults[i - len(expected_args) + len(default_args)] else: missing_args = expected_args[len(args):len(expected_args) - len(default_args)] raise FunctionArgumentsMatchException("missing required positional {0:s}: {1:s}".\ format(_many(missing_args), ", ".join(missing_args))) matched_args[name] = value # put extra provided args to *args if the function allows if argspec.varargs: matched_args[argspec.varargs] = args[len(expected_args):] if len(args) > len(expected_args) else () elif len(args) > len(expected_args): raise FunctionArgumentsMatchException( "takes {0:d} positional {1:s} but {2:d} {3:s} given". format(len(expected_args), _many(expected_args), len(args), len(args) == 1 and "was" or "were")) # match keyword arguments matched_kwargs = {} expected_kwargs = argspec.kwonlyargs default_kwargs = argspec.kwonlydefaults or {} # extract expected kwargs from provided, using defaults if necessary missing_kwargs = [] for name in expected_kwargs: if name in kwargs: matched_kwargs[name] = kwargs[name] elif name in default_kwargs: matched_kwargs[name] = default_kwargs[name] else: missing_kwargs.append(name) if missing_kwargs: raise FunctionArgumentsMatchException("missing required keyword {0:s}: {1:s}".\ format(_many(missing_kwargs), ", ".join(missing_kwargs))) extra_kwarg_names = [ name for name in kwargs if name not in matched_kwargs ] if argspec.varkw: if extra_kwarg_names: extra_kwargs = { name: kwargs[name] for name in extra_kwarg_names } else: extra_kwargs = {} matched_args[argspec.varkw] = extra_kwargs elif extra_kwarg_names: raise FunctionArgumentsMatchException("got unexpected keyword {0:s}: {1:s}".\ format(_many(extra_kwarg_names), ", ".join(extra_kwarg_names))) # both positional and keyword argument are returned in the same scope-like dict for name, value in matched_kwargs.items(): matched_args[name] = value return matched_args ################################################################################ # takes an argument specification for a function, from it extracts and returns # a compiled expression which is to be matched against call arguments def _get_guard_expr(func_name, argspec): guard_expr_text = None if "_when" in argspec.args: defaults = argspec.defaults or () i = argspec.args.index("_when") if i >= len(argspec.args) - len(defaults): guard_expr_text = defaults[i - len(argspec.args) + len(defaults)] elif "_when" in argspec.kwonlyargs: guard_expr_text = (argspec.kwonlydefaults or {}).get("_when") else: return None # indicates default guard if guard_expr_text is None: raise GuardExpressionException("guarded function {0:s}() requires a \"_when\" " "argument with guard expression text as its " "default value".format(func_name)) try: guard_expr = compile(guard_expr_text, func_name, "eval") except Exception as e: error = str(e) else: error = None if error is not None: raise GuardExpressionException("invalid guard expression for {0:s}(): " "{1:s}".format(func_name, error)) return guard_expr ################################################################################ # checks whether two functions' argspecs are compatible to be guarded as one, # compatible argspecs have identical positional and keyword parameters except # for "_when" and annotations def _compatible_argspecs(argspec1, argspec2): return _stripped_argspec(argspec1) == _stripped_argspec(argspec2) def _stripped_argspec(argspec): args = argspec.args[:] defaults = list(argspec.defaults or ()) kwonlyargs = argspec.kwonlyargs[:] kwonlydefaults = (argspec.kwonlydefaults or {}).copy() if "_when" in args: i = args.index("_when") if i >= len(args) - len(defaults): del defaults[i - len(args) + len(defaults)] del args[i] elif "_when" in kwonlyargs and "_when" in kwonlydefaults: i = kwonlyargs.index("_when") del kwonlyargs[i] del kwonlydefaults["_when"] return (args, defaults, kwonlyargs, kwonlydefaults, argspec.varargs, argspec.varkw) ################################################################################ def guard(func, module = None): # the main decorator function # see if it is a function of a lambda try: eval(func.__name__) except SyntaxError: return func # <lambda> => not guarded except NameError: pass # valid name # get to the bottom of a possible decorator chain # to get the original function's specification original_func = func while hasattr(original_func, "__wrapped__"): original_func = original_func.__wrapped__ func_name = qualname(original_func) func_module = module or modules[func.__module__] # module serves only as a place to keep state argspec = getfullargspec(original_func) # the registry of known guarded function is attached to the module containg them guarded_functions = getattr(func_module, "__guarded_functions__", None) if guarded_functions is None: guarded_functions = func_module.__guarded_functions__ = {} original_argspec, first_guard, last_guard = guard_info = \ guarded_functions.setdefault(func_name, [argspec, None, None]) # all the guarded functions with the same name must have identical signature if argspec is not original_argspec and not _compatible_argspecs(argspec, original_argspec): raise IncompatibleFunctionsException("function signature is incompatible " "with the previosly registered {0:s}()".format(func_name)) @wraps(func) def func_guard(*args, **kwargs): # the call proxy function # since all versions of the function have essentially identical signatures, # their mapping to the actually provided arguments can be calculated once # for each call and not against every version of the function try: eval_args = _eval_args(argspec, args, kwargs) except FunctionArgumentsMatchException as e: error = str(e) else: error = None if error is not None: raise FunctionArgumentsMatchException("{0:s}() {1:s}".format(func_name, error)) for name, value in guard.__default_eval_args__.items(): eval_args.setdefault(name, value) # walk the chain of function versions starting with the first, looking # for the one for which the guard expression evaluates to truth current_guard = func_guard.__first_guard__ while current_guard: try: if not current_guard.__guard_expr__ or \ eval(current_guard.__guard_expr__, globals(), eval_args): break except Exception as e: error = str(e) else: error = None if error is not None: raise GuardEvalException("guard expression evaluation failed for " "{0:s}(): {1:s}".format(func_name, error)) current_guard = current_guard.__next_guard__ else: raise NoMatchingFunctionException("none of the guard expressions for {0:s}() " "matched the call arguments".format(func_name)) return current_guard.__wrapped__(*args, **kwargs) # call the winning function version # in different version of Python @wraps behaves differently with regards # to __wrapped__, therefore we set it the way we need it here func_guard.__wrapped__ = func # the guard expression is attached func_guard.__guard_expr__ = _get_guard_expr(func_name, argspec) # maintain a linked list for all versions of the function if last_guard and not last_guard.__guard_expr__: # the list is not empty and the # last guard is already a default if not func_guard.__guard_expr__: raise DuplicateDefaultGuardException("the default version of {0:s}() has already " "been specified".format(func_name)) # the new guard has to be inserted one before the last if first_guard is last_guard: # the list contains just one guard # new becomes first, last is not changed first_guard.__first_guard__ = func_guard.__first_guard__ = func_guard func_guard.__next_guard__ = first_guard first_guard = guard_info[1] = func_guard else: # the list contains more than one guard # neither first nor last are changed prev_guard = first_guard while prev_guard.__next_guard__ is not last_guard: prev_guard = prev_guard.__next_guard__ func_guard.__first_guard__ = first_guard func_guard.__next_guard__ = last_guard prev_guard.__next_guard__ = func_guard else: # the new guard is inserted last if not first_guard: first_guard = guard_info[1] = func_guard func_guard.__first_guard__ = first_guard func_guard.__next_guard__ = None if last_guard: last_guard.__next_guard__ = func_guard last_guard = guard_info[2] = func_guard return func_guard guard.__default_eval_args__ = {} guard.default_eval_args = lambda *args, **kwargs: guard.__default_eval_args__.update(*args, **kwargs) ################################################################################ # EOF