You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

428 lines
17 KiB

4 years ago
  1. #!/usr/bin/env python
  2. """
  3. Support for serialization of numpy data types with msgpack.
  4. """
  5. # Copyright (c) 2013-2018, Lev E. Givon
  6. # All rights reserved.
  7. # Distributed under the terms of the BSD license:
  8. # http://www.opensource.org/licenses/bsd-license
  9. import os
  10. import sys
  11. import functools
  12. import numpy as np
  13. import msgpack
  14. from msgpack import Packer as _Packer, Unpacker as _Unpacker, \
  15. unpack as _unpack, unpackb as _unpackb
  16. def encode(obj, chain=None):
  17. """
  18. Data encoder for serializing numpy data types.
  19. """
  20. if isinstance(obj, np.ndarray):
  21. # If the dtype is structured, store the interface description;
  22. # otherwise, store the corresponding array protocol type string:
  23. if obj.dtype.kind == 'V':
  24. kind = b'V'
  25. descr = obj.dtype.descr
  26. else:
  27. kind = b''
  28. descr = obj.dtype.str
  29. return {b'nd': True,
  30. b'type': descr,
  31. b'kind': kind,
  32. b'shape': obj.shape,
  33. b'data': obj.tobytes()}
  34. elif isinstance(obj, (np.bool_, np.number)):
  35. return {b'nd': False,
  36. b'type': obj.dtype.str,
  37. b'data': obj.tobytes()}
  38. elif isinstance(obj, complex):
  39. return {b'complex': True,
  40. b'data': obj.__repr__()}
  41. else:
  42. return obj if chain is None else chain(obj)
  43. def tostr(x):
  44. if sys.version_info >= (3, 0):
  45. if isinstance(x, bytes):
  46. return x.decode()
  47. else:
  48. return str(x)
  49. else:
  50. return x
  51. def decode(obj, chain=None):
  52. """
  53. Decoder for deserializing numpy data types.
  54. """
  55. try:
  56. if b'nd' in obj:
  57. if obj[b'nd'] is True:
  58. # Check if b'kind' is in obj to enable decoding of data
  59. # serialized with older versions (#20):
  60. if b'kind' in obj and obj[b'kind'] == b'V':
  61. descr = [tuple(tostr(t) if type(t) is bytes else t for t in d) \
  62. for d in obj[b'type']]
  63. else:
  64. descr = obj[b'type']
  65. return np.frombuffer(obj[b'data'],
  66. dtype=np.dtype(descr)).reshape(obj[b'shape'])
  67. else:
  68. descr = obj[b'type']
  69. return np.frombuffer(obj[b'data'],
  70. dtype=np.dtype(descr))[0]
  71. elif b'complex' in obj:
  72. return complex(tostr(obj[b'data']))
  73. else:
  74. return obj if chain is None else chain(obj)
  75. except KeyError:
  76. return obj if chain is None else chain(obj)
  77. # Maintain support for msgpack < 0.4.0:
  78. if msgpack.version < (0, 4, 0):
  79. class Packer(_Packer):
  80. def __init__(self, default=None,
  81. encoding='utf-8',
  82. unicode_errors='strict',
  83. use_single_float=False,
  84. autoreset=1):
  85. default = functools.partial(encode, chain=default)
  86. super(Packer, self).__init__(default=default,
  87. encoding=encoding,
  88. unicode_errors=unicode_errors,
  89. use_single_float=use_single_float,
  90. autoreset=autoreset)
  91. class Unpacker(_Unpacker):
  92. def __init__(self, file_like=None, read_size=0, use_list=None,
  93. object_hook=None,
  94. object_pairs_hook=None, list_hook=None, encoding='utf-8',
  95. unicode_errors='strict', max_buffer_size=0):
  96. object_hook = functools.partial(decode, chain=object_hook)
  97. super(Unpacker, self).__init__(file_like=file_like,
  98. read_size=read_size,
  99. use_list=use_list,
  100. object_hook=object_hook,
  101. object_pairs_hook=object_pairs_hook,
  102. list_hook=list_hook,
  103. encoding=encoding,
  104. unicode_errors=unicode_errors,
  105. max_buffer_size=max_buffer_size)
  106. else:
  107. class Packer(_Packer):
  108. def __init__(self, default=None,
  109. encoding='utf-8',
  110. unicode_errors='strict',
  111. use_single_float=False,
  112. autoreset=1,
  113. use_bin_type=0):
  114. default = functools.partial(encode, chain=default)
  115. super(Packer, self).__init__(default=default,
  116. encoding=encoding,
  117. unicode_errors=unicode_errors,
  118. use_single_float=use_single_float,
  119. autoreset=autoreset,
  120. use_bin_type=use_bin_type)
  121. class Unpacker(_Unpacker):
  122. def __init__(self, file_like=None, read_size=0, use_list=None,
  123. object_hook=None,
  124. object_pairs_hook=None, list_hook=None, encoding=None,
  125. unicode_errors='strict', max_buffer_size=0,
  126. ext_hook=msgpack.ExtType):
  127. object_hook = functools.partial(decode, chain=object_hook)
  128. super(Unpacker, self).__init__(file_like=file_like,
  129. read_size=read_size,
  130. use_list=use_list,
  131. object_hook=object_hook,
  132. object_pairs_hook=object_pairs_hook,
  133. list_hook=list_hook,
  134. encoding=encoding,
  135. unicode_errors=unicode_errors,
  136. max_buffer_size=max_buffer_size,
  137. ext_hook=ext_hook)
  138. def pack(o, stream, **kwargs):
  139. """
  140. Pack an object and write it to a stream.
  141. """
  142. packer = Packer(**kwargs)
  143. stream.write(packer.pack(o))
  144. def packb(o, **kwargs):
  145. """
  146. Pack an object and return the packed bytes.
  147. """
  148. return Packer(**kwargs).pack(o)
  149. def unpack(stream, **kwargs):
  150. """
  151. Unpack a packed object from a stream.
  152. """
  153. object_hook = kwargs.get('object_hook')
  154. kwargs['object_hook'] = functools.partial(decode, chain=object_hook)
  155. return _unpack(stream, **kwargs)
  156. def unpackb(packed, **kwargs):
  157. """
  158. Unpack a packed object.
  159. """
  160. object_hook = kwargs.get('object_hook')
  161. kwargs['object_hook'] = functools.partial(decode, chain=object_hook)
  162. return _unpackb(packed, **kwargs)
  163. load = unpack
  164. loads = unpackb
  165. dump = pack
  166. dumps = packb
  167. def patch():
  168. """
  169. Monkey patch msgpack module to enable support for serializing numpy types.
  170. """
  171. setattr(msgpack, 'Packer', Packer)
  172. setattr(msgpack, 'Unpacker', Unpacker)
  173. setattr(msgpack, 'load', unpack)
  174. setattr(msgpack, 'loads', unpackb)
  175. setattr(msgpack, 'dump', pack)
  176. setattr(msgpack, 'dumps', packb)
  177. setattr(msgpack, 'pack', pack)
  178. setattr(msgpack, 'packb', packb)
  179. setattr(msgpack, 'unpack', unpack)
  180. setattr(msgpack, 'unpackb', unpackb)
  181. if __name__ == '__main__':
  182. try:
  183. range = xrange # Python 2
  184. except NameError:
  185. pass # Python 3
  186. from unittest import main, TestCase, TestSuite
  187. from numpy.testing import assert_equal, assert_array_equal
  188. class ThirdParty(object):
  189. def __init__(self, foo=b'bar'):
  190. self.foo = foo
  191. def __eq__(self, other):
  192. return isinstance(other, ThirdParty) and self.foo == other.foo
  193. class test_numpy_msgpack(TestCase):
  194. def setUp(self):
  195. patch()
  196. def encode_decode(self, x, use_bin_type=False, encoding=None):
  197. x_enc = msgpack.packb(x, use_bin_type=use_bin_type)
  198. return msgpack.unpackb(x_enc, encoding=encoding)
  199. def encode_thirdparty(self, obj):
  200. return dict(__thirdparty__=True, foo=obj.foo)
  201. def decode_thirdparty(self, obj):
  202. if b'__thirdparty__' in obj:
  203. return ThirdParty(foo=obj[b'foo'])
  204. return obj
  205. def encode_decode_thirdparty(self, x, use_bin_type=False, encoding=None):
  206. x_enc = msgpack.packb(x, default=self.encode_thirdparty,
  207. use_bin_type=use_bin_type)
  208. return msgpack.unpackb(x_enc, object_hook=self.decode_thirdparty,
  209. encoding=encoding)
  210. def test_bin(self):
  211. # Since bytes == str in Python 2.7, the following
  212. # should pass on both 2.7 and 3.*
  213. assert_equal(type(self.encode_decode(b'foo')), bytes)
  214. def test_str(self):
  215. assert_equal(type(self.encode_decode('foo')), bytes)
  216. if sys.version_info.major == 2:
  217. assert_equal(type(self.encode_decode(u'foo')), str)
  218. # Test non-default string encoding/decoding:
  219. assert_equal(type(self.encode_decode(u'foo', True, 'utf=8')), unicode)
  220. def test_numpy_scalar_bool(self):
  221. x = np.bool_(True)
  222. x_rec = self.encode_decode(x)
  223. assert_equal(x, x_rec)
  224. assert_equal(type(x), type(x_rec))
  225. x = np.bool_(False)
  226. x_rec = self.encode_decode(x)
  227. assert_equal(x, x_rec)
  228. assert_equal(type(x), type(x_rec))
  229. def test_numpy_scalar_float(self):
  230. x = np.float32(np.random.rand())
  231. x_rec = self.encode_decode(x)
  232. assert_equal(x, x_rec)
  233. assert_equal(type(x), type(x_rec))
  234. def test_numpy_scalar_complex(self):
  235. x = np.complex64(np.random.rand()+1j*np.random.rand())
  236. x_rec = self.encode_decode(x)
  237. assert_equal(x, x_rec)
  238. assert_equal(type(x), type(x_rec))
  239. def test_scalar_float(self):
  240. x = np.random.rand()
  241. x_rec = self.encode_decode(x)
  242. assert_equal(x, x_rec)
  243. assert_equal(type(x), type(x_rec))
  244. def test_scalar_complex(self):
  245. x = np.random.rand()+1j*np.random.rand()
  246. x_rec = self.encode_decode(x)
  247. assert_equal(x, x_rec)
  248. assert_equal(type(x), type(x_rec))
  249. def test_list_numpy_float(self):
  250. x = [np.float32(np.random.rand()) for i in range(5)]
  251. x_rec = self.encode_decode(x)
  252. assert_array_equal(x, x_rec)
  253. assert_array_equal([type(e) for e in x],
  254. [type(e) for e in x_rec])
  255. def test_list_numpy_float_complex(self):
  256. x = [np.float32(np.random.rand()) for i in range(5)] + \
  257. [np.complex128(np.random.rand()+1j*np.random.rand()) for i in range(5)]
  258. x_rec = self.encode_decode(x)
  259. assert_array_equal(x, x_rec)
  260. assert_array_equal([type(e) for e in x],
  261. [type(e) for e in x_rec])
  262. def test_list_float(self):
  263. x = [np.random.rand() for i in range(5)]
  264. x_rec = self.encode_decode(x)
  265. assert_array_equal(x, x_rec)
  266. assert_array_equal([type(e) for e in x],
  267. [type(e) for e in x_rec])
  268. def test_list_float_complex(self):
  269. x = [(np.random.rand()+1j*np.random.rand()) for i in range(5)]
  270. x_rec = self.encode_decode(x)
  271. assert_array_equal(x, x_rec)
  272. assert_array_equal([type(e) for e in x],
  273. [type(e) for e in x_rec])
  274. def test_list_str(self):
  275. x = [b'x'*i for i in range(5)]
  276. x_rec = self.encode_decode(x)
  277. assert_array_equal(x, x_rec)
  278. assert_array_equal([type(e) for e in x_rec], [bytes]*5)
  279. def test_dict_float(self):
  280. x = {b'foo': 1.0, b'bar': 2.0}
  281. x_rec = self.encode_decode(x)
  282. assert_array_equal(sorted(x.values()), sorted(x_rec.values()))
  283. assert_array_equal([type(e) for e in sorted(x.values())],
  284. [type(e) for e in sorted(x_rec.values())])
  285. assert_array_equal(sorted(x.keys()), sorted(x_rec.keys()))
  286. assert_array_equal([type(e) for e in sorted(x.keys())],
  287. [type(e) for e in sorted(x_rec.keys())])
  288. def test_dict_complex(self):
  289. x = {b'foo': 1.0+1.0j, b'bar': 2.0+2.0j}
  290. x_rec = self.encode_decode(x)
  291. assert_array_equal(sorted(x.values(), key=np.linalg.norm),
  292. sorted(x_rec.values(), key=np.linalg.norm))
  293. assert_array_equal([type(e) for e in sorted(x.values(), key=np.linalg.norm)],
  294. [type(e) for e in sorted(x_rec.values(), key=np.linalg.norm)])
  295. assert_array_equal(sorted(x.keys()), sorted(x_rec.keys()))
  296. assert_array_equal([type(e) for e in sorted(x.keys())],
  297. [type(e) for e in sorted(x_rec.keys())])
  298. def test_dict_str(self):
  299. x = {b'foo': b'xxx', b'bar': b'yyyy'}
  300. x_rec = self.encode_decode(x)
  301. assert_array_equal(sorted(x.values()), sorted(x_rec.values()))
  302. assert_array_equal([type(e) for e in sorted(x.values())],
  303. [type(e) for e in sorted(x_rec.values())])
  304. assert_array_equal(sorted(x.keys()), sorted(x_rec.keys()))
  305. assert_array_equal([type(e) for e in sorted(x.keys())],
  306. [type(e) for e in sorted(x_rec.keys())])
  307. def test_dict_numpy_float(self):
  308. x = {b'foo': np.float32(1.0), b'bar': np.float32(2.0)}
  309. x_rec = self.encode_decode(x)
  310. assert_array_equal(sorted(x.values()), sorted(x_rec.values()))
  311. assert_array_equal([type(e) for e in sorted(x.values())],
  312. [type(e) for e in sorted(x_rec.values())])
  313. assert_array_equal(sorted(x.keys()), sorted(x_rec.keys()))
  314. assert_array_equal([type(e) for e in sorted(x.keys())],
  315. [type(e) for e in sorted(x_rec.keys())])
  316. def test_dict_numpy_complex(self):
  317. x = {b'foo': np.complex128(1.0+1.0j), b'bar': np.complex128(2.0+2.0j)}
  318. x_rec = self.encode_decode(x)
  319. assert_array_equal(sorted(x.values(), key=np.linalg.norm),
  320. sorted(x_rec.values(), key=np.linalg.norm))
  321. assert_array_equal([type(e) for e in sorted(x.values(), key=np.linalg.norm)],
  322. [type(e) for e in sorted(x_rec.values(), key=np.linalg.norm)])
  323. assert_array_equal(sorted(x.keys()), sorted(x_rec.keys()))
  324. assert_array_equal([type(e) for e in sorted(x.keys())],
  325. [type(e) for e in sorted(x_rec.keys())])
  326. def test_numpy_array_float(self):
  327. x = np.random.rand(5).astype(np.float32)
  328. x_rec = self.encode_decode(x)
  329. assert_array_equal(x, x_rec)
  330. assert_equal(x.dtype, x_rec.dtype)
  331. def test_numpy_array_complex(self):
  332. x = (np.random.rand(5)+1j*np.random.rand(5)).astype(np.complex128)
  333. x_rec = self.encode_decode(x)
  334. assert_array_equal(x, x_rec)
  335. assert_equal(x.dtype, x_rec.dtype)
  336. def test_numpy_array_float_2d(self):
  337. x = np.random.rand(5,5).astype(np.float32)
  338. x_rec = self.encode_decode(x)
  339. assert_array_equal(x, x_rec)
  340. assert_equal(x.dtype, x_rec.dtype)
  341. def test_numpy_array_str(self):
  342. x = np.array([b'aaa', b'bbbb', b'ccccc'])
  343. x_rec = self.encode_decode(x)
  344. assert_array_equal(x, x_rec)
  345. assert_equal(x.dtype, x_rec.dtype)
  346. def test_numpy_array_mixed(self):
  347. x = np.array([(1, 2, b'a', [1.0, 2.0])],
  348. np.dtype([('arg0', np.uint32),
  349. ('arg1', np.uint32),
  350. ('arg2', 'S1'),
  351. ('arg3', np.float32, (2,))]))
  352. x_rec = self.encode_decode(x)
  353. assert_array_equal(x, x_rec)
  354. assert_equal(x.dtype, x_rec.dtype)
  355. def test_list_mixed(self):
  356. x = [1.0, np.float32(3.5), np.complex128(4.25), b'foo']
  357. x_rec = self.encode_decode(x)
  358. assert_array_equal(x, x_rec)
  359. assert_array_equal([type(e) for e in x],
  360. [type(e) for e in x_rec])
  361. def test_chain(self):
  362. x = ThirdParty(foo=b'test marshal/unmarshal')
  363. x_rec = self.encode_decode_thirdparty(x)
  364. self.assertEqual(x, x_rec)
  365. main()