You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

169 lines
5.0 KiB

4 years ago
  1. """Http related parsers and protocol."""
  2. import asyncio
  3. import collections
  4. import zlib
  5. from typing import Any, Awaitable, Callable, Optional, Union # noqa
  6. from .abc import AbstractStreamWriter
  7. from .base_protocol import BaseProtocol
  8. from .helpers import NO_EXTENSIONS
  9. __all__ = ('StreamWriter', 'HttpVersion', 'HttpVersion10', 'HttpVersion11')
  10. HttpVersion = collections.namedtuple('HttpVersion', ['major', 'minor'])
  11. HttpVersion10 = HttpVersion(1, 0)
  12. HttpVersion11 = HttpVersion(1, 1)
  13. _T_Data = Union[bytes, bytearray, memoryview]
  14. _T_OnChunkSent = Optional[Callable[[_T_Data], Awaitable[None]]]
  15. class StreamWriter(AbstractStreamWriter):
  16. def __init__(self,
  17. protocol: BaseProtocol,
  18. loop: asyncio.AbstractEventLoop,
  19. on_chunk_sent: _T_OnChunkSent = None) -> None:
  20. self._protocol = protocol
  21. self._transport = protocol.transport
  22. self.loop = loop
  23. self.length = None
  24. self.chunked = False
  25. self.buffer_size = 0
  26. self.output_size = 0
  27. self._eof = False
  28. self._compress = None # type: Any
  29. self._drain_waiter = None
  30. self._on_chunk_sent = on_chunk_sent # type: _T_OnChunkSent
  31. @property
  32. def transport(self) -> asyncio.Transport:
  33. return self._transport
  34. @property
  35. def protocol(self) -> BaseProtocol:
  36. return self._protocol
  37. def enable_chunking(self) -> None:
  38. self.chunked = True
  39. def enable_compression(self, encoding: str='deflate') -> None:
  40. zlib_mode = (16 + zlib.MAX_WBITS
  41. if encoding == 'gzip' else -zlib.MAX_WBITS)
  42. self._compress = zlib.compressobj(wbits=zlib_mode)
  43. def _write(self, chunk) -> None:
  44. size = len(chunk)
  45. self.buffer_size += size
  46. self.output_size += size
  47. if self._transport is None or self._transport.is_closing():
  48. raise ConnectionResetError('Cannot write to closing transport')
  49. self._transport.write(chunk)
  50. async def write(self, chunk, *, drain=True, LIMIT=0x10000) -> None:
  51. """Writes chunk of data to a stream.
  52. write_eof() indicates end of stream.
  53. writer can't be used after write_eof() method being called.
  54. write() return drain future.
  55. """
  56. if self._on_chunk_sent is not None:
  57. await self._on_chunk_sent(chunk)
  58. if self._compress is not None:
  59. chunk = self._compress.compress(chunk)
  60. if not chunk:
  61. return
  62. if self.length is not None:
  63. chunk_len = len(chunk)
  64. if self.length >= chunk_len:
  65. self.length = self.length - chunk_len
  66. else:
  67. chunk = chunk[:self.length]
  68. self.length = 0
  69. if not chunk:
  70. return
  71. if chunk:
  72. if self.chunked:
  73. chunk_len = ('%x\r\n' % len(chunk)).encode('ascii')
  74. chunk = chunk_len + chunk + b'\r\n'
  75. self._write(chunk)
  76. if self.buffer_size > LIMIT and drain:
  77. self.buffer_size = 0
  78. await self.drain()
  79. async def write_headers(self, status_line, headers) -> None:
  80. """Write request/response status and headers."""
  81. # status + headers
  82. buf = _serialize_headers(status_line, headers)
  83. self._write(buf)
  84. async def write_eof(self, chunk=b'') -> None:
  85. if self._eof:
  86. return
  87. if chunk and self._on_chunk_sent is not None:
  88. await self._on_chunk_sent(chunk)
  89. if self._compress:
  90. if chunk:
  91. chunk = self._compress.compress(chunk)
  92. chunk = chunk + self._compress.flush()
  93. if chunk and self.chunked:
  94. chunk_len = ('%x\r\n' % len(chunk)).encode('ascii')
  95. chunk = chunk_len + chunk + b'\r\n0\r\n\r\n'
  96. else:
  97. if self.chunked:
  98. if chunk:
  99. chunk_len = ('%x\r\n' % len(chunk)).encode('ascii')
  100. chunk = chunk_len + chunk + b'\r\n0\r\n\r\n'
  101. else:
  102. chunk = b'0\r\n\r\n'
  103. if chunk:
  104. self._write(chunk)
  105. await self.drain()
  106. self._eof = True
  107. self._transport = None
  108. async def drain(self) -> None:
  109. """Flush the write buffer.
  110. The intended use is to write
  111. await w.write(data)
  112. await w.drain()
  113. """
  114. if self._protocol.transport is not None:
  115. await self._protocol._drain_helper()
  116. def _py_serialize_headers(status_line, headers):
  117. headers = status_line + '\r\n' + ''.join(
  118. [k + ': ' + v + '\r\n' for k, v in headers.items()])
  119. return headers.encode('utf-8') + b'\r\n'
  120. _serialize_headers = _py_serialize_headers
  121. try:
  122. import aiohttp._http_writer as _http_writer # type: ignore
  123. _c_serialize_headers = _http_writer._serialize_headers
  124. if not NO_EXTENSIONS: # pragma: no cover
  125. _serialize_headers = _c_serialize_headers
  126. except ImportError:
  127. pass