You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

420 lines
15 KiB

4 years ago
  1. import asyncio
  2. import base64
  3. import binascii
  4. import hashlib
  5. import json
  6. import async_timeout
  7. import attr
  8. from multidict import CIMultiDict
  9. from . import hdrs
  10. from .helpers import call_later, set_result
  11. from .http import (WS_CLOSED_MESSAGE, WS_CLOSING_MESSAGE, WS_KEY,
  12. WebSocketError, WebSocketReader, WebSocketWriter, WSMessage,
  13. WSMsgType, ws_ext_gen, ws_ext_parse)
  14. from .log import ws_logger
  15. from .streams import EofStream, FlowControlDataQueue
  16. from .web_exceptions import HTTPBadRequest, HTTPException, HTTPMethodNotAllowed
  17. from .web_response import StreamResponse
  18. __all__ = ('WebSocketResponse', 'WebSocketReady', 'WSMsgType',)
  19. THRESHOLD_CONNLOST_ACCESS = 5
  20. @attr.s(frozen=True, slots=True)
  21. class WebSocketReady:
  22. ok = attr.ib(type=bool)
  23. protocol = attr.ib(type=str)
  24. def __bool__(self):
  25. return self.ok
  26. class WebSocketResponse(StreamResponse):
  27. def __init__(self, *,
  28. timeout=10.0, receive_timeout=None,
  29. autoclose=True, autoping=True, heartbeat=None,
  30. protocols=(), compress=True, max_msg_size=4*1024*1024):
  31. super().__init__(status=101)
  32. self._protocols = protocols
  33. self._ws_protocol = None
  34. self._writer = None
  35. self._reader = None
  36. self._closed = False
  37. self._closing = False
  38. self._conn_lost = 0
  39. self._close_code = None
  40. self._loop = None
  41. self._waiting = None
  42. self._exception = None
  43. self._timeout = timeout
  44. self._receive_timeout = receive_timeout
  45. self._autoclose = autoclose
  46. self._autoping = autoping
  47. self._heartbeat = heartbeat
  48. self._heartbeat_cb = None
  49. if heartbeat is not None:
  50. self._pong_heartbeat = heartbeat / 2.0
  51. self._pong_response_cb = None
  52. self._compress = compress
  53. self._max_msg_size = max_msg_size
  54. def _cancel_heartbeat(self):
  55. if self._pong_response_cb is not None:
  56. self._pong_response_cb.cancel()
  57. self._pong_response_cb = None
  58. if self._heartbeat_cb is not None:
  59. self._heartbeat_cb.cancel()
  60. self._heartbeat_cb = None
  61. def _reset_heartbeat(self):
  62. self._cancel_heartbeat()
  63. if self._heartbeat is not None:
  64. self._heartbeat_cb = call_later(
  65. self._send_heartbeat, self._heartbeat, self._loop)
  66. def _send_heartbeat(self):
  67. if self._heartbeat is not None and not self._closed:
  68. # fire-and-forget a task is not perfect but maybe ok for
  69. # sending ping. Otherwise we need a long-living heartbeat
  70. # task in the class.
  71. self._loop.create_task(self._writer.ping())
  72. if self._pong_response_cb is not None:
  73. self._pong_response_cb.cancel()
  74. self._pong_response_cb = call_later(
  75. self._pong_not_received, self._pong_heartbeat, self._loop)
  76. def _pong_not_received(self):
  77. if self._req is not None and self._req.transport is not None:
  78. self._closed = True
  79. self._close_code = 1006
  80. self._exception = asyncio.TimeoutError()
  81. self._req.transport.close()
  82. async def prepare(self, request):
  83. # make pre-check to don't hide it by do_handshake() exceptions
  84. if self._payload_writer is not None:
  85. return self._payload_writer
  86. protocol, writer = self._pre_start(request)
  87. payload_writer = await super().prepare(request)
  88. self._post_start(request, protocol, writer)
  89. await payload_writer.drain()
  90. return payload_writer
  91. def _handshake(self, request):
  92. headers = request.headers
  93. if request.method != hdrs.METH_GET:
  94. raise HTTPMethodNotAllowed(request.method, [hdrs.METH_GET])
  95. if 'websocket' != headers.get(hdrs.UPGRADE, '').lower().strip():
  96. raise HTTPBadRequest(
  97. text=('No WebSocket UPGRADE hdr: {}\n Can '
  98. '"Upgrade" only to "WebSocket".')
  99. .format(headers.get(hdrs.UPGRADE)))
  100. if 'upgrade' not in headers.get(hdrs.CONNECTION, '').lower():
  101. raise HTTPBadRequest(
  102. text='No CONNECTION upgrade hdr: {}'.format(
  103. headers.get(hdrs.CONNECTION)))
  104. # find common sub-protocol between client and server
  105. protocol = None
  106. if hdrs.SEC_WEBSOCKET_PROTOCOL in headers:
  107. req_protocols = [str(proto.strip()) for proto in
  108. headers[hdrs.SEC_WEBSOCKET_PROTOCOL].split(',')]
  109. for proto in req_protocols:
  110. if proto in self._protocols:
  111. protocol = proto
  112. break
  113. else:
  114. # No overlap found: Return no protocol as per spec
  115. ws_logger.warning(
  116. 'Client protocols %r don’t overlap server-known ones %r',
  117. req_protocols, self._protocols)
  118. # check supported version
  119. version = headers.get(hdrs.SEC_WEBSOCKET_VERSION, '')
  120. if version not in ('13', '8', '7'):
  121. raise HTTPBadRequest(
  122. text='Unsupported version: {}'.format(version))
  123. # check client handshake for validity
  124. key = headers.get(hdrs.SEC_WEBSOCKET_KEY)
  125. try:
  126. if not key or len(base64.b64decode(key)) != 16:
  127. raise HTTPBadRequest(
  128. text='Handshake error: {!r}'.format(key))
  129. except binascii.Error:
  130. raise HTTPBadRequest(
  131. text='Handshake error: {!r}'.format(key)) from None
  132. accept_val = base64.b64encode(
  133. hashlib.sha1(key.encode() + WS_KEY).digest()).decode()
  134. response_headers = CIMultiDict({hdrs.UPGRADE: 'websocket',
  135. hdrs.CONNECTION: 'upgrade',
  136. hdrs.TRANSFER_ENCODING: 'chunked',
  137. hdrs.SEC_WEBSOCKET_ACCEPT: accept_val})
  138. notakeover = False
  139. compress = self._compress
  140. if compress:
  141. extensions = headers.get(hdrs.SEC_WEBSOCKET_EXTENSIONS)
  142. # Server side always get return with no exception.
  143. # If something happened, just drop compress extension
  144. compress, notakeover = ws_ext_parse(extensions, isserver=True)
  145. if compress:
  146. enabledext = ws_ext_gen(compress=compress, isserver=True,
  147. server_notakeover=notakeover)
  148. response_headers[hdrs.SEC_WEBSOCKET_EXTENSIONS] = enabledext
  149. if protocol:
  150. response_headers[hdrs.SEC_WEBSOCKET_PROTOCOL] = protocol
  151. return (response_headers, protocol, compress, notakeover)
  152. def _pre_start(self, request):
  153. self._loop = request.loop
  154. headers, protocol, compress, notakeover = self._handshake(
  155. request)
  156. self._reset_heartbeat()
  157. self.set_status(101)
  158. self.headers.update(headers)
  159. self.force_close()
  160. self._compress = compress
  161. writer = WebSocketWriter(request._protocol,
  162. request._protocol.transport,
  163. compress=compress,
  164. notakeover=notakeover)
  165. return protocol, writer
  166. def _post_start(self, request, protocol, writer):
  167. self._ws_protocol = protocol
  168. self._writer = writer
  169. self._reader = FlowControlDataQueue(
  170. request._protocol, limit=2 ** 16, loop=self._loop)
  171. request.protocol.set_parser(WebSocketReader(
  172. self._reader, self._max_msg_size, compress=self._compress))
  173. # disable HTTP keepalive for WebSocket
  174. request.protocol.keep_alive(False)
  175. def can_prepare(self, request):
  176. if self._writer is not None:
  177. raise RuntimeError('Already started')
  178. try:
  179. _, protocol, _, _ = self._handshake(request)
  180. except HTTPException:
  181. return WebSocketReady(False, None)
  182. else:
  183. return WebSocketReady(True, protocol)
  184. @property
  185. def closed(self):
  186. return self._closed
  187. @property
  188. def close_code(self):
  189. return self._close_code
  190. @property
  191. def ws_protocol(self):
  192. return self._ws_protocol
  193. @property
  194. def compress(self):
  195. return self._compress
  196. def exception(self):
  197. return self._exception
  198. async def ping(self, message='b'):
  199. if self._writer is None:
  200. raise RuntimeError('Call .prepare() first')
  201. await self._writer.ping(message)
  202. async def pong(self, message='b'):
  203. # unsolicited pong
  204. if self._writer is None:
  205. raise RuntimeError('Call .prepare() first')
  206. await self._writer.pong(message)
  207. async def send_str(self, data, compress=None):
  208. if self._writer is None:
  209. raise RuntimeError('Call .prepare() first')
  210. if not isinstance(data, str):
  211. raise TypeError('data argument must be str (%r)' % type(data))
  212. await self._writer.send(data, binary=False, compress=compress)
  213. async def send_bytes(self, data, compress=None):
  214. if self._writer is None:
  215. raise RuntimeError('Call .prepare() first')
  216. if not isinstance(data, (bytes, bytearray, memoryview)):
  217. raise TypeError('data argument must be byte-ish (%r)' %
  218. type(data))
  219. await self._writer.send(data, binary=True, compress=compress)
  220. async def send_json(self, data, compress=None, *, dumps=json.dumps):
  221. await self.send_str(dumps(data), compress=compress)
  222. async def write_eof(self):
  223. if self._eof_sent:
  224. return
  225. if self._payload_writer is None:
  226. raise RuntimeError("Response has not been started")
  227. await self.close()
  228. self._eof_sent = True
  229. async def close(self, *, code=1000, message=b''):
  230. if self._writer is None:
  231. raise RuntimeError('Call .prepare() first')
  232. self._cancel_heartbeat()
  233. # we need to break `receive()` cycle first,
  234. # `close()` may be called from different task
  235. if self._waiting is not None and not self._closed:
  236. self._reader.feed_data(WS_CLOSING_MESSAGE, 0)
  237. await self._waiting
  238. if not self._closed:
  239. self._closed = True
  240. try:
  241. await self._writer.close(code, message)
  242. await self._payload_writer.drain()
  243. except (asyncio.CancelledError, asyncio.TimeoutError):
  244. self._close_code = 1006
  245. raise
  246. except Exception as exc:
  247. self._close_code = 1006
  248. self._exception = exc
  249. return True
  250. if self._closing:
  251. return True
  252. try:
  253. with async_timeout.timeout(self._timeout, loop=self._loop):
  254. msg = await self._reader.read()
  255. except asyncio.CancelledError:
  256. self._close_code = 1006
  257. raise
  258. except Exception as exc:
  259. self._close_code = 1006
  260. self._exception = exc
  261. return True
  262. if msg.type == WSMsgType.CLOSE:
  263. self._close_code = msg.data
  264. return True
  265. self._close_code = 1006
  266. self._exception = asyncio.TimeoutError()
  267. return True
  268. else:
  269. return False
  270. async def receive(self, timeout=None):
  271. if self._reader is None:
  272. raise RuntimeError('Call .prepare() first')
  273. while True:
  274. if self._waiting is not None:
  275. raise RuntimeError(
  276. 'Concurrent call to receive() is not allowed')
  277. if self._closed:
  278. self._conn_lost += 1
  279. if self._conn_lost >= THRESHOLD_CONNLOST_ACCESS:
  280. raise RuntimeError('WebSocket connection is closed.')
  281. return WS_CLOSED_MESSAGE
  282. elif self._closing:
  283. return WS_CLOSING_MESSAGE
  284. try:
  285. self._waiting = self._loop.create_future()
  286. try:
  287. with async_timeout.timeout(
  288. timeout or self._receive_timeout, loop=self._loop):
  289. msg = await self._reader.read()
  290. self._reset_heartbeat()
  291. finally:
  292. waiter = self._waiting
  293. set_result(waiter, True)
  294. self._waiting = None
  295. except (asyncio.CancelledError, asyncio.TimeoutError):
  296. self._close_code = 1006
  297. raise
  298. except EofStream:
  299. self._close_code = 1000
  300. await self.close()
  301. return WSMessage(WSMsgType.CLOSED, None, None)
  302. except WebSocketError as exc:
  303. self._close_code = exc.code
  304. await self.close(code=exc.code)
  305. return WSMessage(WSMsgType.ERROR, exc, None)
  306. except Exception as exc:
  307. self._exception = exc
  308. self._closing = True
  309. self._close_code = 1006
  310. await self.close()
  311. return WSMessage(WSMsgType.ERROR, exc, None)
  312. if msg.type == WSMsgType.CLOSE:
  313. self._closing = True
  314. self._close_code = msg.data
  315. if not self._closed and self._autoclose:
  316. await self.close()
  317. elif msg.type == WSMsgType.CLOSING:
  318. self._closing = True
  319. elif msg.type == WSMsgType.PING and self._autoping:
  320. await self.pong(msg.data)
  321. continue
  322. elif msg.type == WSMsgType.PONG and self._autoping:
  323. continue
  324. return msg
  325. async def receive_str(self, *, timeout=None):
  326. msg = await self.receive(timeout)
  327. if msg.type != WSMsgType.TEXT:
  328. raise TypeError(
  329. "Received message {}:{!r} is not WSMsgType.TEXT".format(
  330. msg.type, msg.data))
  331. return msg.data
  332. async def receive_bytes(self, *, timeout=None):
  333. msg = await self.receive(timeout)
  334. if msg.type != WSMsgType.BINARY:
  335. raise TypeError(
  336. "Received message {}:{!r} is not bytes".format(msg.type,
  337. msg.data))
  338. return msg.data
  339. async def receive_json(self, *, loads=json.loads, timeout=None):
  340. data = await self.receive_str(timeout=timeout)
  341. return loads(data)
  342. async def write(self, data):
  343. raise RuntimeError("Cannot call .write() for websocket")
  344. def __aiter__(self):
  345. return self
  346. async def __anext__(self):
  347. msg = await self.receive()
  348. if msg.type in (WSMsgType.CLOSE,
  349. WSMsgType.CLOSING,
  350. WSMsgType.CLOSED):
  351. raise StopAsyncIteration # NOQA
  352. return msg