import ast from itertools import cycle, chain, islice from collections import namedtuple import operator def explain(code_string, important=None): """Parse a string containing a function-like call, return a namedtuple containing the results >>> explain('mymod.nestmod.func("arg1", "arg2", kw1="kword1", kw2="kword2", *args, **kws') [Call( args=['arg1', 'arg2'], keywords={'kw1': 'kword1', 'kw2': 'kword2'}, starargs='args', func='mymod.nestmod.func', kwargs='kws')] optional 'important' argument is a list of features to parse from the code_string. Features defined for a Call Node: args - positional arguments, keywords - keyword arguments, starargs - excess positional arguments, kwargs - excess keyword arguments, func - chained function attribute lookup. """ node = ast.parse(code_string) visitor = StrNodeVisitor(important) return visitor.visit(node) #-------------------------------------------------------------------- def attrgetter(name): """Get attribute 'name' from object and return a string representation of it.""" getname = operator.attrgetter(name) def str_getattr(self, obj=None): obj = self if obj is None else obj return str(getname(obj)) return str_getattr def strmap(show): """Hardcode a particular ast Node to string representation 'show'.""" return lambda self, node=None: show #-------------------------------------------------------------------- class StrNodeVisitor(ast.NodeVisitor): """A class to return string representations of visited ast nodes.""" visit_Name = attrgetter('id') visit_Num = attrgetter('n') visit_Str = attrgetter('s') # hardcoded these Nodes to return string argument when visited. visit_Add = strmap('+') visit_Sub = strmap('-') visit_Mult = strmap('*') visit_Div = strmap('/') visit_Mod = strmap('%') visit_Pow = strmap('**') visit_LShift = strmap('<<') visit_RShift = strmap('>>') visit_FloorDiv = strmap('//') visit_Not = strmap('not') visit_And = strmap('and') visit_Or = strmap('or') visit_Eq = strmap('==') visit_NotEq = strmap('!=') visit_Lt = strmap('<') visit_LtE = strmap('<=') visit_Gt = strmap('>') visit_GtE = strmap('>=') visit_Is = strmap('is') visit_IsNot = strmap('not is') visit_In = strmap('in') visit_NotIn = strmap('not in') def __init__(self, interested=None): """interested - a sequence of features of a function to include in returned namedtuple. Allowed features: func, args, keywords, starargs, kwargs""" try: self._interested = set(interested) except TypeError: self._interested = interested def visit_Module(self, node): visit = self.visit return [visit(body) for body in node.body] def visit_Expr(self, node): return self.visit(node.value) def visit_Call(self, node): """return a NamedTuple that represents a Call: f(arg, kw=1, *args, **kws). Call node defines: func, args, keywords, starargs, kwargs""" # determine which of the fields we are allowed to handle. defined = set(node._fields) try: interested = self._interested & defined except TypeError: interested = defined fields = {} for field in interested: field_contents = getattr(node, field) if field_contents is None: # short circuit if the node field is a NoneType. fields[field] = None continue # handle the field using one of the convenience functions. fields[field] = getattr(self, field)(field_contents) # return the result as a namedtuple rather than dict. BaseCallTuple = namedtuple(classname(node), interested) class MyCallTuple(BaseCallTuple): """Enable representation in a nicer string format. Don't use this MyCallTuple class if 'func' is not a field. as the string representation relies on it.""" __str__ = CallTuple2Str if 'func' in interested: mytup = MyCallTuple(**fields) else: mytup = BaseCallTuple(**fields) return mytup def visit_List(self, node): """return a string representation of list.""" return self._sequence(node, '[%s]') def visit_Tuple(self, node): """return a string representation of tuple.""" return self._sequence(node, '(%s)') def visit_Dict(self, node): """return a string representation of a dict.""" visit = self.visit keyvals = zip(node.keys, node.values) contents = ', '.join(['%s: %s' % (visit(key), visit(value)) for key, value in keyvals]) return '{%s}' % contents def visit_Attribute(self, node): """Attribute of form: obj.attr.""" return '%s.%s' % (self.visit(node.value), node.attr) def visit_BoolOp(self, node): """BoolOp of form: op values e.g. a and b.""" visit = self.visit op = ' %s ' % visit(node.op) return op.join([visit(n) for n in node.values]) def visit_UnaryOp(self, node): """UnaryOp of form: op operand e.g. not [].""" return '%(op)s %(operand)s' % dict( op=self.visit(node.op), operand=self.visit(node.operand)) def visit_BinOp(self, node): """BinOp of form: left op right e.g. 2 * 3.""" visit = self.visit return '(%(left)s %(op)s %(right)s)' % dict( left=visit(node.left), op=visit(node.op), right=visit(node.right)) def visit_Subscript(self, node): """Subscript of form: value[slice]. e.g. a[1:10:2].""" visit = self.visit return '%s[%s]' % (visit(node.value), visit(node.slice)) def visit_Slice(self, node): """Slice of form: lower:upper:step. e.g. 1:10:2.""" visit = self.visit return '%s:%s:%s' % (visit(node.lower), visit(node.upper), visit(node.step)) def visit_Compare(self, node): """Compare of form: left ops comparators. e.g. x > y > z -> left=x, ops=['>', '>'], comparators=['y', 'z'] """ visit = self.visit rest = ' '.join([visit(r) for r in roundrobin(node.ops, node.comparators)]) return '%s %s' % (visit(node.left), rest) # Convenience functions. def _sequence(self, node, signature): visit = self.visit contents = ', '.join([visit(elt) for elt in node.elts]) return signature % contents def func(self, func): """convenience function called from visit_Call.""" return self.visit(func) def args(self, args): """convenience function called from visit_Call.""" visit = self.visit return [visit(n) for n in args] def keywords(self, keywords): """convenience function called from visit_Call.""" visit = self.visit return dict((kw.arg, visit(kw.value)) for kw in keywords) def starargs(self, starargs): """convenience function called from visit_Call.""" return self.visit(starargs) def kwargs(self, kwargs): """convenience function called from visit_Call.""" return self.visit(kwargs) def generic_visit(self, node): """Called as a fallback handler if all other visit_* functions failed. return ''. if node is NoneType return ''""" if node is None: return '' return '' % classname(node) #-------------------------------------------------------------------- def classname(obj): return obj.__class__.__name__ def roundrobin(*iterables): "roundrobin('ABC', 'D', 'EF') --> A D E B F C" # Recipe credited to George Sakkis pending = len(iterables) nexts = cycle(iter(it).next for it in iterables) while pending: try: for next in nexts: yield next() except StopIteration: pending -= 1 nexts = cycle(islice(nexts, pending)) def CallTuple2Str(self): """replacement for CallTuple's __str__ method. Assumes that func field is present. The print signature should look like: func(args, keywords, *starargs, **kwargs).""" func = self.func order = ['args', 'keywords', 'starargs', 'kwargs'] # handle args. arg_values = getattr(self, 'args', []) args = ', '.join([str(arg) for arg in arg_values]) # handle keywords. kw_values = getattr(self, 'keywords', {}) keywords = ', '.join(['%s=%s' % (k, v) for k, v in kw_values.items()]) # handle starargs. star = getattr(self, 'starargs', None) if star: starargs = '*%s' % star else: starargs = '' # handle kwargs. kwargs = getattr(self, 'kwargs', None) if kwargs: kwargs = '**%s' % kwargs else: kwargs = '' # put it all together. arguments = [args, keywords, starargs, kwargs] signature = ', '.join([arg for arg in arguments if arg != '']) return '%s(%s)' % (func, signature) #-------------------------------------------------------------------- if __name__ == '__main__': tests = dict( tuple_test = "mod1.f_tuple((1,2), kw1=(1,2))", list_test = "f_list([1,2], kw1=[1,2])", dict_test = "f_dict({1:2}, kw1={1:2})", complex_test = "f_complex(1 + 2j, kw1=1 + 2j)", fn_test = "f_func(abs(-1), kw1=explain('f1(2, 3)'))", bool_test = "f_bool(True, False, not [], hello or 'hello')", slice_test = "f_slice(a[:2], b=b[1:2])", lambda_test = "f_lambda(lambda x: x)", compare_test = "f_compare(x > y > z not in [True])", genexp_test = "f_genexp([a for a in range(2)], b=(b for b in range(2)))") for name, test in tests.iteritems(): print '%s: %s' % (name, explain(test, ['func', 'keywords', 'args'])[0])