You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

204 lines
6.4 KiB

4 years ago
  1. import re
  2. from string import ascii_letters, ascii_lowercase, digits
  3. from typing import Optional, TYPE_CHECKING, cast
  4. BASCII_LOWERCASE = ascii_lowercase.encode("ascii")
  5. BPCT_ALLOWED = {"%{:02X}".format(i).encode("ascii") for i in range(256)}
  6. GEN_DELIMS = ":/?#[]@"
  7. SUB_DELIMS_WITHOUT_QS = "!$'()*,"
  8. SUB_DELIMS = SUB_DELIMS_WITHOUT_QS + "+&=;"
  9. RESERVED = GEN_DELIMS + SUB_DELIMS
  10. UNRESERVED = ascii_letters + digits + "-._~"
  11. ALLOWED = UNRESERVED + SUB_DELIMS_WITHOUT_QS
  12. _IS_HEX = re.compile(b"[A-Z0-9][A-Z0-9]")
  13. class _Quoter:
  14. def __init__(
  15. self, *, safe: str = "", protected: str = "", qs: bool = False
  16. ) -> None:
  17. self._safe = safe
  18. self._protected = protected
  19. self._qs = qs
  20. def __call__(self, val: Optional[str]) -> Optional[str]:
  21. if val is None:
  22. return None
  23. if not isinstance(val, str):
  24. raise TypeError("Argument should be str")
  25. if not val:
  26. return ""
  27. bval = cast(str, val).encode("utf8", errors="ignore")
  28. ret = bytearray()
  29. pct = bytearray()
  30. safe = self._safe
  31. safe += ALLOWED
  32. if not self._qs:
  33. safe += "+&=;"
  34. safe += self._protected
  35. bsafe = safe.encode("ascii")
  36. idx = 0
  37. while idx < len(bval):
  38. ch = bval[idx]
  39. idx += 1
  40. if pct:
  41. if ch in BASCII_LOWERCASE:
  42. ch = ch - 32 # convert to uppercase
  43. pct.append(ch)
  44. if len(pct) == 3: # pragma: no branch # peephole optimizer
  45. buf = pct[1:]
  46. if not _IS_HEX.match(buf):
  47. ret.extend(b"%25")
  48. pct.clear()
  49. idx -= 2
  50. continue
  51. try:
  52. unquoted = chr(int(pct[1:].decode("ascii"), base=16))
  53. except ValueError:
  54. ret.extend(b"%25")
  55. pct.clear()
  56. idx -= 2
  57. continue
  58. if unquoted in self._protected:
  59. ret.extend(pct)
  60. elif unquoted in safe:
  61. ret.append(ord(unquoted))
  62. else:
  63. ret.extend(pct)
  64. pct.clear()
  65. # special case, if we have only one char after "%"
  66. elif len(pct) == 2 and idx == len(bval):
  67. ret.extend(b"%25")
  68. pct.clear()
  69. idx -= 1
  70. continue
  71. elif ch == ord("%"):
  72. pct.clear()
  73. pct.append(ch)
  74. # special case if "%" is last char
  75. if idx == len(bval):
  76. ret.extend(b"%25")
  77. continue
  78. if self._qs:
  79. if ch == ord(" "):
  80. ret.append(ord("+"))
  81. continue
  82. if ch in bsafe:
  83. ret.append(ch)
  84. continue
  85. ret.extend(("%{:02X}".format(ch)).encode("ascii"))
  86. return ret.decode("ascii")
  87. class _Unquoter:
  88. def __init__(self, *, unsafe: str = "", qs: bool = False) -> None:
  89. self._unsafe = unsafe
  90. self._qs = qs
  91. self._quoter = _Quoter()
  92. self._qs_quoter = _Quoter(qs=True)
  93. def __call__(self, val: Optional[str]) -> Optional[str]:
  94. if val is None:
  95. return None
  96. if not isinstance(val, str):
  97. raise TypeError("Argument should be str")
  98. if not val:
  99. return ""
  100. pct = ""
  101. last_pct = ""
  102. pcts = bytearray()
  103. ret = []
  104. for ch in val:
  105. if pct:
  106. pct += ch
  107. if len(pct) == 3: # pragma: no branch # peephole optimizer
  108. pcts.append(int(pct[1:], base=16))
  109. last_pct = pct
  110. pct = ""
  111. continue
  112. if pcts:
  113. try:
  114. unquoted = pcts.decode("utf8")
  115. except UnicodeDecodeError:
  116. pass
  117. else:
  118. if self._qs and unquoted in "+=&;":
  119. to_add = self._qs_quoter(unquoted)
  120. if to_add is None: # pragma: no cover
  121. raise RuntimeError("Cannot quote None")
  122. ret.append(to_add)
  123. elif unquoted in self._unsafe:
  124. to_add = self._qs_quoter(unquoted)
  125. if to_add is None: # pragma: no cover
  126. raise RuntimeError("Cannot quote None")
  127. ret.append(to_add)
  128. else:
  129. ret.append(unquoted)
  130. del pcts[:]
  131. if ch == "%":
  132. pct = ch
  133. continue
  134. if pcts:
  135. ret.append(last_pct) # %F8ab
  136. last_pct = ""
  137. if ch == "+":
  138. if not self._qs or ch in self._unsafe:
  139. ret.append("+")
  140. else:
  141. ret.append(" ")
  142. continue
  143. if ch in self._unsafe:
  144. ret.append("%")
  145. h = hex(ord(ch)).upper()[2:]
  146. for ch in h:
  147. ret.append(ch)
  148. continue
  149. ret.append(ch)
  150. if pcts:
  151. try:
  152. unquoted = pcts.decode("utf8")
  153. except UnicodeDecodeError:
  154. ret.append(last_pct) # %F8
  155. else:
  156. if self._qs and unquoted in "+=&;":
  157. to_add = self._qs_quoter(unquoted)
  158. if to_add is None: # pragma: no cover
  159. raise RuntimeError("Cannot quote None")
  160. ret.append(to_add)
  161. elif unquoted in self._unsafe:
  162. to_add = self._qs_quoter(unquoted)
  163. if to_add is None: # pragma: no cover
  164. raise RuntimeError("Cannot quote None")
  165. ret.append(to_add)
  166. else:
  167. ret.append(unquoted)
  168. return "".join(ret)
  169. _PyQuoter = _Quoter
  170. _PyUnquoter = _Unquoter
  171. if not TYPE_CHECKING: # pragma: no branch
  172. try:
  173. from ._quoting import _Quoter, _Unquoter
  174. except ImportError: # pragma: no cover
  175. _Quoter = _PyQuoter
  176. _Unquoter = _PyUnquoter