diff --git a/gunicorn-logging-extension/src/gunicorn_logging_extension/aiohttp.py b/gunicorn-logging-extension/src/gunicorn_logging_extension/aiohttp.py index 3aeccd9..a5df0ad 100644 --- a/gunicorn-logging-extension/src/gunicorn_logging_extension/aiohttp.py +++ b/gunicorn-logging-extension/src/gunicorn_logging_extension/aiohttp.py @@ -1,10 +1,19 @@ +from typing import Type + from aiohttp.abc import AbstractAccessLogger from aiohttp.web_request import BaseRequest from aiohttp.web_response import StreamResponse + from . import REDIRECT_CODES -class AccessLogger(AbstractAccessLogger): +try: + import uvloop + from aiohttp.worker import GunicornUVLoopWebWorker as GunicornWebWorker +except ImportError: + from aiohttp.worker import GunicornWebWorker + +class AccessLogger(AbstractAccessLogger): def log(self, request: BaseRequest, response: StreamResponse, time: float): level = self.logger.info if response.status >= 400: @@ -25,8 +34,79 @@ class AccessLogger(AbstractAccessLogger): extra["REFERER"] = request.headers.get("Referer") if request.headers.get("user-agent", False): extra["USER_AGENT"] = request.headers.get("user-agent") - location="" + location = "" if response.status_code in REDIRECT_CODES: extra["LOCATION"] = request.headers.get("location") - location = f" -> {extra["LOCATION"]}" - level(f"Access({response.status}) {request.method} {request.rel_url}{}", extra=extra) + location = f" -> {extra['LOCATION']}" + level( + f"Access({response.status}) {request.method} {request.rel_url}{location}", + extra=extra, + ) + + +class ExtendedGunicornWebWorker(GunicornWebWorker): + access_log_class: Type[AbstractAccessLogger] = AccessLogger + + async def _run(self) -> None: + runner = None + if isinstance(self.wsgi, Application): + app = self.wsgi + elif asyncio.iscoroutinefunction(self.wsgi): + wsgi = await self.wsgi() + if isinstance(wsgi, web.AppRunner): + runner = wsgi + app = runner.app + else: + app = wsgi + else: + raise RuntimeError( + "wsgi app should be either Application or " + "async function returning Application, got {}".format(self.wsgi) + ) + + if runner is None: + access_log = self.log.access_log if self.cfg.accesslog else None + runner = web.AppRunner( + app, + logger=self.log, + keepalive_timeout=self.cfg.keepalive, + access_log=access_log, + access_log_class=self.access_log_class, + shutdown_timeout=self.cfg.graceful_timeout / 100 * 95, + ) + await runner.setup() + + ctx = self._create_ssl_context(self.cfg) if self.cfg.is_ssl else None + + runner = runner + assert runner is not None + server = runner.server + assert server is not None + for sock in self.sockets: + site = web.SockSite( + runner, + sock, + ssl_context=ctx, + ) + await site.start() + + # If our parent changed then we shut down. + pid = os.getpid() + try: + while self.alive: # type: ignore[has-type] + self.notify() + + cnt = server.requests_count + if self.max_requests and cnt > self.max_requests: + self.alive = False + self.log.info("Max requests, shutting down: %s", self) + + elif pid == os.getpid() and self.ppid != os.getppid(): + self.alive = False + self.log.info("Parent changed, shutting down: %s", self) + else: + await self._wait_next_notify() + except BaseException: + pass + + await runner.cleanup()