import types
import opcode
def _replace_globals_and_closures(generator, **constants):
gi_code = generator.gi_code
new_code = list(gi_code.co_code)
new_consts = list(gi_code.co_consts)
locals = generator.gi_frame.f_locals
freevars = list(gi_code.co_freevars)
# Replace global lookups by the values defined in *constants*.
i = 0
while i < len(new_code):
op_code = new_code[i]
if op_code == opcode.opmap['LOAD_GLOBAL']:
oparg = new_code[i + 1] + (new_code[i + 2] << 8)
name = gi_code.co_names[oparg]
if name in constants:
value = constants[name]
for pos, v in enumerate(new_consts):
if v is value:
break
else:
pos = len(new_consts)
new_consts.append(value)
new_code[i] = opcode.opmap['LOAD_CONST']
new_code[i + 1] = pos & 0xFF
new_code[i + 2] = pos >> 8
i += 1
if op_code >= opcode.HAVE_ARGUMENT:
i += 2
# Repalce closures lookups by the values defined in *constants*
i = 0
while i < len(new_code):
op_code = new_code[i]
if op_code == opcode.opmap['LOAD_DEREF']:
oparg = new_code[i + 1] + (new_code[i + 2] << 8)
name = freevars[oparg]
if name in constants:
value = constants[name]
for pos, v in enumerate(new_consts):
if v is value:
break
else:
pos = len(new_consts)
new_consts.append(value)
new_code[i] = opcode.opmap['LOAD_CONST']
new_code[i + 1] = pos & 0xFF
new_code[i + 2] = pos >> 8
if name in locals:
del locals[name]
freevars.remove(name)
i += 1
if op_code >= opcode.HAVE_ARGUMENT:
i += 2
code_str = ''.join(map(chr, new_code))
code_object = types.CodeType(
gi_code.co_argcount,
gi_code.co_kwonlyargcount,
gi_code.co_nlocals,
gi_code.co_stacksize,
gi_code.co_flags,
bytes(code_str, 'utf-8'),
tuple(new_consts),
gi_code.co_names,
gi_code.co_varnames,
gi_code.co_filename,
gi_code.co_name,
gi_code.co_firstlineno,
gi_code.co_lnotab,
tuple(freevars),
gi_code.co_cellvars)
function = types.FunctionType(
code_object,
generator.gi_frame.f_globals,
generator.__name__,
)
return function(**locals)
class WhereType:
"""Implement the *<* operator that apply the function to the generator."""
def __gt__(self, other):
return _replace_globals_and_closures(other, **self.constants)
def __call__(self, **constants):
self.constants = constants
return self
# !!! The where function
where = WhereType()
if __name__ == '__main__':
print(">>> gen = ((x, y, z) for _ in range(5)) < where(x=1, y=2, z=3)")
print(">>> list(gen)")
print(list(((x, y, z) for _ in range(5)) < where(x=1, y=2, z=3)))