- import re
- import errno
- import socket
- import struct
- import collections
- from base64 import b64decode, b64encode
- from hashlib import sha1
- WS_KEY = b'258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
- def pack_message(message):
- """Pack the message inside ``00`` and ``FF``
- As per the dataframing section (5.3) for the websocket spec
- """
- if isinstance(message, unicode):
- message = message.encode('utf-8')
- elif not isinstance(message, str):
- message = str(message)
- packed = "\x00%s\xFF" % message
- return packed
- def encode_hybi(buf, opcode, base64=False):
- """ Encode a HyBi style WebSocket frame.
- Optional opcode:
- 0x0 - continuation
- 0x1 - text frame (base64 encode buf)
- 0x2 - binary frame (use raw buf)
- 0x8 - connection close
- 0x9 - ping
- 0xA - pong
- """
- if base64:
- buf = b64encode(buf)
- b1 = 0x80 | (opcode & 0x0f) # FIN + opcode
- payload_len = len(buf)
- if payload_len <= 125:
- header = struct.pack('>BB', b1, payload_len)
- elif payload_len > 125 and payload_len < 65536:
- header = struct.pack('>BBH', b1, 126, payload_len)
- elif payload_len >= 65536:
- header = struct.pack('>BBQ', b1, 127, payload_len)
- return header + buf, len(header), 0
- def decode_hybi(buf, base64=False):
- """Decode HyBi style WebSocket packets."""
- f = {'fin' : 0,
- 'opcode' : 0,
- 'mask' : 0,
- 'hlen' : 2,
- 'length' : 0,
- 'payload' : None,
- 'left' : 0,
- 'close_code' : None,
- 'close_reason' : None}
- blen = len(buf)
- f['left'] = blen
- if blen < f['hlen']:
- return f # Incomplete frame header
- b1, b2 = struct.unpack_from(">BB", buf)
- f['opcode'] = b1 & 0x0f
- f['fin'] = (b1 & 0x80) >> 7
- has_mask = (b2 & 0x80) >> 7
- f['length'] = b2 & 0x7f
- if f['length'] == 126:
- f['hlen'] = 4
- if blen < f['hlen']:
- return f # Incomplete frame header
- (f['length'],) = struct.unpack_from('>xxH', buf)
- elif f['length'] == 127:
- f['hlen'] = 10
- if blen < f['hlen']:
- return f # Incomplete frame header
- (f['length'],) = struct.unpack_from('>xxQ', buf)
- full_len = f['hlen'] + has_mask * 4 + f['length']
- if blen < full_len: # Incomplete frame
- return f # Incomplete frame header
- # Number of bytes that are part of the next frame(s)
- f['left'] = blen - full_len
- # Process 1 frame
- if has_mask:
- # unmask payload
- f['mask'] = buf[f['hlen']:f['hlen']+4]
- b = c = ''
- if f['length'] >= 4:
- data = struct.unpack('<I', buf[f['hlen']:f['hlen']+4])[0]
- of1 = f['hlen']+4
- b = ''
- for i in xrange(0, int(f['length']/4)):
- mask = struct.unpack('<I', buf[of1+4*i:of1+4*(i+1)])[0]
- b += struct.pack('I', data ^ mask)
- if f['length'] % 4:
- l = f['length'] % 4
- of1 = f['hlen']
- of2 = full_len - l
- c = ''
- for i in range(0, l):
- mask = struct.unpack('B', buf[of1 + i])[0]
- data = struct.unpack('B', buf[of2 + i])[0]
- c += chr(data ^ mask)
- f['payload'] = b + c
- else:
- f['payload'] = buf[(f['hlen'] + has_mask * 4):full_len]
- if base64 and f['opcode'] in [1, 2]:
- f['payload'] = b64decode(f['payload'])
- if f['opcode'] == 0x08:
- if f['length'] >= 2:
- f['close_code'] = struct.unpack_from(">H", f['payload'])
- if f['length'] > 3:
- f['close_reason'] = f['payload'][2:]
- return f
- class WebSocketWSGI(object):
- def __init__(self, handler):
- self.handler = handler
- def verify_client(self, ws):
- pass
- def __call__(self, environ, start_response):
- if not (environ.get('HTTP_CONNECTION').find('Upgrade') != -1 and
- environ['HTTP_UPGRADE'].lower() == 'websocket'):
- # need to check a few more things here for true compliance
- start_response('400 Bad Request', [('Connection','close')])
- return []
- sock = environ['gunicorn.socket']
- ws = WebSocket(sock, environ)
- handshake_reply = ("HTTP/1.1 101 Switching Protocols\r\n"
- "Upgrade: websocket\r\n"
- "Connection: Upgrade\r\n")
- path = environ['PATH_INFO']
- key = environ.get('HTTP_SEC_WEBSOCKET_KEY')
- if key:
- ws_key = b64decode(key)
- if len(ws_key) != 16:
- start_response('400 Bad Request', [('Connection','close')])
- return []
- protocols = []
- subprotocols = environ.get('HTTP_SEC_WEBSOCKET_PROTOCOL')
- ws_protocols = []
- if subprotocols:
- for s in subprotocols.split(','):
- s = s.strip()
- if s in protocols:
- ws_protocols.append(s)
- if ws_protocols:
- handshake_reply += 'Sec-WebSocket-Protocol: %s\r\n' % ', '.join(ws_protocols)
- exts = []
- extensions = environ.get('HTTP_SEC_WEBSOCKET_EXTENSIONS')
- ws_extensions = []
- if extensions:
- for ext in extensions.split(','):
- ext = ext.strip()
- if ext in exts:
- ws_extensions.append(ext)
- if ws_extensions:
- handshake_reply += 'Sec-WebSocket-Extensions: %s\r\n' % ', '.join(ws_extensions)
- handshake_reply += (
- "Sec-WebSocket-Origin: %s\r\n"
- "Sec-WebSocket-Location: ws://%s%s\r\n"
- "Sec-WebSocket-Version: %s\r\n"
- "Sec-WebSocket-Accept: %s\r\n\r\n"
- % (
- environ.get('HTTP_ORIGIN'),
- environ.get('HTTP_HOST'),
- path,
- ws.version,
- b64encode(sha1(key + WS_KEY).digest())
- ))
- else:
- handshake_reply += (
- "WebSocket-Origin: %s\r\n"
- "WebSocket-Location: ws://%s%s\r\n\r\n" % (
- environ.get('HTTP_ORIGIN'),
- environ.get('HTTP_HOST'),
- path))
- sock.sendall(handshake_reply)
- try:
- self.handler(ws)
- except socket.error as e:
- if e[0] != errno.EPIPE:
- raise
- # use this undocumented feature of grainbows to ensure that it
- # doesn't barf on the fact that we didn't call start_response
- class WebSocket(object):
- def __init__(self, sock, environ, version=76):
- self._socket = sock
- try:
- version = int(environ.get('HTTP_SEC_WEBSOCKET_VERSION'))
- except (ValueError, TypeError):
- version = 76
- self.version = version
- self.closed = False
- self.accepted = False
- self._buf = b''
- self._msgs = collections.deque()
- def _parse_messages(self):
- """ Parses for messages in the buffer *buf*. It is assumed that
- the buffer contains the start character for a message, but that it
- may contain only part of the rest of the message.
- Returns an array of messages, and the buffer remainder that
- didn't contain any full messages."""
- msgs = []
- end_idx = 0
- buf = self._buf
- while buf:
- if self.version in (7, 8, 13):
- frame = decode_hybi(buf, base64=False)
- if frame['payload'] == None:
- break
- else:
- if frame['opcode'] == 0x8: # connection close
- self.closed = True
- break
- else:
- msgs.append(frame['payload']);
- if frame['left']:
- buf = buf[-frame['left']:]
- else:
- buf = b''
- else:
- frame_type = ord(buf[0])
- if frame_type == 0:
- # Normal message.
- end_idx = buf.find("\xFF")
- if end_idx == -1: #pragma NO COVER
- break
- msgs.append(buf[1:end_idx].decode('utf-8', 'replace'))
- buf = buf[end_idx+1:]
- elif frame_type == 255:
- # Closing handshake.
- assert ord(buf[1]) == 0, "Unexpected closing handshake: %r" % buf
- self.closed = True
- break
- else:
- raise ValueError("Don't understand how to parse this type of message: %r" % buf)
- self._buf = buf
- return msgs
- def send(self, message):
- """Send a message to the browser.
- *message* should be convertable to a string; unicode objects should be
- encodable as utf-8. Raises socket.error with errno of 32
- (broken pipe) if the socket has already been closed by the client.
- """
- if self.version in (7, 8, 13):
- packed, lenhead, lentail = encode_hybi(
- message, opcode=0x01, base64=False)
- else:
- packed = pack_message(message)
- self._socket.sendall(packed)
- def wait(self):
- """Waits for and deserializes messages.
- Returns a single message; the oldest not yet processed. If the client
- has already closed the connection, returns None. This is different
- from normal socket behavior because the empty string is a valid
- websocket message."""
- while not self._msgs:
- # Websocket might be closed already.
- if self.closed:
- return None
- # no parsed messages, must mean buf needs more data
- delta = self._socket.recv(8096)
- if delta == '':
- return None
- self._buf += delta
- msgs = self._parse_messages()
- self._msgs.extend(msgs)
- return self._msgs.popleft()
- def _send_closing_frame(self, ignore_send_errors=False):
- """Sends the closing frame to the client, if required."""
- if self.version in (7, 8, 13) and not self.closed:
- msg = ''
- #if code != None:
- # msg = struct.pack(">H%ds" % (len(reason)), code)
- buf, h, t = encode_hybi(msg, opcode=0x08, base64=False)
- self._socket.sendall(buf)
- self.closed = True
- elif self.version == 76 and not self.closed:
- try:
- self._socket.sendall("\xff\x00")
- except socket.error:
- # Sometimes, like when the remote side cuts off the connection,
- # we don't care about this.
- if not ignore_send_errors: #pragma NO COVER
- raise
- self.closed = True
- def close(self):
- """Forcibly close the websocket; generally it is preferable to
- return from the handler method."""
- self._send_closing_frame()
- self._socket.shutdown(True)
- self._socket.close()