You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

225 lines
7.1 KiB

4 years ago
  1. """Async gunicorn worker for aiohttp.web"""
  2. import asyncio
  3. import os
  4. import re
  5. import signal
  6. import sys
  7. from contextlib import suppress
  8. from gunicorn.config import AccessLogFormat as GunicornAccessLogFormat
  9. from gunicorn.workers import base
  10. from aiohttp import web
  11. from .helpers import AccessLogger, set_result
  12. try:
  13. import ssl
  14. except ImportError: # pragma: no cover
  15. ssl = None # type: ignore
  16. __all__ = ('GunicornWebWorker',
  17. 'GunicornUVLoopWebWorker',
  18. 'GunicornTokioWebWorker')
  19. class GunicornWebWorker(base.Worker):
  20. DEFAULT_AIOHTTP_LOG_FORMAT = AccessLogger.LOG_FORMAT
  21. DEFAULT_GUNICORN_LOG_FORMAT = GunicornAccessLogFormat.default
  22. def __init__(self, *args, **kw): # pragma: no cover
  23. super().__init__(*args, **kw)
  24. self._runner = None
  25. self._task = None
  26. self.exit_code = 0
  27. self._notify_waiter = None
  28. def init_process(self):
  29. # create new event_loop after fork
  30. asyncio.get_event_loop().close()
  31. self.loop = asyncio.new_event_loop()
  32. asyncio.set_event_loop(self.loop)
  33. super().init_process()
  34. def run(self):
  35. access_log = self.log.access_log if self.cfg.accesslog else None
  36. params = dict(
  37. logger=self.log,
  38. keepalive_timeout=self.cfg.keepalive,
  39. access_log=access_log,
  40. access_log_format=self._get_valid_log_format(
  41. self.cfg.access_log_format))
  42. if asyncio.iscoroutinefunction(self.wsgi):
  43. self.wsgi = self.loop.run_until_complete(self.wsgi())
  44. self._runner = web.AppRunner(self.wsgi, **params)
  45. self.loop.run_until_complete(self._runner.setup())
  46. self._task = self.loop.create_task(self._run())
  47. with suppress(Exception): # ignore all finalization problems
  48. self.loop.run_until_complete(self._task)
  49. if hasattr(self.loop, 'shutdown_asyncgens'):
  50. self.loop.run_until_complete(self.loop.shutdown_asyncgens())
  51. self.loop.close()
  52. sys.exit(self.exit_code)
  53. async def _run(self):
  54. ctx = self._create_ssl_context(self.cfg) if self.cfg.is_ssl else None
  55. for sock in self.sockets:
  56. site = web.SockSite(
  57. self._runner, sock, ssl_context=ctx,
  58. shutdown_timeout=self.cfg.graceful_timeout / 100 * 95)
  59. await site.start()
  60. # If our parent changed then we shut down.
  61. pid = os.getpid()
  62. try:
  63. while self.alive:
  64. self.notify()
  65. cnt = self._runner.server.requests_count
  66. if self.cfg.max_requests and cnt > self.cfg.max_requests:
  67. self.alive = False
  68. self.log.info("Max requests, shutting down: %s", self)
  69. elif pid == os.getpid() and self.ppid != os.getppid():
  70. self.alive = False
  71. self.log.info("Parent changed, shutting down: %s", self)
  72. else:
  73. await self._wait_next_notify()
  74. except BaseException:
  75. pass
  76. await self._runner.cleanup()
  77. def _wait_next_notify(self):
  78. self._notify_waiter_done()
  79. self._notify_waiter = waiter = self.loop.create_future()
  80. self.loop.call_later(1.0, self._notify_waiter_done, waiter)
  81. return waiter
  82. def _notify_waiter_done(self, waiter=None):
  83. if waiter is None:
  84. waiter = self._notify_waiter
  85. if waiter is not None:
  86. set_result(waiter, True)
  87. if waiter is self._notify_waiter:
  88. self._notify_waiter = None
  89. def init_signals(self):
  90. # Set up signals through the event loop API.
  91. self.loop.add_signal_handler(signal.SIGQUIT, self.handle_quit,
  92. signal.SIGQUIT, None)
  93. self.loop.add_signal_handler(signal.SIGTERM, self.handle_exit,
  94. signal.SIGTERM, None)
  95. self.loop.add_signal_handler(signal.SIGINT, self.handle_quit,
  96. signal.SIGINT, None)
  97. self.loop.add_signal_handler(signal.SIGWINCH, self.handle_winch,
  98. signal.SIGWINCH, None)
  99. self.loop.add_signal_handler(signal.SIGUSR1, self.handle_usr1,
  100. signal.SIGUSR1, None)
  101. self.loop.add_signal_handler(signal.SIGABRT, self.handle_abort,
  102. signal.SIGABRT, None)
  103. # Don't let SIGTERM and SIGUSR1 disturb active requests
  104. # by interrupting system calls
  105. signal.siginterrupt(signal.SIGTERM, False)
  106. signal.siginterrupt(signal.SIGUSR1, False)
  107. def handle_quit(self, sig, frame):
  108. self.alive = False
  109. # worker_int callback
  110. self.cfg.worker_int(self)
  111. # wakeup closing process
  112. self._notify_waiter_done()
  113. def handle_abort(self, sig, frame):
  114. self.alive = False
  115. self.exit_code = 1
  116. self.cfg.worker_abort(self)
  117. sys.exit(1)
  118. @staticmethod
  119. def _create_ssl_context(cfg):
  120. """ Creates SSLContext instance for usage in asyncio.create_server.
  121. See ssl.SSLSocket.__init__ for more details.
  122. """
  123. if ssl is None: # pragma: no cover
  124. raise RuntimeError('SSL is not supported.')
  125. ctx = ssl.SSLContext(cfg.ssl_version)
  126. ctx.load_cert_chain(cfg.certfile, cfg.keyfile)
  127. ctx.verify_mode = cfg.cert_reqs
  128. if cfg.ca_certs:
  129. ctx.load_verify_locations(cfg.ca_certs)
  130. if cfg.ciphers:
  131. ctx.set_ciphers(cfg.ciphers)
  132. return ctx
  133. def _get_valid_log_format(self, source_format):
  134. if source_format == self.DEFAULT_GUNICORN_LOG_FORMAT:
  135. return self.DEFAULT_AIOHTTP_LOG_FORMAT
  136. elif re.search(r'%\([^\)]+\)', source_format):
  137. raise ValueError(
  138. "Gunicorn's style options in form of `%(name)s` are not "
  139. "supported for the log formatting. Please use aiohttp's "
  140. "format specification to configure access log formatting: "
  141. "http://docs.aiohttp.org/en/stable/logging.html"
  142. "#format-specification"
  143. )
  144. else:
  145. return source_format
  146. class GunicornUVLoopWebWorker(GunicornWebWorker):
  147. def init_process(self):
  148. import uvloop
  149. # Close any existing event loop before setting a
  150. # new policy.
  151. asyncio.get_event_loop().close()
  152. # Setup uvloop policy, so that every
  153. # asyncio.get_event_loop() will create an instance
  154. # of uvloop event loop.
  155. asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
  156. super().init_process()
  157. class GunicornTokioWebWorker(GunicornWebWorker):
  158. def init_process(self): # pragma: no cover
  159. import tokio
  160. # Close any existing event loop before setting a
  161. # new policy.
  162. asyncio.get_event_loop().close()
  163. # Setup tokio policy, so that every
  164. # asyncio.get_event_loop() will create an instance
  165. # of tokio event loop.
  166. asyncio.set_event_loop_policy(tokio.EventLoopPolicy())
  167. super().init_process()