You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

640 lines
23 KiB

4 years ago
  1. """WebSocket protocol versions 13 and 8."""
  2. import collections
  3. import json
  4. import random
  5. import re
  6. import sys
  7. import zlib
  8. from enum import IntEnum
  9. from struct import Struct
  10. from .helpers import NO_EXTENSIONS
  11. from .log import ws_logger
  12. __all__ = ('WS_CLOSED_MESSAGE', 'WS_CLOSING_MESSAGE', 'WS_KEY',
  13. 'WebSocketReader', 'WebSocketWriter', 'WSMessage',
  14. 'WebSocketError', 'WSMsgType', 'WSCloseCode')
  15. class WSCloseCode(IntEnum):
  16. OK = 1000
  17. GOING_AWAY = 1001
  18. PROTOCOL_ERROR = 1002
  19. UNSUPPORTED_DATA = 1003
  20. INVALID_TEXT = 1007
  21. POLICY_VIOLATION = 1008
  22. MESSAGE_TOO_BIG = 1009
  23. MANDATORY_EXTENSION = 1010
  24. INTERNAL_ERROR = 1011
  25. SERVICE_RESTART = 1012
  26. TRY_AGAIN_LATER = 1013
  27. ALLOWED_CLOSE_CODES = {int(i) for i in WSCloseCode}
  28. class WSMsgType(IntEnum):
  29. # websocket spec types
  30. CONTINUATION = 0x0
  31. TEXT = 0x1
  32. BINARY = 0x2
  33. PING = 0x9
  34. PONG = 0xa
  35. CLOSE = 0x8
  36. # aiohttp specific types
  37. CLOSING = 0x100
  38. CLOSED = 0x101
  39. ERROR = 0x102
  40. text = TEXT
  41. binary = BINARY
  42. ping = PING
  43. pong = PONG
  44. close = CLOSE
  45. closing = CLOSING
  46. closed = CLOSED
  47. error = ERROR
  48. WS_KEY = b'258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
  49. UNPACK_LEN2 = Struct('!H').unpack_from
  50. UNPACK_LEN3 = Struct('!Q').unpack_from
  51. UNPACK_CLOSE_CODE = Struct('!H').unpack
  52. PACK_LEN1 = Struct('!BB').pack
  53. PACK_LEN2 = Struct('!BBH').pack
  54. PACK_LEN3 = Struct('!BBQ').pack
  55. PACK_CLOSE_CODE = Struct('!H').pack
  56. MSG_SIZE = 2 ** 14
  57. DEFAULT_LIMIT = 2 ** 16
  58. _WSMessageBase = collections.namedtuple('_WSMessageBase',
  59. ['type', 'data', 'extra'])
  60. class WSMessage(_WSMessageBase):
  61. def json(self, *, loads=json.loads):
  62. """Return parsed JSON data.
  63. .. versionadded:: 0.22
  64. """
  65. return loads(self.data)
  66. WS_CLOSED_MESSAGE = WSMessage(WSMsgType.CLOSED, None, None)
  67. WS_CLOSING_MESSAGE = WSMessage(WSMsgType.CLOSING, None, None)
  68. class WebSocketError(Exception):
  69. """WebSocket protocol parser error."""
  70. def __init__(self, code, message):
  71. self.code = code
  72. super().__init__(message)
  73. class WSHandshakeError(Exception):
  74. """WebSocket protocol handshake error."""
  75. native_byteorder = sys.byteorder
  76. # Used by _websocket_mask_python
  77. _XOR_TABLE = [bytes(a ^ b for a in range(256)) for b in range(256)]
  78. def _websocket_mask_python(mask, data):
  79. """Websocket masking function.
  80. `mask` is a `bytes` object of length 4; `data` is a `bytearray`
  81. object of any length. The contents of `data` are masked with `mask`,
  82. as specified in section 5.3 of RFC 6455.
  83. Note that this function mutates the `data` argument.
  84. This pure-python implementation may be replaced by an optimized
  85. version when available.
  86. """
  87. assert isinstance(data, bytearray), data
  88. assert len(mask) == 4, mask
  89. if data:
  90. a, b, c, d = (_XOR_TABLE[n] for n in mask)
  91. data[::4] = data[::4].translate(a)
  92. data[1::4] = data[1::4].translate(b)
  93. data[2::4] = data[2::4].translate(c)
  94. data[3::4] = data[3::4].translate(d)
  95. if NO_EXTENSIONS:
  96. _websocket_mask = _websocket_mask_python
  97. else:
  98. try:
  99. from ._websocket import _websocket_mask_cython # type: ignore
  100. _websocket_mask = _websocket_mask_cython
  101. except ImportError: # pragma: no cover
  102. _websocket_mask = _websocket_mask_python
  103. _WS_DEFLATE_TRAILING = bytes([0x00, 0x00, 0xff, 0xff])
  104. _WS_EXT_RE = re.compile(r'^(?:;\s*(?:'
  105. r'(server_no_context_takeover)|'
  106. r'(client_no_context_takeover)|'
  107. r'(server_max_window_bits(?:=(\d+))?)|'
  108. r'(client_max_window_bits(?:=(\d+))?)))*$')
  109. _WS_EXT_RE_SPLIT = re.compile(r'permessage-deflate([^,]+)?')
  110. def ws_ext_parse(extstr, isserver=False):
  111. if not extstr:
  112. return 0, False
  113. compress = 0
  114. notakeover = False
  115. for ext in _WS_EXT_RE_SPLIT.finditer(extstr):
  116. defext = ext.group(1)
  117. # Return compress = 15 when get `permessage-deflate`
  118. if not defext:
  119. compress = 15
  120. break
  121. match = _WS_EXT_RE.match(defext)
  122. if match:
  123. compress = 15
  124. if isserver:
  125. # Server never fail to detect compress handshake.
  126. # Server does not need to send max wbit to client
  127. if match.group(4):
  128. compress = int(match.group(4))
  129. # Group3 must match if group4 matches
  130. # Compress wbit 8 does not support in zlib
  131. # If compress level not support,
  132. # CONTINUE to next extension
  133. if compress > 15 or compress < 9:
  134. compress = 0
  135. continue
  136. if match.group(1):
  137. notakeover = True
  138. # Ignore regex group 5 & 6 for client_max_window_bits
  139. break
  140. else:
  141. if match.group(6):
  142. compress = int(match.group(6))
  143. # Group5 must match if group6 matches
  144. # Compress wbit 8 does not support in zlib
  145. # If compress level not support,
  146. # FAIL the parse progress
  147. if compress > 15 or compress < 9:
  148. raise WSHandshakeError('Invalid window size')
  149. if match.group(2):
  150. notakeover = True
  151. # Ignore regex group 5 & 6 for client_max_window_bits
  152. break
  153. # Return Fail if client side and not match
  154. elif not isserver:
  155. raise WSHandshakeError('Extension for deflate not supported' +
  156. ext.group(1))
  157. return compress, notakeover
  158. def ws_ext_gen(compress=15, isserver=False,
  159. server_notakeover=False):
  160. # client_notakeover=False not used for server
  161. # compress wbit 8 does not support in zlib
  162. if compress < 9 or compress > 15:
  163. raise ValueError('Compress wbits must between 9 and 15, '
  164. 'zlib does not support wbits=8')
  165. enabledext = ['permessage-deflate']
  166. if not isserver:
  167. enabledext.append('client_max_window_bits')
  168. if compress < 15:
  169. enabledext.append('server_max_window_bits=' + str(compress))
  170. if server_notakeover:
  171. enabledext.append('server_no_context_takeover')
  172. # if client_notakeover:
  173. # enabledext.append('client_no_context_takeover')
  174. return '; '.join(enabledext)
  175. class WSParserState(IntEnum):
  176. READ_HEADER = 1
  177. READ_PAYLOAD_LENGTH = 2
  178. READ_PAYLOAD_MASK = 3
  179. READ_PAYLOAD = 4
  180. class WebSocketReader:
  181. def __init__(self, queue, max_msg_size, compress=True):
  182. self.queue = queue
  183. self._max_msg_size = max_msg_size
  184. self._exc = None
  185. self._partial = bytearray()
  186. self._state = WSParserState.READ_HEADER
  187. self._opcode = None
  188. self._frame_fin = False
  189. self._frame_opcode = None
  190. self._frame_payload = bytearray()
  191. self._tail = b''
  192. self._has_mask = False
  193. self._frame_mask = None
  194. self._payload_length = 0
  195. self._payload_length_flag = 0
  196. self._compressed = None
  197. self._decompressobj = None
  198. self._compress = compress
  199. def feed_eof(self):
  200. self.queue.feed_eof()
  201. def feed_data(self, data):
  202. if self._exc:
  203. return True, data
  204. try:
  205. return self._feed_data(data)
  206. except Exception as exc:
  207. self._exc = exc
  208. self.queue.set_exception(exc)
  209. return True, b''
  210. def _feed_data(self, data):
  211. for fin, opcode, payload, compressed in self.parse_frame(data):
  212. if compressed and not self._decompressobj:
  213. self._decompressobj = zlib.decompressobj(wbits=-zlib.MAX_WBITS)
  214. if opcode == WSMsgType.CLOSE:
  215. if len(payload) >= 2:
  216. close_code = UNPACK_CLOSE_CODE(payload[:2])[0]
  217. if (close_code < 3000 and
  218. close_code not in ALLOWED_CLOSE_CODES):
  219. raise WebSocketError(
  220. WSCloseCode.PROTOCOL_ERROR,
  221. 'Invalid close code: {}'.format(close_code))
  222. try:
  223. close_message = payload[2:].decode('utf-8')
  224. except UnicodeDecodeError as exc:
  225. raise WebSocketError(
  226. WSCloseCode.INVALID_TEXT,
  227. 'Invalid UTF-8 text message') from exc
  228. msg = WSMessage(WSMsgType.CLOSE, close_code, close_message)
  229. elif payload:
  230. raise WebSocketError(
  231. WSCloseCode.PROTOCOL_ERROR,
  232. 'Invalid close frame: {} {} {!r}'.format(
  233. fin, opcode, payload))
  234. else:
  235. msg = WSMessage(WSMsgType.CLOSE, 0, '')
  236. self.queue.feed_data(msg, 0)
  237. elif opcode == WSMsgType.PING:
  238. self.queue.feed_data(
  239. WSMessage(WSMsgType.PING, payload, ''), len(payload))
  240. elif opcode == WSMsgType.PONG:
  241. self.queue.feed_data(
  242. WSMessage(WSMsgType.PONG, payload, ''), len(payload))
  243. elif opcode not in (
  244. WSMsgType.TEXT, WSMsgType.BINARY) and self._opcode is None:
  245. raise WebSocketError(
  246. WSCloseCode.PROTOCOL_ERROR,
  247. "Unexpected opcode={!r}".format(opcode))
  248. else:
  249. # load text/binary
  250. if not fin:
  251. # got partial frame payload
  252. if opcode != WSMsgType.CONTINUATION:
  253. self._opcode = opcode
  254. self._partial.extend(payload)
  255. if (self._max_msg_size and
  256. len(self._partial) >= self._max_msg_size):
  257. raise WebSocketError(
  258. WSCloseCode.MESSAGE_TOO_BIG,
  259. "Message size {} exceeds limit {}".format(
  260. len(self._partial), self._max_msg_size))
  261. else:
  262. # previous frame was non finished
  263. # we should get continuation opcode
  264. if self._partial:
  265. if opcode != WSMsgType.CONTINUATION:
  266. raise WebSocketError(
  267. WSCloseCode.PROTOCOL_ERROR,
  268. 'The opcode in non-fin frame is expected '
  269. 'to be zero, got {!r}'.format(opcode))
  270. if opcode == WSMsgType.CONTINUATION:
  271. opcode = self._opcode
  272. self._opcode = None
  273. self._partial.extend(payload)
  274. if (self._max_msg_size and
  275. len(self._partial) >= self._max_msg_size):
  276. raise WebSocketError(
  277. WSCloseCode.MESSAGE_TOO_BIG,
  278. "Message size {} exceeds limit {}".format(
  279. len(self._partial), self._max_msg_size))
  280. # Decompress process must to be done after all packets
  281. # received.
  282. if compressed:
  283. self._partial.extend(_WS_DEFLATE_TRAILING)
  284. payload_merged = self._decompressobj.decompress(
  285. self._partial, self._max_msg_size)
  286. if self._decompressobj.unconsumed_tail:
  287. left = len(self._decompressobj.unconsumed_tail)
  288. raise WebSocketError(
  289. WSCloseCode.MESSAGE_TOO_BIG,
  290. "Decompressed message size exceeds limit {}".
  291. format(self._max_msg_size + left,
  292. self._max_msg_size))
  293. else:
  294. payload_merged = bytes(self._partial)
  295. self._partial.clear()
  296. if opcode == WSMsgType.TEXT:
  297. try:
  298. text = payload_merged.decode('utf-8')
  299. self.queue.feed_data(
  300. WSMessage(WSMsgType.TEXT, text, ''), len(text))
  301. except UnicodeDecodeError as exc:
  302. raise WebSocketError(
  303. WSCloseCode.INVALID_TEXT,
  304. 'Invalid UTF-8 text message') from exc
  305. else:
  306. self.queue.feed_data(
  307. WSMessage(WSMsgType.BINARY, payload_merged, ''),
  308. len(payload_merged))
  309. return False, b''
  310. def parse_frame(self, buf):
  311. """Return the next frame from the socket."""
  312. frames = []
  313. if self._tail:
  314. buf, self._tail = self._tail + buf, b''
  315. start_pos = 0
  316. buf_length = len(buf)
  317. while True:
  318. # read header
  319. if self._state == WSParserState.READ_HEADER:
  320. if buf_length - start_pos >= 2:
  321. data = buf[start_pos:start_pos+2]
  322. start_pos += 2
  323. first_byte, second_byte = data
  324. fin = (first_byte >> 7) & 1
  325. rsv1 = (first_byte >> 6) & 1
  326. rsv2 = (first_byte >> 5) & 1
  327. rsv3 = (first_byte >> 4) & 1
  328. opcode = first_byte & 0xf
  329. # frame-fin = %x0 ; more frames of this message follow
  330. # / %x1 ; final frame of this message
  331. # frame-rsv1 = %x0 ;
  332. # 1 bit, MUST be 0 unless negotiated otherwise
  333. # frame-rsv2 = %x0 ;
  334. # 1 bit, MUST be 0 unless negotiated otherwise
  335. # frame-rsv3 = %x0 ;
  336. # 1 bit, MUST be 0 unless negotiated otherwise
  337. #
  338. # Remove rsv1 from this test for deflate development
  339. if rsv2 or rsv3 or (rsv1 and not self._compress):
  340. raise WebSocketError(
  341. WSCloseCode.PROTOCOL_ERROR,
  342. 'Received frame with non-zero reserved bits')
  343. if opcode > 0x7 and fin == 0:
  344. raise WebSocketError(
  345. WSCloseCode.PROTOCOL_ERROR,
  346. 'Received fragmented control frame')
  347. has_mask = (second_byte >> 7) & 1
  348. length = second_byte & 0x7f
  349. # Control frames MUST have a payload
  350. # length of 125 bytes or less
  351. if opcode > 0x7 and length > 125:
  352. raise WebSocketError(
  353. WSCloseCode.PROTOCOL_ERROR,
  354. 'Control frame payload cannot be '
  355. 'larger than 125 bytes')
  356. # Set compress status if last package is FIN
  357. # OR set compress status if this is first fragment
  358. # Raise error if not first fragment with rsv1 = 0x1
  359. if self._frame_fin or self._compressed is None:
  360. self._compressed = True if rsv1 else False
  361. elif rsv1:
  362. raise WebSocketError(
  363. WSCloseCode.PROTOCOL_ERROR,
  364. 'Received frame with non-zero reserved bits')
  365. self._frame_fin = fin
  366. self._frame_opcode = opcode
  367. self._has_mask = has_mask
  368. self._payload_length_flag = length
  369. self._state = WSParserState.READ_PAYLOAD_LENGTH
  370. else:
  371. break
  372. # read payload length
  373. if self._state == WSParserState.READ_PAYLOAD_LENGTH:
  374. length = self._payload_length_flag
  375. if length == 126:
  376. if buf_length - start_pos >= 2:
  377. data = buf[start_pos:start_pos+2]
  378. start_pos += 2
  379. length = UNPACK_LEN2(data)[0]
  380. self._payload_length = length
  381. self._state = (
  382. WSParserState.READ_PAYLOAD_MASK
  383. if self._has_mask
  384. else WSParserState.READ_PAYLOAD)
  385. else:
  386. break
  387. elif length > 126:
  388. if buf_length - start_pos >= 8:
  389. data = buf[start_pos:start_pos+8]
  390. start_pos += 8
  391. length = UNPACK_LEN3(data)[0]
  392. self._payload_length = length
  393. self._state = (
  394. WSParserState.READ_PAYLOAD_MASK
  395. if self._has_mask
  396. else WSParserState.READ_PAYLOAD)
  397. else:
  398. break
  399. else:
  400. self._payload_length = length
  401. self._state = (
  402. WSParserState.READ_PAYLOAD_MASK
  403. if self._has_mask
  404. else WSParserState.READ_PAYLOAD)
  405. # read payload mask
  406. if self._state == WSParserState.READ_PAYLOAD_MASK:
  407. if buf_length - start_pos >= 4:
  408. self._frame_mask = buf[start_pos:start_pos+4]
  409. start_pos += 4
  410. self._state = WSParserState.READ_PAYLOAD
  411. else:
  412. break
  413. if self._state == WSParserState.READ_PAYLOAD:
  414. length = self._payload_length
  415. payload = self._frame_payload
  416. chunk_len = buf_length - start_pos
  417. if length >= chunk_len:
  418. self._payload_length = length - chunk_len
  419. payload.extend(buf[start_pos:])
  420. start_pos = buf_length
  421. else:
  422. self._payload_length = 0
  423. payload.extend(buf[start_pos:start_pos+length])
  424. start_pos = start_pos + length
  425. if self._payload_length == 0:
  426. if self._has_mask:
  427. _websocket_mask(self._frame_mask, payload)
  428. frames.append((
  429. self._frame_fin,
  430. self._frame_opcode,
  431. payload,
  432. self._compressed))
  433. self._frame_payload = bytearray()
  434. self._state = WSParserState.READ_HEADER
  435. else:
  436. break
  437. self._tail = buf[start_pos:]
  438. return frames
  439. class WebSocketWriter:
  440. def __init__(self, protocol, transport, *,
  441. use_mask=False, limit=DEFAULT_LIMIT, random=random.Random(),
  442. compress=0, notakeover=False):
  443. self.protocol = protocol
  444. self.transport = transport
  445. self.use_mask = use_mask
  446. self.randrange = random.randrange
  447. self.compress = compress
  448. self.notakeover = notakeover
  449. self._closing = False
  450. self._limit = limit
  451. self._output_size = 0
  452. self._compressobj = None
  453. async def _send_frame(self, message, opcode, compress=None):
  454. """Send a frame over the websocket with message as its payload."""
  455. if self._closing:
  456. ws_logger.warning('websocket connection is closing.')
  457. rsv = 0
  458. # Only compress larger packets (disabled)
  459. # Does small packet needs to be compressed?
  460. # if self.compress and opcode < 8 and len(message) > 124:
  461. if (compress or self.compress) and opcode < 8:
  462. if compress:
  463. # Do not set self._compress if compressing is for this frame
  464. compressobj = zlib.compressobj(wbits=-compress)
  465. else: # self.compress
  466. if not self._compressobj:
  467. self._compressobj = zlib.compressobj(wbits=-self.compress)
  468. compressobj = self._compressobj
  469. message = compressobj.compress(message)
  470. message = message + compressobj.flush(
  471. zlib.Z_FULL_FLUSH if self.notakeover else zlib.Z_SYNC_FLUSH)
  472. if message.endswith(_WS_DEFLATE_TRAILING):
  473. message = message[:-4]
  474. rsv = rsv | 0x40
  475. msg_length = len(message)
  476. use_mask = self.use_mask
  477. if use_mask:
  478. mask_bit = 0x80
  479. else:
  480. mask_bit = 0
  481. if msg_length < 126:
  482. header = PACK_LEN1(0x80 | rsv | opcode, msg_length | mask_bit)
  483. elif msg_length < (1 << 16):
  484. header = PACK_LEN2(0x80 | rsv | opcode, 126 | mask_bit, msg_length)
  485. else:
  486. header = PACK_LEN3(0x80 | rsv | opcode, 127 | mask_bit, msg_length)
  487. if use_mask:
  488. mask = self.randrange(0, 0xffffffff)
  489. mask = mask.to_bytes(4, 'big')
  490. message = bytearray(message)
  491. _websocket_mask(mask, message)
  492. self.transport.write(header + mask + message)
  493. self._output_size += len(header) + len(mask) + len(message)
  494. else:
  495. if len(message) > MSG_SIZE:
  496. self.transport.write(header)
  497. self.transport.write(message)
  498. else:
  499. self.transport.write(header + message)
  500. self._output_size += len(header) + len(message)
  501. if self._output_size > self._limit:
  502. self._output_size = 0
  503. await self.protocol._drain_helper()
  504. async def pong(self, message=b''):
  505. """Send pong message."""
  506. if isinstance(message, str):
  507. message = message.encode('utf-8')
  508. return await self._send_frame(message, WSMsgType.PONG)
  509. async def ping(self, message=b''):
  510. """Send ping message."""
  511. if isinstance(message, str):
  512. message = message.encode('utf-8')
  513. return await self._send_frame(message, WSMsgType.PING)
  514. async def send(self, message, binary=False, compress=None):
  515. """Send a frame over the websocket with message as its payload."""
  516. if isinstance(message, str):
  517. message = message.encode('utf-8')
  518. if binary:
  519. return await self._send_frame(message, WSMsgType.BINARY, compress)
  520. else:
  521. return await self._send_frame(message, WSMsgType.TEXT, compress)
  522. async def close(self, code=1000, message=b''):
  523. """Close the websocket, sending the specified code and message."""
  524. if isinstance(message, str):
  525. message = message.encode('utf-8')
  526. try:
  527. return await self._send_frame(
  528. PACK_CLOSE_CODE(code) + message, opcode=WSMsgType.CLOSE)
  529. finally:
  530. self._closing = True