You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

456 lines
14 KiB

4 years ago
  1. import asyncio
  2. import enum
  3. import io
  4. import json
  5. import mimetypes
  6. import os
  7. import warnings
  8. from abc import ABC, abstractmethod
  9. from itertools import chain
  10. from typing import (
  11. IO,
  12. TYPE_CHECKING,
  13. Any,
  14. ByteString,
  15. Dict,
  16. Iterable,
  17. Optional,
  18. Text,
  19. TextIO,
  20. Tuple,
  21. Type,
  22. Union,
  23. )
  24. from multidict import CIMultiDict
  25. from . import hdrs
  26. from .abc import AbstractStreamWriter
  27. from .helpers import (
  28. PY_36,
  29. content_disposition_header,
  30. guess_filename,
  31. parse_mimetype,
  32. sentinel,
  33. )
  34. from .streams import DEFAULT_LIMIT, StreamReader
  35. from .typedefs import JSONEncoder, _CIMultiDict
  36. __all__ = ('PAYLOAD_REGISTRY', 'get_payload', 'payload_type', 'Payload',
  37. 'BytesPayload', 'StringPayload',
  38. 'IOBasePayload', 'BytesIOPayload', 'BufferedReaderPayload',
  39. 'TextIOPayload', 'StringIOPayload', 'JsonPayload',
  40. 'AsyncIterablePayload')
  41. TOO_LARGE_BYTES_BODY = 2 ** 20 # 1 MB
  42. if TYPE_CHECKING: # pragma: no cover
  43. from typing import List # noqa
  44. class LookupError(Exception):
  45. pass
  46. class Order(str, enum.Enum):
  47. normal = 'normal'
  48. try_first = 'try_first'
  49. try_last = 'try_last'
  50. def get_payload(data: Any, *args: Any, **kwargs: Any) -> 'Payload':
  51. return PAYLOAD_REGISTRY.get(data, *args, **kwargs)
  52. def register_payload(factory: Type['Payload'],
  53. type: Any,
  54. *,
  55. order: Order=Order.normal) -> None:
  56. PAYLOAD_REGISTRY.register(factory, type, order=order)
  57. class payload_type:
  58. def __init__(self, type: Any, *, order: Order=Order.normal) -> None:
  59. self.type = type
  60. self.order = order
  61. def __call__(self, factory: Type['Payload']) -> Type['Payload']:
  62. register_payload(factory, self.type, order=self.order)
  63. return factory
  64. class PayloadRegistry:
  65. """Payload registry.
  66. note: we need zope.interface for more efficient adapter search
  67. """
  68. def __init__(self) -> None:
  69. self._first = [] # type: List[Tuple[Type[Payload], Any]]
  70. self._normal = [] # type: List[Tuple[Type[Payload], Any]]
  71. self._last = [] # type: List[Tuple[Type[Payload], Any]]
  72. def get(self,
  73. data: Any,
  74. *args: Any,
  75. _CHAIN: Any=chain,
  76. **kwargs: Any) -> 'Payload':
  77. if isinstance(data, Payload):
  78. return data
  79. for factory, type in _CHAIN(self._first, self._normal, self._last):
  80. if isinstance(data, type):
  81. return factory(data, *args, **kwargs)
  82. raise LookupError()
  83. def register(self,
  84. factory: Type['Payload'],
  85. type: Any,
  86. *,
  87. order: Order=Order.normal) -> None:
  88. if order is Order.try_first:
  89. self._first.append((factory, type))
  90. elif order is Order.normal:
  91. self._normal.append((factory, type))
  92. elif order is Order.try_last:
  93. self._last.append((factory, type))
  94. else:
  95. raise ValueError("Unsupported order {!r}".format(order))
  96. class Payload(ABC):
  97. _default_content_type = 'application/octet-stream' # type: str
  98. _size = None # type: Optional[int]
  99. def __init__(self,
  100. value: Any,
  101. headers: Optional[
  102. Union[
  103. _CIMultiDict,
  104. Dict[str, str],
  105. Iterable[Tuple[str, str]]
  106. ]
  107. ] = None,
  108. content_type: Optional[str]=sentinel,
  109. filename: Optional[str]=None,
  110. encoding: Optional[str]=None,
  111. **kwargs: Any) -> None:
  112. self._encoding = encoding
  113. self._filename = filename
  114. self._headers = CIMultiDict() # type: _CIMultiDict
  115. self._value = value
  116. if content_type is not sentinel and content_type is not None:
  117. self._headers[hdrs.CONTENT_TYPE] = content_type
  118. elif self._filename is not None:
  119. content_type = mimetypes.guess_type(self._filename)[0]
  120. if content_type is None:
  121. content_type = self._default_content_type
  122. self._headers[hdrs.CONTENT_TYPE] = content_type
  123. else:
  124. self._headers[hdrs.CONTENT_TYPE] = self._default_content_type
  125. self._headers.update(headers or {})
  126. @property
  127. def size(self) -> Optional[int]:
  128. """Size of the payload."""
  129. return self._size
  130. @property
  131. def filename(self) -> Optional[str]:
  132. """Filename of the payload."""
  133. return self._filename
  134. @property
  135. def headers(self) -> _CIMultiDict:
  136. """Custom item headers"""
  137. return self._headers
  138. @property
  139. def _binary_headers(self) -> bytes:
  140. return ''.join(
  141. [k + ': ' + v + '\r\n' for k, v in self.headers.items()]
  142. ).encode('utf-8') + b'\r\n'
  143. @property
  144. def encoding(self) -> Optional[str]:
  145. """Payload encoding"""
  146. return self._encoding
  147. @property
  148. def content_type(self) -> str:
  149. """Content type"""
  150. return self._headers[hdrs.CONTENT_TYPE]
  151. def set_content_disposition(self,
  152. disptype: str,
  153. quote_fields: bool=True,
  154. **params: Any) -> None:
  155. """Sets ``Content-Disposition`` header."""
  156. self._headers[hdrs.CONTENT_DISPOSITION] = content_disposition_header(
  157. disptype, quote_fields=quote_fields, **params)
  158. @abstractmethod
  159. async def write(self, writer: AbstractStreamWriter) -> None:
  160. """Write payload.
  161. writer is an AbstractStreamWriter instance:
  162. """
  163. class BytesPayload(Payload):
  164. def __init__(self,
  165. value: ByteString,
  166. *args: Any,
  167. **kwargs: Any) -> None:
  168. if not isinstance(value, (bytes, bytearray, memoryview)):
  169. raise TypeError("value argument must be byte-ish, not (!r)"
  170. .format(type(value)))
  171. if 'content_type' not in kwargs:
  172. kwargs['content_type'] = 'application/octet-stream'
  173. super().__init__(value, *args, **kwargs)
  174. self._size = len(value)
  175. if self._size > TOO_LARGE_BYTES_BODY:
  176. if PY_36:
  177. kwargs = {'source': self}
  178. else:
  179. kwargs = {}
  180. warnings.warn("Sending a large body directly with raw bytes might"
  181. " lock the event loop. You should probably pass an "
  182. "io.BytesIO object instead", ResourceWarning,
  183. **kwargs)
  184. async def write(self, writer: AbstractStreamWriter) -> None:
  185. await writer.write(self._value)
  186. class StringPayload(BytesPayload):
  187. def __init__(self,
  188. value: Text,
  189. *args: Any,
  190. encoding: Optional[str]=None,
  191. content_type: Optional[str]=None,
  192. **kwargs: Any) -> None:
  193. if encoding is None:
  194. if content_type is None:
  195. real_encoding = 'utf-8'
  196. content_type = 'text/plain; charset=utf-8'
  197. else:
  198. mimetype = parse_mimetype(content_type)
  199. real_encoding = mimetype.parameters.get('charset', 'utf-8')
  200. else:
  201. if content_type is None:
  202. content_type = 'text/plain; charset=%s' % encoding
  203. real_encoding = encoding
  204. super().__init__(
  205. value.encode(real_encoding),
  206. encoding=real_encoding,
  207. content_type=content_type,
  208. *args,
  209. **kwargs,
  210. )
  211. class StringIOPayload(StringPayload):
  212. def __init__(self,
  213. value: IO[str],
  214. *args: Any,
  215. **kwargs: Any) -> None:
  216. super().__init__(value.read(), *args, **kwargs)
  217. class IOBasePayload(Payload):
  218. def __init__(self,
  219. value: IO[Any],
  220. disposition: str='attachment',
  221. *args: Any,
  222. **kwargs: Any) -> None:
  223. if 'filename' not in kwargs:
  224. kwargs['filename'] = guess_filename(value)
  225. super().__init__(value, *args, **kwargs)
  226. if self._filename is not None and disposition is not None:
  227. if hdrs.CONTENT_DISPOSITION not in self.headers:
  228. self.set_content_disposition(
  229. disposition, filename=self._filename
  230. )
  231. async def write(self, writer: AbstractStreamWriter) -> None:
  232. loop = asyncio.get_event_loop()
  233. try:
  234. chunk = await loop.run_in_executor(
  235. None, self._value.read, DEFAULT_LIMIT
  236. )
  237. while chunk:
  238. await writer.write(chunk)
  239. chunk = await loop.run_in_executor(
  240. None, self._value.read, DEFAULT_LIMIT
  241. )
  242. finally:
  243. await loop.run_in_executor(None, self._value.close)
  244. class TextIOPayload(IOBasePayload):
  245. def __init__(self,
  246. value: TextIO,
  247. *args: Any,
  248. encoding: Optional[str]=None,
  249. content_type: Optional[str]=None,
  250. **kwargs: Any) -> None:
  251. if encoding is None:
  252. if content_type is None:
  253. encoding = 'utf-8'
  254. content_type = 'text/plain; charset=utf-8'
  255. else:
  256. mimetype = parse_mimetype(content_type)
  257. encoding = mimetype.parameters.get('charset', 'utf-8')
  258. else:
  259. if content_type is None:
  260. content_type = 'text/plain; charset=%s' % encoding
  261. super().__init__(
  262. value,
  263. content_type=content_type,
  264. encoding=encoding,
  265. *args,
  266. **kwargs,
  267. )
  268. @property
  269. def size(self) -> Optional[int]:
  270. try:
  271. return os.fstat(self._value.fileno()).st_size - self._value.tell()
  272. except OSError:
  273. return None
  274. async def write(self, writer: AbstractStreamWriter) -> None:
  275. loop = asyncio.get_event_loop()
  276. try:
  277. chunk = await loop.run_in_executor(
  278. None, self._value.read, DEFAULT_LIMIT
  279. )
  280. while chunk:
  281. await writer.write(chunk.encode(self._encoding))
  282. chunk = await loop.run_in_executor(
  283. None, self._value.read, DEFAULT_LIMIT
  284. )
  285. finally:
  286. await loop.run_in_executor(None, self._value.close)
  287. class BytesIOPayload(IOBasePayload):
  288. @property
  289. def size(self) -> int:
  290. position = self._value.tell()
  291. end = self._value.seek(0, os.SEEK_END)
  292. self._value.seek(position)
  293. return end - position
  294. class BufferedReaderPayload(IOBasePayload):
  295. @property
  296. def size(self) -> Optional[int]:
  297. try:
  298. return os.fstat(self._value.fileno()).st_size - self._value.tell()
  299. except OSError:
  300. # data.fileno() is not supported, e.g.
  301. # io.BufferedReader(io.BytesIO(b'data'))
  302. return None
  303. class JsonPayload(BytesPayload):
  304. def __init__(self,
  305. value: Any,
  306. encoding: str='utf-8',
  307. content_type: str='application/json',
  308. dumps: JSONEncoder=json.dumps,
  309. *args: Any,
  310. **kwargs: Any) -> None:
  311. super().__init__(
  312. dumps(value).encode(encoding),
  313. content_type=content_type, encoding=encoding, *args, **kwargs)
  314. if TYPE_CHECKING: # pragma: no cover
  315. from typing import AsyncIterator, AsyncIterable
  316. _AsyncIterator = AsyncIterator[bytes]
  317. _AsyncIterable = AsyncIterable[bytes]
  318. else:
  319. from collections.abc import AsyncIterable, AsyncIterator
  320. _AsyncIterator = AsyncIterator
  321. _AsyncIterable = AsyncIterable
  322. class AsyncIterablePayload(Payload):
  323. _iter = None # type: Optional[_AsyncIterator]
  324. def __init__(self,
  325. value: _AsyncIterable,
  326. *args: Any,
  327. **kwargs: Any) -> None:
  328. if not isinstance(value, AsyncIterable):
  329. raise TypeError("value argument must support "
  330. "collections.abc.AsyncIterablebe interface, "
  331. "got {!r}".format(type(value)))
  332. if 'content_type' not in kwargs:
  333. kwargs['content_type'] = 'application/octet-stream'
  334. super().__init__(value, *args, **kwargs)
  335. self._iter = value.__aiter__()
  336. async def write(self, writer: AbstractStreamWriter) -> None:
  337. if self._iter:
  338. try:
  339. # iter is not None check prevents rare cases
  340. # when the case iterable is used twice
  341. while True:
  342. chunk = await self._iter.__anext__()
  343. await writer.write(chunk)
  344. except StopAsyncIteration:
  345. self._iter = None
  346. class StreamReaderPayload(AsyncIterablePayload):
  347. def __init__(self, value: StreamReader, *args: Any, **kwargs: Any) -> None:
  348. super().__init__(value.iter_any(), *args, **kwargs)
  349. PAYLOAD_REGISTRY = PayloadRegistry()
  350. PAYLOAD_REGISTRY.register(BytesPayload, (bytes, bytearray, memoryview))
  351. PAYLOAD_REGISTRY.register(StringPayload, str)
  352. PAYLOAD_REGISTRY.register(StringIOPayload, io.StringIO)
  353. PAYLOAD_REGISTRY.register(TextIOPayload, io.TextIOBase)
  354. PAYLOAD_REGISTRY.register(BytesIOPayload, io.BytesIO)
  355. PAYLOAD_REGISTRY.register(
  356. BufferedReaderPayload, (io.BufferedReader, io.BufferedRandom))
  357. PAYLOAD_REGISTRY.register(IOBasePayload, io.IOBase)
  358. PAYLOAD_REGISTRY.register(StreamReaderPayload, StreamReader)
  359. # try_last for giving a chance to more specialized async interables like
  360. # multidict.BodyPartReaderPayload override the default
  361. PAYLOAD_REGISTRY.register(AsyncIterablePayload, AsyncIterable,
  362. order=Order.try_last)