|
|
- """WebSocket protocol versions 13 and 8."""
-
- import collections
- import json
- import random
- import re
- import sys
- import zlib
- from enum import IntEnum
- from struct import Struct
-
- from .helpers import NO_EXTENSIONS
- from .log import ws_logger
-
-
- __all__ = ('WS_CLOSED_MESSAGE', 'WS_CLOSING_MESSAGE', 'WS_KEY',
- 'WebSocketReader', 'WebSocketWriter', 'WSMessage',
- 'WebSocketError', 'WSMsgType', 'WSCloseCode')
-
-
- class WSCloseCode(IntEnum):
- OK = 1000
- GOING_AWAY = 1001
- PROTOCOL_ERROR = 1002
- UNSUPPORTED_DATA = 1003
- INVALID_TEXT = 1007
- POLICY_VIOLATION = 1008
- MESSAGE_TOO_BIG = 1009
- MANDATORY_EXTENSION = 1010
- INTERNAL_ERROR = 1011
- SERVICE_RESTART = 1012
- TRY_AGAIN_LATER = 1013
-
-
- ALLOWED_CLOSE_CODES = {int(i) for i in WSCloseCode}
-
-
- class WSMsgType(IntEnum):
- # websocket spec types
- CONTINUATION = 0x0
- TEXT = 0x1
- BINARY = 0x2
- PING = 0x9
- PONG = 0xa
- CLOSE = 0x8
-
- # aiohttp specific types
- CLOSING = 0x100
- CLOSED = 0x101
- ERROR = 0x102
-
- text = TEXT
- binary = BINARY
- ping = PING
- pong = PONG
- close = CLOSE
- closing = CLOSING
- closed = CLOSED
- error = ERROR
-
-
- WS_KEY = b'258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
-
-
- UNPACK_LEN2 = Struct('!H').unpack_from
- UNPACK_LEN3 = Struct('!Q').unpack_from
- UNPACK_CLOSE_CODE = Struct('!H').unpack
- PACK_LEN1 = Struct('!BB').pack
- PACK_LEN2 = Struct('!BBH').pack
- PACK_LEN3 = Struct('!BBQ').pack
- PACK_CLOSE_CODE = Struct('!H').pack
- MSG_SIZE = 2 ** 14
- DEFAULT_LIMIT = 2 ** 16
-
-
- _WSMessageBase = collections.namedtuple('_WSMessageBase',
- ['type', 'data', 'extra'])
-
-
- class WSMessage(_WSMessageBase):
-
- def json(self, *, loads=json.loads):
- """Return parsed JSON data.
-
- .. versionadded:: 0.22
- """
- return loads(self.data)
-
-
- WS_CLOSED_MESSAGE = WSMessage(WSMsgType.CLOSED, None, None)
- WS_CLOSING_MESSAGE = WSMessage(WSMsgType.CLOSING, None, None)
-
-
- class WebSocketError(Exception):
- """WebSocket protocol parser error."""
-
- def __init__(self, code, message):
- self.code = code
- super().__init__(message)
-
-
- class WSHandshakeError(Exception):
- """WebSocket protocol handshake error."""
-
-
- native_byteorder = sys.byteorder
-
-
- # Used by _websocket_mask_python
- _XOR_TABLE = [bytes(a ^ b for a in range(256)) for b in range(256)]
-
-
- def _websocket_mask_python(mask, data):
- """Websocket masking function.
-
- `mask` is a `bytes` object of length 4; `data` is a `bytearray`
- object of any length. The contents of `data` are masked with `mask`,
- as specified in section 5.3 of RFC 6455.
-
- Note that this function mutates the `data` argument.
-
- This pure-python implementation may be replaced by an optimized
- version when available.
-
- """
- assert isinstance(data, bytearray), data
- assert len(mask) == 4, mask
-
- if data:
- a, b, c, d = (_XOR_TABLE[n] for n in mask)
- data[::4] = data[::4].translate(a)
- data[1::4] = data[1::4].translate(b)
- data[2::4] = data[2::4].translate(c)
- data[3::4] = data[3::4].translate(d)
-
-
- if NO_EXTENSIONS:
- _websocket_mask = _websocket_mask_python
- else:
- try:
- from ._websocket import _websocket_mask_cython # type: ignore
- _websocket_mask = _websocket_mask_cython
- except ImportError: # pragma: no cover
- _websocket_mask = _websocket_mask_python
-
- _WS_DEFLATE_TRAILING = bytes([0x00, 0x00, 0xff, 0xff])
-
-
- _WS_EXT_RE = re.compile(r'^(?:;\s*(?:'
- r'(server_no_context_takeover)|'
- r'(client_no_context_takeover)|'
- r'(server_max_window_bits(?:=(\d+))?)|'
- r'(client_max_window_bits(?:=(\d+))?)))*$')
-
- _WS_EXT_RE_SPLIT = re.compile(r'permessage-deflate([^,]+)?')
-
-
- def ws_ext_parse(extstr, isserver=False):
- if not extstr:
- return 0, False
-
- compress = 0
- notakeover = False
- for ext in _WS_EXT_RE_SPLIT.finditer(extstr):
- defext = ext.group(1)
- # Return compress = 15 when get `permessage-deflate`
- if not defext:
- compress = 15
- break
- match = _WS_EXT_RE.match(defext)
- if match:
- compress = 15
- if isserver:
- # Server never fail to detect compress handshake.
- # Server does not need to send max wbit to client
- if match.group(4):
- compress = int(match.group(4))
- # Group3 must match if group4 matches
- # Compress wbit 8 does not support in zlib
- # If compress level not support,
- # CONTINUE to next extension
- if compress > 15 or compress < 9:
- compress = 0
- continue
- if match.group(1):
- notakeover = True
- # Ignore regex group 5 & 6 for client_max_window_bits
- break
- else:
- if match.group(6):
- compress = int(match.group(6))
- # Group5 must match if group6 matches
- # Compress wbit 8 does not support in zlib
- # If compress level not support,
- # FAIL the parse progress
- if compress > 15 or compress < 9:
- raise WSHandshakeError('Invalid window size')
- if match.group(2):
- notakeover = True
- # Ignore regex group 5 & 6 for client_max_window_bits
- break
- # Return Fail if client side and not match
- elif not isserver:
- raise WSHandshakeError('Extension for deflate not supported' +
- ext.group(1))
-
- return compress, notakeover
-
-
- def ws_ext_gen(compress=15, isserver=False,
- server_notakeover=False):
- # client_notakeover=False not used for server
- # compress wbit 8 does not support in zlib
- if compress < 9 or compress > 15:
- raise ValueError('Compress wbits must between 9 and 15, '
- 'zlib does not support wbits=8')
- enabledext = ['permessage-deflate']
- if not isserver:
- enabledext.append('client_max_window_bits')
-
- if compress < 15:
- enabledext.append('server_max_window_bits=' + str(compress))
- if server_notakeover:
- enabledext.append('server_no_context_takeover')
- # if client_notakeover:
- # enabledext.append('client_no_context_takeover')
- return '; '.join(enabledext)
-
-
- class WSParserState(IntEnum):
- READ_HEADER = 1
- READ_PAYLOAD_LENGTH = 2
- READ_PAYLOAD_MASK = 3
- READ_PAYLOAD = 4
-
-
- class WebSocketReader:
-
- def __init__(self, queue, max_msg_size, compress=True):
- self.queue = queue
- self._max_msg_size = max_msg_size
-
- self._exc = None
- self._partial = bytearray()
- self._state = WSParserState.READ_HEADER
-
- self._opcode = None
- self._frame_fin = False
- self._frame_opcode = None
- self._frame_payload = bytearray()
-
- self._tail = b''
- self._has_mask = False
- self._frame_mask = None
- self._payload_length = 0
- self._payload_length_flag = 0
- self._compressed = None
- self._decompressobj = None
- self._compress = compress
-
- def feed_eof(self):
- self.queue.feed_eof()
-
- def feed_data(self, data):
- if self._exc:
- return True, data
-
- try:
- return self._feed_data(data)
- except Exception as exc:
- self._exc = exc
- self.queue.set_exception(exc)
- return True, b''
-
- def _feed_data(self, data):
- for fin, opcode, payload, compressed in self.parse_frame(data):
- if compressed and not self._decompressobj:
- self._decompressobj = zlib.decompressobj(wbits=-zlib.MAX_WBITS)
- if opcode == WSMsgType.CLOSE:
- if len(payload) >= 2:
- close_code = UNPACK_CLOSE_CODE(payload[:2])[0]
- if (close_code < 3000 and
- close_code not in ALLOWED_CLOSE_CODES):
- raise WebSocketError(
- WSCloseCode.PROTOCOL_ERROR,
- 'Invalid close code: {}'.format(close_code))
- try:
- close_message = payload[2:].decode('utf-8')
- except UnicodeDecodeError as exc:
- raise WebSocketError(
- WSCloseCode.INVALID_TEXT,
- 'Invalid UTF-8 text message') from exc
- msg = WSMessage(WSMsgType.CLOSE, close_code, close_message)
- elif payload:
- raise WebSocketError(
- WSCloseCode.PROTOCOL_ERROR,
- 'Invalid close frame: {} {} {!r}'.format(
- fin, opcode, payload))
- else:
- msg = WSMessage(WSMsgType.CLOSE, 0, '')
-
- self.queue.feed_data(msg, 0)
-
- elif opcode == WSMsgType.PING:
- self.queue.feed_data(
- WSMessage(WSMsgType.PING, payload, ''), len(payload))
-
- elif opcode == WSMsgType.PONG:
- self.queue.feed_data(
- WSMessage(WSMsgType.PONG, payload, ''), len(payload))
-
- elif opcode not in (
- WSMsgType.TEXT, WSMsgType.BINARY) and self._opcode is None:
- raise WebSocketError(
- WSCloseCode.PROTOCOL_ERROR,
- "Unexpected opcode={!r}".format(opcode))
- else:
- # load text/binary
- if not fin:
- # got partial frame payload
- if opcode != WSMsgType.CONTINUATION:
- self._opcode = opcode
- self._partial.extend(payload)
- if (self._max_msg_size and
- len(self._partial) >= self._max_msg_size):
- raise WebSocketError(
- WSCloseCode.MESSAGE_TOO_BIG,
- "Message size {} exceeds limit {}".format(
- len(self._partial), self._max_msg_size))
- else:
- # previous frame was non finished
- # we should get continuation opcode
- if self._partial:
- if opcode != WSMsgType.CONTINUATION:
- raise WebSocketError(
- WSCloseCode.PROTOCOL_ERROR,
- 'The opcode in non-fin frame is expected '
- 'to be zero, got {!r}'.format(opcode))
-
- if opcode == WSMsgType.CONTINUATION:
- opcode = self._opcode
- self._opcode = None
-
- self._partial.extend(payload)
- if (self._max_msg_size and
- len(self._partial) >= self._max_msg_size):
- raise WebSocketError(
- WSCloseCode.MESSAGE_TOO_BIG,
- "Message size {} exceeds limit {}".format(
- len(self._partial), self._max_msg_size))
-
- # Decompress process must to be done after all packets
- # received.
- if compressed:
- self._partial.extend(_WS_DEFLATE_TRAILING)
- payload_merged = self._decompressobj.decompress(
- self._partial, self._max_msg_size)
- if self._decompressobj.unconsumed_tail:
- left = len(self._decompressobj.unconsumed_tail)
- raise WebSocketError(
- WSCloseCode.MESSAGE_TOO_BIG,
- "Decompressed message size exceeds limit {}".
- format(self._max_msg_size + left,
- self._max_msg_size))
- else:
- payload_merged = bytes(self._partial)
-
- self._partial.clear()
-
- if opcode == WSMsgType.TEXT:
- try:
- text = payload_merged.decode('utf-8')
- self.queue.feed_data(
- WSMessage(WSMsgType.TEXT, text, ''), len(text))
- except UnicodeDecodeError as exc:
- raise WebSocketError(
- WSCloseCode.INVALID_TEXT,
- 'Invalid UTF-8 text message') from exc
- else:
- self.queue.feed_data(
- WSMessage(WSMsgType.BINARY, payload_merged, ''),
- len(payload_merged))
-
- return False, b''
-
- def parse_frame(self, buf):
- """Return the next frame from the socket."""
- frames = []
- if self._tail:
- buf, self._tail = self._tail + buf, b''
-
- start_pos = 0
- buf_length = len(buf)
-
- while True:
- # read header
- if self._state == WSParserState.READ_HEADER:
- if buf_length - start_pos >= 2:
- data = buf[start_pos:start_pos+2]
- start_pos += 2
- first_byte, second_byte = data
-
- fin = (first_byte >> 7) & 1
- rsv1 = (first_byte >> 6) & 1
- rsv2 = (first_byte >> 5) & 1
- rsv3 = (first_byte >> 4) & 1
- opcode = first_byte & 0xf
-
- # frame-fin = %x0 ; more frames of this message follow
- # / %x1 ; final frame of this message
- # frame-rsv1 = %x0 ;
- # 1 bit, MUST be 0 unless negotiated otherwise
- # frame-rsv2 = %x0 ;
- # 1 bit, MUST be 0 unless negotiated otherwise
- # frame-rsv3 = %x0 ;
- # 1 bit, MUST be 0 unless negotiated otherwise
- #
- # Remove rsv1 from this test for deflate development
- if rsv2 or rsv3 or (rsv1 and not self._compress):
- raise WebSocketError(
- WSCloseCode.PROTOCOL_ERROR,
- 'Received frame with non-zero reserved bits')
-
- if opcode > 0x7 and fin == 0:
- raise WebSocketError(
- WSCloseCode.PROTOCOL_ERROR,
- 'Received fragmented control frame')
-
- has_mask = (second_byte >> 7) & 1
- length = second_byte & 0x7f
-
- # Control frames MUST have a payload
- # length of 125 bytes or less
- if opcode > 0x7 and length > 125:
- raise WebSocketError(
- WSCloseCode.PROTOCOL_ERROR,
- 'Control frame payload cannot be '
- 'larger than 125 bytes')
-
- # Set compress status if last package is FIN
- # OR set compress status if this is first fragment
- # Raise error if not first fragment with rsv1 = 0x1
- if self._frame_fin or self._compressed is None:
- self._compressed = True if rsv1 else False
- elif rsv1:
- raise WebSocketError(
- WSCloseCode.PROTOCOL_ERROR,
- 'Received frame with non-zero reserved bits')
-
- self._frame_fin = fin
- self._frame_opcode = opcode
- self._has_mask = has_mask
- self._payload_length_flag = length
- self._state = WSParserState.READ_PAYLOAD_LENGTH
- else:
- break
-
- # read payload length
- if self._state == WSParserState.READ_PAYLOAD_LENGTH:
- length = self._payload_length_flag
- if length == 126:
- if buf_length - start_pos >= 2:
- data = buf[start_pos:start_pos+2]
- start_pos += 2
- length = UNPACK_LEN2(data)[0]
- self._payload_length = length
- self._state = (
- WSParserState.READ_PAYLOAD_MASK
- if self._has_mask
- else WSParserState.READ_PAYLOAD)
- else:
- break
- elif length > 126:
- if buf_length - start_pos >= 8:
- data = buf[start_pos:start_pos+8]
- start_pos += 8
- length = UNPACK_LEN3(data)[0]
- self._payload_length = length
- self._state = (
- WSParserState.READ_PAYLOAD_MASK
- if self._has_mask
- else WSParserState.READ_PAYLOAD)
- else:
- break
- else:
- self._payload_length = length
- self._state = (
- WSParserState.READ_PAYLOAD_MASK
- if self._has_mask
- else WSParserState.READ_PAYLOAD)
-
- # read payload mask
- if self._state == WSParserState.READ_PAYLOAD_MASK:
- if buf_length - start_pos >= 4:
- self._frame_mask = buf[start_pos:start_pos+4]
- start_pos += 4
- self._state = WSParserState.READ_PAYLOAD
- else:
- break
-
- if self._state == WSParserState.READ_PAYLOAD:
- length = self._payload_length
- payload = self._frame_payload
-
- chunk_len = buf_length - start_pos
- if length >= chunk_len:
- self._payload_length = length - chunk_len
- payload.extend(buf[start_pos:])
- start_pos = buf_length
- else:
- self._payload_length = 0
- payload.extend(buf[start_pos:start_pos+length])
- start_pos = start_pos + length
-
- if self._payload_length == 0:
- if self._has_mask:
- _websocket_mask(self._frame_mask, payload)
-
- frames.append((
- self._frame_fin,
- self._frame_opcode,
- payload,
- self._compressed))
-
- self._frame_payload = bytearray()
- self._state = WSParserState.READ_HEADER
- else:
- break
-
- self._tail = buf[start_pos:]
-
- return frames
-
-
- class WebSocketWriter:
-
- def __init__(self, protocol, transport, *,
- use_mask=False, limit=DEFAULT_LIMIT, random=random.Random(),
- compress=0, notakeover=False):
- self.protocol = protocol
- self.transport = transport
- self.use_mask = use_mask
- self.randrange = random.randrange
- self.compress = compress
- self.notakeover = notakeover
- self._closing = False
- self._limit = limit
- self._output_size = 0
- self._compressobj = None
-
- async def _send_frame(self, message, opcode, compress=None):
- """Send a frame over the websocket with message as its payload."""
- if self._closing:
- ws_logger.warning('websocket connection is closing.')
-
- rsv = 0
-
- # Only compress larger packets (disabled)
- # Does small packet needs to be compressed?
- # if self.compress and opcode < 8 and len(message) > 124:
- if (compress or self.compress) and opcode < 8:
- if compress:
- # Do not set self._compress if compressing is for this frame
- compressobj = zlib.compressobj(wbits=-compress)
- else: # self.compress
- if not self._compressobj:
- self._compressobj = zlib.compressobj(wbits=-self.compress)
- compressobj = self._compressobj
-
- message = compressobj.compress(message)
- message = message + compressobj.flush(
- zlib.Z_FULL_FLUSH if self.notakeover else zlib.Z_SYNC_FLUSH)
- if message.endswith(_WS_DEFLATE_TRAILING):
- message = message[:-4]
- rsv = rsv | 0x40
-
- msg_length = len(message)
-
- use_mask = self.use_mask
- if use_mask:
- mask_bit = 0x80
- else:
- mask_bit = 0
-
- if msg_length < 126:
- header = PACK_LEN1(0x80 | rsv | opcode, msg_length | mask_bit)
- elif msg_length < (1 << 16):
- header = PACK_LEN2(0x80 | rsv | opcode, 126 | mask_bit, msg_length)
- else:
- header = PACK_LEN3(0x80 | rsv | opcode, 127 | mask_bit, msg_length)
- if use_mask:
- mask = self.randrange(0, 0xffffffff)
- mask = mask.to_bytes(4, 'big')
- message = bytearray(message)
- _websocket_mask(mask, message)
- self.transport.write(header + mask + message)
- self._output_size += len(header) + len(mask) + len(message)
- else:
- if len(message) > MSG_SIZE:
- self.transport.write(header)
- self.transport.write(message)
- else:
- self.transport.write(header + message)
-
- self._output_size += len(header) + len(message)
-
- if self._output_size > self._limit:
- self._output_size = 0
- await self.protocol._drain_helper()
-
- async def pong(self, message=b''):
- """Send pong message."""
- if isinstance(message, str):
- message = message.encode('utf-8')
- return await self._send_frame(message, WSMsgType.PONG)
-
- async def ping(self, message=b''):
- """Send ping message."""
- if isinstance(message, str):
- message = message.encode('utf-8')
- return await self._send_frame(message, WSMsgType.PING)
-
- async def send(self, message, binary=False, compress=None):
- """Send a frame over the websocket with message as its payload."""
- if isinstance(message, str):
- message = message.encode('utf-8')
- if binary:
- return await self._send_frame(message, WSMsgType.BINARY, compress)
- else:
- return await self._send_frame(message, WSMsgType.TEXT, compress)
-
- async def close(self, code=1000, message=b''):
- """Close the websocket, sending the specified code and message."""
- if isinstance(message, str):
- message = message.encode('utf-8')
- try:
- return await self._send_frame(
- PACK_CLOSE_CODE(code) + message, opcode=WSMsgType.CLOSE)
- finally:
- self._closing = True
|