You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

203 lines
6.4 KiB

4 years ago
  1. import re
  2. from string import ascii_letters, ascii_lowercase, digits
  3. from typing import Optional, TYPE_CHECKING, cast
  4. BASCII_LOWERCASE = ascii_lowercase.encode('ascii')
  5. BPCT_ALLOWED = {'%{:02X}'.format(i).encode('ascii') for i in range(256)}
  6. GEN_DELIMS = ":/?#[]@"
  7. SUB_DELIMS_WITHOUT_QS = "!$'()*,"
  8. SUB_DELIMS = SUB_DELIMS_WITHOUT_QS + '+&=;'
  9. RESERVED = GEN_DELIMS + SUB_DELIMS
  10. UNRESERVED = ascii_letters + digits + '-._~'
  11. ALLOWED = UNRESERVED + SUB_DELIMS_WITHOUT_QS
  12. _IS_HEX = re.compile(b'[A-Z0-9][A-Z0-9]')
  13. class _Quoter:
  14. def __init__(self, *,
  15. safe: str='', protected: str='', qs: bool=False) -> None:
  16. self._safe = safe
  17. self._protected = protected
  18. self._qs = qs
  19. def __call__(self, val: Optional[str]) -> Optional[str]:
  20. if val is None:
  21. return None
  22. if not isinstance(val, str):
  23. raise TypeError("Argument should be str")
  24. if not val:
  25. return ''
  26. bval = cast(str, val).encode('utf8', errors='ignore')
  27. ret = bytearray()
  28. pct = bytearray()
  29. safe = self._safe
  30. safe += ALLOWED
  31. if not self._qs:
  32. safe += '+&=;'
  33. safe += self._protected
  34. bsafe = safe.encode('ascii')
  35. idx = 0
  36. while idx < len(bval):
  37. ch = bval[idx]
  38. idx += 1
  39. if pct:
  40. if ch in BASCII_LOWERCASE:
  41. ch = ch - 32 # convert to uppercase
  42. pct.append(ch)
  43. if len(pct) == 3: # pragma: no branch # peephole optimizer
  44. buf = pct[1:]
  45. if not _IS_HEX.match(buf):
  46. ret.extend(b'%25')
  47. pct.clear()
  48. idx -= 2
  49. continue
  50. try:
  51. unquoted = chr(int(pct[1:].decode('ascii'), base=16))
  52. except ValueError:
  53. ret.extend(b'%25')
  54. pct.clear()
  55. idx -= 2
  56. continue
  57. if unquoted in self._protected:
  58. ret.extend(pct)
  59. elif unquoted in safe:
  60. ret.append(ord(unquoted))
  61. else:
  62. ret.extend(pct)
  63. pct.clear()
  64. # special case, if we have only one char after "%"
  65. elif len(pct) == 2 and idx == len(bval):
  66. ret.extend(b'%25')
  67. pct.clear()
  68. idx -= 1
  69. continue
  70. elif ch == ord('%'):
  71. pct.clear()
  72. pct.append(ch)
  73. # special case if "%" is last char
  74. if idx == len(bval):
  75. ret.extend(b'%25')
  76. continue
  77. if self._qs:
  78. if ch == ord(' '):
  79. ret.append(ord('+'))
  80. continue
  81. if ch in bsafe:
  82. ret.append(ch)
  83. continue
  84. ret.extend(('%{:02X}'.format(ch)).encode('ascii'))
  85. return ret.decode('ascii')
  86. class _Unquoter:
  87. def __init__(self, *, unsafe: str='', qs: bool=False) -> None:
  88. self._unsafe = unsafe
  89. self._qs = qs
  90. self._quoter = _Quoter()
  91. self._qs_quoter = _Quoter(qs=True)
  92. def __call__(self, val: Optional[str]) -> Optional[str]:
  93. if val is None:
  94. return None
  95. if not isinstance(val, str):
  96. raise TypeError("Argument should be str")
  97. if not val:
  98. return ''
  99. pct = ''
  100. last_pct = ''
  101. pcts = bytearray()
  102. ret = []
  103. for ch in val:
  104. if pct:
  105. pct += ch
  106. if len(pct) == 3: # pragma: no branch # peephole optimizer
  107. pcts.append(int(pct[1:], base=16))
  108. last_pct = pct
  109. pct = ''
  110. continue
  111. if pcts:
  112. try:
  113. unquoted = pcts.decode('utf8')
  114. except UnicodeDecodeError:
  115. pass
  116. else:
  117. if self._qs and unquoted in '+=&;':
  118. to_add = self._qs_quoter(unquoted)
  119. if to_add is None: # pragma: no cover
  120. raise RuntimeError("Cannot quote None")
  121. ret.append(to_add)
  122. elif unquoted in self._unsafe:
  123. to_add = self._qs_quoter(unquoted)
  124. if to_add is None: # pragma: no cover
  125. raise RuntimeError("Cannot quote None")
  126. ret.append(to_add)
  127. else:
  128. ret.append(unquoted)
  129. del pcts[:]
  130. if ch == '%':
  131. pct = ch
  132. continue
  133. if pcts:
  134. ret.append(last_pct) # %F8ab
  135. last_pct = ''
  136. if ch == '+':
  137. if not self._qs or ch in self._unsafe:
  138. ret.append('+')
  139. else:
  140. ret.append(' ')
  141. continue
  142. if ch in self._unsafe:
  143. ret.append('%')
  144. h = hex(ord(ch)).upper()[2:]
  145. for ch in h:
  146. ret.append(ch)
  147. continue
  148. ret.append(ch)
  149. if pcts:
  150. try:
  151. unquoted = pcts.decode('utf8')
  152. except UnicodeDecodeError:
  153. ret.append(last_pct) # %F8
  154. else:
  155. if self._qs and unquoted in '+=&;':
  156. to_add = self._qs_quoter(unquoted)
  157. if to_add is None: # pragma: no cover
  158. raise RuntimeError("Cannot quote None")
  159. ret.append(to_add)
  160. elif unquoted in self._unsafe:
  161. to_add = self._qs_quoter(unquoted)
  162. if to_add is None: # pragma: no cover
  163. raise RuntimeError("Cannot quote None")
  164. ret.append(to_add)
  165. else:
  166. ret.append(unquoted)
  167. return ''.join(ret)
  168. _PyQuoter = _Quoter
  169. _PyUnquoter = _Unquoter
  170. if not TYPE_CHECKING: # pragma: no branch
  171. try:
  172. from ._quoting import _Quoter, _Unquoter
  173. except ImportError: # pragma: no cover
  174. _Quoter = _PyQuoter
  175. _Unquoter = _PyUnquoter