You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

347 lines
10 KiB

4 years ago
  1. import enum
  2. import io
  3. import json
  4. import mimetypes
  5. import os
  6. import warnings
  7. from abc import ABC, abstractmethod
  8. from collections.abc import AsyncIterable
  9. from itertools import chain
  10. from multidict import CIMultiDict
  11. from . import hdrs
  12. from .helpers import (PY_36, content_disposition_header, guess_filename,
  13. parse_mimetype, sentinel)
  14. from .streams import DEFAULT_LIMIT, StreamReader
  15. __all__ = ('PAYLOAD_REGISTRY', 'get_payload', 'payload_type', 'Payload',
  16. 'BytesPayload', 'StringPayload',
  17. 'IOBasePayload', 'BytesIOPayload', 'BufferedReaderPayload',
  18. 'TextIOPayload', 'StringIOPayload', 'JsonPayload',
  19. 'AsyncIterablePayload')
  20. TOO_LARGE_BYTES_BODY = 2 ** 20 # 1 MB
  21. class LookupError(Exception):
  22. pass
  23. class Order(enum.Enum):
  24. normal = 'normal'
  25. try_first = 'try_first'
  26. try_last = 'try_last'
  27. def get_payload(data, *args, **kwargs):
  28. return PAYLOAD_REGISTRY.get(data, *args, **kwargs)
  29. def register_payload(factory, type, *, order=Order.normal):
  30. PAYLOAD_REGISTRY.register(factory, type, order=order)
  31. class payload_type:
  32. def __init__(self, type, *, order=Order.normal):
  33. self.type = type
  34. self.order = order
  35. def __call__(self, factory):
  36. register_payload(factory, self.type, order=self.order)
  37. return factory
  38. class PayloadRegistry:
  39. """Payload registry.
  40. note: we need zope.interface for more efficient adapter search
  41. """
  42. def __init__(self):
  43. self._first = []
  44. self._normal = []
  45. self._last = []
  46. def get(self, data, *args, _CHAIN=chain, **kwargs):
  47. if isinstance(data, Payload):
  48. return data
  49. for factory, type in _CHAIN(self._first, self._normal, self._last):
  50. if isinstance(data, type):
  51. return factory(data, *args, **kwargs)
  52. raise LookupError()
  53. def register(self, factory, type, *, order=Order.normal):
  54. if order is Order.try_first:
  55. self._first.append((factory, type))
  56. elif order is Order.normal:
  57. self._normal.append((factory, type))
  58. elif order is Order.try_last:
  59. self._last.append((factory, type))
  60. else:
  61. raise ValueError("Unsupported order {!r}".format(order))
  62. class Payload(ABC):
  63. _size = None
  64. _headers = None
  65. _content_type = 'application/octet-stream'
  66. def __init__(self, value, *, headers=None, content_type=sentinel,
  67. filename=None, encoding=None, **kwargs):
  68. self._value = value
  69. self._encoding = encoding
  70. self._filename = filename
  71. if headers is not None:
  72. self._headers = CIMultiDict(headers)
  73. if content_type is sentinel and hdrs.CONTENT_TYPE in self._headers:
  74. content_type = self._headers[hdrs.CONTENT_TYPE]
  75. if content_type is sentinel:
  76. content_type = None
  77. self._content_type = content_type
  78. @property
  79. def size(self):
  80. """Size of the payload."""
  81. return self._size
  82. @property
  83. def filename(self):
  84. """Filename of the payload."""
  85. return self._filename
  86. @property
  87. def headers(self):
  88. """Custom item headers"""
  89. return self._headers
  90. @property
  91. def encoding(self):
  92. """Payload encoding"""
  93. return self._encoding
  94. @property
  95. def content_type(self):
  96. """Content type"""
  97. if self._content_type is not None:
  98. return self._content_type
  99. elif self._filename is not None:
  100. mime = mimetypes.guess_type(self._filename)[0]
  101. return 'application/octet-stream' if mime is None else mime
  102. else:
  103. return Payload._content_type
  104. def set_content_disposition(self, disptype, quote_fields=True, **params):
  105. """Sets ``Content-Disposition`` header."""
  106. if self._headers is None:
  107. self._headers = CIMultiDict()
  108. self._headers[hdrs.CONTENT_DISPOSITION] = content_disposition_header(
  109. disptype, quote_fields=quote_fields, **params)
  110. @abstractmethod
  111. async def write(self, writer):
  112. """Write payload.
  113. writer is an AbstractStreamWriter instance:
  114. """
  115. class BytesPayload(Payload):
  116. def __init__(self, value, *args, **kwargs):
  117. if not isinstance(value, (bytes, bytearray, memoryview)):
  118. raise TypeError("value argument must be byte-ish, not (!r)"
  119. .format(type(value)))
  120. if 'content_type' not in kwargs:
  121. kwargs['content_type'] = 'application/octet-stream'
  122. super().__init__(value, *args, **kwargs)
  123. self._size = len(value)
  124. if self._size > TOO_LARGE_BYTES_BODY:
  125. if PY_36:
  126. kwargs = {'source': self}
  127. else:
  128. kwargs = {}
  129. warnings.warn("Sending a large body directly with raw bytes might"
  130. " lock the event loop. You should probably pass an "
  131. "io.BytesIO object instead", ResourceWarning,
  132. **kwargs)
  133. async def write(self, writer):
  134. await writer.write(self._value)
  135. class StringPayload(BytesPayload):
  136. def __init__(self, value, *args,
  137. encoding=None, content_type=None, **kwargs):
  138. if encoding is None:
  139. if content_type is None:
  140. encoding = 'utf-8'
  141. content_type = 'text/plain; charset=utf-8'
  142. else:
  143. mimetype = parse_mimetype(content_type)
  144. encoding = mimetype.parameters.get('charset', 'utf-8')
  145. else:
  146. if content_type is None:
  147. content_type = 'text/plain; charset=%s' % encoding
  148. super().__init__(
  149. value.encode(encoding),
  150. encoding=encoding, content_type=content_type, *args, **kwargs)
  151. class StringIOPayload(StringPayload):
  152. def __init__(self, value, *args, **kwargs):
  153. super().__init__(value.read(), *args, **kwargs)
  154. class IOBasePayload(Payload):
  155. def __init__(self, value, disposition='attachment', *args, **kwargs):
  156. if 'filename' not in kwargs:
  157. kwargs['filename'] = guess_filename(value)
  158. super().__init__(value, *args, **kwargs)
  159. if self._filename is not None and disposition is not None:
  160. self.set_content_disposition(disposition, filename=self._filename)
  161. async def write(self, writer):
  162. try:
  163. chunk = self._value.read(DEFAULT_LIMIT)
  164. while chunk:
  165. await writer.write(chunk)
  166. chunk = self._value.read(DEFAULT_LIMIT)
  167. finally:
  168. self._value.close()
  169. class TextIOPayload(IOBasePayload):
  170. def __init__(self, value, *args,
  171. encoding=None, content_type=None, **kwargs):
  172. if encoding is None:
  173. if content_type is None:
  174. encoding = 'utf-8'
  175. content_type = 'text/plain; charset=utf-8'
  176. else:
  177. mimetype = parse_mimetype(content_type)
  178. encoding = mimetype.parameters.get('charset', 'utf-8')
  179. else:
  180. if content_type is None:
  181. content_type = 'text/plain; charset=%s' % encoding
  182. super().__init__(
  183. value,
  184. content_type=content_type, encoding=encoding, *args, **kwargs)
  185. @property
  186. def size(self):
  187. try:
  188. return os.fstat(self._value.fileno()).st_size - self._value.tell()
  189. except OSError:
  190. return None
  191. async def write(self, writer):
  192. try:
  193. chunk = self._value.read(DEFAULT_LIMIT)
  194. while chunk:
  195. await writer.write(chunk.encode(self._encoding))
  196. chunk = self._value.read(DEFAULT_LIMIT)
  197. finally:
  198. self._value.close()
  199. class BytesIOPayload(IOBasePayload):
  200. @property
  201. def size(self):
  202. position = self._value.tell()
  203. end = self._value.seek(0, os.SEEK_END)
  204. self._value.seek(position)
  205. return end - position
  206. class BufferedReaderPayload(IOBasePayload):
  207. @property
  208. def size(self):
  209. try:
  210. return os.fstat(self._value.fileno()).st_size - self._value.tell()
  211. except OSError:
  212. # data.fileno() is not supported, e.g.
  213. # io.BufferedReader(io.BytesIO(b'data'))
  214. return None
  215. class JsonPayload(BytesPayload):
  216. def __init__(self, value,
  217. encoding='utf-8', content_type='application/json',
  218. dumps=json.dumps, *args, **kwargs):
  219. super().__init__(
  220. dumps(value).encode(encoding),
  221. content_type=content_type, encoding=encoding, *args, **kwargs)
  222. class AsyncIterablePayload(Payload):
  223. def __init__(self, value, *args, **kwargs):
  224. if not isinstance(value, AsyncIterable):
  225. raise TypeError("value argument must support "
  226. "collections.abc.AsyncIterablebe interface, "
  227. "got {!r}".format(type(value)))
  228. if 'content_type' not in kwargs:
  229. kwargs['content_type'] = 'application/octet-stream'
  230. super().__init__(value, *args, **kwargs)
  231. self._iter = value.__aiter__()
  232. async def write(self, writer):
  233. try:
  234. # iter is not None check prevents rare cases
  235. # when the case iterable is used twice
  236. while True:
  237. chunk = await self._iter.__anext__()
  238. await writer.write(chunk)
  239. except StopAsyncIteration:
  240. self._iter = None
  241. class StreamReaderPayload(AsyncIterablePayload):
  242. def __init__(self, value, *args, **kwargs):
  243. super().__init__(value.iter_any(), *args, **kwargs)
  244. PAYLOAD_REGISTRY = PayloadRegistry()
  245. PAYLOAD_REGISTRY.register(BytesPayload, (bytes, bytearray, memoryview))
  246. PAYLOAD_REGISTRY.register(StringPayload, str)
  247. PAYLOAD_REGISTRY.register(StringIOPayload, io.StringIO)
  248. PAYLOAD_REGISTRY.register(TextIOPayload, io.TextIOBase)
  249. PAYLOAD_REGISTRY.register(BytesIOPayload, io.BytesIO)
  250. PAYLOAD_REGISTRY.register(
  251. BufferedReaderPayload, (io.BufferedReader, io.BufferedRandom))
  252. PAYLOAD_REGISTRY.register(IOBasePayload, io.IOBase)
  253. PAYLOAD_REGISTRY.register(StreamReaderPayload, StreamReader)
  254. # try_last for giving a chance to more specialized async interables like
  255. # multidict.BodyPartReaderPayload override the default
  256. PAYLOAD_REGISTRY.register(AsyncIterablePayload, AsyncIterable,
  257. order=Order.try_last)