You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

570 lines
20 KiB

4 years ago
  1. # cython: profile=True
  2. # coding: utf8
  3. from __future__ import unicode_literals, print_function
  4. import re
  5. import ujson
  6. import random
  7. import cytoolz
  8. import itertools
  9. from .syntax import nonproj
  10. from .tokens import Doc
  11. from .errors import Errors
  12. from . import util
  13. from .util import minibatch
  14. def tags_to_entities(tags):
  15. entities = []
  16. start = None
  17. for i, tag in enumerate(tags):
  18. if tag is None:
  19. continue
  20. if tag.startswith('O'):
  21. # TODO: We shouldn't be getting these malformed inputs. Fix this.
  22. if start is not None:
  23. start = None
  24. continue
  25. elif tag == '-':
  26. continue
  27. elif tag.startswith('I'):
  28. if start is None:
  29. raise ValueError(Errors.E067.format(tags=tags[:i+1]))
  30. continue
  31. if tag.startswith('U'):
  32. entities.append((tag[2:], i, i))
  33. elif tag.startswith('B'):
  34. start = i
  35. elif tag.startswith('L'):
  36. entities.append((tag[2:], start, i))
  37. start = None
  38. else:
  39. raise ValueError(Errors.E068.format(tag=tag))
  40. return entities
  41. def merge_sents(sents):
  42. m_deps = [[], [], [], [], [], []]
  43. m_brackets = []
  44. i = 0
  45. for (ids, words, tags, heads, labels, ner), brackets in sents:
  46. m_deps[0].extend(id_ + i for id_ in ids)
  47. m_deps[1].extend(words)
  48. m_deps[2].extend(tags)
  49. m_deps[3].extend(head + i for head in heads)
  50. m_deps[4].extend(labels)
  51. m_deps[5].extend(ner)
  52. m_brackets.extend((b['first'] + i, b['last'] + i, b['label'])
  53. for b in brackets)
  54. i += len(ids)
  55. return [(m_deps, m_brackets)]
  56. def align(cand_words, gold_words):
  57. cost, edit_path = _min_edit_path(cand_words, gold_words)
  58. alignment = []
  59. i_of_gold = 0
  60. for move in edit_path:
  61. if move == 'M':
  62. alignment.append(i_of_gold)
  63. i_of_gold += 1
  64. elif move == 'S':
  65. alignment.append(None)
  66. i_of_gold += 1
  67. elif move == 'D':
  68. alignment.append(None)
  69. elif move == 'I':
  70. i_of_gold += 1
  71. else:
  72. raise Exception(move)
  73. return alignment
  74. punct_re = re.compile(r'\W')
  75. def _min_edit_path(cand_words, gold_words):
  76. cdef:
  77. Pool mem
  78. int i, j, n_cand, n_gold
  79. int* curr_costs
  80. int* prev_costs
  81. # TODO: Fix this --- just do it properly, make the full edit matrix and
  82. # then walk back over it...
  83. # Preprocess inputs
  84. cand_words = [punct_re.sub('', w).lower() for w in cand_words]
  85. gold_words = [punct_re.sub('', w).lower() for w in gold_words]
  86. if cand_words == gold_words:
  87. return 0, ''.join(['M' for _ in gold_words])
  88. mem = Pool()
  89. n_cand = len(cand_words)
  90. n_gold = len(gold_words)
  91. # Levenshtein distance, except we need the history, and we may want
  92. # different costs. Mark operations with a string, and score the history
  93. # using _edit_cost.
  94. previous_row = []
  95. prev_costs = <int*>mem.alloc(n_gold + 1, sizeof(int))
  96. curr_costs = <int*>mem.alloc(n_gold + 1, sizeof(int))
  97. for i in range(n_gold + 1):
  98. cell = ''
  99. for j in range(i):
  100. cell += 'I'
  101. previous_row.append('I' * i)
  102. prev_costs[i] = i
  103. for i, cand in enumerate(cand_words):
  104. current_row = ['D' * (i + 1)]
  105. curr_costs[0] = i+1
  106. for j, gold in enumerate(gold_words):
  107. if gold.lower() == cand.lower():
  108. s_cost = prev_costs[j]
  109. i_cost = curr_costs[j] + 1
  110. d_cost = prev_costs[j + 1] + 1
  111. else:
  112. s_cost = prev_costs[j] + 1
  113. i_cost = curr_costs[j] + 1
  114. d_cost = prev_costs[j + 1] + (1 if cand else 0)
  115. if s_cost <= i_cost and s_cost <= d_cost:
  116. best_cost = s_cost
  117. best_hist = previous_row[j] + ('M' if gold == cand else 'S')
  118. elif i_cost <= s_cost and i_cost <= d_cost:
  119. best_cost = i_cost
  120. best_hist = current_row[j] + 'I'
  121. else:
  122. best_cost = d_cost
  123. best_hist = previous_row[j + 1] + 'D'
  124. current_row.append(best_hist)
  125. curr_costs[j+1] = best_cost
  126. previous_row = current_row
  127. for j in range(len(gold_words) + 1):
  128. prev_costs[j] = curr_costs[j]
  129. curr_costs[j] = 0
  130. return prev_costs[n_gold], previous_row[-1]
  131. class GoldCorpus(object):
  132. """An annotated corpus, using the JSON file format. Manages
  133. annotations for tagging, dependency parsing and NER."""
  134. def __init__(self, train_path, dev_path, gold_preproc=True, limit=None):
  135. """Create a GoldCorpus.
  136. train_path (unicode or Path): File or directory of training data.
  137. dev_path (unicode or Path): File or directory of development data.
  138. RETURNS (GoldCorpus): The newly created object.
  139. """
  140. self.train_path = util.ensure_path(train_path)
  141. self.dev_path = util.ensure_path(dev_path)
  142. self.limit = limit
  143. self.train_locs = self.walk_corpus(self.train_path)
  144. self.dev_locs = self.walk_corpus(self.dev_path)
  145. @property
  146. def train_tuples(self):
  147. i = 0
  148. for loc in self.train_locs:
  149. gold_tuples = read_json_file(loc)
  150. for item in gold_tuples:
  151. yield item
  152. i += len(item[1])
  153. if self.limit and i >= self.limit:
  154. break
  155. @property
  156. def dev_tuples(self):
  157. i = 0
  158. for loc in self.dev_locs:
  159. gold_tuples = read_json_file(loc)
  160. for item in gold_tuples:
  161. yield item
  162. i += len(item[1])
  163. if self.limit and i >= self.limit:
  164. break
  165. def count_train(self):
  166. n = 0
  167. i = 0
  168. for raw_text, paragraph_tuples in self.train_tuples:
  169. n += sum([len(s[0][1]) for s in paragraph_tuples])
  170. if self.limit and i >= self.limit:
  171. break
  172. i += len(paragraph_tuples)
  173. return n
  174. def train_docs(self, nlp, gold_preproc=False,
  175. projectivize=False, max_length=None,
  176. noise_level=0.0):
  177. train_tuples = self.train_tuples
  178. if projectivize:
  179. train_tuples = nonproj.preprocess_training_data(
  180. self.train_tuples, label_freq_cutoff=100)
  181. random.shuffle(train_tuples)
  182. gold_docs = self.iter_gold_docs(nlp, train_tuples, gold_preproc,
  183. max_length=max_length,
  184. noise_level=noise_level)
  185. yield from gold_docs
  186. def dev_docs(self, nlp, gold_preproc=False):
  187. gold_docs = self.iter_gold_docs(nlp, self.dev_tuples, gold_preproc)
  188. yield from gold_docs
  189. @classmethod
  190. def iter_gold_docs(cls, nlp, tuples, gold_preproc, max_length=None,
  191. noise_level=0.0):
  192. for raw_text, paragraph_tuples in tuples:
  193. if gold_preproc:
  194. raw_text = None
  195. else:
  196. paragraph_tuples = merge_sents(paragraph_tuples)
  197. docs = cls._make_docs(nlp, raw_text, paragraph_tuples,
  198. gold_preproc, noise_level=noise_level)
  199. golds = cls._make_golds(docs, paragraph_tuples)
  200. for doc, gold in zip(docs, golds):
  201. if (not max_length) or len(doc) < max_length:
  202. yield doc, gold
  203. @classmethod
  204. def _make_docs(cls, nlp, raw_text, paragraph_tuples, gold_preproc,
  205. noise_level=0.0):
  206. if raw_text is not None:
  207. raw_text = add_noise(raw_text, noise_level)
  208. return [nlp.make_doc(raw_text)]
  209. else:
  210. return [Doc(nlp.vocab,
  211. words=add_noise(sent_tuples[1], noise_level))
  212. for (sent_tuples, brackets) in paragraph_tuples]
  213. @classmethod
  214. def _make_golds(cls, docs, paragraph_tuples):
  215. if len(docs) != len(paragraph_tuples):
  216. raise ValueError(Errors.E070.format(n_docs=len(docs),
  217. n_annots=len(paragraph_tuples)))
  218. if len(docs) == 1:
  219. return [GoldParse.from_annot_tuples(docs[0],
  220. paragraph_tuples[0][0])]
  221. else:
  222. return [GoldParse.from_annot_tuples(doc, sent_tuples)
  223. for doc, (sent_tuples, brackets)
  224. in zip(docs, paragraph_tuples)]
  225. @staticmethod
  226. def walk_corpus(path):
  227. if not path.is_dir():
  228. return [path]
  229. paths = [path]
  230. locs = []
  231. seen = set()
  232. for path in paths:
  233. if str(path) in seen:
  234. continue
  235. seen.add(str(path))
  236. if path.parts[-1].startswith('.'):
  237. continue
  238. elif path.is_dir():
  239. paths.extend(path.iterdir())
  240. elif path.parts[-1].endswith('.json'):
  241. locs.append(path)
  242. return locs
  243. def add_noise(orig, noise_level):
  244. if random.random() >= noise_level:
  245. return orig
  246. elif type(orig) == list:
  247. corrupted = [_corrupt(word, noise_level) for word in orig]
  248. corrupted = [w for w in corrupted if w]
  249. return corrupted
  250. else:
  251. return ''.join(_corrupt(c, noise_level) for c in orig)
  252. def _corrupt(c, noise_level):
  253. if random.random() >= noise_level:
  254. return c
  255. elif c == ' ':
  256. return '\n'
  257. elif c == '\n':
  258. return ' '
  259. elif c in ['.', "'", "!", "?"]:
  260. return ''
  261. else:
  262. return c.lower()
  263. def read_json_file(loc, docs_filter=None, limit=None):
  264. loc = util.ensure_path(loc)
  265. if loc.is_dir():
  266. for filename in loc.iterdir():
  267. yield from read_json_file(loc / filename, limit=limit)
  268. else:
  269. with loc.open('r', encoding='utf8') as file_:
  270. docs = ujson.load(file_)
  271. if limit is not None:
  272. docs = docs[:limit]
  273. for doc in docs:
  274. if docs_filter is not None and not docs_filter(doc):
  275. continue
  276. paragraphs = []
  277. for paragraph in doc['paragraphs']:
  278. sents = []
  279. for sent in paragraph['sentences']:
  280. words = []
  281. ids = []
  282. tags = []
  283. heads = []
  284. labels = []
  285. ner = []
  286. for i, token in enumerate(sent['tokens']):
  287. words.append(token['orth'])
  288. ids.append(i)
  289. tags.append(token.get('tag', '-'))
  290. heads.append(token.get('head', 0) + i)
  291. labels.append(token.get('dep', ''))
  292. # Ensure ROOT label is case-insensitive
  293. if labels[-1].lower() == 'root':
  294. labels[-1] = 'ROOT'
  295. ner.append(token.get('ner', '-'))
  296. sents.append([
  297. [ids, words, tags, heads, labels, ner],
  298. sent.get('brackets', [])])
  299. if sents:
  300. yield [paragraph.get('raw', None), sents]
  301. def iob_to_biluo(tags):
  302. out = []
  303. curr_label = None
  304. tags = list(tags)
  305. while tags:
  306. out.extend(_consume_os(tags))
  307. out.extend(_consume_ent(tags))
  308. return out
  309. def _consume_os(tags):
  310. while tags and tags[0] == 'O':
  311. yield tags.pop(0)
  312. def _consume_ent(tags):
  313. if not tags:
  314. return []
  315. tag = tags.pop(0)
  316. target_in = 'I' + tag[1:]
  317. target_last = 'L' + tag[1:]
  318. length = 1
  319. while tags and tags[0] in {target_in, target_last}:
  320. length += 1
  321. tags.pop(0)
  322. label = tag[2:]
  323. if length == 1:
  324. return ['U-' + label]
  325. else:
  326. start = 'B-' + label
  327. end = 'L-' + label
  328. middle = ['I-%s' % label for _ in range(1, length - 1)]
  329. return [start] + middle + [end]
  330. cdef class GoldParse:
  331. """Collection for training annotations."""
  332. @classmethod
  333. def from_annot_tuples(cls, doc, annot_tuples, make_projective=False):
  334. _, words, tags, heads, deps, entities = annot_tuples
  335. return cls(doc, words=words, tags=tags, heads=heads, deps=deps,
  336. entities=entities, make_projective=make_projective)
  337. def __init__(self, doc, annot_tuples=None, words=None, tags=None,
  338. heads=None, deps=None, entities=None, make_projective=False,
  339. cats=None):
  340. """Create a GoldParse.
  341. doc (Doc): The document the annotations refer to.
  342. words (iterable): A sequence of unicode word strings.
  343. tags (iterable): A sequence of strings, representing tag annotations.
  344. heads (iterable): A sequence of integers, representing syntactic
  345. head offsets.
  346. deps (iterable): A sequence of strings, representing the syntactic
  347. relation types.
  348. entities (iterable): A sequence of named entity annotations, either as
  349. BILUO tag strings, or as `(start_char, end_char, label)` tuples,
  350. representing the entity positions.
  351. cats (dict): Labels for text classification. Each key in the dictionary
  352. may be a string or an int, or a `(start_char, end_char, label)`
  353. tuple, indicating that the label is applied to only part of the
  354. document (usually a sentence). Unlike entity annotations, label
  355. annotations can overlap, i.e. a single word can be covered by
  356. multiple labelled spans. The TextCategorizer component expects
  357. true examples of a label to have the value 1.0, and negative
  358. examples of a label to have the value 0.0. Labels not in the
  359. dictionary are treated as missing - the gradient for those labels
  360. will be zero.
  361. RETURNS (GoldParse): The newly constructed object.
  362. """
  363. if words is None:
  364. words = [token.text for token in doc]
  365. if tags is None:
  366. tags = [None for _ in doc]
  367. if heads is None:
  368. heads = [None for token in doc]
  369. if deps is None:
  370. deps = [None for _ in doc]
  371. if entities is None:
  372. entities = [None for _ in doc]
  373. elif len(entities) == 0:
  374. entities = ['O' for _ in doc]
  375. elif not isinstance(entities[0], basestring):
  376. # Assume we have entities specified by character offset.
  377. entities = biluo_tags_from_offsets(doc, entities)
  378. self.mem = Pool()
  379. self.loss = 0
  380. self.length = len(doc)
  381. # These are filled by the tagger/parser/entity recogniser
  382. self.c.tags = <int*>self.mem.alloc(len(doc), sizeof(int))
  383. self.c.heads = <int*>self.mem.alloc(len(doc), sizeof(int))
  384. self.c.labels = <attr_t*>self.mem.alloc(len(doc), sizeof(attr_t))
  385. self.c.has_dep = <int*>self.mem.alloc(len(doc), sizeof(int))
  386. self.c.sent_start = <int*>self.mem.alloc(len(doc), sizeof(int))
  387. self.c.ner = <Transition*>self.mem.alloc(len(doc), sizeof(Transition))
  388. self.cats = {} if cats is None else dict(cats)
  389. self.words = [None] * len(doc)
  390. self.tags = [None] * len(doc)
  391. self.heads = [None] * len(doc)
  392. self.labels = [None] * len(doc)
  393. self.ner = [None] * len(doc)
  394. self.cand_to_gold = align([t.orth_ for t in doc], words)
  395. self.gold_to_cand = align(words, [t.orth_ for t in doc])
  396. annot_tuples = (range(len(words)), words, tags, heads, deps, entities)
  397. self.orig_annot = list(zip(*annot_tuples))
  398. for i, gold_i in enumerate(self.cand_to_gold):
  399. if doc[i].text.isspace():
  400. self.words[i] = doc[i].text
  401. self.tags[i] = '_SP'
  402. self.heads[i] = None
  403. self.labels[i] = None
  404. self.ner[i] = 'O'
  405. if gold_i is None:
  406. pass
  407. else:
  408. self.words[i] = words[gold_i]
  409. self.tags[i] = tags[gold_i]
  410. if heads[gold_i] is None:
  411. self.heads[i] = None
  412. else:
  413. self.heads[i] = self.gold_to_cand[heads[gold_i]]
  414. self.labels[i] = deps[gold_i]
  415. self.ner[i] = entities[gold_i]
  416. cycle = nonproj.contains_cycle(self.heads)
  417. if cycle is not None:
  418. raise ValueError(Errors.E069.format(cycle=cycle))
  419. if make_projective:
  420. proj_heads, _ = nonproj.projectivize(self.heads, self.labels)
  421. self.heads = proj_heads
  422. def __len__(self):
  423. """Get the number of gold-standard tokens.
  424. RETURNS (int): The number of gold-standard tokens.
  425. """
  426. return self.length
  427. @property
  428. def is_projective(self):
  429. """Whether the provided syntactic annotations form a projective
  430. dependency tree.
  431. """
  432. return not nonproj.is_nonproj_tree(self.heads)
  433. @property
  434. def sent_starts(self):
  435. return [self.c.sent_start[i] for i in range(self.length)]
  436. def biluo_tags_from_offsets(doc, entities, missing='O'):
  437. """Encode labelled spans into per-token tags, using the
  438. Begin/In/Last/Unit/Out scheme (BILUO).
  439. doc (Doc): The document that the entity offsets refer to. The output tags
  440. will refer to the token boundaries within the document.
  441. entities (iterable): A sequence of `(start, end, label)` triples. `start`
  442. and `end` should be character-offset integers denoting the slice into
  443. the original string.
  444. RETURNS (list): A list of unicode strings, describing the tags. Each tag
  445. string will be of the form either "", "O" or "{action}-{label}", where
  446. action is one of "B", "I", "L", "U". The string "-" is used where the
  447. entity offsets don't align with the tokenization in the `Doc` object.
  448. The training algorithm will view these as missing values. "O" denotes a
  449. non-entity token. "B" denotes the beginning of a multi-token entity,
  450. "I" the inside of an entity of three or more tokens, and "L" the end
  451. of an entity of two or more tokens. "U" denotes a single-token entity.
  452. EXAMPLE:
  453. >>> text = 'I like London.'
  454. >>> entities = [(len('I like '), len('I like London'), 'LOC')]
  455. >>> doc = nlp.tokenizer(text)
  456. >>> tags = biluo_tags_from_offsets(doc, entities)
  457. >>> assert tags == ['O', 'O', 'U-LOC', 'O']
  458. """
  459. starts = {token.idx: token.i for token in doc}
  460. ends = {token.idx+len(token): token.i for token in doc}
  461. biluo = ['-' for _ in doc]
  462. # Handle entity cases
  463. for start_char, end_char, label in entities:
  464. start_token = starts.get(start_char)
  465. end_token = ends.get(end_char)
  466. # Only interested if the tokenization is correct
  467. if start_token is not None and end_token is not None:
  468. if start_token == end_token:
  469. biluo[start_token] = 'U-%s' % label
  470. else:
  471. biluo[start_token] = 'B-%s' % label
  472. for i in range(start_token+1, end_token):
  473. biluo[i] = 'I-%s' % label
  474. biluo[end_token] = 'L-%s' % label
  475. # Now distinguish the O cases from ones where we miss the tokenization
  476. entity_chars = set()
  477. for start_char, end_char, label in entities:
  478. for i in range(start_char, end_char):
  479. entity_chars.add(i)
  480. for token in doc:
  481. for i in range(token.idx, token.idx+len(token)):
  482. if i in entity_chars:
  483. break
  484. else:
  485. biluo[token.i] = missing
  486. return biluo
  487. def offsets_from_biluo_tags(doc, tags):
  488. """Encode per-token tags following the BILUO scheme into entity offsets.
  489. doc (Doc): The document that the BILUO tags refer to.
  490. entities (iterable): A sequence of BILUO tags with each tag describing one
  491. token. Each tags string will be of the form of either "", "O" or
  492. "{action}-{label}", where action is one of "B", "I", "L", "U".
  493. RETURNS (list): A sequence of `(start, end, label)` triples. `start` and
  494. `end` will be character-offset integers denoting the slice into the
  495. original string.
  496. """
  497. token_offsets = tags_to_entities(tags)
  498. offsets = []
  499. for label, start_idx, end_idx in token_offsets:
  500. span = doc[start_idx : end_idx + 1]
  501. offsets.append((span.start_char, span.end_char, label))
  502. return offsets
  503. def is_punct_label(label):
  504. return label == 'P' or label.lower() == 'punct'