Welcome, guest | Sign In | My Account | Store | Cart
"""
Allow coroutines which yield in nested routines.

"""

"""
Based on the abstract syntax tree as defined in section 32 of the Python
library documentation (http://docs.python.org/lib/node892.html).
"""
from __future__ import with_statement
import _ast
class indent(object):
    def __init__(self, t):
        self.t = t
        
    def __enter__(self):
        self.t.indent()

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.t.deindent()
            
class suspend_transform(object):
    def __init__(self, t, f=None):
        self.t = t
        if f is None:
            self.f = self.False
        else:
            self.f = f

    def False(self):
        return False
    
    def __enter__(self):
        self.oldtransform = self.t.transform
        self.t.transform = self.t.transform and self.f()

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.t.transform = self.oldtransform
        
class bracket(object):
    def __init__(self, t, l):
        assert len(l) == 2
        self.t = t
        self.l = l

    def __enter__(self):
        self.t.f.write(self.l[0])

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.t.f.write(self.l[1])
        
class transform(object):
    def __init__(self, ast, f, firstFunction):
        self.f = f
        if firstFunction:
            self.f.write("from __future__ import with_statement\n")
        self.ast = ast
        self._indent = ''
        self.transform = True

    def indent(self):
        self._indent = self._indent + '    '

    def deindent(self):
        self._indent=self._indent[4:]

    def write_indent(self):
        self.f.write(self._indent)

    def newline(self):
        self.f.write("\n")
        
    def newline_and_write_indent(self):
        self.newline()
        self.write_indent()
        
    def orelse(self, s):
        if s.orelse:
            self.newline_and_write_indent()
            self.write("else:")
            with indent(self):
                self.newline_and_write_indent()
                self._dispatch(s.orelse)

    def _dispatch(self, ast):
        return self.dispatch[ast.__class__](self, ast)

    def _dispatch(self, ast):
        return getattr(self, ast.__class__.__name__) (ast)
    
    def doit(self):
        self._dispatch(self.ast)
        
    def Add(self, op):    self.f.write("+")
    def And(self, boolop): self.f.write(" and ")
    def Eq(self, eq): self.f.write("==")
    def Gt(self, gt): self.f.write(">")
    def Mod(self, m): self.f.write("%")
    def GtE(self, gte): self.f.write(">+")
    def In(self, i): self.f.write(" in  ")
    def Is(self, i): self.f.write(" is ")
    def IsNot(self, isnot): self.f.write(" is not ")
    def Lt(self, lt): self.f.write("<")
    def LtE(self, lte): self.f.write("<=")
    def Mult(self, op):   self.f.write("*")
    def Not(self, op): self.f.write(" not ")
    def NotEq(self, noteq): self.f.write("!=")
    def NotIn(self, notin): self.f.write(" not in ")
    def Or(self, boolop): self.f.write(" or ")
    def Sub(self, op):    self.f.write("-")
        
    def alias(self, a):
        self.f.write(a.name)
        if a.asname:
            self.f.write(" as ")
            self.f.write(a.asname)
            
    def arguments(self, arguments):
        f = self.f
        _dispatch = self._dispatch
        write = f.write
        first = True
        for arg, default in map(None, arguments.args, arguments.defaults):
            if first:
                first = False
            else:
                write(",")
            _dispatch(arg)
            if default:
                write("=")
                _dispatch(default)
        vararg = arguments.vararg
        if vararg:
            if first:
                first = False
            else:
                write(",")
            write("*")
            write(vararg)
        kwarg = arguments.kwarg
        if kwarg:
            if first:
                first = False
            else:
                write(",")
            write("**")
            write(kwarg)

    op_dict = {
        _ast.Add: '+=',
        _ast.Sub: '-=',
        _ast.Mult: '*=',
        _ast.Div: '/=',
        _ast.Mod: '%=',
        _ast.Pow: '**=',
        _ast.RShift: '>>=',
        _ast.LShift: '<<=',
        _ast.BitOr: '|=',
        _ast.BitXor: '^=',
        _ast.BitAnd: '&=',
        }
    
    def AugAssign(self, a):
        self._dispatch(a.target)
        self.f.write(self.op_dict[a.op.__class__])
        self._dispatch(a.value)
        
    def Assert(self, a):
        self.f.write("assert ")
        self._dispatch(a.test)
        if a.msg:
            self.f.write(",")
            self._dispatch(a.msg)

    def Assign(self, assign):
        for i, target in enumerate(assign.targets):
            if i > 0:
                self.f.write(",")
            self._dispatch(target)
        self.f.write("=")
        self._dispatch(assign.value)

    def Attribute(self, attribute):
        self._dispatch(attribute.value)
        self.f.write(".")
        self.f.write(attribute.attr)

    def BinOp(self, binop):
        for p in (binop.left, binop.op, binop.right):
            self._dispatch(p)

    def BoolOp(self, boolOp):
        with bracket(self, "()"):
            first = True
            for value in boolOp.values:
                if first:
                    first = False
                else:
                    self._dispatch(boolOp.op)
                self.f.write("("); self._dispatch(value); self.f.write(")")

    def Break(self, b):
        self.f.write("break")
        
    def Call(self, call):
        def newtransform():
            if (isinstance(call.func, _ast.Name) and
                call.func.id in __builtins__):
                return False
            return True

        def transform_regular_arguments(MustBeAList):
            if MustBeAList and not call.args:
                self.f.write("tuple()")
            else:
                with bracket(self, "()"):
                    for expr in call.args:
                        if self.first:
                            self.first = False
                        else:
                            self.f.write(",")
                        self._dispatch(expr)
                    if MustBeAList and (len(call.args) == 1):
                        self.f.write(",")

        def transform_keyword_arguments():
            self.f.write(",")
            with bracket(self, "{}"):
                for i, keyword in enumerate(call.keywords):
                    if i > 0:
                        self.f.write(",")
                    self.f.write('"')
                    self.f.write(keyword.arg)
                    self.f.write('":')
                    self._dispatch(keyword.value)

        def transform_starargs():
            self.f.write(",")
            self._dispatch(call.starargs)

        def transform_kwargs():
            self.f.write(",")
            self._dispatch(call.kwargs)
            
        with suspend_transform(self, newtransform):
            self.first = True
            if self.transform:
                if (call.starargs or call.kwargs):
                    c = 'YIELD_CALL'
                elif not call.keywords:
                    c = 'YIELD_SIMPLECALL'
                else:
                    c = 'YIELD_CALL_WITH_KEYWORDS'
                self.f.write("(yield self.__class__.__metaclass__.%s,(" % c)
                self._dispatch(call.func)
                self.f.write(",")
                if c == 'YIELD_SIMPLECALL':
                    transform_regular_arguments(True)
                elif c == 'YIELD_CALL_WITH_KEYWORDS':
                    transform_regular_arguments(True)
                    transform_keyword_arguments()
                elif c == 'YIELD_CALL':
                    transform_regular_arguments(True)
                    transform_keyword_arguments()
                    transform_starargs()
                    transform_kwargs()
                self.f.write("))")
            else:
                self._dispatch(call.func)
                transform_regular_arguments(False)


    def ClassDef(self, cd):
        self.f.write("class ")
        self.f.write(cd.name)
        self.f.write("(")
        newtransform = False
        for i, base in enumerate(cd.bases):
            if i == 0:
                newtransform = base == 'coroutine'
                # not quite correct, but works in simple cases
            else:
                self.f.write(",")
            self._dispatch(base)
        self.f.write("):")
        with indent(self):
            self.newline_and_write_indent()
            self.f.write("pass")
            with suspend_transform(self):
                for stmt in cd.body:
                    self.newline_and_write_indent()
                    self._dispatch(stmt)
        self.newline()

    def Compare(self, compare):
       self._dispatch(compare.left)
       for op, expr in map(None, compare.ops, compare.comparators):
           self._dispatch(op)
           self._dispatch(expr)

    def comprehension(self, c):
        self.f.write(" for ")
        self._dispatch(c.target)
        self.f.write(" in ")
        self._dispatch(c.iter)
        for e in c.ifs:
            self._dispatch(e)
            
    def Continue(self, c):
        self.f.write("continue")
        
    def Delete(self, d):
        self.f.write("del ")
        for target in d.targets:
            self._dispatch(target)
            self.f.write(",")

    def Dict(self, d):
        with bracket(self, "{}"):
            with indent(self):
                for key, value in zip(d.keys, d.values):
                    self.newline_and_write_indent()
                    self._dispatch(key)
                    self.f.write(":")
                    self._dispatch(value)
                    self.f.write(",")
            self.newline_and_write_indent()

    def excepthandler(self, eh):
        self.newline_and_write_indent()
        self.f.write("except ")
        if eh.type:
            self._dispatch(eh.type)
        if eh.name:
            self.f.write(",")
            self._dispatch(eh.name)
        self.f.write(":")
        with indent(self):
            for b in eh.body:
                self.newline_and_write_indent()
                self._dispatch(b)

    def Exec(self, e):
        self.f.write("exec ")
        self._dispatch(e.body)
        if e.globals:
            self.f.write(" in ")
            self._dispatch(e.globals)
        if e.locals:
            self.f.write(" in ")
            self._dispatch(e.locals)

    def GeneratorExp(self, ge):
        with bracket(self, "()"):
            self._dispatch(ge.elt)
            for generator in ge.generators:
                self._dispatch(generator)
        
    def For(self, f):
        self.f.write("for ")
        self._dispatch(f.target)
        self.f.write(" in ")
        self._dispatch(f.iter)
        self.f.write(":")
        with indent(self):
            for b in f.body:
                self.newline_and_write_indent()
                self._dispatch(b)
        self.orelse(f)

    def Import(self, i):
        self.f.write("import ")
        for ind, a in enumerate(i.names):
            if ind > 0:
                self.f.write(",")
            self._dispatch(a)

    def ImportFrom(self, ia):
        self.f.write("from ")
        self.f.write(ia.module)
        self.f.write(" import ")
        for a in ia.names:
            self._dispatch(a)
        
    def Expr(self, expr):
        self._dispatch(expr.value)

    def FunctionDef(self, functiondef):
        def new_transform():
            if functiondef.name.startswith("__"):
                return False
            return len(functiondef.decorators) == 0
        
        with suspend_transform(self):
            for expr in functiondef.decorators:
                self.f.write("@")
                self_dispatch(expr)
                self.newline_and_write_indent()
        self.f.write("def ")
        self.f.write(functiondef.name)
        self.f.write("(")
        self._dispatch(functiondef.args)
        self.f.write("):")
        with indent(self):
            with suspend_transform(self, new_transform):
                # __init__ must remain a function
                for stmt in functiondef.body:
                    self.newline_and_write_indent()
                    self._dispatch(stmt)
                self.newline_and_write_indent()
                if self.transform:
                    self.f.write("yield (self.__class__.__metaclass__.YIELD_RETURN, None)")
                else:
                    self.f.write("pass")

    def Global(self, g):
        self.f.write("global ")
        self.f.write(','.join(g.names))

    def If(self, ast):
        self.f.write("if ")
        self._dispatch(ast.test)
        self.f.write(":")
        with indent(self):
            for stmt in ast.body:
                self.newline_and_write_indent()
                self._dispatch(stmt)
        if ast.orelse:
            self.newline_and_write_indent()
            self.f.write("else:")
            with indent(self):
                for stmt in ast.orelse:
                    self.newline_and_write_indent()
                    self._dispatch(stmt)
        self.newline()

    def IfExp(self, i):
        with bracket(self, "()"):
            self._dispatch(i.test)
            self.f.write(" if ")
            self._dispatch(i.body)
            self.f.write(" else ")
            self._dispatch(i.orelse)
        
    def Index(self, i):
        self._dispatch(i.value)

    def keyword(self, k):
        self.f.write(k.arg)
        self.f.write("=")
        self._dispatch(k.value)
        
    def Lambda(self, l):
        self.f.write("lambda ")
        self._dispatch(l.args)
        self.f.write(":")
        self._dispatch(l.body)

    def List(self, l):
        with bracket(self, "[]"):
            for i, e in enumerate(l.elts):
                if i > 0:
                    self.f.write(",")
                self._dispatch(e)
        
    def ListComp(self, lc):
        with bracket(self, "[]"):
            self._dispatch(lc.elt)
            for g in lc.generators:
                self._dispatch(g)
        
    def Module(self, module):
        for stmt in module.body:
            self.newline_and_write_indent()
            self._dispatch(stmt)

    def Name(self, name):
        self.f.write(name.id)

    def Num(self, num):
        self.f.write(str(num.n))
        
    def Pass(self, ast):
        self.f.write("pass")

    def Print(self, print_):
        self.f.write("print ")
        dest = print_.dest
        if dest:
            self.f.write(">>")
            self._dispatch(dest)
            first = False
        else:
            first = True
        for value in print_.values:
            if first:
                first = False
            else:
                self.f.write(",")
            self._dispatch(value)
        nl = print_.nl
        if nl:
            pass
        else:
            self.f.write(",")

    def Raise(self, r):
        self.f.write("raise ")
        if r.type:
            self._dispatch(r.type)
            if r.inst:
                self.f.write(",")
                self._dispatch(r.inst)
                if r.tback:
                    self.f.write(",")
                    self._dispatch(r.tback)
                    
    def Repr(self, r):
        with bracket(self, "``"):
            self._dispatch(r.value)
        
    def Return(self, ret):
        write = self.f.write
        value = ret.value
        if value:
            write("yield (self.__class__.__metaclass__.YIELD_RETURN,(")
            self._dispatch(value)
            write("))")
        else:
            write("(yield (self.__class__.__metaclass__.YIELD_RETURN, None))")

    def Str(self, s):
        self.f.write(repr(s.s))

    def Subscript(self, s):
        self._dispatch(s.value)
        with bracket(self, "[]"):
            self._dispatch(s.slice)

    def TryExcept(self, te):
        self.f.write("try:")
        with indent(self):
            for b in te.body:
                self.newline_and_write_indent()
                self._dispatch(b)
        with suspend_transform(self):
            for eh in te.handlers:
                self._dispatch(eh)
            self.orelse(te)

    def TryFinally(self, tf):
        with suspend_transform(self):
            self.f.write("try:")
            with indent(self):
                for b in tf.body:
                    self.newline_and_write_indent()
                    self._dispatch(b)
            self.newline_and_write_indent()
            self.f.write("finally:")
            with indent(self):
                for b in tf.finalbody:
                    self.newline_and_write_indent()
                    self._dispatch(b)
                
    def Tuple(self, t):
        with bracket(self, "()"):
            for i, exp in enumerate(t.elts):
                if i > 0:
                    self.f.write(",")
                self._dispatch(exp)
            if i == 0:
                self.f.write(",")

    def UnaryOp(self, uo):
        self._dispatch(uo.op)
        self._dispatch(uo.operand)
        
    def While(self, w):
        self.f.write("while ")
        self._dispatch(w.test)
        self.f.write(":")
        with indent(self):
            for b in w.body:
                self.newline_and_write_indent()
                self._dispatch(b)
        self.orelse(w)

    def With(self, w):
        self.f.write("with ")
        self._dispatch(w.context_expr)
        if w.optional_vars:
            self.write(" as ")
            self._dispatch(w.optional_vars)
        self.f.write(":")
        with indent(self):
            for b in w.body:
                self.newline_and_write_indent()
                self._dispatch(b)

    def Yield(self, y):
        with bracket(self, "()"):
            self.f.write("yield ")
            if y.value:
                self._dispatch(y.value)

import __future__
import _ast, inspect, string, os
from StringIO import StringIO
from types import FunctionType
from pprint import pprint

from transform_source import transform

class coroutine_metaclass(type):
    filesdir = os.path.join(os.getcwd(), "transformed")
    if not os.path.exists(filesdir):
        os.mkdir(filesdir)
    def __new__(mcl, classname, bases, classdict):
        filesdir = coroutine_metaclass.filesdir
        newdict = {}
        filename = os.path.join(filesdir, classname + ".py")
        f = file(filename, "w")
        to_replace = []
        firstFunction = True
        for key, value in classdict.items():
            newdict[key] = value
            if type(value) == FunctionType:
                is_generator_function = (value.func_code.co_flags & 0x20) != 0
                if not is_generator_function:
                    to_replace.append(key)
                    source = inspect.getsourcelines(value)
                    sourcelines = source[0]
                    firstline = sourcelines[0]
                    i = 0
                    line = sourcelines[0]
                    while line[i] in string.whitespace:
                        i += 1
                    sourcelines = [line[i:] for line in sourcelines]
                    ast = compile(''.join(sourcelines), inspect.getfile(value), 'exec', _ast.PyCF_ONLY_AST + __future__.CO_FUTURE_WITH_STATEMENT)
                    d = transform(ast, f, firstFunction).doit()
                    firstFunction = False
        f.close()
        if to_replace:
            import sys
            if sys.path[0] != filesdir:
                sys.path.insert(0, filesdir)
            if 0: import pdb; pdb.set_trace()
            replacements = __import__(classname)
            for key in to_replace:
                newdict[key] = getattr(replacements, key)
                newdict[key].transformed_by_coroutine_metaclass = True
                # Mark the transormed functions
        return super(coroutine_metaclass, mcl).__new__(mcl, classname, bases, newdict)

    YIELD_RETURN             = 0
    YIELD_SIMPLECALL         = 1
    YIELD_CALL_WITH_KEYWORDS = 2
    YIELD_CALL               = 3
    YIELD_RERAISE            = 4
    
class ReturnValue(Exception):
    def __init__(self, result, mthread):
        self.result = result
        self.mthread = mthread
        
class mthread(object):
    def __init__(self, code, args):
        self._frames = []
        self.tickactions = [None]*5
        self.tickactions[coroutine.YIELD_RETURN] = self.yieldreturn
        self.tickactions[coroutine.YIELD_SIMPLECALL]  = self.yieldsimplecall
        self.tickactions[coroutine.YIELD_CALL_WITH_KEYWORDS] = self.yieldcallWithKeywords
        self.tickactions[coroutine.YIELD_CALL] = self.yieldcall
        self.tickactions[coroutine.YIELD_RERAISE] = self.yieldreraise
        self.i = 2
        self.result = (code, args)
        self.yieldsimplecall()

    def call_common(self, code, g):
        self._frames.append(g)
        if hasattr(code, "transformed_by_coroutine_metaclass"):
            if hasattr(code, "is_generator"):
                self.i, self.result = generator_proxy(code)
            else:
                self.i, self.result = g.next()
        else:
            # regular function call
            self.i, self.result = coroutine.YIELD_RETURN, g

    def yieldsimplecall(self):
        code, args = self.result
        g = code(*args)
        self.call_common(code, g)

    def yieldreturn(self):
        self._frames.pop()
        if self._frames:
            self.i, self.result = self._frames[-1].send(self.result)
        else:
            raise ReturnValue(self.result, self)

    def yieldyield(self):
        """
        I think that we could integrate generator functions, but then the runtime would
        be quite a bit more complicated.
        """
        raise NotImplementedError

    def yieldreraise(self):
        try:
            self.i, self.result = self._frames[-1].throw(self.result.__class__, self.result)
        except Exception, ex:
            self._frames.pop()
            if self._frames:
                self.i, self.result = coroutine.YIELD_RERAISE, ex
            else:
                raise
            
    def yieldcallWithKeywords(self):
        code, regularargs, keywords = self.result
        g = code(*regularargs, **keywords)
        self.call_common(code, g)

    def yieldcall(self):
        code, args1, keywords1, args2, keywords2 = self.result
        keywords = keywords1.copy()
        keywords.update(keywords2)
        g = code(*(args1 + args2), **keywords)
        self.call_common(code, g)
        
    def tick(self):
        try:
            self.tickactions[self.i]()
        except ReturnValue:
            raise
        except Exception, ex:
            if self._frames:
                self._frames.pop()
                if self._frames:
                    self.i, self.result = coroutine.YIELD_RERAISE, ex
                else:
                    raise
            else:
                raise

class coroutine_runtime(object):
    def __init__(self):
        self.mthreads = []

    def more(self):
        return bool(self.mthreads)
    
    def call(self, method, args):
        self.mthreads.append(mthread(method, args))

    def tick(self):
        try:
            for mthread in self.mthreads:
                mthread.tick()
                
        except ReturnValue, rv:
            self.mthreads.remove(mthread)
            raise
    
class coroutine(object):
    __metaclass__ = coroutine_metaclass

def run_threads(runtime):
    try:
        while 1:
            runtime.tick()
    except ReturnValue, rv:
        return rv.result
    
def simple_call(f, args):
    """
    Shows how a coroutine can be called.
    """
    runtime = coroutine_runtime()
    runtime.call(f, args)
    while runtime.more():
        run_threads(runtime)

if __name__ == '__main__':
    # a small usage example:
    class ack(coroutine):
        def outer(self, m, n):
            result = self.ack(m, n)
            if 1:
                print "Ackermann ", m, ",", n, "=", result
        def ack(self, m, n):
            if m == 0:
                return n + 1
            if m > 0 and n == 0:
                result = self.ack(m-1, 1)
                return result
            return self.ack(m-1, self.ack(m, n-1))

    runtime = coroutine_runtime()
    
    for i, j in ( (0 , 0), (0 , 1), (1 , 0), (1 , 1), (2 , 0), (2 , 1), (3 , 0), (3 , 1), (4 , 0), (3 , 2)):
            runtime.call(ack().outer, (i, j))
    while runtime.more():
        run_threads(runtime)

History

  • revision 2 (17 years ago)
  • previous revisions are not available