You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

390 lines
11 KiB

4 years ago
  1. # cython: language_level=3
  2. from libc.stdint cimport uint8_t, uint64_t
  3. from libc.string cimport memcpy, memset
  4. from cpython.exc cimport PyErr_NoMemory
  5. from cpython.mem cimport PyMem_Malloc, PyMem_Realloc, PyMem_Free
  6. from cpython.unicode cimport PyUnicode_DecodeASCII
  7. from string import ascii_letters, digits
  8. cdef str GEN_DELIMS = ":/?#[]@"
  9. cdef str SUB_DELIMS_WITHOUT_QS = "!$'()*,"
  10. cdef str SUB_DELIMS = SUB_DELIMS_WITHOUT_QS + '+?=;'
  11. cdef str RESERVED = GEN_DELIMS + SUB_DELIMS
  12. cdef str UNRESERVED = ascii_letters + digits + '-._~'
  13. cdef str ALLOWED = UNRESERVED + SUB_DELIMS_WITHOUT_QS
  14. cdef str QS = '+&=;'
  15. DEF BUF_SIZE = 8 * 1024 # 8KiB
  16. cdef char BUFFER[BUF_SIZE]
  17. cdef inline Py_UCS4 _to_hex(uint8_t v):
  18. if v < 10:
  19. return <Py_UCS4>(v+0x30) # ord('0') == 0x30
  20. else:
  21. return <Py_UCS4>(v+0x41-10) # ord('A') == 0x41
  22. cdef inline int _from_hex(Py_UCS4 v):
  23. if '0' <= v <= '9':
  24. return <int>(v) - 0x30 # ord('0') == 0x30
  25. elif 'A' <= v <= 'F':
  26. return <int>(v) - 0x41 + 10 # ord('A') == 0x41
  27. elif 'a' <= v <= 'f':
  28. return <int>(v) - 0x61 + 10 # ord('a') == 0x61
  29. else:
  30. return -1
  31. cdef inline Py_UCS4 _restore_ch(Py_UCS4 d1, Py_UCS4 d2):
  32. cdef int digit1 = _from_hex(d1)
  33. if digit1 < 0:
  34. return <Py_UCS4>-1
  35. cdef int digit2 = _from_hex(d2)
  36. if digit2 < 0:
  37. return <Py_UCS4>-1
  38. return <Py_UCS4>(digit1 << 4 | digit2)
  39. cdef uint8_t ALLOWED_TABLE[16]
  40. cdef uint8_t ALLOWED_NOTQS_TABLE[16]
  41. cdef inline bint bit_at(uint8_t array[], uint64_t ch):
  42. return array[ch >> 3] & (1 << (ch & 7))
  43. cdef inline void set_bit(uint8_t array[], uint64_t ch):
  44. array[ch >> 3] |= (1 << (ch & 7))
  45. memset(ALLOWED_TABLE, 0, sizeof(ALLOWED_TABLE))
  46. memset(ALLOWED_NOTQS_TABLE, 0, sizeof(ALLOWED_NOTQS_TABLE))
  47. for i in range(128):
  48. if chr(i) in ALLOWED:
  49. set_bit(ALLOWED_TABLE, i)
  50. set_bit(ALLOWED_NOTQS_TABLE, i)
  51. if chr(i) in QS:
  52. set_bit(ALLOWED_NOTQS_TABLE, i)
  53. # ----------------- writer ---------------------------
  54. cdef struct Writer:
  55. char *buf
  56. Py_ssize_t size
  57. Py_ssize_t pos
  58. bint changed
  59. cdef inline void _init_writer(Writer* writer):
  60. writer.buf = &BUFFER[0]
  61. writer.size = BUF_SIZE
  62. writer.pos = 0
  63. writer.changed = 0
  64. cdef inline void _release_writer(Writer* writer):
  65. if writer.buf != BUFFER:
  66. PyMem_Free(writer.buf)
  67. cdef inline int _write_char(Writer* writer, Py_UCS4 ch, bint changed):
  68. cdef char * buf
  69. cdef Py_ssize_t size
  70. if writer.pos == writer.size:
  71. # reallocate
  72. size = writer.size + BUF_SIZE
  73. if writer.buf == BUFFER:
  74. buf = <char*>PyMem_Malloc(size)
  75. if buf == NULL:
  76. PyErr_NoMemory()
  77. return -1
  78. memcpy(buf, writer.buf, writer.size)
  79. else:
  80. buf = <char*>PyMem_Realloc(writer.buf, size)
  81. if buf == NULL:
  82. PyErr_NoMemory()
  83. return -1
  84. writer.buf = buf
  85. writer.size = size
  86. writer.buf[writer.pos] = <char>ch
  87. writer.pos += 1
  88. writer.changed |= changed
  89. return 0
  90. cdef inline int _write_pct(Writer* writer, uint8_t ch, bint changed):
  91. if _write_char(writer, '%', changed) < 0:
  92. return -1
  93. if _write_char(writer, _to_hex(<uint8_t>ch >> 4), changed) < 0:
  94. return -1
  95. return _write_char(writer, _to_hex(<uint8_t>ch & 0x0f), changed)
  96. cdef inline int _write_percent(Writer* writer):
  97. if _write_char(writer, '%', True) < 0:
  98. return -1
  99. if _write_char(writer, '2', True) < 0:
  100. return -1
  101. return _write_char(writer, '5', True)
  102. cdef inline int _write_pct_check(Writer* writer, Py_UCS4 ch, Py_UCS4 pct[]):
  103. cdef Py_UCS4 pct1 = _to_hex(<uint8_t>ch >> 4)
  104. cdef Py_UCS4 pct2 = _to_hex(<uint8_t>ch & 0x0f)
  105. cdef bint changed = pct[0] != pct1 or pct[1] != pct2
  106. if _write_char(writer, '%', changed) < 0:
  107. return -1
  108. if _write_char(writer, pct1, changed) < 0:
  109. return -1
  110. return _write_char(writer, pct2, changed)
  111. cdef inline int _write_utf8(Writer* writer, Py_UCS4 symbol):
  112. cdef uint64_t utf = <uint64_t> symbol
  113. if utf < 0x80:
  114. return _write_pct(writer, <uint8_t>utf, True)
  115. elif utf < 0x800:
  116. if _write_pct(writer, <uint8_t>(0xc0 | (utf >> 6)), True) < 0:
  117. return -1
  118. return _write_pct(writer, <uint8_t>(0x80 | (utf & 0x3f)), True)
  119. elif 0xD800 <= utf <= 0xDFFF:
  120. # surogate pair, ignored
  121. return 0
  122. elif utf < 0x10000:
  123. if _write_pct(writer, <uint8_t>(0xe0 | (utf >> 12)), True) < 0:
  124. return -1
  125. if _write_pct(writer, <uint8_t>(0x80 | ((utf >> 6) & 0x3f)),
  126. True) < 0:
  127. return -1
  128. return _write_pct(writer, <uint8_t>(0x80 | (utf & 0x3f)), True)
  129. elif utf > 0x10FFFF:
  130. # symbol is too large
  131. return 0
  132. else:
  133. if _write_pct(writer, <uint8_t>(0xf0 | (utf >> 18)), True) < 0:
  134. return -1
  135. if _write_pct(writer, <uint8_t>(0x80 | ((utf >> 12) & 0x3f)),
  136. True) < 0:
  137. return -1
  138. if _write_pct(writer, <uint8_t>(0x80 | ((utf >> 6) & 0x3f)),
  139. True) < 0:
  140. return -1
  141. return _write_pct(writer, <uint8_t>(0x80 | (utf & 0x3f)), True)
  142. # --------------------- end writer --------------------------
  143. cdef class _Quoter:
  144. cdef bint _qs
  145. cdef uint8_t _safe_table[16]
  146. cdef uint8_t _protected_table[16]
  147. def __init__(self, *, str safe='', str protected='', bint qs=False):
  148. cdef Py_UCS4 ch
  149. self._qs = qs
  150. if not self._qs:
  151. memcpy(self._safe_table,
  152. ALLOWED_NOTQS_TABLE,
  153. sizeof(self._safe_table))
  154. else:
  155. memcpy(self._safe_table,
  156. ALLOWED_TABLE,
  157. sizeof(self._safe_table))
  158. for ch in safe:
  159. if ord(ch) > 127:
  160. raise ValueError("Only safe symbols with ORD < 128 are allowed")
  161. set_bit(self._safe_table, ch)
  162. memset(self._protected_table, 0, sizeof(self._protected_table))
  163. for ch in protected:
  164. if ord(ch) > 127:
  165. raise ValueError("Only safe symbols with ORD < 128 are allowed")
  166. set_bit(self._safe_table, ch)
  167. set_bit(self._protected_table, ch)
  168. def __call__(self, val):
  169. cdef Writer writer
  170. if val is None:
  171. return None
  172. if type(val) is not str:
  173. if isinstance(val, str):
  174. # derived from str
  175. val = str(val)
  176. else:
  177. raise TypeError("Argument should be str")
  178. _init_writer(&writer)
  179. try:
  180. return self._do_quote(<str>val, &writer)
  181. finally:
  182. _release_writer(&writer)
  183. cdef str _do_quote(self, str val, Writer *writer):
  184. cdef Py_UCS4 ch
  185. cdef int has_pct = 0
  186. cdef Py_UCS4 pct[2]
  187. cdef int idx = 0
  188. for ch in val:
  189. if has_pct:
  190. pct[has_pct-1] = ch
  191. has_pct += 1
  192. if has_pct == 3:
  193. ch = _restore_ch(pct[0], pct[1])
  194. has_pct = 0
  195. if ch == <Py_UCS4>-1:
  196. if _write_percent(writer) < 0:
  197. raise
  198. if self._write(writer, pct[0]) < 0:
  199. raise
  200. if self._write(writer, pct[1]) < 0:
  201. raise
  202. continue
  203. if ch < 128:
  204. if bit_at(self._protected_table, ch):
  205. if _write_pct(writer, ch, True) < 0:
  206. raise
  207. continue
  208. if bit_at(self._safe_table, ch):
  209. if _write_char(writer, ch, True) < 0:
  210. raise
  211. continue
  212. if _write_pct_check(writer, ch, pct) < 0:
  213. raise
  214. continue
  215. elif ch == '%':
  216. has_pct = 1
  217. continue
  218. if self._write(writer, ch) < 0:
  219. raise
  220. if has_pct:
  221. if _write_percent(writer) < 0:
  222. raise
  223. if has_pct > 1: # the value is 2
  224. if self._write(writer, ch) < 0:
  225. raise
  226. if not writer.changed:
  227. return val
  228. else:
  229. return PyUnicode_DecodeASCII(writer.buf, writer.pos, "strict")
  230. cdef inline int _write(self, Writer *writer, Py_UCS4 ch):
  231. if self._qs:
  232. if ch == ' ':
  233. return _write_char(writer, '+', True)
  234. if ch < 128 and bit_at(self._safe_table, ch):
  235. return _write_char(writer, ch, False)
  236. return _write_utf8(writer, ch)
  237. cdef class _Unquoter:
  238. cdef str _unsafe
  239. cdef bint _qs
  240. cdef _Quoter _quoter
  241. cdef _Quoter _qs_quoter
  242. def __init__(self, *, unsafe='', qs=False):
  243. self._unsafe = unsafe
  244. self._qs = qs
  245. self._quoter = _Quoter()
  246. self._qs_quoter = _Quoter(qs=True)
  247. def __call__(self, val):
  248. if val is None:
  249. return None
  250. if type(val) is not str:
  251. if isinstance(val, str):
  252. # derived from str
  253. val = str(val)
  254. else:
  255. raise TypeError("Argument should be str")
  256. return self._do_unquote(<str>val)
  257. cdef str _do_unquote(self, str val):
  258. if len(val) == 0:
  259. return val
  260. cdef str pct = ''
  261. cdef str last_pct = ''
  262. cdef bytearray pcts = bytearray()
  263. cdef list ret = []
  264. cdef str unquoted
  265. for ch in val:
  266. if pct:
  267. pct += ch
  268. if len(pct) == 3: # pragma: no branch # peephole optimizer
  269. pcts.append(int(pct[1:], base=16))
  270. last_pct = pct
  271. pct = ''
  272. continue
  273. if pcts:
  274. try:
  275. unquoted = pcts.decode('utf8')
  276. except UnicodeDecodeError:
  277. pass
  278. else:
  279. if self._qs and unquoted in '+=&;':
  280. ret.append(self._qs_quoter(unquoted))
  281. elif unquoted in self._unsafe:
  282. ret.append(self._quoter(unquoted))
  283. else:
  284. ret.append(unquoted)
  285. del pcts[:]
  286. if ch == '%':
  287. pct = ch
  288. continue
  289. if pcts:
  290. ret.append(last_pct) # %F8ab
  291. last_pct = ''
  292. if ch == '+':
  293. if not self._qs or ch in self._unsafe:
  294. ret.append('+')
  295. else:
  296. ret.append(' ')
  297. continue
  298. if ch in self._unsafe:
  299. ret.append('%')
  300. h = hex(ord(ch)).upper()[2:]
  301. for ch in h:
  302. ret.append(ch)
  303. continue
  304. ret.append(ch)
  305. if pcts:
  306. try:
  307. unquoted = pcts.decode('utf8')
  308. except UnicodeDecodeError:
  309. ret.append(last_pct) # %F8
  310. else:
  311. if self._qs and unquoted in '+=&;':
  312. ret.append(self._qs_quoter(unquoted))
  313. elif unquoted in self._unsafe:
  314. ret.append(self._quoter(unquoted))
  315. else:
  316. ret.append(unquoted)
  317. return ''.join(ret)