You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

457 lines
14 KiB

4 years ago
  1. """
  2. This module is an API for downloading, getting information and loading datasets/models.
  3. Give information about available models/datasets:
  4. >>> import gensim.downloader as api
  5. >>>
  6. >>> api.info() # return dict with info about available models/datasets
  7. >>> api.info("text8") # return dict with info about "text8" dataset
  8. Model example:
  9. >>> import gensim.downloader as api
  10. >>>
  11. >>> model = api.load("glove-twitter-25") # load glove vectors
  12. >>> model.most_similar("cat") # show words that similar to word 'cat'
  13. Dataset example:
  14. >>> import gensim.downloader as api
  15. >>> from gensim.models import Word2Vec
  16. >>>
  17. >>> dataset = api.load("text8") # load dataset as iterable
  18. >>> model = Word2Vec(dataset) # train w2v model
  19. Also, this API available via CLI::
  20. python -m gensim.downloader --info <dataname> # same as api.info(dataname)
  21. python -m gensim.downloader --info name # same as api.info(name_only=True)
  22. python -m gensim.downloader --download <dataname> # same as api.load(dataname, return_path=True)
  23. """
  24. from __future__ import absolute_import
  25. import argparse
  26. import os
  27. import json
  28. import logging
  29. import sys
  30. import errno
  31. import hashlib
  32. import math
  33. import shutil
  34. import tempfile
  35. from functools import partial
  36. if sys.version_info[0] == 2:
  37. import urllib
  38. from urllib2 import urlopen
  39. else:
  40. import urllib.request as urllib
  41. from urllib.request import urlopen
  42. user_dir = os.path.expanduser('~')
  43. base_dir = os.path.join(user_dir, 'gensim-data')
  44. logger = logging.getLogger('gensim.api')
  45. DATA_LIST_URL = "https://raw.githubusercontent.com/RaRe-Technologies/gensim-data/master/list.json"
  46. DOWNLOAD_BASE_URL = "https://github.com/RaRe-Technologies/gensim-data/releases/download"
  47. def _progress(chunks_downloaded, chunk_size, total_size, part=1, total_parts=1):
  48. """Reporthook for :func:`urllib.urlretrieve`, code from [1]_.
  49. Parameters
  50. ----------
  51. chunks_downloaded : int
  52. Number of chunks of data that have been downloaded.
  53. chunk_size : int
  54. Size of each chunk of data.
  55. total_size : int
  56. Total size of the dataset/model.
  57. part : int, optional
  58. Number of current part, used only if `no_parts` > 1.
  59. total_parts : int, optional
  60. Total number of parts.
  61. References
  62. ----------
  63. [1] https://gist.github.com/vladignatyev/06860ec2040cb497f0f3
  64. """
  65. bar_len = 50
  66. size_downloaded = float(chunks_downloaded * chunk_size)
  67. filled_len = int(math.floor((bar_len * size_downloaded) / total_size))
  68. percent_downloaded = round(((size_downloaded * 100) / total_size), 1)
  69. bar = '=' * filled_len + '-' * (bar_len - filled_len)
  70. if total_parts == 1:
  71. sys.stdout.write(
  72. '\r[%s] %s%s %s/%sMB downloaded' % (
  73. bar, percent_downloaded, "%",
  74. round(size_downloaded / (1024 * 1024), 1),
  75. round(float(total_size) / (1024 * 1024), 1))
  76. )
  77. sys.stdout.flush()
  78. else:
  79. sys.stdout.write(
  80. '\r Part %s/%s [%s] %s%s %s/%sMB downloaded' % (
  81. part + 1, total_parts, bar, percent_downloaded, "%",
  82. round(size_downloaded / (1024 * 1024), 1),
  83. round(float(total_size) / (1024 * 1024), 1))
  84. )
  85. sys.stdout.flush()
  86. def _create_base_dir():
  87. """Create the gensim-data directory in home directory, if it has not been already created.
  88. Raises
  89. ------
  90. Exception
  91. An exception is raised when read/write permissions are not available or a file named gensim-data
  92. already exists in the home directory.
  93. """
  94. if not os.path.isdir(base_dir):
  95. try:
  96. logger.info("Creating %s", base_dir)
  97. os.makedirs(base_dir)
  98. except OSError as e:
  99. if e.errno == errno.EEXIST:
  100. raise Exception(
  101. "Not able to create folder gensim-data in {}. File gensim-data "
  102. "exists in the direcory already.".format(user_dir)
  103. )
  104. else:
  105. raise Exception(
  106. "Can't create {}. Make sure you have the read/write permissions "
  107. "to the directory or you can try creating the folder manually"
  108. .format(base_dir)
  109. )
  110. def _calculate_md5_checksum(fname):
  111. """Calculate the checksum of the file, exactly same as md5-sum linux util.
  112. Parameters
  113. ----------
  114. fname : str
  115. Path to the file.
  116. Returns
  117. -------
  118. str
  119. MD5-hash of file names as `fname`.
  120. """
  121. hash_md5 = hashlib.md5()
  122. with open(fname, "rb") as f:
  123. for chunk in iter(lambda: f.read(4096), b""):
  124. hash_md5.update(chunk)
  125. return hash_md5.hexdigest()
  126. def info(name=None, show_only_latest=True, name_only=False):
  127. """Provide the information related to model/dataset.
  128. Parameters
  129. ----------
  130. name : str, optional
  131. Name of model/dataset. If not set - shows all available data.
  132. show_only_latest : bool, optional
  133. If storage contains different versions for one data/model, this flag allow to hide outdated versions.
  134. Affects only if `name` is None.
  135. name_only : bool, optional
  136. If True, will return only the names of available models and corpora.
  137. Returns
  138. -------
  139. dict
  140. Detailed information about one or all models/datasets.
  141. If name is specified, return full information about concrete dataset/model,
  142. otherwise, return information about all available datasets/models.
  143. Raises
  144. ------
  145. Exception
  146. If name that has been passed is incorrect.
  147. Examples
  148. --------
  149. >>> import gensim.downloader as api
  150. >>> api.info("text8") # retrieve information about text8 dataset
  151. {u'checksum': u'68799af40b6bda07dfa47a32612e5364',
  152. u'description': u'Cleaned small sample from wikipedia',
  153. u'file_name': u'text8.gz',
  154. u'parts': 1,
  155. u'source': u'http://mattmahoney.net/dc/text8.zip'}
  156. >>>
  157. >>> api.info() # retrieve information about all available datasets and models
  158. """
  159. information = json.loads(urlopen(DATA_LIST_URL).read().decode("utf-8"))
  160. if name is not None:
  161. corpora = information['corpora']
  162. models = information['models']
  163. if name in corpora:
  164. return information['corpora'][name]
  165. elif name in models:
  166. return information['models'][name]
  167. else:
  168. raise ValueError("Incorrect model/corpus name")
  169. if not show_only_latest:
  170. return information
  171. if name_only:
  172. return {"corpora": list(information['corpora'].keys()), "models": list(information['models'])}
  173. return {
  174. "corpora": {name: data for (name, data) in information['corpora'].items() if data.get("latest", True)},
  175. "models": {name: data for (name, data) in information['models'].items() if data.get("latest", True)}
  176. }
  177. def _get_checksum(name, part=None):
  178. """Retrieve the checksum of the model/dataset from gensim-data repository.
  179. Parameters
  180. ----------
  181. name : str
  182. Dataset/model name.
  183. part : int, optional
  184. Number of part (for multipart data only).
  185. Returns
  186. -------
  187. str
  188. Retrieved checksum of dataset/model.
  189. """
  190. information = info()
  191. corpora = information['corpora']
  192. models = information['models']
  193. if part is None:
  194. if name in corpora:
  195. return information['corpora'][name]["checksum"]
  196. elif name in models:
  197. return information['models'][name]["checksum"]
  198. else:
  199. if name in corpora:
  200. return information['corpora'][name]["checksum-{}".format(part)]
  201. elif name in models:
  202. return information['models'][name]["checksum-{}".format(part)]
  203. def _get_parts(name):
  204. """Retrieve the number of parts in which dataset/model has been split.
  205. Parameters
  206. ----------
  207. name: str
  208. Dataset/model name.
  209. Returns
  210. -------
  211. int
  212. Number of parts in which dataset/model has been split.
  213. """
  214. information = info()
  215. corpora = information['corpora']
  216. models = information['models']
  217. if name in corpora:
  218. return information['corpora'][name]["parts"]
  219. elif name in models:
  220. return information['models'][name]["parts"]
  221. def _download(name):
  222. """Download and extract the dataset/model.
  223. Parameters
  224. ----------
  225. name: str
  226. Dataset/model name which has to be downloaded.
  227. Raises
  228. ------
  229. Exception
  230. If md5sum on client and in repo are different.
  231. """
  232. url_load_file = "{base}/{fname}/__init__.py".format(base=DOWNLOAD_BASE_URL, fname=name)
  233. data_folder_dir = os.path.join(base_dir, name)
  234. data_folder_dir_tmp = data_folder_dir + '_tmp'
  235. tmp_dir = tempfile.mkdtemp()
  236. init_path = os.path.join(tmp_dir, "__init__.py")
  237. urllib.urlretrieve(url_load_file, init_path)
  238. total_parts = _get_parts(name)
  239. if total_parts > 1:
  240. concatenated_folder_name = "{fname}.gz".format(fname=name)
  241. concatenated_folder_dir = os.path.join(tmp_dir, concatenated_folder_name)
  242. for part in range(0, total_parts):
  243. url_data = "{base}/{fname}/{fname}.gz_0{part}".format(base=DOWNLOAD_BASE_URL, fname=name, part=part)
  244. fname = "{f}.gz_0{p}".format(f=name, p=part)
  245. dst_path = os.path.join(tmp_dir, fname)
  246. urllib.urlretrieve(
  247. url_data, dst_path,
  248. reporthook=partial(_progress, part=part, total_parts=total_parts)
  249. )
  250. if _calculate_md5_checksum(dst_path) == _get_checksum(name, part):
  251. sys.stdout.write("\n")
  252. sys.stdout.flush()
  253. logger.info("Part %s/%s downloaded", part + 1, total_parts)
  254. else:
  255. shutil.rmtree(tmp_dir)
  256. raise Exception("Checksum comparison failed, try again")
  257. with open(concatenated_folder_dir, 'wb') as wfp:
  258. for part in range(0, total_parts):
  259. part_path = os.path.join(tmp_dir, "{fname}.gz_0{part}".format(fname=name, part=part))
  260. with open(part_path, "rb") as rfp:
  261. shutil.copyfileobj(rfp, wfp)
  262. os.remove(part_path)
  263. else:
  264. url_data = "{base}/{fname}/{fname}.gz".format(base=DOWNLOAD_BASE_URL, fname=name)
  265. fname = "{fname}.gz".format(fname=name)
  266. dst_path = os.path.join(tmp_dir, fname)
  267. urllib.urlretrieve(url_data, dst_path, reporthook=_progress)
  268. if _calculate_md5_checksum(dst_path) == _get_checksum(name):
  269. sys.stdout.write("\n")
  270. sys.stdout.flush()
  271. logger.info("%s downloaded", name)
  272. else:
  273. shutil.rmtree(tmp_dir)
  274. raise Exception("Checksum comparison failed, try again")
  275. if os.path.exists(data_folder_dir_tmp):
  276. os.remove(data_folder_dir_tmp)
  277. shutil.move(tmp_dir, data_folder_dir_tmp)
  278. os.rename(data_folder_dir_tmp, data_folder_dir)
  279. def _get_filename(name):
  280. """Retrieve the filename of the dataset/model.
  281. Parameters
  282. ----------
  283. name: str
  284. Name of dataset/model.
  285. Returns
  286. -------
  287. str:
  288. Filename of the dataset/model.
  289. """
  290. information = info()
  291. corpora = information['corpora']
  292. models = information['models']
  293. if name in corpora:
  294. return information['corpora'][name]["file_name"]
  295. elif name in models:
  296. return information['models'][name]["file_name"]
  297. def load(name, return_path=False):
  298. """Download (if needed) dataset/model and load it to memory (unless `return_path` is set).
  299. Parameters
  300. ----------
  301. name: str
  302. Name of the model/dataset.
  303. return_path: bool, optional
  304. If True, return full path to file, otherwise, return loaded model / iterable dataset.
  305. Returns
  306. -------
  307. Model
  308. Requested model, if `name` is model and `return_path` == False.
  309. Dataset (iterable)
  310. Requested dataset, if `name` is dataset and `return_path` == False.
  311. str
  312. Path to file with dataset / model, only when `return_path` == True.
  313. Raises
  314. ------
  315. Exception
  316. Raised if `name` is incorrect.
  317. Examples
  318. --------
  319. Model example:
  320. >>> import gensim.downloader as api
  321. >>>
  322. >>> model = api.load("glove-twitter-25") # load glove vectors
  323. >>> model.most_similar("cat") # show words that similar to word 'cat'
  324. Dataset example:
  325. >>> import gensim.downloader as api
  326. >>>
  327. >>> wiki = api.load("wiki-en") # load extracted Wikipedia dump, around 6 Gb
  328. >>> for article in wiki: # iterate over all wiki script
  329. >>> ...
  330. Download only example
  331. >>> import gensim.downloader as api
  332. >>>
  333. >>> print(api.load("wiki-en", return_path=True)) # output: /home/user/gensim-data/wiki-en/wiki-en.gz
  334. """
  335. _create_base_dir()
  336. file_name = _get_filename(name)
  337. if file_name is None:
  338. raise ValueError("Incorrect model/corpus name")
  339. folder_dir = os.path.join(base_dir, name)
  340. path = os.path.join(folder_dir, file_name)
  341. if not os.path.exists(folder_dir):
  342. _download(name)
  343. if return_path:
  344. return path
  345. else:
  346. sys.path.insert(0, base_dir)
  347. module = __import__(name)
  348. return module.load_data()
  349. if __name__ == '__main__':
  350. logging.basicConfig(
  351. format='%(asctime)s : %(name)s : %(levelname)s : %(message)s', stream=sys.stdout, level=logging.INFO
  352. )
  353. parser = argparse.ArgumentParser(
  354. description="Gensim console API",
  355. usage="python -m gensim.api.downloader [-h] [-d data_name | -i data_name | -c]"
  356. )
  357. group = parser.add_mutually_exclusive_group()
  358. group.add_argument(
  359. "-d", "--download", metavar="data_name", nargs=1,
  360. help="To download a corpus/model : python -m gensim.downloader -d <dataname>"
  361. )
  362. full_information = 1
  363. group.add_argument(
  364. "-i", "--info", metavar="data_name", nargs='?', const=full_information,
  365. help="To get information about a corpus/model : python -m gensim.downloader -i <dataname>"
  366. )
  367. args = parser.parse_args()
  368. if args.download is not None:
  369. data_path = load(args.download[0], return_path=True)
  370. logger.info("Data has been installed and data path is %s", data_path)
  371. elif args.info is not None:
  372. if args.info == 'name':
  373. print(json.dumps(info(name_only=True), indent=4))
  374. else:
  375. output = info() if (args.info == full_information) else info(name=args.info)
  376. print(json.dumps(output, indent=4))