You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

637 lines
21 KiB

4 years ago
  1. import collections
  2. import datetime
  3. import enum
  4. import json
  5. import math
  6. import time
  7. import warnings
  8. import zlib
  9. from email.utils import parsedate
  10. from http.cookies import SimpleCookie
  11. from multidict import CIMultiDict, CIMultiDictProxy
  12. from . import hdrs, payload
  13. from .helpers import HeadersMixin, rfc822_formatted_time, sentinel
  14. from .http import RESPONSES, SERVER_SOFTWARE, HttpVersion10, HttpVersion11
  15. __all__ = ('ContentCoding', 'StreamResponse', 'Response', 'json_response')
  16. class ContentCoding(enum.Enum):
  17. # The content codings that we have support for.
  18. #
  19. # Additional registered codings are listed at:
  20. # https://www.iana.org/assignments/http-parameters/http-parameters.xhtml#content-coding
  21. deflate = 'deflate'
  22. gzip = 'gzip'
  23. identity = 'identity'
  24. ############################################################
  25. # HTTP Response classes
  26. ############################################################
  27. class StreamResponse(collections.MutableMapping, HeadersMixin):
  28. _length_check = True
  29. def __init__(self, *, status=200, reason=None, headers=None):
  30. self._body = None
  31. self._keep_alive = None
  32. self._chunked = False
  33. self._compression = False
  34. self._compression_force = None
  35. self._cookies = SimpleCookie()
  36. self._req = None
  37. self._payload_writer = None
  38. self._eof_sent = False
  39. self._body_length = 0
  40. self._state = {}
  41. if headers is not None:
  42. self._headers = CIMultiDict(headers)
  43. else:
  44. self._headers = CIMultiDict()
  45. self.set_status(status, reason)
  46. @property
  47. def prepared(self):
  48. return self._payload_writer is not None
  49. @property
  50. def task(self):
  51. return getattr(self._req, 'task', None)
  52. @property
  53. def status(self):
  54. return self._status
  55. @property
  56. def chunked(self):
  57. return self._chunked
  58. @property
  59. def compression(self):
  60. return self._compression
  61. @property
  62. def reason(self):
  63. return self._reason
  64. def set_status(self, status, reason=None, _RESPONSES=RESPONSES):
  65. assert not self.prepared, \
  66. 'Cannot change the response status code after ' \
  67. 'the headers have been sent'
  68. self._status = int(status)
  69. if reason is None:
  70. try:
  71. reason = _RESPONSES[self._status][0]
  72. except Exception:
  73. reason = ''
  74. self._reason = reason
  75. @property
  76. def keep_alive(self):
  77. return self._keep_alive
  78. def force_close(self):
  79. self._keep_alive = False
  80. @property
  81. def body_length(self):
  82. return self._body_length
  83. @property
  84. def output_length(self):
  85. warnings.warn('output_length is deprecated', DeprecationWarning)
  86. return self._payload_writer.buffer_size
  87. def enable_chunked_encoding(self, chunk_size=None):
  88. """Enables automatic chunked transfer encoding."""
  89. self._chunked = True
  90. if hdrs.CONTENT_LENGTH in self._headers:
  91. raise RuntimeError("You can't enable chunked encoding when "
  92. "a content length is set")
  93. if chunk_size is not None:
  94. warnings.warn('Chunk size is deprecated #1615', DeprecationWarning)
  95. def enable_compression(self, force=None):
  96. """Enables response compression encoding."""
  97. # Backwards compatibility for when force was a bool <0.17.
  98. if type(force) == bool:
  99. force = ContentCoding.deflate if force else ContentCoding.identity
  100. elif force is not None:
  101. assert isinstance(force, ContentCoding), ("force should one of "
  102. "None, bool or "
  103. "ContentEncoding")
  104. self._compression = True
  105. self._compression_force = force
  106. @property
  107. def headers(self):
  108. return self._headers
  109. @property
  110. def cookies(self):
  111. return self._cookies
  112. def set_cookie(self, name, value, *, expires=None,
  113. domain=None, max_age=None, path='/',
  114. secure=None, httponly=None, version=None):
  115. """Set or update response cookie.
  116. Sets new cookie or updates existent with new value.
  117. Also updates only those params which are not None.
  118. """
  119. old = self._cookies.get(name)
  120. if old is not None and old.coded_value == '':
  121. # deleted cookie
  122. self._cookies.pop(name, None)
  123. self._cookies[name] = value
  124. c = self._cookies[name]
  125. if expires is not None:
  126. c['expires'] = expires
  127. elif c.get('expires') == 'Thu, 01 Jan 1970 00:00:00 GMT':
  128. del c['expires']
  129. if domain is not None:
  130. c['domain'] = domain
  131. if max_age is not None:
  132. c['max-age'] = max_age
  133. elif 'max-age' in c:
  134. del c['max-age']
  135. c['path'] = path
  136. if secure is not None:
  137. c['secure'] = secure
  138. if httponly is not None:
  139. c['httponly'] = httponly
  140. if version is not None:
  141. c['version'] = version
  142. def del_cookie(self, name, *, domain=None, path='/'):
  143. """Delete cookie.
  144. Creates new empty expired cookie.
  145. """
  146. # TODO: do we need domain/path here?
  147. self._cookies.pop(name, None)
  148. self.set_cookie(name, '', max_age=0,
  149. expires="Thu, 01 Jan 1970 00:00:00 GMT",
  150. domain=domain, path=path)
  151. @property
  152. def content_length(self):
  153. # Just a placeholder for adding setter
  154. return super().content_length
  155. @content_length.setter
  156. def content_length(self, value):
  157. if value is not None:
  158. value = int(value)
  159. if self._chunked:
  160. raise RuntimeError("You can't set content length when "
  161. "chunked encoding is enable")
  162. self._headers[hdrs.CONTENT_LENGTH] = str(value)
  163. else:
  164. self._headers.pop(hdrs.CONTENT_LENGTH, None)
  165. @property
  166. def content_type(self):
  167. # Just a placeholder for adding setter
  168. return super().content_type
  169. @content_type.setter
  170. def content_type(self, value):
  171. self.content_type # read header values if needed
  172. self._content_type = str(value)
  173. self._generate_content_type_header()
  174. @property
  175. def charset(self):
  176. # Just a placeholder for adding setter
  177. return super().charset
  178. @charset.setter
  179. def charset(self, value):
  180. ctype = self.content_type # read header values if needed
  181. if ctype == 'application/octet-stream':
  182. raise RuntimeError("Setting charset for application/octet-stream "
  183. "doesn't make sense, setup content_type first")
  184. if value is None:
  185. self._content_dict.pop('charset', None)
  186. else:
  187. self._content_dict['charset'] = str(value).lower()
  188. self._generate_content_type_header()
  189. @property
  190. def last_modified(self):
  191. """The value of Last-Modified HTTP header, or None.
  192. This header is represented as a `datetime` object.
  193. """
  194. httpdate = self.headers.get(hdrs.LAST_MODIFIED)
  195. if httpdate is not None:
  196. timetuple = parsedate(httpdate)
  197. if timetuple is not None:
  198. return datetime.datetime(*timetuple[:6],
  199. tzinfo=datetime.timezone.utc)
  200. return None
  201. @last_modified.setter
  202. def last_modified(self, value):
  203. if value is None:
  204. self.headers.pop(hdrs.LAST_MODIFIED, None)
  205. elif isinstance(value, (int, float)):
  206. self.headers[hdrs.LAST_MODIFIED] = time.strftime(
  207. "%a, %d %b %Y %H:%M:%S GMT", time.gmtime(math.ceil(value)))
  208. elif isinstance(value, datetime.datetime):
  209. self.headers[hdrs.LAST_MODIFIED] = time.strftime(
  210. "%a, %d %b %Y %H:%M:%S GMT", value.utctimetuple())
  211. elif isinstance(value, str):
  212. self.headers[hdrs.LAST_MODIFIED] = value
  213. def _generate_content_type_header(self, CONTENT_TYPE=hdrs.CONTENT_TYPE):
  214. params = '; '.join("%s=%s" % i for i in self._content_dict.items())
  215. if params:
  216. ctype = self._content_type + '; ' + params
  217. else:
  218. ctype = self._content_type
  219. self.headers[CONTENT_TYPE] = ctype
  220. def _do_start_compression(self, coding):
  221. if coding != ContentCoding.identity:
  222. self.headers[hdrs.CONTENT_ENCODING] = coding.value
  223. self._payload_writer.enable_compression(coding.value)
  224. # Compressed payload may have different content length,
  225. # remove the header
  226. self._headers.popall(hdrs.CONTENT_LENGTH, None)
  227. def _start_compression(self, request):
  228. if self._compression_force:
  229. self._do_start_compression(self._compression_force)
  230. else:
  231. accept_encoding = request.headers.get(
  232. hdrs.ACCEPT_ENCODING, '').lower()
  233. for coding in ContentCoding:
  234. if coding.value in accept_encoding:
  235. self._do_start_compression(coding)
  236. return
  237. async def prepare(self, request):
  238. if self._eof_sent:
  239. return
  240. if self._payload_writer is not None:
  241. return self._payload_writer
  242. await request._prepare_hook(self)
  243. return await self._start(request)
  244. async def _start(self, request,
  245. HttpVersion10=HttpVersion10,
  246. HttpVersion11=HttpVersion11,
  247. CONNECTION=hdrs.CONNECTION,
  248. DATE=hdrs.DATE,
  249. SERVER=hdrs.SERVER,
  250. CONTENT_TYPE=hdrs.CONTENT_TYPE,
  251. CONTENT_LENGTH=hdrs.CONTENT_LENGTH,
  252. SET_COOKIE=hdrs.SET_COOKIE,
  253. SERVER_SOFTWARE=SERVER_SOFTWARE,
  254. TRANSFER_ENCODING=hdrs.TRANSFER_ENCODING):
  255. self._req = request
  256. keep_alive = self._keep_alive
  257. if keep_alive is None:
  258. keep_alive = request.keep_alive
  259. self._keep_alive = keep_alive
  260. version = request.version
  261. writer = self._payload_writer = request._payload_writer
  262. headers = self._headers
  263. for cookie in self._cookies.values():
  264. value = cookie.output(header='')[1:]
  265. headers.add(SET_COOKIE, value)
  266. if self._compression:
  267. self._start_compression(request)
  268. if self._chunked:
  269. if version != HttpVersion11:
  270. raise RuntimeError(
  271. "Using chunked encoding is forbidden "
  272. "for HTTP/{0.major}.{0.minor}".format(request.version))
  273. writer.enable_chunking()
  274. headers[TRANSFER_ENCODING] = 'chunked'
  275. if CONTENT_LENGTH in headers:
  276. del headers[CONTENT_LENGTH]
  277. elif self._length_check:
  278. writer.length = self.content_length
  279. if writer.length is None:
  280. if version >= HttpVersion11:
  281. writer.enable_chunking()
  282. headers[TRANSFER_ENCODING] = 'chunked'
  283. if CONTENT_LENGTH in headers:
  284. del headers[CONTENT_LENGTH]
  285. else:
  286. keep_alive = False
  287. headers.setdefault(CONTENT_TYPE, 'application/octet-stream')
  288. headers.setdefault(DATE, rfc822_formatted_time())
  289. headers.setdefault(SERVER, SERVER_SOFTWARE)
  290. # connection header
  291. if CONNECTION not in headers:
  292. if keep_alive:
  293. if version == HttpVersion10:
  294. headers[CONNECTION] = 'keep-alive'
  295. else:
  296. if version == HttpVersion11:
  297. headers[CONNECTION] = 'close'
  298. # status line
  299. status_line = 'HTTP/{}.{} {} {}'.format(
  300. version[0], version[1], self._status, self._reason)
  301. await writer.write_headers(status_line, headers)
  302. return writer
  303. async def write(self, data):
  304. assert isinstance(data, (bytes, bytearray, memoryview)), \
  305. "data argument must be byte-ish (%r)" % type(data)
  306. if self._eof_sent:
  307. raise RuntimeError("Cannot call write() after write_eof()")
  308. if self._payload_writer is None:
  309. raise RuntimeError("Cannot call write() before prepare()")
  310. await self._payload_writer.write(data)
  311. async def drain(self):
  312. assert not self._eof_sent, "EOF has already been sent"
  313. assert self._payload_writer is not None, \
  314. "Response has not been started"
  315. warnings.warn("drain method is deprecated, use await resp.write()",
  316. DeprecationWarning,
  317. stacklevel=2)
  318. await self._payload_writer.drain()
  319. async def write_eof(self, data=b''):
  320. assert isinstance(data, (bytes, bytearray, memoryview)), \
  321. "data argument must be byte-ish (%r)" % type(data)
  322. if self._eof_sent:
  323. return
  324. assert self._payload_writer is not None, \
  325. "Response has not been started"
  326. await self._payload_writer.write_eof(data)
  327. self._eof_sent = True
  328. self._req = None
  329. self._body_length = self._payload_writer.output_size
  330. self._payload_writer = None
  331. def __repr__(self):
  332. if self._eof_sent:
  333. info = "eof"
  334. elif self.prepared:
  335. info = "{} {} ".format(self._req.method, self._req.path)
  336. else:
  337. info = "not prepared"
  338. return "<{} {} {}>".format(self.__class__.__name__,
  339. self.reason, info)
  340. def __getitem__(self, key):
  341. return self._state[key]
  342. def __setitem__(self, key, value):
  343. self._state[key] = value
  344. def __delitem__(self, key):
  345. del self._state[key]
  346. def __len__(self):
  347. return len(self._state)
  348. def __iter__(self):
  349. return iter(self._state)
  350. def __hash__(self):
  351. return hash(id(self))
  352. def __eq__(self, other):
  353. return self is other
  354. class Response(StreamResponse):
  355. def __init__(self, *, body=None, status=200,
  356. reason=None, text=None, headers=None, content_type=None,
  357. charset=None):
  358. if body is not None and text is not None:
  359. raise ValueError("body and text are not allowed together")
  360. if headers is None:
  361. headers = CIMultiDict()
  362. elif not isinstance(headers, (CIMultiDict, CIMultiDictProxy)):
  363. headers = CIMultiDict(headers)
  364. if content_type is not None and "charset" in content_type:
  365. raise ValueError("charset must not be in content_type "
  366. "argument")
  367. if text is not None:
  368. if hdrs.CONTENT_TYPE in headers:
  369. if content_type or charset:
  370. raise ValueError("passing both Content-Type header and "
  371. "content_type or charset params "
  372. "is forbidden")
  373. else:
  374. # fast path for filling headers
  375. if not isinstance(text, str):
  376. raise TypeError("text argument must be str (%r)" %
  377. type(text))
  378. if content_type is None:
  379. content_type = 'text/plain'
  380. if charset is None:
  381. charset = 'utf-8'
  382. headers[hdrs.CONTENT_TYPE] = (
  383. content_type + '; charset=' + charset)
  384. body = text.encode(charset)
  385. text = None
  386. else:
  387. if hdrs.CONTENT_TYPE in headers:
  388. if content_type is not None or charset is not None:
  389. raise ValueError("passing both Content-Type header and "
  390. "content_type or charset params "
  391. "is forbidden")
  392. else:
  393. if content_type is not None:
  394. if charset is not None:
  395. content_type += '; charset=' + charset
  396. headers[hdrs.CONTENT_TYPE] = content_type
  397. super().__init__(status=status, reason=reason, headers=headers)
  398. if text is not None:
  399. self.text = text
  400. else:
  401. self.body = body
  402. self._compressed_body = None
  403. @property
  404. def body(self):
  405. return self._body
  406. @body.setter
  407. def body(self, body,
  408. CONTENT_TYPE=hdrs.CONTENT_TYPE,
  409. CONTENT_LENGTH=hdrs.CONTENT_LENGTH):
  410. if body is None:
  411. self._body = None
  412. self._body_payload = False
  413. elif isinstance(body, (bytes, bytearray)):
  414. self._body = body
  415. self._body_payload = False
  416. else:
  417. try:
  418. self._body = body = payload.PAYLOAD_REGISTRY.get(body)
  419. except payload.LookupError:
  420. raise ValueError('Unsupported body type %r' % type(body))
  421. self._body_payload = True
  422. headers = self._headers
  423. # set content-length header if needed
  424. if not self._chunked and CONTENT_LENGTH not in headers:
  425. size = body.size
  426. if size is not None:
  427. headers[CONTENT_LENGTH] = str(size)
  428. # set content-type
  429. if CONTENT_TYPE not in headers:
  430. headers[CONTENT_TYPE] = body.content_type
  431. # copy payload headers
  432. if body.headers:
  433. for (key, value) in body.headers.items():
  434. if key not in headers:
  435. headers[key] = value
  436. self._compressed_body = None
  437. @property
  438. def text(self):
  439. if self._body is None:
  440. return None
  441. return self._body.decode(self.charset or 'utf-8')
  442. @text.setter
  443. def text(self, text):
  444. assert text is None or isinstance(text, str), \
  445. "text argument must be str (%r)" % type(text)
  446. if self.content_type == 'application/octet-stream':
  447. self.content_type = 'text/plain'
  448. if self.charset is None:
  449. self.charset = 'utf-8'
  450. self._body = text.encode(self.charset)
  451. self._body_payload = False
  452. self._compressed_body = None
  453. @property
  454. def content_length(self):
  455. if self._chunked:
  456. return None
  457. if hdrs.CONTENT_LENGTH in self.headers:
  458. return super().content_length
  459. if self._compressed_body is not None:
  460. # Return length of the compressed body
  461. return len(self._compressed_body)
  462. elif self._body_payload:
  463. # A payload without content length, or a compressed payload
  464. return None
  465. elif self._body is not None:
  466. return len(self._body)
  467. else:
  468. return 0
  469. @content_length.setter
  470. def content_length(self, value):
  471. raise RuntimeError("Content length is set automatically")
  472. async def write_eof(self):
  473. if self._eof_sent:
  474. return
  475. if self._compressed_body is not None:
  476. body = self._compressed_body
  477. else:
  478. body = self._body
  479. if body is not None:
  480. if (self._req._method == hdrs.METH_HEAD or
  481. self._status in [204, 304]):
  482. await super().write_eof()
  483. elif self._body_payload:
  484. await body.write(self._payload_writer)
  485. await super().write_eof()
  486. else:
  487. await super().write_eof(body)
  488. else:
  489. await super().write_eof()
  490. async def _start(self, request):
  491. if not self._chunked and hdrs.CONTENT_LENGTH not in self._headers:
  492. if not self._body_payload:
  493. if self._body is not None:
  494. self._headers[hdrs.CONTENT_LENGTH] = str(len(self._body))
  495. else:
  496. self._headers[hdrs.CONTENT_LENGTH] = '0'
  497. return await super()._start(request)
  498. def _do_start_compression(self, coding):
  499. if self._body_payload or self._chunked:
  500. return super()._do_start_compression(coding)
  501. if coding != ContentCoding.identity:
  502. # Instead of using _payload_writer.enable_compression,
  503. # compress the whole body
  504. zlib_mode = (16 + zlib.MAX_WBITS
  505. if coding.value == 'gzip' else -zlib.MAX_WBITS)
  506. compressobj = zlib.compressobj(wbits=zlib_mode)
  507. self._compressed_body = compressobj.compress(self._body) +\
  508. compressobj.flush()
  509. self._headers[hdrs.CONTENT_ENCODING] = coding.value
  510. self._headers[hdrs.CONTENT_LENGTH] = \
  511. str(len(self._compressed_body))
  512. def json_response(data=sentinel, *, text=None, body=None, status=200,
  513. reason=None, headers=None, content_type='application/json',
  514. dumps=json.dumps):
  515. if data is not sentinel:
  516. if text or body:
  517. raise ValueError(
  518. "only one of data, text, or body should be specified"
  519. )
  520. else:
  521. text = dumps(data)
  522. return Response(text=text, body=body, status=status, reason=reason,
  523. headers=headers, content_type=content_type)