Welcome, guest | Sign In | My Account | Store | Cart

Function that work like an "where statement" for generator expression. The code below

x, y, z = 1, 2, 3
((x, y, z) for _ in range(5))

Is equivalent to:

((x, y, z) for _ in range(5)) < where(x=1, y=2, z=3)
Python, 104 lines
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
import types
import opcode


def _replace_globals_and_closures(generator, **constants):
    gi_code = generator.gi_code
    new_code = list(gi_code.co_code)
    new_consts = list(gi_code.co_consts)
    locals = generator.gi_frame.f_locals
    freevars = list(gi_code.co_freevars)

    # Replace global lookups by the values defined in *constants*.
    i = 0
    while i < len(new_code):
        op_code = new_code[i]
        if op_code == opcode.opmap['LOAD_GLOBAL']:
            oparg = new_code[i + 1] + (new_code[i + 2] << 8)
            name = gi_code.co_names[oparg]
            if name in constants:
                value = constants[name]
                for pos, v in enumerate(new_consts):
                    if v is value:
                        break
                else:
                    pos = len(new_consts)
                    new_consts.append(value)
                new_code[i] = opcode.opmap['LOAD_CONST']
                new_code[i + 1] = pos & 0xFF
                new_code[i + 2] = pos >> 8
        i += 1
        if op_code >= opcode.HAVE_ARGUMENT:
            i += 2

    # Repalce closures lookups by the values defined in *constants*
    i = 0
    while i < len(new_code):
        op_code = new_code[i]
        if op_code == opcode.opmap['LOAD_DEREF']:
            oparg = new_code[i + 1] + (new_code[i + 2] << 8)
            name = freevars[oparg]
            if name in constants:
                value = constants[name]
                for pos, v in enumerate(new_consts):
                    if v is value:
                        break
                else:
                    pos = len(new_consts)
                    new_consts.append(value)
                new_code[i] = opcode.opmap['LOAD_CONST']
                new_code[i + 1] = pos & 0xFF
                new_code[i + 2] = pos >> 8
            if name in locals:
                del locals[name]
                freevars.remove(name)
        i += 1
        if op_code >= opcode.HAVE_ARGUMENT:
            i += 2

    code_str = ''.join(map(chr, new_code))
    code_object = types.CodeType(
        gi_code.co_argcount,
        gi_code.co_kwonlyargcount,
        gi_code.co_nlocals,
        gi_code.co_stacksize,
        gi_code.co_flags,
        bytes(code_str, 'utf-8'),
        tuple(new_consts),
        gi_code.co_names,
        gi_code.co_varnames,
        gi_code.co_filename,
        gi_code.co_name,
        gi_code.co_firstlineno,
        gi_code.co_lnotab,
        tuple(freevars),
        gi_code.co_cellvars)

    function = types.FunctionType(
        code_object,
        generator.gi_frame.f_globals,
        generator.__name__,
    )

    return function(**locals)


class WhereType:
    """Implement the *<* operator that apply the function to the generator."""

    def __gt__(self, other):
        return _replace_globals_and_closures(other, **self.constants)

    def __call__(self, **constants):
        self.constants = constants
        return self


# !!! The where function
where = WhereType()


if __name__ == '__main__':
    print(">>> gen = ((x, y, z) for _ in range(5)) < where(x=1, y=2, z=3)")
    print(">>> list(gen)")
    print(list(((x, y, z) for _ in range(5)) < where(x=1, y=2, z=3)))

This recipe was inspired by the recipe of Raymond Hettinger Decorator for BindingConstants at compile time (Python recipe)