You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

357 lines
11 KiB

4 years ago
  1. import asyncio
  2. import datetime
  3. import os # noqa
  4. import pathlib
  5. import pickle
  6. import re
  7. from collections import defaultdict
  8. from http.cookies import BaseCookie, Morsel, SimpleCookie # noqa
  9. from math import ceil
  10. from typing import ( # noqa
  11. DefaultDict,
  12. Dict,
  13. Iterable,
  14. Iterator,
  15. Mapping,
  16. Optional,
  17. Set,
  18. Tuple,
  19. Union,
  20. cast,
  21. )
  22. from yarl import URL
  23. from .abc import AbstractCookieJar
  24. from .helpers import is_ip_address
  25. from .typedefs import LooseCookies, PathLike
  26. __all__ = ('CookieJar', 'DummyCookieJar')
  27. CookieItem = Union[str, 'Morsel[str]']
  28. class CookieJar(AbstractCookieJar):
  29. """Implements cookie storage adhering to RFC 6265."""
  30. DATE_TOKENS_RE = re.compile(
  31. r"[\x09\x20-\x2F\x3B-\x40\x5B-\x60\x7B-\x7E]*"
  32. r"(?P<token>[\x00-\x08\x0A-\x1F\d:a-zA-Z\x7F-\xFF]+)")
  33. DATE_HMS_TIME_RE = re.compile(r"(\d{1,2}):(\d{1,2}):(\d{1,2})")
  34. DATE_DAY_OF_MONTH_RE = re.compile(r"(\d{1,2})")
  35. DATE_MONTH_RE = re.compile("(jan)|(feb)|(mar)|(apr)|(may)|(jun)|(jul)|"
  36. "(aug)|(sep)|(oct)|(nov)|(dec)", re.I)
  37. DATE_YEAR_RE = re.compile(r"(\d{2,4})")
  38. MAX_TIME = 2051215261.0 # so far in future (2035-01-01)
  39. def __init__(self, *, unsafe: bool=False,
  40. loop: Optional[asyncio.AbstractEventLoop]=None) -> None:
  41. super().__init__(loop=loop)
  42. self._cookies = defaultdict(SimpleCookie) #type: DefaultDict[str, SimpleCookie] # noqa
  43. self._host_only_cookies = set() # type: Set[Tuple[str, str]]
  44. self._unsafe = unsafe
  45. self._next_expiration = ceil(self._loop.time())
  46. self._expirations = {} # type: Dict[Tuple[str, str], int]
  47. def save(self, file_path: PathLike) -> None:
  48. file_path = pathlib.Path(file_path)
  49. with file_path.open(mode='wb') as f:
  50. pickle.dump(self._cookies, f, pickle.HIGHEST_PROTOCOL)
  51. def load(self, file_path: PathLike) -> None:
  52. file_path = pathlib.Path(file_path)
  53. with file_path.open(mode='rb') as f:
  54. self._cookies = pickle.load(f)
  55. def clear(self) -> None:
  56. self._cookies.clear()
  57. self._host_only_cookies.clear()
  58. self._next_expiration = ceil(self._loop.time())
  59. self._expirations.clear()
  60. def __iter__(self) -> 'Iterator[Morsel[str]]':
  61. self._do_expiration()
  62. for val in self._cookies.values():
  63. yield from val.values()
  64. def __len__(self) -> int:
  65. return sum(1 for i in self)
  66. def _do_expiration(self) -> None:
  67. now = self._loop.time()
  68. if self._next_expiration > now:
  69. return
  70. if not self._expirations:
  71. return
  72. next_expiration = self.MAX_TIME
  73. to_del = []
  74. cookies = self._cookies
  75. expirations = self._expirations
  76. for (domain, name), when in expirations.items():
  77. if when <= now:
  78. cookies[domain].pop(name, None)
  79. to_del.append((domain, name))
  80. self._host_only_cookies.discard((domain, name))
  81. else:
  82. next_expiration = min(next_expiration, when)
  83. for key in to_del:
  84. del expirations[key]
  85. self._next_expiration = ceil(next_expiration)
  86. def _expire_cookie(self, when: float, domain: str, name: str) -> None:
  87. iwhen = int(when)
  88. self._next_expiration = min(self._next_expiration, iwhen)
  89. self._expirations[(domain, name)] = iwhen
  90. def update_cookies(self,
  91. cookies: LooseCookies,
  92. response_url: URL=URL()) -> None:
  93. """Update cookies."""
  94. hostname = response_url.raw_host
  95. if not self._unsafe and is_ip_address(hostname):
  96. # Don't accept cookies from IPs
  97. return
  98. if isinstance(cookies, Mapping):
  99. cookies = cookies.items() # type: ignore
  100. for name, cookie in cookies:
  101. if not isinstance(cookie, Morsel):
  102. tmp = SimpleCookie()
  103. tmp[name] = cookie # type: ignore
  104. cookie = tmp[name]
  105. domain = cookie["domain"]
  106. # ignore domains with trailing dots
  107. if domain.endswith('.'):
  108. domain = ""
  109. del cookie["domain"]
  110. if not domain and hostname is not None:
  111. # Set the cookie's domain to the response hostname
  112. # and set its host-only-flag
  113. self._host_only_cookies.add((hostname, name))
  114. domain = cookie["domain"] = hostname
  115. if domain.startswith("."):
  116. # Remove leading dot
  117. domain = domain[1:]
  118. cookie["domain"] = domain
  119. if hostname and not self._is_domain_match(domain, hostname):
  120. # Setting cookies for different domains is not allowed
  121. continue
  122. path = cookie["path"]
  123. if not path or not path.startswith("/"):
  124. # Set the cookie's path to the response path
  125. path = response_url.path
  126. if not path.startswith("/"):
  127. path = "/"
  128. else:
  129. # Cut everything from the last slash to the end
  130. path = "/" + path[1:path.rfind("/")]
  131. cookie["path"] = path
  132. max_age = cookie["max-age"]
  133. if max_age:
  134. try:
  135. delta_seconds = int(max_age)
  136. self._expire_cookie(self._loop.time() + delta_seconds,
  137. domain, name)
  138. except ValueError:
  139. cookie["max-age"] = ""
  140. else:
  141. expires = cookie["expires"]
  142. if expires:
  143. expire_time = self._parse_date(expires)
  144. if expire_time:
  145. self._expire_cookie(expire_time.timestamp(),
  146. domain, name)
  147. else:
  148. cookie["expires"] = ""
  149. self._cookies[domain][name] = cookie
  150. self._do_expiration()
  151. def filter_cookies(self, request_url: URL=URL()) -> 'BaseCookie[str]':
  152. """Returns this jar's cookies filtered by their attributes."""
  153. self._do_expiration()
  154. request_url = URL(request_url)
  155. filtered = SimpleCookie()
  156. hostname = request_url.raw_host or ""
  157. is_not_secure = request_url.scheme not in ("https", "wss")
  158. for cookie in self:
  159. name = cookie.key
  160. domain = cookie["domain"]
  161. # Send shared cookies
  162. if not domain:
  163. filtered[name] = cookie.value
  164. continue
  165. if not self._unsafe and is_ip_address(hostname):
  166. continue
  167. if (domain, name) in self._host_only_cookies:
  168. if domain != hostname:
  169. continue
  170. elif not self._is_domain_match(domain, hostname):
  171. continue
  172. if not self._is_path_match(request_url.path, cookie["path"]):
  173. continue
  174. if is_not_secure and cookie["secure"]:
  175. continue
  176. # It's critical we use the Morsel so the coded_value
  177. # (based on cookie version) is preserved
  178. mrsl_val = cast('Morsel[str]', cookie.get(cookie.key, Morsel()))
  179. mrsl_val.set(cookie.key, cookie.value, cookie.coded_value)
  180. filtered[name] = mrsl_val
  181. return filtered
  182. @staticmethod
  183. def _is_domain_match(domain: str, hostname: str) -> bool:
  184. """Implements domain matching adhering to RFC 6265."""
  185. if hostname == domain:
  186. return True
  187. if not hostname.endswith(domain):
  188. return False
  189. non_matching = hostname[:-len(domain)]
  190. if not non_matching.endswith("."):
  191. return False
  192. return not is_ip_address(hostname)
  193. @staticmethod
  194. def _is_path_match(req_path: str, cookie_path: str) -> bool:
  195. """Implements path matching adhering to RFC 6265."""
  196. if not req_path.startswith("/"):
  197. req_path = "/"
  198. if req_path == cookie_path:
  199. return True
  200. if not req_path.startswith(cookie_path):
  201. return False
  202. if cookie_path.endswith("/"):
  203. return True
  204. non_matching = req_path[len(cookie_path):]
  205. return non_matching.startswith("/")
  206. @classmethod
  207. def _parse_date(cls, date_str: str) -> Optional[datetime.datetime]:
  208. """Implements date string parsing adhering to RFC 6265."""
  209. if not date_str:
  210. return None
  211. found_time = False
  212. found_day = False
  213. found_month = False
  214. found_year = False
  215. hour = minute = second = 0
  216. day = 0
  217. month = 0
  218. year = 0
  219. for token_match in cls.DATE_TOKENS_RE.finditer(date_str):
  220. token = token_match.group("token")
  221. if not found_time:
  222. time_match = cls.DATE_HMS_TIME_RE.match(token)
  223. if time_match:
  224. found_time = True
  225. hour, minute, second = [
  226. int(s) for s in time_match.groups()]
  227. continue
  228. if not found_day:
  229. day_match = cls.DATE_DAY_OF_MONTH_RE.match(token)
  230. if day_match:
  231. found_day = True
  232. day = int(day_match.group())
  233. continue
  234. if not found_month:
  235. month_match = cls.DATE_MONTH_RE.match(token)
  236. if month_match:
  237. found_month = True
  238. month = month_match.lastindex
  239. continue
  240. if not found_year:
  241. year_match = cls.DATE_YEAR_RE.match(token)
  242. if year_match:
  243. found_year = True
  244. year = int(year_match.group())
  245. if 70 <= year <= 99:
  246. year += 1900
  247. elif 0 <= year <= 69:
  248. year += 2000
  249. if False in (found_day, found_month, found_year, found_time):
  250. return None
  251. if not 1 <= day <= 31:
  252. return None
  253. if year < 1601 or hour > 23 or minute > 59 or second > 59:
  254. return None
  255. return datetime.datetime(year, month, day,
  256. hour, minute, second,
  257. tzinfo=datetime.timezone.utc)
  258. class DummyCookieJar(AbstractCookieJar):
  259. """Implements a dummy cookie storage.
  260. It can be used with the ClientSession when no cookie processing is needed.
  261. """
  262. def __init__(self, *,
  263. loop: Optional[asyncio.AbstractEventLoop]=None) -> None:
  264. super().__init__(loop=loop)
  265. def __iter__(self) -> 'Iterator[Morsel[str]]':
  266. while False:
  267. yield None
  268. def __len__(self) -> int:
  269. return 0
  270. def clear(self) -> None:
  271. pass
  272. def update_cookies(self,
  273. cookies: LooseCookies,
  274. response_url: URL=URL()) -> None:
  275. pass
  276. def filter_cookies(self, request_url: URL) -> 'BaseCookie[str]':
  277. return SimpleCookie()