"""Various helper functions"""
|
|
|
|
import asyncio
|
|
import base64
|
|
import binascii
|
|
import cgi
|
|
import datetime
|
|
import functools
|
|
import inspect
|
|
import netrc
|
|
import os
|
|
import re
|
|
import sys
|
|
import time
|
|
import weakref
|
|
from collections import namedtuple
|
|
from collections.abc import Mapping as ABCMapping
|
|
from contextlib import suppress
|
|
from math import ceil
|
|
from pathlib import Path
|
|
from typing import Any, Dict, List, Optional, Tuple # noqa
|
|
from urllib.parse import quote
|
|
from urllib.request import getproxies
|
|
|
|
import async_timeout
|
|
import attr
|
|
from multidict import MultiDict
|
|
from yarl import URL
|
|
|
|
from . import hdrs
|
|
from .abc import AbstractAccessLogger
|
|
from .log import client_logger
|
|
|
|
|
|
__all__ = ('BasicAuth', 'ChainMapProxy')
|
|
|
|
PY_36 = sys.version_info >= (3, 6)
|
|
PY_37 = sys.version_info >= (3, 7)
|
|
|
|
if not PY_37:
|
|
import idna_ssl
|
|
idna_ssl.patch_match_hostname()
|
|
|
|
|
|
sentinel = object() # type: Any
|
|
NO_EXTENSIONS = bool(os.environ.get('AIOHTTP_NO_EXTENSIONS')) # type: bool
|
|
|
|
# N.B. sys.flags.dev_mode is available on Python 3.7+, use getattr
|
|
# for compatibility with older versions
|
|
DEBUG = (getattr(sys.flags, 'dev_mode', False) or
|
|
(not sys.flags.ignore_environment and
|
|
bool(os.environ.get('PYTHONASYNCIODEBUG'))))
|
|
|
|
|
|
CHAR = set(chr(i) for i in range(0, 128))
|
|
CTL = set(chr(i) for i in range(0, 32)) | {chr(127), }
|
|
SEPARATORS = {'(', ')', '<', '>', '@', ',', ';', ':', '\\', '"', '/', '[', ']',
|
|
'?', '=', '{', '}', ' ', chr(9)}
|
|
TOKEN = CHAR ^ CTL ^ SEPARATORS
|
|
|
|
|
|
coroutines = asyncio.coroutines
|
|
old_debug = coroutines._DEBUG # type: ignore
|
|
|
|
# prevent "coroutine noop was never awaited" warning.
|
|
coroutines._DEBUG = False # type: ignore
|
|
|
|
|
|
@asyncio.coroutine
|
|
def noop(*args, **kwargs):
|
|
return
|
|
|
|
|
|
coroutines._DEBUG = old_debug # type: ignore
|
|
|
|
|
|
class BasicAuth(namedtuple('BasicAuth', ['login', 'password', 'encoding'])):
|
|
"""Http basic authentication helper."""
|
|
|
|
def __new__(cls, login: str,
|
|
password: str='',
|
|
encoding: str='latin1') -> 'BasicAuth':
|
|
if login is None:
|
|
raise ValueError('None is not allowed as login value')
|
|
|
|
if password is None:
|
|
raise ValueError('None is not allowed as password value')
|
|
|
|
if ':' in login:
|
|
raise ValueError(
|
|
'A ":" is not allowed in login (RFC 1945#section-11.1)')
|
|
|
|
return super().__new__(cls, login, password, encoding)
|
|
|
|
@classmethod
|
|
def decode(cls, auth_header: str, encoding: str='latin1') -> 'BasicAuth':
|
|
"""Create a BasicAuth object from an Authorization HTTP header."""
|
|
split = auth_header.strip().split(' ')
|
|
if len(split) == 2:
|
|
if split[0].strip().lower() != 'basic':
|
|
raise ValueError('Unknown authorization method %s' % split[0])
|
|
to_decode = split[1]
|
|
else:
|
|
raise ValueError('Could not parse authorization header.')
|
|
|
|
try:
|
|
username, _, password = base64.b64decode(
|
|
to_decode.encode('ascii')
|
|
).decode(encoding).partition(':')
|
|
except binascii.Error:
|
|
raise ValueError('Invalid base64 encoding.')
|
|
|
|
return cls(username, password, encoding=encoding)
|
|
|
|
@classmethod
|
|
def from_url(cls, url: URL,
|
|
*, encoding: str='latin1') -> Optional['BasicAuth']:
|
|
"""Create BasicAuth from url."""
|
|
if not isinstance(url, URL):
|
|
raise TypeError("url should be yarl.URL instance")
|
|
if url.user is None:
|
|
return None
|
|
return cls(url.user, url.password or '', encoding=encoding)
|
|
|
|
def encode(self) -> str:
|
|
"""Encode credentials."""
|
|
creds = ('%s:%s' % (self.login, self.password)).encode(self.encoding)
|
|
return 'Basic %s' % base64.b64encode(creds).decode(self.encoding)
|
|
|
|
|
|
def strip_auth_from_url(url: URL) -> Tuple[URL, Optional[BasicAuth]]:
|
|
auth = BasicAuth.from_url(url)
|
|
if auth is None:
|
|
return url, None
|
|
else:
|
|
return url.with_user(None), auth
|
|
|
|
|
|
def netrc_from_env():
|
|
netrc_obj = None
|
|
netrc_path = os.environ.get('NETRC')
|
|
try:
|
|
if netrc_path is not None:
|
|
netrc_path = Path(netrc_path)
|
|
else:
|
|
home_dir = Path.home()
|
|
if os.name == 'nt': # pragma: no cover
|
|
netrc_path = home_dir.joinpath('_netrc')
|
|
else:
|
|
netrc_path = home_dir.joinpath('.netrc')
|
|
|
|
if netrc_path and netrc_path.is_file():
|
|
try:
|
|
netrc_obj = netrc.netrc(str(netrc_path))
|
|
except (netrc.NetrcParseError, OSError) as e:
|
|
client_logger.warning(".netrc file parses fail: %s", e)
|
|
|
|
if netrc_obj is None:
|
|
client_logger.warning("could't find .netrc file")
|
|
except RuntimeError as e: # pragma: no cover
|
|
""" handle error raised by pathlib """
|
|
client_logger.warning("could't find .netrc file: %s", e)
|
|
return netrc_obj
|
|
|
|
|
|
@attr.s(frozen=True, slots=True)
|
|
class ProxyInfo:
|
|
proxy = attr.ib(type=str)
|
|
proxy_auth = attr.ib(type=BasicAuth)
|
|
|
|
|
|
def proxies_from_env():
|
|
proxy_urls = {k: URL(v) for k, v in getproxies().items()
|
|
if k in ('http', 'https')}
|
|
netrc_obj = netrc_from_env()
|
|
stripped = {k: strip_auth_from_url(v) for k, v in proxy_urls.items()}
|
|
ret = {}
|
|
for proto, val in stripped.items():
|
|
proxy, auth = val
|
|
if proxy.scheme == 'https':
|
|
client_logger.warning(
|
|
"HTTPS proxies %s are not supported, ignoring", proxy)
|
|
continue
|
|
if netrc_obj and auth is None:
|
|
auth_from_netrc = netrc_obj.authenticators(proxy.host)
|
|
if auth_from_netrc is not None:
|
|
# auth_from_netrc is a (`user`, `account`, `password`) tuple,
|
|
# `user` and `account` both can be username,
|
|
# if `user` is None, use `account`
|
|
*logins, password = auth_from_netrc
|
|
auth = BasicAuth(logins[0] if logins[0] else logins[-1],
|
|
password)
|
|
ret[proto] = ProxyInfo(proxy, auth)
|
|
return ret
|
|
|
|
|
|
def current_task(loop=None):
|
|
if PY_37:
|
|
return asyncio.current_task(loop=loop)
|
|
else:
|
|
return asyncio.Task.current_task(loop=loop)
|
|
|
|
|
|
def isasyncgenfunction(obj):
|
|
if hasattr(inspect, 'isasyncgenfunction'):
|
|
return inspect.isasyncgenfunction(obj)
|
|
return False
|
|
|
|
|
|
@attr.s(frozen=True, slots=True)
|
|
class MimeType:
|
|
type = attr.ib(type=str)
|
|
subtype = attr.ib(type=str)
|
|
suffix = attr.ib(type=str)
|
|
parameters = attr.ib(type=MultiDict)
|
|
|
|
|
|
def parse_mimetype(mimetype):
|
|
"""Parses a MIME type into its components.
|
|
|
|
mimetype is a MIME type string.
|
|
|
|
Returns a MimeType object.
|
|
|
|
Example:
|
|
|
|
>>> parse_mimetype('text/html; charset=utf-8')
|
|
MimeType(type='text', subtype='html', suffix='',
|
|
parameters={'charset': 'utf-8'})
|
|
|
|
"""
|
|
if not mimetype:
|
|
return MimeType(type='', subtype='', suffix='', parameters={})
|
|
|
|
parts = mimetype.split(';')
|
|
params = []
|
|
for item in parts[1:]:
|
|
if not item:
|
|
continue
|
|
key, value = item.split('=', 1) if '=' in item else (item, '')
|
|
params.append((key.lower().strip(), value.strip(' "')))
|
|
params = MultiDict(params)
|
|
|
|
fulltype = parts[0].strip().lower()
|
|
if fulltype == '*':
|
|
fulltype = '*/*'
|
|
|
|
mtype, stype = fulltype.split('/', 1) \
|
|
if '/' in fulltype else (fulltype, '')
|
|
stype, suffix = stype.split('+', 1) if '+' in stype else (stype, '')
|
|
|
|
return MimeType(type=mtype, subtype=stype, suffix=suffix,
|
|
parameters=params)
|
|
|
|
|
|
def guess_filename(obj, default=None):
|
|
name = getattr(obj, 'name', None)
|
|
if name and isinstance(name, str) and name[0] != '<' and name[-1] != '>':
|
|
return Path(name).name
|
|
return default
|
|
|
|
|
|
def content_disposition_header(disptype, quote_fields=True, **params):
|
|
"""Sets ``Content-Disposition`` header.
|
|
|
|
disptype is a disposition type: inline, attachment, form-data.
|
|
Should be valid extension token (see RFC 2183)
|
|
|
|
params is a dict with disposition params.
|
|
"""
|
|
if not disptype or not (TOKEN > set(disptype)):
|
|
raise ValueError('bad content disposition type {!r}'
|
|
''.format(disptype))
|
|
|
|
value = disptype
|
|
if params:
|
|
lparams = []
|
|
for key, val in params.items():
|
|
if not key or not (TOKEN > set(key)):
|
|
raise ValueError('bad content disposition parameter'
|
|
' {!r}={!r}'.format(key, val))
|
|
qval = quote(val, '') if quote_fields else val
|
|
lparams.append((key, '"%s"' % qval))
|
|
if key == 'filename':
|
|
lparams.append(('filename*', "utf-8''" + qval))
|
|
sparams = '; '.join('='.join(pair) for pair in lparams)
|
|
value = '; '.join((value, sparams))
|
|
return value
|
|
|
|
|
|
KeyMethod = namedtuple('KeyMethod', 'key method')
|
|
|
|
|
|
class AccessLogger(AbstractAccessLogger):
|
|
"""Helper object to log access.
|
|
|
|
Usage:
|
|
log = logging.getLogger("spam")
|
|
log_format = "%a %{User-Agent}i"
|
|
access_logger = AccessLogger(log, log_format)
|
|
access_logger.log(request, response, time)
|
|
|
|
Format:
|
|
%% The percent sign
|
|
%a Remote IP-address (IP-address of proxy if using reverse proxy)
|
|
%t Time when the request was started to process
|
|
%P The process ID of the child that serviced the request
|
|
%r First line of request
|
|
%s Response status code
|
|
%b Size of response in bytes, including HTTP headers
|
|
%T Time taken to serve the request, in seconds
|
|
%Tf Time taken to serve the request, in seconds with floating fraction
|
|
in .06f format
|
|
%D Time taken to serve the request, in microseconds
|
|
%{FOO}i request.headers['FOO']
|
|
%{FOO}o response.headers['FOO']
|
|
%{FOO}e os.environ['FOO']
|
|
|
|
"""
|
|
LOG_FORMAT_MAP = {
|
|
'a': 'remote_address',
|
|
't': 'request_start_time',
|
|
'P': 'process_id',
|
|
'r': 'first_request_line',
|
|
's': 'response_status',
|
|
'b': 'response_size',
|
|
'T': 'request_time',
|
|
'Tf': 'request_time_frac',
|
|
'D': 'request_time_micro',
|
|
'i': 'request_header',
|
|
'o': 'response_header',
|
|
}
|
|
|
|
LOG_FORMAT = '%a %t "%r" %s %b "%{Referer}i" "%{User-Agent}i"'
|
|
FORMAT_RE = re.compile(r'%(\{([A-Za-z0-9\-_]+)\}([ioe])|[atPrsbOD]|Tf?)')
|
|
CLEANUP_RE = re.compile(r'(%[^s])')
|
|
_FORMAT_CACHE = {} # type: Dict[str, Tuple[str, List[KeyMethod]]]
|
|
|
|
def __init__(self, logger, log_format=LOG_FORMAT):
|
|
"""Initialise the logger.
|
|
|
|
logger is a logger object to be used for logging.
|
|
log_format is an string with apache compatible log format description.
|
|
|
|
"""
|
|
super().__init__(logger, log_format=log_format)
|
|
|
|
_compiled_format = AccessLogger._FORMAT_CACHE.get(log_format)
|
|
if not _compiled_format:
|
|
_compiled_format = self.compile_format(log_format)
|
|
AccessLogger._FORMAT_CACHE[log_format] = _compiled_format
|
|
|
|
self._log_format, self._methods = _compiled_format
|
|
|
|
def compile_format(self, log_format):
|
|
"""Translate log_format into form usable by modulo formatting
|
|
|
|
All known atoms will be replaced with %s
|
|
Also methods for formatting of those atoms will be added to
|
|
_methods in appropriate order
|
|
|
|
For example we have log_format = "%a %t"
|
|
This format will be translated to "%s %s"
|
|
Also contents of _methods will be
|
|
[self._format_a, self._format_t]
|
|
These method will be called and results will be passed
|
|
to translated string format.
|
|
|
|
Each _format_* method receive 'args' which is list of arguments
|
|
given to self.log
|
|
|
|
Exceptions are _format_e, _format_i and _format_o methods which
|
|
also receive key name (by functools.partial)
|
|
|
|
"""
|
|
# list of (key, method) tuples, we don't use an OrderedDict as users
|
|
# can repeat the same key more than once
|
|
methods = list()
|
|
|
|
for atom in self.FORMAT_RE.findall(log_format):
|
|
if atom[1] == '':
|
|
format_key = self.LOG_FORMAT_MAP[atom[0]]
|
|
m = getattr(AccessLogger, '_format_%s' % atom[0])
|
|
else:
|
|
format_key = (self.LOG_FORMAT_MAP[atom[2]], atom[1])
|
|
m = getattr(AccessLogger, '_format_%s' % atom[2])
|
|
m = functools.partial(m, atom[1])
|
|
|
|
methods.append(KeyMethod(format_key, m))
|
|
|
|
log_format = self.FORMAT_RE.sub(r'%s', log_format)
|
|
log_format = self.CLEANUP_RE.sub(r'%\1', log_format)
|
|
return log_format, methods
|
|
|
|
@staticmethod
|
|
def _format_i(key, request, response, time):
|
|
if request is None:
|
|
return '(no headers)'
|
|
|
|
# suboptimal, make istr(key) once
|
|
return request.headers.get(key, '-')
|
|
|
|
@staticmethod
|
|
def _format_o(key, request, response, time):
|
|
# suboptimal, make istr(key) once
|
|
return response.headers.get(key, '-')
|
|
|
|
@staticmethod
|
|
def _format_a(request, response, time):
|
|
if request is None:
|
|
return '-'
|
|
ip = request.remote
|
|
return ip if ip is not None else '-'
|
|
|
|
@staticmethod
|
|
def _format_t(request, response, time):
|
|
now = datetime.datetime.utcnow()
|
|
start_time = now - datetime.timedelta(seconds=time)
|
|
return start_time.strftime('[%d/%b/%Y:%H:%M:%S +0000]')
|
|
|
|
@staticmethod
|
|
def _format_P(request, response, time):
|
|
return "<%s>" % os.getpid()
|
|
|
|
@staticmethod
|
|
def _format_r(request, response, time):
|
|
if request is None:
|
|
return '-'
|
|
return '%s %s HTTP/%s.%s' % tuple((request.method,
|
|
request.path_qs) + request.version)
|
|
|
|
@staticmethod
|
|
def _format_s(request, response, time):
|
|
return response.status
|
|
|
|
@staticmethod
|
|
def _format_b(request, response, time):
|
|
return response.body_length
|
|
|
|
@staticmethod
|
|
def _format_T(request, response, time):
|
|
return round(time)
|
|
|
|
@staticmethod
|
|
def _format_Tf(request, response, time):
|
|
return '%06f' % time
|
|
|
|
@staticmethod
|
|
def _format_D(request, response, time):
|
|
return round(time * 1000000)
|
|
|
|
def _format_line(self, request, response, time):
|
|
return ((key, method(request, response, time))
|
|
for key, method in self._methods)
|
|
|
|
def log(self, request, response, time):
|
|
try:
|
|
fmt_info = self._format_line(request, response, time)
|
|
|
|
values = list()
|
|
extra = dict()
|
|
for key, value in fmt_info:
|
|
values.append(value)
|
|
|
|
if key.__class__ is str:
|
|
extra[key] = value
|
|
else:
|
|
k1, k2 = key
|
|
dct = extra.get(k1, {})
|
|
dct[k2] = value
|
|
extra[k1] = dct
|
|
|
|
self.logger.info(self._log_format % tuple(values), extra=extra)
|
|
except Exception:
|
|
self.logger.exception("Error in logging")
|
|
|
|
|
|
class reify:
|
|
"""Use as a class method decorator. It operates almost exactly like
|
|
the Python `@property` decorator, but it puts the result of the
|
|
method it decorates into the instance dict after the first call,
|
|
effectively replacing the function it decorates with an instance
|
|
variable. It is, in Python parlance, a data descriptor.
|
|
|
|
"""
|
|
|
|
def __init__(self, wrapped):
|
|
self.wrapped = wrapped
|
|
self.__doc__ = wrapped.__doc__
|
|
self.name = wrapped.__name__
|
|
|
|
def __get__(self, inst, owner):
|
|
try:
|
|
try:
|
|
return inst._cache[self.name]
|
|
except KeyError:
|
|
val = self.wrapped(inst)
|
|
inst._cache[self.name] = val
|
|
return val
|
|
except AttributeError:
|
|
if inst is None:
|
|
return self
|
|
raise
|
|
|
|
def __set__(self, inst, value):
|
|
raise AttributeError("reified property is read-only")
|
|
|
|
|
|
reify_py = reify
|
|
|
|
try:
|
|
from ._helpers import reify as reify_c
|
|
if not NO_EXTENSIONS:
|
|
reify = reify_c # type: ignore
|
|
except ImportError:
|
|
pass
|
|
|
|
_ipv4_pattern = (r'^(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}'
|
|
r'(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)$')
|
|
_ipv6_pattern = (
|
|
r'^(?:(?:(?:[A-F0-9]{1,4}:){6}|(?=(?:[A-F0-9]{0,4}:){0,6}'
|
|
r'(?:[0-9]{1,3}\.){3}[0-9]{1,3}$)(([0-9A-F]{1,4}:){0,5}|:)'
|
|
r'((:[0-9A-F]{1,4}){1,5}:|:)|::(?:[A-F0-9]{1,4}:){5})'
|
|
r'(?:(?:25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9]?[0-9])\.){3}'
|
|
r'(?:25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9]?[0-9])|(?:[A-F0-9]{1,4}:){7}'
|
|
r'[A-F0-9]{1,4}|(?=(?:[A-F0-9]{0,4}:){0,7}[A-F0-9]{0,4}$)'
|
|
r'(([0-9A-F]{1,4}:){1,7}|:)((:[0-9A-F]{1,4}){1,7}|:)|(?:[A-F0-9]{1,4}:){7}'
|
|
r':|:(:[A-F0-9]{1,4}){7})$')
|
|
_ipv4_regex = re.compile(_ipv4_pattern)
|
|
_ipv6_regex = re.compile(_ipv6_pattern, flags=re.IGNORECASE)
|
|
_ipv4_regexb = re.compile(_ipv4_pattern.encode('ascii'))
|
|
_ipv6_regexb = re.compile(_ipv6_pattern.encode('ascii'), flags=re.IGNORECASE)
|
|
|
|
|
|
def is_ip_address(host):
|
|
if host is None:
|
|
return False
|
|
if isinstance(host, str):
|
|
if _ipv4_regex.match(host) or _ipv6_regex.match(host):
|
|
return True
|
|
else:
|
|
return False
|
|
elif isinstance(host, (bytes, bytearray, memoryview)):
|
|
if _ipv4_regexb.match(host) or _ipv6_regexb.match(host):
|
|
return True
|
|
else:
|
|
return False
|
|
else:
|
|
raise TypeError("{} [{}] is not a str or bytes"
|
|
.format(host, type(host)))
|
|
|
|
|
|
_cached_current_datetime = None
|
|
_cached_formatted_datetime = None
|
|
|
|
|
|
def rfc822_formatted_time():
|
|
global _cached_current_datetime
|
|
global _cached_formatted_datetime
|
|
|
|
now = int(time.time())
|
|
if now != _cached_current_datetime:
|
|
# Weekday and month names for HTTP date/time formatting;
|
|
# always English!
|
|
# Tuples are constants stored in codeobject!
|
|
_weekdayname = ("Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun")
|
|
_monthname = ("", # Dummy so we can use 1-based month numbers
|
|
"Jan", "Feb", "Mar", "Apr", "May", "Jun",
|
|
"Jul", "Aug", "Sep", "Oct", "Nov", "Dec")
|
|
|
|
year, month, day, hh, mm, ss, wd, y, z = time.gmtime(now)
|
|
_cached_formatted_datetime = "%s, %02d %3s %4d %02d:%02d:%02d GMT" % (
|
|
_weekdayname[wd], day, _monthname[month], year, hh, mm, ss
|
|
)
|
|
_cached_current_datetime = now
|
|
return _cached_formatted_datetime
|
|
|
|
|
|
def _weakref_handle(info):
|
|
ref, name = info
|
|
ob = ref()
|
|
if ob is not None:
|
|
with suppress(Exception):
|
|
getattr(ob, name)()
|
|
|
|
|
|
def weakref_handle(ob, name, timeout, loop, ceil_timeout=True):
|
|
if timeout is not None and timeout > 0:
|
|
when = loop.time() + timeout
|
|
if ceil_timeout:
|
|
when = ceil(when)
|
|
|
|
return loop.call_at(when, _weakref_handle, (weakref.ref(ob), name))
|
|
|
|
|
|
def call_later(cb, timeout, loop):
|
|
if timeout is not None and timeout > 0:
|
|
when = ceil(loop.time() + timeout)
|
|
return loop.call_at(when, cb)
|
|
|
|
|
|
class TimeoutHandle:
|
|
""" Timeout handle """
|
|
|
|
def __init__(self, loop, timeout):
|
|
self._timeout = timeout
|
|
self._loop = loop
|
|
self._callbacks = []
|
|
|
|
def register(self, callback, *args, **kwargs):
|
|
self._callbacks.append((callback, args, kwargs))
|
|
|
|
def close(self):
|
|
self._callbacks.clear()
|
|
|
|
def start(self):
|
|
if self._timeout is not None and self._timeout > 0:
|
|
at = ceil(self._loop.time() + self._timeout)
|
|
return self._loop.call_at(at, self.__call__)
|
|
|
|
def timer(self):
|
|
if self._timeout is not None and self._timeout > 0:
|
|
timer = TimerContext(self._loop)
|
|
self.register(timer.timeout)
|
|
else:
|
|
timer = TimerNoop()
|
|
return timer
|
|
|
|
def __call__(self):
|
|
for cb, args, kwargs in self._callbacks:
|
|
with suppress(Exception):
|
|
cb(*args, **kwargs)
|
|
|
|
self._callbacks.clear()
|
|
|
|
|
|
class TimerNoop:
|
|
|
|
def __enter__(self):
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
return False
|
|
|
|
|
|
class TimerContext:
|
|
""" Low resolution timeout context manager """
|
|
|
|
def __init__(self, loop):
|
|
self._loop = loop
|
|
self._tasks = []
|
|
self._cancelled = False
|
|
|
|
def __enter__(self):
|
|
task = current_task(loop=self._loop)
|
|
|
|
if task is None:
|
|
raise RuntimeError('Timeout context manager should be used '
|
|
'inside a task')
|
|
|
|
if self._cancelled:
|
|
task.cancel()
|
|
raise asyncio.TimeoutError from None
|
|
|
|
self._tasks.append(task)
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
if self._tasks:
|
|
self._tasks.pop()
|
|
|
|
if exc_type is asyncio.CancelledError and self._cancelled:
|
|
raise asyncio.TimeoutError from None
|
|
|
|
def timeout(self):
|
|
if not self._cancelled:
|
|
for task in set(self._tasks):
|
|
task.cancel()
|
|
|
|
self._cancelled = True
|
|
|
|
|
|
class CeilTimeout(async_timeout.timeout):
|
|
|
|
def __enter__(self):
|
|
if self._timeout is not None:
|
|
self._task = current_task(loop=self._loop)
|
|
if self._task is None:
|
|
raise RuntimeError(
|
|
'Timeout context manager should be used inside a task')
|
|
self._cancel_handler = self._loop.call_at(
|
|
ceil(self._loop.time() + self._timeout), self._cancel_task)
|
|
return self
|
|
|
|
|
|
class HeadersMixin:
|
|
|
|
ATTRS = frozenset([
|
|
'_content_type', '_content_dict', '_stored_content_type'])
|
|
|
|
_content_type = None
|
|
_content_dict = None
|
|
_stored_content_type = sentinel
|
|
|
|
def _parse_content_type(self, raw):
|
|
self._stored_content_type = raw
|
|
if raw is None:
|
|
# default value according to RFC 2616
|
|
self._content_type = 'application/octet-stream'
|
|
self._content_dict = {}
|
|
else:
|
|
self._content_type, self._content_dict = cgi.parse_header(raw)
|
|
|
|
@property
|
|
def content_type(self):
|
|
"""The value of content part for Content-Type HTTP header."""
|
|
raw = self._headers.get(hdrs.CONTENT_TYPE)
|
|
if self._stored_content_type != raw:
|
|
self._parse_content_type(raw)
|
|
return self._content_type
|
|
|
|
@property
|
|
def charset(self):
|
|
"""The value of charset part for Content-Type HTTP header."""
|
|
raw = self._headers.get(hdrs.CONTENT_TYPE)
|
|
if self._stored_content_type != raw:
|
|
self._parse_content_type(raw)
|
|
return self._content_dict.get('charset')
|
|
|
|
@property
|
|
def content_length(self):
|
|
"""The value of Content-Length HTTP header."""
|
|
content_length = self._headers.get(hdrs.CONTENT_LENGTH)
|
|
|
|
if content_length:
|
|
return int(content_length)
|
|
|
|
|
|
def set_result(fut, result):
|
|
if not fut.done():
|
|
fut.set_result(result)
|
|
|
|
|
|
def set_exception(fut, exc):
|
|
if not fut.done():
|
|
fut.set_exception(exc)
|
|
|
|
|
|
class ChainMapProxy(ABCMapping):
|
|
__slots__ = ('_maps',)
|
|
|
|
def __init__(self, maps):
|
|
self._maps = tuple(maps)
|
|
|
|
def __init_subclass__(cls):
|
|
raise TypeError("Inheritance class {} from ChainMapProxy "
|
|
"is forbidden".format(cls.__name__))
|
|
|
|
def __getitem__(self, key):
|
|
for mapping in self._maps:
|
|
try:
|
|
return mapping[key]
|
|
except KeyError:
|
|
pass
|
|
raise KeyError(key)
|
|
|
|
def get(self, key, default=None):
|
|
return self[key] if key in self else default
|
|
|
|
def __len__(self):
|
|
# reuses stored hash values if possible
|
|
return len(set().union(*self._maps))
|
|
|
|
def __iter__(self):
|
|
d = {}
|
|
for mapping in reversed(self._maps):
|
|
# reuses stored hash values if possible
|
|
d.update(mapping)
|
|
return iter(d)
|
|
|
|
def __contains__(self, key):
|
|
return any(key in m for m in self._maps)
|
|
|
|
def __bool__(self):
|
|
return any(self._maps)
|
|
|
|
def __repr__(self):
|
|
content = ", ".join(map(repr, self._maps))
|
|
return 'ChainMapProxy({})'.format(content)
|