You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

808 lines
29 KiB

4 years ago
  1. # cython: profile=True
  2. # cython: infer_types=True
  3. # coding: utf8
  4. from __future__ import unicode_literals
  5. import ujson
  6. from cymem.cymem cimport Pool
  7. from preshed.maps cimport PreshMap
  8. from libcpp.vector cimport vector
  9. from libcpp.pair cimport pair
  10. from murmurhash.mrmr cimport hash64
  11. from libc.stdint cimport int32_t
  12. from .typedefs cimport attr_t
  13. from .typedefs cimport hash_t
  14. from .structs cimport TokenC
  15. from .tokens.doc cimport Doc, get_token_attr
  16. from .vocab cimport Vocab
  17. from .errors import Errors, TempErrors
  18. from .attrs import IDS
  19. from .attrs cimport attr_id_t, ID, NULL_ATTR
  20. from .attrs import FLAG61 as U_ENT
  21. from .attrs import FLAG60 as B2_ENT
  22. from .attrs import FLAG59 as B3_ENT
  23. from .attrs import FLAG58 as B4_ENT
  24. from .attrs import FLAG57 as B5_ENT
  25. from .attrs import FLAG56 as B6_ENT
  26. from .attrs import FLAG55 as B7_ENT
  27. from .attrs import FLAG54 as B8_ENT
  28. from .attrs import FLAG53 as B9_ENT
  29. from .attrs import FLAG52 as B10_ENT
  30. from .attrs import FLAG51 as I3_ENT
  31. from .attrs import FLAG50 as I4_ENT
  32. from .attrs import FLAG49 as I5_ENT
  33. from .attrs import FLAG48 as I6_ENT
  34. from .attrs import FLAG47 as I7_ENT
  35. from .attrs import FLAG46 as I8_ENT
  36. from .attrs import FLAG45 as I9_ENT
  37. from .attrs import FLAG44 as I10_ENT
  38. from .attrs import FLAG43 as L2_ENT
  39. from .attrs import FLAG42 as L3_ENT
  40. from .attrs import FLAG41 as L4_ENT
  41. from .attrs import FLAG40 as L5_ENT
  42. from .attrs import FLAG39 as L6_ENT
  43. from .attrs import FLAG38 as L7_ENT
  44. from .attrs import FLAG37 as L8_ENT
  45. from .attrs import FLAG36 as L9_ENT
  46. from .attrs import FLAG35 as L10_ENT
  47. DELIMITER = '||'
  48. cpdef enum quantifier_t:
  49. _META
  50. ONE
  51. ZERO
  52. ZERO_ONE
  53. ZERO_PLUS
  54. cdef enum action_t:
  55. REJECT
  56. ADVANCE
  57. REPEAT
  58. ACCEPT
  59. ADVANCE_ZERO
  60. ACCEPT_PREV
  61. PANIC
  62. # A "match expression" consists of one or more token patterns
  63. # Each token pattern consists of a quantifier and 0+ (attr, value) pairs.
  64. # A state is an (int, pattern pointer) pair, where the int is the start
  65. # position, and the pattern pointer shows where we're up to
  66. # in the pattern.
  67. cdef struct AttrValueC:
  68. attr_id_t attr
  69. attr_t value
  70. cdef struct TokenPatternC:
  71. AttrValueC* attrs
  72. int32_t nr_attr
  73. quantifier_t quantifier
  74. ctypedef TokenPatternC* TokenPatternC_ptr
  75. ctypedef pair[int, TokenPatternC_ptr] StateC
  76. DEF PADDING = 5
  77. cdef TokenPatternC* init_pattern(Pool mem, attr_t entity_id,
  78. object token_specs) except NULL:
  79. pattern = <TokenPatternC*>mem.alloc(len(token_specs) + 1, sizeof(TokenPatternC))
  80. cdef int i
  81. for i, (quantifier, spec) in enumerate(token_specs):
  82. pattern[i].quantifier = quantifier
  83. pattern[i].attrs = <AttrValueC*>mem.alloc(len(spec), sizeof(AttrValueC))
  84. pattern[i].nr_attr = len(spec)
  85. for j, (attr, value) in enumerate(spec):
  86. pattern[i].attrs[j].attr = attr
  87. pattern[i].attrs[j].value = value
  88. i = len(token_specs)
  89. pattern[i].attrs = <AttrValueC*>mem.alloc(2, sizeof(AttrValueC))
  90. pattern[i].attrs[0].attr = ID
  91. pattern[i].attrs[0].value = entity_id
  92. pattern[i].nr_attr = 0
  93. return pattern
  94. cdef attr_t get_pattern_key(const TokenPatternC* pattern) except 0:
  95. while pattern.nr_attr != 0:
  96. pattern += 1
  97. id_attr = pattern[0].attrs[0]
  98. if id_attr.attr != ID:
  99. raise ValueError(Errors.E074.format(attr=ID, bad_attr=id_attr.attr))
  100. return id_attr.value
  101. cdef int get_action(const TokenPatternC* pattern, const TokenC* token) nogil:
  102. lookahead = &pattern[1]
  103. for attr in pattern.attrs[:pattern.nr_attr]:
  104. if get_token_attr(token, attr.attr) != attr.value:
  105. if pattern.quantifier == ONE:
  106. return REJECT
  107. elif pattern.quantifier == ZERO:
  108. return ACCEPT if lookahead.nr_attr == 0 else ADVANCE
  109. elif pattern.quantifier in (ZERO_ONE, ZERO_PLUS):
  110. return ACCEPT_PREV if lookahead.nr_attr == 0 else ADVANCE_ZERO
  111. else:
  112. return PANIC
  113. if pattern.quantifier == ZERO:
  114. return REJECT
  115. elif lookahead.nr_attr == 0:
  116. return ACCEPT
  117. elif pattern.quantifier in (ONE, ZERO_ONE):
  118. return ADVANCE
  119. elif pattern.quantifier == ZERO_PLUS:
  120. # This is a bandaid over the 'shadowing' problem described here:
  121. # https://github.com/explosion/spaCy/issues/864
  122. next_action = get_action(lookahead, token)
  123. if next_action is REJECT:
  124. return REPEAT
  125. else:
  126. return ADVANCE_ZERO
  127. else:
  128. return PANIC
  129. def _convert_strings(token_specs, string_store):
  130. # Support 'syntactic sugar' operator '+', as combination of ONE, ZERO_PLUS
  131. operators = {'!': (ZERO,), '*': (ZERO_PLUS,), '+': (ONE, ZERO_PLUS),
  132. '?': (ZERO_ONE,), '1': (ONE,)}
  133. tokens = []
  134. op = ONE
  135. for spec in token_specs:
  136. if not spec:
  137. # Signifier for 'any token'
  138. tokens.append((ONE, [(NULL_ATTR, 0)]))
  139. continue
  140. token = []
  141. ops = (ONE,)
  142. for attr, value in spec.items():
  143. if isinstance(attr, basestring) and attr.upper() == 'OP':
  144. if value in operators:
  145. ops = operators[value]
  146. else:
  147. keys = ', '.join(operators.keys())
  148. raise KeyError(Errors.E011.format(op=value, opts=keys))
  149. if isinstance(attr, basestring):
  150. attr = IDS.get(attr.upper())
  151. if isinstance(value, basestring):
  152. value = string_store.add(value)
  153. if isinstance(value, bool):
  154. value = int(value)
  155. if attr is not None:
  156. token.append((attr, value))
  157. for op in ops:
  158. tokens.append((op, token))
  159. return tokens
  160. def merge_phrase(matcher, doc, i, matches):
  161. """Callback to merge a phrase on match."""
  162. ent_id, label, start, end = matches[i]
  163. span = doc[start:end]
  164. span.merge(ent_type=label, ent_id=ent_id)
  165. def unpickle_matcher(vocab, patterns, callbacks):
  166. matcher = Matcher(vocab)
  167. for key, specs in patterns.items():
  168. callback = callbacks.get(key, None)
  169. matcher.add(key, callback, *specs)
  170. return matcher
  171. cdef class Matcher:
  172. """Match sequences of tokens, based on pattern rules."""
  173. cdef Pool mem
  174. cdef vector[TokenPatternC*] patterns
  175. cdef readonly Vocab vocab
  176. cdef public object _patterns
  177. cdef public object _entities
  178. cdef public object _callbacks
  179. def __init__(self, vocab):
  180. """Create the Matcher.
  181. vocab (Vocab): The vocabulary object, which must be shared with the
  182. documents the matcher will operate on.
  183. RETURNS (Matcher): The newly constructed object.
  184. """
  185. self._patterns = {}
  186. self._entities = {}
  187. self._callbacks = {}
  188. self.vocab = vocab
  189. self.mem = Pool()
  190. def __reduce__(self):
  191. data = (self.vocab, self._patterns, self._callbacks)
  192. return (unpickle_matcher, data, None, None)
  193. def __len__(self):
  194. """Get the number of rules added to the matcher. Note that this only
  195. returns the number of rules (identical with the number of IDs), not the
  196. number of individual patterns.
  197. RETURNS (int): The number of rules.
  198. """
  199. return len(self._patterns)
  200. def __contains__(self, key):
  201. """Check whether the matcher contains rules for a match ID.
  202. key (unicode): The match ID.
  203. RETURNS (bool): Whether the matcher contains rules for this match ID.
  204. """
  205. return self._normalize_key(key) in self._patterns
  206. def add(self, key, on_match, *patterns):
  207. """Add a match-rule to the matcher. A match-rule consists of: an ID
  208. key, an on_match callback, and one or more patterns.
  209. If the key exists, the patterns are appended to the previous ones, and
  210. the previous on_match callback is replaced. The `on_match` callback
  211. will receive the arguments `(matcher, doc, i, matches)`. You can also
  212. set `on_match` to `None` to not perform any actions.
  213. A pattern consists of one or more `token_specs`, where a `token_spec`
  214. is a dictionary mapping attribute IDs to values, and optionally a
  215. quantifier operator under the key "op". The available quantifiers are:
  216. '!': Negate the pattern, by requiring it to match exactly 0 times.
  217. '?': Make the pattern optional, by allowing it to match 0 or 1 times.
  218. '+': Require the pattern to match 1 or more times.
  219. '*': Allow the pattern to zero or more times.
  220. The + and * operators are usually interpretted "greedily", i.e. longer
  221. matches are returned where possible. However, if you specify two '+'
  222. and '*' patterns in a row and their matches overlap, the first
  223. operator will behave non-greedily. This quirk in the semantics makes
  224. the matcher more efficient, by avoiding the need for back-tracking.
  225. key (unicode): The match ID.
  226. on_match (callable): Callback executed on match.
  227. *patterns (list): List of token descriptions.
  228. """
  229. for pattern in patterns:
  230. if len(pattern) == 0:
  231. raise ValueError(Errors.E012.format(key=key))
  232. key = self._normalize_key(key)
  233. for pattern in patterns:
  234. specs = _convert_strings(pattern, self.vocab.strings)
  235. self.patterns.push_back(init_pattern(self.mem, key, specs))
  236. self._patterns.setdefault(key, [])
  237. self._callbacks[key] = on_match
  238. self._patterns[key].extend(patterns)
  239. def remove(self, key):
  240. """Remove a rule from the matcher. A KeyError is raised if the key does
  241. not exist.
  242. key (unicode): The ID of the match rule.
  243. """
  244. key = self._normalize_key(key)
  245. self._patterns.pop(key)
  246. self._callbacks.pop(key)
  247. cdef int i = 0
  248. while i < self.patterns.size():
  249. pattern_key = get_pattern_key(self.patterns.at(i))
  250. if pattern_key == key:
  251. self.patterns.erase(self.patterns.begin()+i)
  252. else:
  253. i += 1
  254. def has_key(self, key):
  255. """Check whether the matcher has a rule with a given key.
  256. key (string or int): The key to check.
  257. RETURNS (bool): Whether the matcher has the rule.
  258. """
  259. key = self._normalize_key(key)
  260. return key in self._patterns
  261. def get(self, key, default=None):
  262. """Retrieve the pattern stored for a key.
  263. key (unicode or int): The key to retrieve.
  264. RETURNS (tuple): The rule, as an (on_match, patterns) tuple.
  265. """
  266. key = self._normalize_key(key)
  267. if key not in self._patterns:
  268. return default
  269. return (self._callbacks[key], self._patterns[key])
  270. def pipe(self, docs, batch_size=1000, n_threads=2):
  271. """Match a stream of documents, yielding them in turn.
  272. docs (iterable): A stream of documents.
  273. batch_size (int): Number of documents to accumulate into a working set.
  274. n_threads (int): The number of threads with which to work on the buffer
  275. in parallel, if the implementation supports multi-threading.
  276. YIELDS (Doc): Documents, in order.
  277. """
  278. for doc in docs:
  279. self(doc)
  280. yield doc
  281. def __call__(self, Doc doc):
  282. """Find all token sequences matching the supplied pattern.
  283. doc (Doc): The document to match over.
  284. RETURNS (list): A list of `(key, start, end)` tuples,
  285. describing the matches. A match tuple describes a span
  286. `doc[start:end]`. The `label_id` and `key` are both integers.
  287. """
  288. cdef vector[StateC] partials
  289. cdef int n_partials = 0
  290. cdef int q = 0
  291. cdef int i, token_i
  292. cdef const TokenC* token
  293. cdef StateC state
  294. matches = []
  295. for token_i in range(doc.length):
  296. token = &doc.c[token_i]
  297. q = 0
  298. # Go over the open matches, extending or finalizing if able.
  299. # Otherwise, we over-write them (q doesn't advance)
  300. for state in partials:
  301. action = get_action(state.second, token)
  302. if action == PANIC:
  303. raise ValueError(Errors.E013)
  304. while action == ADVANCE_ZERO:
  305. state.second += 1
  306. action = get_action(state.second, token)
  307. if action == PANIC:
  308. raise ValueError(Errors.E013)
  309. if action == REPEAT:
  310. # Leave the state in the queue, and advance to next slot
  311. # (i.e. we don't overwrite -- we want to greedily match
  312. # more pattern.
  313. q += 1
  314. elif action == REJECT:
  315. pass
  316. elif action == ADVANCE:
  317. partials[q] = state
  318. partials[q].second += 1
  319. q += 1
  320. elif action in (ACCEPT, ACCEPT_PREV):
  321. # TODO: What to do about patterns starting with ZERO? Need
  322. # to adjust the start position.
  323. start = state.first
  324. end = token_i+1 if action == ACCEPT else token_i
  325. ent_id = state.second[1].attrs[0].value
  326. label = state.second[1].attrs[1].value
  327. matches.append((ent_id, start, end))
  328. partials.resize(q)
  329. # Check whether we open any new patterns on this token
  330. for pattern in self.patterns:
  331. action = get_action(pattern, token)
  332. if action == PANIC:
  333. raise ValueError(Errors.E013)
  334. while action == ADVANCE_ZERO:
  335. pattern += 1
  336. action = get_action(pattern, token)
  337. if action == REPEAT:
  338. state.first = token_i
  339. state.second = pattern
  340. partials.push_back(state)
  341. elif action == ADVANCE:
  342. # TODO: What to do about patterns starting with ZERO? Need
  343. # to adjust the start position.
  344. state.first = token_i
  345. state.second = pattern + 1
  346. partials.push_back(state)
  347. elif action in (ACCEPT, ACCEPT_PREV):
  348. start = token_i
  349. end = token_i+1 if action == ACCEPT else token_i
  350. ent_id = pattern[1].attrs[0].value
  351. label = pattern[1].attrs[1].value
  352. matches.append((ent_id, start, end))
  353. # Look for open patterns that are actually satisfied
  354. for state in partials:
  355. while state.second.quantifier in (ZERO, ZERO_ONE, ZERO_PLUS):
  356. state.second += 1
  357. if state.second.nr_attr == 0:
  358. start = state.first
  359. end = len(doc)
  360. ent_id = state.second.attrs[0].value
  361. label = state.second.attrs[0].value
  362. matches.append((ent_id, start, end))
  363. for i, (ent_id, start, end) in enumerate(matches):
  364. on_match = self._callbacks.get(ent_id)
  365. if on_match is not None:
  366. on_match(self, doc, i, matches)
  367. return matches
  368. def _normalize_key(self, key):
  369. if isinstance(key, basestring):
  370. return self.vocab.strings.add(key)
  371. else:
  372. return key
  373. def get_bilou(length):
  374. if length == 1:
  375. return [U_ENT]
  376. elif length == 2:
  377. return [B2_ENT, L2_ENT]
  378. elif length == 3:
  379. return [B3_ENT, I3_ENT, L3_ENT]
  380. elif length == 4:
  381. return [B4_ENT, I4_ENT, I4_ENT, L4_ENT]
  382. elif length == 5:
  383. return [B5_ENT, I5_ENT, I5_ENT, I5_ENT, L5_ENT]
  384. elif length == 6:
  385. return [B6_ENT, I6_ENT, I6_ENT, I6_ENT, I6_ENT, L6_ENT]
  386. elif length == 7:
  387. return [B7_ENT, I7_ENT, I7_ENT, I7_ENT, I7_ENT, I7_ENT, L7_ENT]
  388. elif length == 8:
  389. return [B8_ENT, I8_ENT, I8_ENT, I8_ENT, I8_ENT, I8_ENT, I8_ENT, L8_ENT]
  390. elif length == 9:
  391. return [B9_ENT, I9_ENT, I9_ENT, I9_ENT, I9_ENT, I9_ENT, I9_ENT, I9_ENT,
  392. L9_ENT]
  393. elif length == 10:
  394. return [B10_ENT, I10_ENT, I10_ENT, I10_ENT, I10_ENT, I10_ENT, I10_ENT,
  395. I10_ENT, I10_ENT, L10_ENT]
  396. else:
  397. raise ValueError(TempErrors.T001)
  398. cdef class PhraseMatcher:
  399. cdef Pool mem
  400. cdef Vocab vocab
  401. cdef Matcher matcher
  402. cdef PreshMap phrase_ids
  403. cdef int max_length
  404. cdef attr_t* _phrase_key
  405. cdef public object _callbacks
  406. cdef public object _patterns
  407. def __init__(self, Vocab vocab, max_length=10):
  408. self.mem = Pool()
  409. self._phrase_key = <attr_t*>self.mem.alloc(max_length, sizeof(attr_t))
  410. self.max_length = max_length
  411. self.vocab = vocab
  412. self.matcher = Matcher(self.vocab)
  413. self.phrase_ids = PreshMap()
  414. abstract_patterns = []
  415. for length in range(1, max_length):
  416. abstract_patterns.append([{tag: True}
  417. for tag in get_bilou(length)])
  418. self.matcher.add('Candidate', None, *abstract_patterns)
  419. self._callbacks = {}
  420. def __len__(self):
  421. """Get the number of rules added to the matcher. Note that this only
  422. returns the number of rules (identical with the number of IDs), not the
  423. number of individual patterns.
  424. RETURNS (int): The number of rules.
  425. """
  426. return len(self.phrase_ids)
  427. def __contains__(self, key):
  428. """Check whether the matcher contains rules for a match ID.
  429. key (unicode): The match ID.
  430. RETURNS (bool): Whether the matcher contains rules for this match ID.
  431. """
  432. cdef hash_t ent_id = self.matcher._normalize_key(key)
  433. return ent_id in self._callbacks
  434. def __reduce__(self):
  435. return (self.__class__, (self.vocab,), None, None)
  436. def add(self, key, on_match, *docs):
  437. """Add a match-rule to the matcher. A match-rule consists of: an ID
  438. key, an on_match callback, and one or more patterns.
  439. key (unicode): The match ID.
  440. on_match (callable): Callback executed on match.
  441. *docs (Doc): `Doc` objects representing match patterns.
  442. """
  443. cdef Doc doc
  444. for doc in docs:
  445. if len(doc) >= self.max_length:
  446. raise ValueError(TempErrors.T002.format(doc_len=len(doc),
  447. max_len=self.max_length))
  448. cdef hash_t ent_id = self.matcher._normalize_key(key)
  449. self._callbacks[ent_id] = on_match
  450. cdef int length
  451. cdef int i
  452. cdef hash_t phrase_hash
  453. for doc in docs:
  454. length = doc.length
  455. tags = get_bilou(length)
  456. for i in range(self.max_length):
  457. self._phrase_key[i] = 0
  458. for i, tag in enumerate(tags):
  459. lexeme = self.vocab[doc.c[i].lex.orth]
  460. lexeme.set_flag(tag, True)
  461. self._phrase_key[i] = lexeme.orth
  462. phrase_hash = hash64(self._phrase_key,
  463. self.max_length * sizeof(attr_t), 0)
  464. self.phrase_ids.set(phrase_hash, <void*>ent_id)
  465. def __call__(self, Doc doc):
  466. """Find all sequences matching the supplied patterns on the `Doc`.
  467. doc (Doc): The document to match over.
  468. RETURNS (list): A list of `(key, start, end)` tuples,
  469. describing the matches. A match tuple describes a span
  470. `doc[start:end]`. The `label_id` and `key` are both integers.
  471. """
  472. matches = []
  473. for _, start, end in self.matcher(doc):
  474. ent_id = self.accept_match(doc, start, end)
  475. if ent_id is not None:
  476. matches.append((ent_id, start, end))
  477. for i, (ent_id, start, end) in enumerate(matches):
  478. on_match = self._callbacks.get(ent_id)
  479. if on_match is not None:
  480. on_match(self, doc, i, matches)
  481. return matches
  482. def pipe(self, stream, batch_size=1000, n_threads=2):
  483. """Match a stream of documents, yielding them in turn.
  484. docs (iterable): A stream of documents.
  485. batch_size (int): Number of documents to accumulate into a working set.
  486. n_threads (int): The number of threads with which to work on the buffer
  487. in parallel, if the implementation supports multi-threading.
  488. YIELDS (Doc): Documents, in order.
  489. """
  490. for doc in stream:
  491. self(doc)
  492. yield doc
  493. def accept_match(self, Doc doc, int start, int end):
  494. if (end - start) >= self.max_length:
  495. raise ValueError(Errors.E075.format(length=end - start,
  496. max_len=self.max_length))
  497. cdef int i, j
  498. for i in range(self.max_length):
  499. self._phrase_key[i] = 0
  500. for i, j in enumerate(range(start, end)):
  501. self._phrase_key[i] = doc.c[j].lex.orth
  502. cdef hash_t key = hash64(self._phrase_key,
  503. self.max_length * sizeof(attr_t), 0)
  504. ent_id = <hash_t>self.phrase_ids.get(key)
  505. if ent_id == 0:
  506. return None
  507. else:
  508. return ent_id
  509. cdef class DependencyTreeMatcher:
  510. """Match dependency parse tree based on pattern rules."""
  511. cdef Pool mem
  512. cdef readonly Vocab vocab
  513. cdef readonly Matcher token_matcher
  514. cdef public object _patterns
  515. cdef public object _keys_to_token
  516. cdef public object _root
  517. cdef public object _entities
  518. cdef public object _callbacks
  519. cdef public object _nodes
  520. cdef public object _tree
  521. def __init__(self, vocab):
  522. """Create the DependencyTreeMatcher.
  523. vocab (Vocab): The vocabulary object, which must be shared with the
  524. documents the matcher will operate on.
  525. RETURNS (DependencyTreeMatcher): The newly constructed object.
  526. """
  527. size = 20
  528. self.token_matcher = Matcher(vocab)
  529. self._keys_to_token = {}
  530. self._patterns = {}
  531. self._root = {}
  532. self._nodes = {}
  533. self._tree = {}
  534. self._entities = {}
  535. self._callbacks = {}
  536. self.vocab = vocab
  537. self.mem = Pool()
  538. def __reduce__(self):
  539. data = (self.vocab, self._patterns,self._tree, self._callbacks)
  540. return (unpickle_matcher, data, None, None)
  541. def __len__(self):
  542. """Get the number of rules, which are edges ,added to the dependency tree matcher.
  543. RETURNS (int): The number of rules.
  544. """
  545. return len(self._patterns)
  546. def __contains__(self, key):
  547. """Check whether the matcher contains rules for a match ID.
  548. key (unicode): The match ID.
  549. RETURNS (bool): Whether the matcher contains rules for this match ID.
  550. """
  551. return self._normalize_key(key) in self._patterns
  552. def add(self, key, on_match, *patterns):
  553. # TODO : validations
  554. # 1. check if input pattern is connected
  555. # 2. check if pattern format is correct
  556. # 3. check if atleast one root node is present
  557. # 4. check if node names are not repeated
  558. # 5. check if each node has only one head
  559. for pattern in patterns:
  560. if len(pattern) == 0:
  561. raise ValueError(Errors.E012.format(key=key))
  562. key = self._normalize_key(key)
  563. _patterns = []
  564. for pattern in patterns:
  565. token_patterns = []
  566. for i in range(len(pattern)):
  567. token_pattern = [pattern[i]['PATTERN']]
  568. token_patterns.append(token_pattern)
  569. # self.patterns.append(token_patterns)
  570. _patterns.append(token_patterns)
  571. self._patterns.setdefault(key, [])
  572. self._callbacks[key] = on_match
  573. self._patterns[key].extend(_patterns)
  574. # Add each node pattern of all the input patterns individually to the matcher.
  575. # This enables only a single instance of Matcher to be used.
  576. # Multiple adds are required to track each node pattern.
  577. _keys_to_token_list = []
  578. for i in range(len(_patterns)):
  579. _keys_to_token = {}
  580. # TODO : Better ways to hash edges in pattern?
  581. for j in range(len(_patterns[i])):
  582. k = self._normalize_key(unicode(key)+DELIMITER+unicode(i)+DELIMITER+unicode(j))
  583. self.token_matcher.add(k,None,_patterns[i][j])
  584. _keys_to_token[k] = j
  585. _keys_to_token_list.append(_keys_to_token)
  586. self._keys_to_token.setdefault(key, [])
  587. self._keys_to_token[key].extend(_keys_to_token_list)
  588. _nodes_list = []
  589. for pattern in patterns:
  590. nodes = {}
  591. for i in range(len(pattern)):
  592. nodes[pattern[i]['SPEC']['NODE_NAME']]=i
  593. _nodes_list.append(nodes)
  594. self._nodes.setdefault(key, [])
  595. self._nodes[key].extend(_nodes_list)
  596. # Create an object tree to traverse later on.
  597. # This datastructure enable easy tree pattern match.
  598. # Doc-Token based tree cannot be reused since it is memory heavy and tightly coupled with doc
  599. self.retrieve_tree(patterns,_nodes_list,key)
  600. def retrieve_tree(self,patterns,_nodes_list,key):
  601. _heads_list = []
  602. _root_list = []
  603. for i in range(len(patterns)):
  604. heads = {}
  605. root = -1
  606. for j in range(len(patterns[i])):
  607. token_pattern = patterns[i][j]
  608. if('NBOR_RELOP' not in token_pattern['SPEC']):
  609. heads[j] = j
  610. root = j
  611. else:
  612. # TODO: Add semgrex rules
  613. # 1. >
  614. if(token_pattern['SPEC']['NBOR_RELOP'] == '>'):
  615. heads[j] = _nodes_list[i][token_pattern['SPEC']['NBOR_NAME']]
  616. # 2. <
  617. if(token_pattern['SPEC']['NBOR_RELOP'] == '<'):
  618. heads[_nodes_list[i][token_pattern['SPEC']['NBOR_NAME']]] = j
  619. _heads_list.append(heads)
  620. _root_list.append(root)
  621. _tree_list = []
  622. for i in range(len(patterns)):
  623. tree = {}
  624. for j in range(len(patterns[i])):
  625. if(j == _heads_list[i][j]):
  626. continue
  627. head = _heads_list[i][j]
  628. if(head not in tree):
  629. tree[head] = []
  630. tree[head].append(j)
  631. _tree_list.append(tree)
  632. self._tree.setdefault(key, [])
  633. self._tree[key].extend(_tree_list)
  634. self._root.setdefault(key, [])
  635. self._root[key].extend(_root_list)
  636. def has_key(self, key):
  637. """Check whether the matcher has a rule with a given key.
  638. key (string or int): The key to check.
  639. RETURNS (bool): Whether the matcher has the rule.
  640. """
  641. key = self._normalize_key(key)
  642. return key in self._patterns
  643. def get(self, key, default=None):
  644. """Retrieve the pattern stored for a key.
  645. key (unicode or int): The key to retrieve.
  646. RETURNS (tuple): The rule, as an (on_match, patterns) tuple.
  647. """
  648. key = self._normalize_key(key)
  649. if key not in self._patterns:
  650. return default
  651. return (self._callbacks[key], self._patterns[key])
  652. def __call__(self, Doc doc):
  653. matched_trees = []
  654. matches = self.token_matcher(doc)
  655. for key in list(self._patterns.keys()):
  656. _patterns_list = self._patterns[key]
  657. _keys_to_token_list = self._keys_to_token[key]
  658. _root_list = self._root[key]
  659. _tree_list = self._tree[key]
  660. _nodes_list = self._nodes[key]
  661. length = len(_patterns_list)
  662. for i in range(length):
  663. _keys_to_token = _keys_to_token_list[i]
  664. _root = _root_list[i]
  665. _tree = _tree_list[i]
  666. _nodes = _nodes_list[i]
  667. id_to_position = {}
  668. # This could be taken outside to improve running time..?
  669. for match_id, start, end in matches:
  670. if match_id in _keys_to_token:
  671. if _keys_to_token[match_id] not in id_to_position:
  672. id_to_position[_keys_to_token[match_id]] = []
  673. id_to_position[_keys_to_token[match_id]].append(start)
  674. length = len(_nodes)
  675. if _root in id_to_position:
  676. candidates = id_to_position[_root]
  677. for candidate in candidates:
  678. isVisited = {}
  679. self.dfs(candidate,_root,_tree,id_to_position,doc,isVisited)
  680. # to check if the subtree pattern is completely identified
  681. if(len(isVisited) == length):
  682. matched_trees.append((key,list(isVisited)))
  683. for i, (ent_id, nodes) in enumerate(matched_trees):
  684. on_match = self._callbacks.get(ent_id)
  685. if on_match is not None:
  686. on_match(self, doc, i, matches)
  687. return matched_trees
  688. def dfs(self,candidate,root,tree,id_to_position,doc,isVisited):
  689. if(root in id_to_position and candidate in id_to_position[root]):
  690. # color the node since it is valid
  691. isVisited[candidate] = True
  692. candidate_children = doc[candidate].children
  693. for candidate_child in candidate_children:
  694. if root in tree:
  695. for root_child in tree[root]:
  696. self.dfs(candidate_child.i,root_child,tree,id_to_position,doc,isVisited)
  697. def _normalize_key(self, key):
  698. if isinstance(key, basestring):
  699. return self.vocab.strings.add(key)
  700. else:
  701. return key