#!/usr/bin/env python
import asio
import logging
import os
import sys
MAX_READ_BYTES = 2 ** 16
def createLogger():
logger = logging.getLogger('proxy')
logger.setLevel(logging.INFO)
consoleHandler = logging.StreamHandler()
consoleHandler.setLevel(logging.DEBUG)
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s')
consoleHandler.setFormatter(formatter)
logger.addHandler(consoleHandler)
return logger
logger = createLogger()
class Connection(object):
def __init__(self, ioService, clientToProxySocket,
remoteAddress, remotePort):
self.__dataFromClient = ''
self.__dataFromRemote = ''
self.__writingToClient = False
self.__writingToRemote = False
self.__clientToProxySocket = clientToProxySocket
self.__clientToProxyString = ('%s -> %s' %
(clientToProxySocket.getpeername(),
clientToProxySocket.getsockname()))
self.__proxyToRemoteSocket = ioService.createAsyncSocket()
self.__proxyToRemoteString = ''
self.__proxyToRemoteSocket.asyncConnect(
(remoteAddress, remotePort),
self.__connectCallback)
def close(self):
if not self.__writingToClient:
if ((not self.__clientToProxySocket.closed()) and
(len(self.__clientToProxyString) > 0)):
logger.info('disconnect %s' % self.__clientToProxyString)
self.__clientToProxySocket.close()
if not self.__writingToRemote:
if ((not self.__proxyToRemoteSocket.closed()) and
(len(self.__proxyToRemoteString) > 0)):
logger.info('disconnect %s' % self.__proxyToRemoteString)
self.__proxyToRemoteSocket.close()
def __connectCallback(self, err):
if (err != 0):
logger.info('connect error \'%s\'' % (os.strerror(err)))
self.close()
else:
self.__proxyToRemoteString = ('%s -> %s' %
(self.__proxyToRemoteSocket.getpeername(),
self.__proxyToRemoteSocket.getsockname()))
logger.info('connect %s' % self.__proxyToRemoteString)
self.__clientToProxySocket.asyncRead(
MAX_READ_BYTES,
self.__readFromClientCallback)
self.__proxyToRemoteSocket.asyncRead(
MAX_READ_BYTES,
self.__readFromRemoteCallback)
def __readFromClientCallback(self, data, err):
if self.__proxyToRemoteSocket.closed():
self.close()
elif (err != 0):
self.close()
elif not data:
self.close()
else:
self.__writingToRemote = True
self.__proxyToRemoteSocket.asyncWriteAll(data, self.__writeToRemoteCallback)
def __readFromRemoteCallback(self, data, err):
if self.__clientToProxySocket.closed():
self.close()
elif (err != 0):
self.close()
elif not data:
self.close()
else:
self.__writingToClient = True
self.__clientToProxySocket.asyncWriteAll(data, self.__writeToClientCallback)
def __writeToRemoteCallback(self, err):
self.__writingToRemote = False
if self.__clientToProxySocket.closed():
self.close()
elif (err != 0):
self.close()
else:
self.__clientToProxySocket.asyncRead(MAX_READ_BYTES, self.__readFromClientCallback)
def __writeToClientCallback(self, err):
self.__writingToClient = False
if self.__proxyToRemoteSocket.closed():
self.close()
elif (err != 0):
self.close()
else:
self.__proxyToRemoteSocket.asyncRead(MAX_READ_BYTES, self.__readFromRemoteCallback)
class Acceptor(object):
def __init__(self, ioService,
localAddress, localPort,
remoteAddress, remotePort):
self.__ioService = ioService
self.__remoteAddress = remoteAddress
self.__remotePort = remotePort
self.__asyncSocket = ioService.createAsyncSocket();
self.__asyncSocket.setReuseAddress()
self.__asyncSocket.bind((localAddress, localPort))
self.__asyncSocket.listen()
self.__asyncSocket.asyncAccept(self.__acceptCallback)
logger.info('listening on %s' % str(self.__asyncSocket.getsockname()))
def __acceptCallback(self, sock, err):
if ((err == 0) and (sock != None)):
logger.info('accept %s -> %s' % (sock.getpeername(), sock.getsockname()))
Connection(self.__ioService, sock, self.__remoteAddress, self.__remotePort)
self.__asyncSocket.asyncAccept(self.__acceptCallback)
def parseAddrPortString(addrPortString):
addrPortList = addrPortString.split(':', 1)
return (addrPortList[0], int(addrPortList[1]))
def printUsage():
logger.error(
'Usage: %s <listen addr> [<listen addr> ...] <remote addr>' %
sys.argv[0])
def main():
if (len(sys.argv) < 3):
printUsage()
sys.exit(1)
localAddressPortList = map(parseAddrPortString, sys.argv[1:-1])
(remoteAddress, remotePort) = parseAddrPortString(sys.argv[-1])
ioService = asio.createAsyncIOService()
logger.info('ioService = ' + str(ioService))
for (localAddress, localPort) in localAddressPortList:
Acceptor(ioService,
localAddress = localAddress,
localPort = localPort,
remoteAddress = remoteAddress,
remotePort = remotePort)
logger.info('remote address %s' % str((remoteAddress, remotePort)))
ioService.run()
if __name__ == '__main__':
try:
main()
except KeyboardInterrupt:
pass