# Author: Giampaolo Rodola' <g.rodola [AT] gmail [DOT] com>
# License: MIT
from __future__ import with_statement
import contextlib
import signal
import sys
def _sigterm_handler(signum, frame):
sys.exit(0)
_sigterm_handler.__enter_ctx__ = False
@contextlib.contextmanager
def handle_exit(callback=None, append=False):
"""A context manager which properly handles SIGTERM and SIGINT
(KeyboardInterrupt) signals, registering a function which is
guaranteed to be called after signals are received.
Also, it makes sure to execute previously registered signal
handlers as well (if any).
>>> app = App()
>>> with handle_exit(app.stop):
... app.start()
...
>>>
If append == False raise RuntimeError if there's already a handler
registered for SIGTERM, otherwise both new and old handlers are
executed in this order.
"""
old_handler = signal.signal(signal.SIGTERM, _sigterm_handler)
if (old_handler != signal.SIG_DFL) and (old_handler != _sigterm_handler):
if not append:
raise RuntimeError("there is already a handler registered for "
"SIGTERM: %r" % old_handler)
def handler(signum, frame):
try:
_sigterm_handler(signum, frame)
finally:
old_handler(signum, frame)
signal.signal(signal.SIGTERM, handler)
if _sigterm_handler.__enter_ctx__:
raise RuntimeError("can't use nested contexts")
_sigterm_handler.__enter_ctx__ = True
try:
yield
except KeyboardInterrupt:
pass
except SystemExit, err:
# code != 0 refers to an application error (e.g. explicit
# sys.exit('some error') call).
# We don't want that to pass silently.
# Nevertheless, the 'finally' clause below will always
# be executed.
if err.code != 0:
raise
finally:
_sigterm_handler.__enter_ctx__ = False
if callback is not None:
callback()
if __name__ == '__main__':
# ===============================================================
# --- test suite
# ===============================================================
import unittest
import os
class TestOnExit(unittest.TestCase):
def setUp(self):
# reset signal handlers
signal.signal(signal.SIGTERM, signal.SIG_DFL)
self.flag = None
def tearDown(self):
# make sure we exited the ctx manager
self.assertTrue(self.flag is not None)
def test_base(self):
with handle_exit():
pass
self.flag = True
def test_callback(self):
callback = []
with handle_exit(lambda: callback.append(None)):
pass
self.flag = True
self.assertEqual(callback, [None])
def test_kinterrupt(self):
with handle_exit():
raise KeyboardInterrupt
self.flag = True
def test_sigterm(self):
with handle_exit():
os.kill(os.getpid(), signal.SIGTERM)
self.flag = True
def test_sigint(self):
with handle_exit():
os.kill(os.getpid(), signal.SIGINT)
self.flag = True
def test_sigterm_old(self):
# make sure the old handler gets executed
queue = []
signal.signal(signal.SIGTERM, lambda s, f: queue.append('old'))
with handle_exit(lambda: queue.append('new'), append=True):
os.kill(os.getpid(), signal.SIGTERM)
self.flag = True
self.assertEqual(queue, ['old', 'new'])
def test_sigint_old(self):
# make sure the old handler gets executed
queue = []
signal.signal(signal.SIGINT, lambda s, f: queue.append('old'))
with handle_exit(lambda: queue.append('new'), append=True):
os.kill(os.getpid(), signal.SIGINT)
self.flag = True
self.assertEqual(queue, ['old', 'new'])
def test_no_append(self):
# make sure we can't use the context manager if there's
# already a handler registered for SIGTERM
signal.signal(signal.SIGTERM, lambda s, f: sys.exit(0))
try:
with handle_exit(lambda: self.flag.append(None)):
pass
except RuntimeError:
pass
else:
self.fail("exception not raised")
finally:
self.flag = True
def test_nested_context(self):
self.flag = True
try:
with handle_exit():
with handle_exit():
pass
except RuntimeError:
pass
else:
self.fail("exception not raised")
unittest.main()
Diff to Previous Revision
--- revision 22 2013-09-23 22:38:32
+++ revision 23 2014-08-01 08:28:07
@@ -10,6 +10,7 @@
def _sigterm_handler(signum, frame):
sys.exit(0)
_sigterm_handler.__enter_ctx__ = False
+
@contextlib.contextmanager
def handle_exit(callback=None, append=False):
@@ -32,13 +33,14 @@
old_handler = signal.signal(signal.SIGTERM, _sigterm_handler)
if (old_handler != signal.SIG_DFL) and (old_handler != _sigterm_handler):
if not append:
- raise RuntimeError("there is already a handler registered for " \
+ raise RuntimeError("there is already a handler registered for "
"SIGTERM: %r" % old_handler)
+
def handler(signum, frame):
try:
_sigterm_handler(signum, frame)
finally:
- old_handler(signum, frame)
+ old_handler(signum, frame)
signal.signal(signal.SIGTERM, handler)
if _sigterm_handler.__enter_ctx__: