You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

334 lines
10 KiB

4 years ago
  1. import datetime
  2. import pathlib
  3. import pickle
  4. import re
  5. from collections import defaultdict
  6. from collections.abc import Mapping
  7. from http.cookies import Morsel, SimpleCookie
  8. from math import ceil
  9. from yarl import URL
  10. from .abc import AbstractCookieJar
  11. from .helpers import is_ip_address
  12. __all__ = ('CookieJar', 'DummyCookieJar')
  13. class CookieJar(AbstractCookieJar):
  14. """Implements cookie storage adhering to RFC 6265."""
  15. DATE_TOKENS_RE = re.compile(
  16. r"[\x09\x20-\x2F\x3B-\x40\x5B-\x60\x7B-\x7E]*"
  17. r"(?P<token>[\x00-\x08\x0A-\x1F\d:a-zA-Z\x7F-\xFF]+)")
  18. DATE_HMS_TIME_RE = re.compile(r"(\d{1,2}):(\d{1,2}):(\d{1,2})")
  19. DATE_DAY_OF_MONTH_RE = re.compile(r"(\d{1,2})")
  20. DATE_MONTH_RE = re.compile("(jan)|(feb)|(mar)|(apr)|(may)|(jun)|(jul)|"
  21. "(aug)|(sep)|(oct)|(nov)|(dec)", re.I)
  22. DATE_YEAR_RE = re.compile(r"(\d{2,4})")
  23. MAX_TIME = 2051215261.0 # so far in future (2035-01-01)
  24. def __init__(self, *, unsafe=False, loop=None):
  25. super().__init__(loop=loop)
  26. self._cookies = defaultdict(SimpleCookie)
  27. self._host_only_cookies = set()
  28. self._unsafe = unsafe
  29. self._next_expiration = ceil(self._loop.time())
  30. self._expirations = {}
  31. def save(self, file_path):
  32. file_path = pathlib.Path(file_path)
  33. with file_path.open(mode='wb') as f:
  34. pickle.dump(self._cookies, f, pickle.HIGHEST_PROTOCOL)
  35. def load(self, file_path):
  36. file_path = pathlib.Path(file_path)
  37. with file_path.open(mode='rb') as f:
  38. self._cookies = pickle.load(f)
  39. def clear(self):
  40. self._cookies.clear()
  41. self._host_only_cookies.clear()
  42. self._next_expiration = ceil(self._loop.time())
  43. self._expirations.clear()
  44. def __iter__(self):
  45. self._do_expiration()
  46. for val in self._cookies.values():
  47. yield from val.values()
  48. def __len__(self):
  49. return sum(1 for i in self)
  50. def _do_expiration(self):
  51. now = self._loop.time()
  52. if self._next_expiration > now:
  53. return
  54. if not self._expirations:
  55. return
  56. next_expiration = self.MAX_TIME
  57. to_del = []
  58. cookies = self._cookies
  59. expirations = self._expirations
  60. for (domain, name), when in expirations.items():
  61. if when <= now:
  62. cookies[domain].pop(name, None)
  63. to_del.append((domain, name))
  64. self._host_only_cookies.discard((domain, name))
  65. else:
  66. next_expiration = min(next_expiration, when)
  67. for key in to_del:
  68. del expirations[key]
  69. self._next_expiration = ceil(next_expiration)
  70. def _expire_cookie(self, when, domain, name):
  71. self._next_expiration = min(self._next_expiration, when)
  72. self._expirations[(domain, name)] = when
  73. def update_cookies(self, cookies, response_url=URL()):
  74. """Update cookies."""
  75. hostname = response_url.raw_host
  76. if not self._unsafe and is_ip_address(hostname):
  77. # Don't accept cookies from IPs
  78. return
  79. if isinstance(cookies, Mapping):
  80. cookies = cookies.items()
  81. for name, cookie in cookies:
  82. if not isinstance(cookie, Morsel):
  83. tmp = SimpleCookie()
  84. tmp[name] = cookie
  85. cookie = tmp[name]
  86. domain = cookie["domain"]
  87. # ignore domains with trailing dots
  88. if domain.endswith('.'):
  89. domain = ""
  90. del cookie["domain"]
  91. if not domain and hostname is not None:
  92. # Set the cookie's domain to the response hostname
  93. # and set its host-only-flag
  94. self._host_only_cookies.add((hostname, name))
  95. domain = cookie["domain"] = hostname
  96. if domain.startswith("."):
  97. # Remove leading dot
  98. domain = domain[1:]
  99. cookie["domain"] = domain
  100. if hostname and not self._is_domain_match(domain, hostname):
  101. # Setting cookies for different domains is not allowed
  102. continue
  103. path = cookie["path"]
  104. if not path or not path.startswith("/"):
  105. # Set the cookie's path to the response path
  106. path = response_url.path
  107. if not path.startswith("/"):
  108. path = "/"
  109. else:
  110. # Cut everything from the last slash to the end
  111. path = "/" + path[1:path.rfind("/")]
  112. cookie["path"] = path
  113. max_age = cookie["max-age"]
  114. if max_age:
  115. try:
  116. delta_seconds = int(max_age)
  117. self._expire_cookie(self._loop.time() + delta_seconds,
  118. domain, name)
  119. except ValueError:
  120. cookie["max-age"] = ""
  121. else:
  122. expires = cookie["expires"]
  123. if expires:
  124. expire_time = self._parse_date(expires)
  125. if expire_time:
  126. self._expire_cookie(expire_time.timestamp(),
  127. domain, name)
  128. else:
  129. cookie["expires"] = ""
  130. self._cookies[domain][name] = cookie
  131. self._do_expiration()
  132. def filter_cookies(self, request_url=URL()):
  133. """Returns this jar's cookies filtered by their attributes."""
  134. self._do_expiration()
  135. request_url = URL(request_url)
  136. filtered = SimpleCookie()
  137. hostname = request_url.raw_host or ""
  138. is_not_secure = request_url.scheme not in ("https", "wss")
  139. for cookie in self:
  140. name = cookie.key
  141. domain = cookie["domain"]
  142. # Send shared cookies
  143. if not domain:
  144. filtered[name] = cookie.value
  145. continue
  146. if not self._unsafe and is_ip_address(hostname):
  147. continue
  148. if (domain, name) in self._host_only_cookies:
  149. if domain != hostname:
  150. continue
  151. elif not self._is_domain_match(domain, hostname):
  152. continue
  153. if not self._is_path_match(request_url.path, cookie["path"]):
  154. continue
  155. if is_not_secure and cookie["secure"]:
  156. continue
  157. # It's critical we use the Morsel so the coded_value
  158. # (based on cookie version) is preserved
  159. mrsl_val = cookie.get(cookie.key, Morsel())
  160. mrsl_val.set(cookie.key, cookie.value, cookie.coded_value)
  161. filtered[name] = mrsl_val
  162. return filtered
  163. @staticmethod
  164. def _is_domain_match(domain, hostname):
  165. """Implements domain matching adhering to RFC 6265."""
  166. if hostname == domain:
  167. return True
  168. if not hostname.endswith(domain):
  169. return False
  170. non_matching = hostname[:-len(domain)]
  171. if not non_matching.endswith("."):
  172. return False
  173. return not is_ip_address(hostname)
  174. @staticmethod
  175. def _is_path_match(req_path, cookie_path):
  176. """Implements path matching adhering to RFC 6265."""
  177. if not req_path.startswith("/"):
  178. req_path = "/"
  179. if req_path == cookie_path:
  180. return True
  181. if not req_path.startswith(cookie_path):
  182. return False
  183. if cookie_path.endswith("/"):
  184. return True
  185. non_matching = req_path[len(cookie_path):]
  186. return non_matching.startswith("/")
  187. @classmethod
  188. def _parse_date(cls, date_str):
  189. """Implements date string parsing adhering to RFC 6265."""
  190. if not date_str:
  191. return
  192. found_time = False
  193. found_day = False
  194. found_month = False
  195. found_year = False
  196. hour = minute = second = 0
  197. day = 0
  198. month = 0
  199. year = 0
  200. for token_match in cls.DATE_TOKENS_RE.finditer(date_str):
  201. token = token_match.group("token")
  202. if not found_time:
  203. time_match = cls.DATE_HMS_TIME_RE.match(token)
  204. if time_match:
  205. found_time = True
  206. hour, minute, second = [
  207. int(s) for s in time_match.groups()]
  208. continue
  209. if not found_day:
  210. day_match = cls.DATE_DAY_OF_MONTH_RE.match(token)
  211. if day_match:
  212. found_day = True
  213. day = int(day_match.group())
  214. continue
  215. if not found_month:
  216. month_match = cls.DATE_MONTH_RE.match(token)
  217. if month_match:
  218. found_month = True
  219. month = month_match.lastindex
  220. continue
  221. if not found_year:
  222. year_match = cls.DATE_YEAR_RE.match(token)
  223. if year_match:
  224. found_year = True
  225. year = int(year_match.group())
  226. if 70 <= year <= 99:
  227. year += 1900
  228. elif 0 <= year <= 69:
  229. year += 2000
  230. if False in (found_day, found_month, found_year, found_time):
  231. return
  232. if not 1 <= day <= 31:
  233. return
  234. if year < 1601 or hour > 23 or minute > 59 or second > 59:
  235. return
  236. return datetime.datetime(year, month, day,
  237. hour, minute, second,
  238. tzinfo=datetime.timezone.utc)
  239. class DummyCookieJar(AbstractCookieJar):
  240. """Implements a dummy cookie storage.
  241. It can be used with the ClientSession when no cookie processing is needed.
  242. """
  243. def __init__(self, *, loop=None):
  244. super().__init__(loop=loop)
  245. def __iter__(self):
  246. while False:
  247. yield None
  248. def __len__(self):
  249. return 0
  250. def clear(self):
  251. pass
  252. def update_cookies(self, cookies, response_url=None):
  253. pass
  254. def filter_cookies(self, request_url):
  255. return None