You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

428 lines
17 KiB

#!/usr/bin/env python
"""
Support for serialization of numpy data types with msgpack.
"""
# Copyright (c) 2013-2018, Lev E. Givon
# All rights reserved.
# Distributed under the terms of the BSD license:
# http://www.opensource.org/licenses/bsd-license
import os
import sys
import functools
import numpy as np
import msgpack
from msgpack import Packer as _Packer, Unpacker as _Unpacker, \
unpack as _unpack, unpackb as _unpackb
def encode(obj, chain=None):
"""
Data encoder for serializing numpy data types.
"""
if isinstance(obj, np.ndarray):
# If the dtype is structured, store the interface description;
# otherwise, store the corresponding array protocol type string:
if obj.dtype.kind == 'V':
kind = b'V'
descr = obj.dtype.descr
else:
kind = b''
descr = obj.dtype.str
return {b'nd': True,
b'type': descr,
b'kind': kind,
b'shape': obj.shape,
b'data': obj.tobytes()}
elif isinstance(obj, (np.bool_, np.number)):
return {b'nd': False,
b'type': obj.dtype.str,
b'data': obj.tobytes()}
elif isinstance(obj, complex):
return {b'complex': True,
b'data': obj.__repr__()}
else:
return obj if chain is None else chain(obj)
def tostr(x):
if sys.version_info >= (3, 0):
if isinstance(x, bytes):
return x.decode()
else:
return str(x)
else:
return x
def decode(obj, chain=None):
"""
Decoder for deserializing numpy data types.
"""
try:
if b'nd' in obj:
if obj[b'nd'] is True:
# Check if b'kind' is in obj to enable decoding of data
# serialized with older versions (#20):
if b'kind' in obj and obj[b'kind'] == b'V':
descr = [tuple(tostr(t) if type(t) is bytes else t for t in d) \
for d in obj[b'type']]
else:
descr = obj[b'type']
return np.frombuffer(obj[b'data'],
dtype=np.dtype(descr)).reshape(obj[b'shape'])
else:
descr = obj[b'type']
return np.frombuffer(obj[b'data'],
dtype=np.dtype(descr))[0]
elif b'complex' in obj:
return complex(tostr(obj[b'data']))
else:
return obj if chain is None else chain(obj)
except KeyError:
return obj if chain is None else chain(obj)
# Maintain support for msgpack < 0.4.0:
if msgpack.version < (0, 4, 0):
class Packer(_Packer):
def __init__(self, default=None,
encoding='utf-8',
unicode_errors='strict',
use_single_float=False,
autoreset=1):
default = functools.partial(encode, chain=default)
super(Packer, self).__init__(default=default,
encoding=encoding,
unicode_errors=unicode_errors,
use_single_float=use_single_float,
autoreset=autoreset)
class Unpacker(_Unpacker):
def __init__(self, file_like=None, read_size=0, use_list=None,
object_hook=None,
object_pairs_hook=None, list_hook=None, encoding='utf-8',
unicode_errors='strict', max_buffer_size=0):
object_hook = functools.partial(decode, chain=object_hook)
super(Unpacker, self).__init__(file_like=file_like,
read_size=read_size,
use_list=use_list,
object_hook=object_hook,
object_pairs_hook=object_pairs_hook,
list_hook=list_hook,
encoding=encoding,
unicode_errors=unicode_errors,
max_buffer_size=max_buffer_size)
else:
class Packer(_Packer):
def __init__(self, default=None,
encoding='utf-8',
unicode_errors='strict',
use_single_float=False,
autoreset=1,
use_bin_type=0):
default = functools.partial(encode, chain=default)
super(Packer, self).__init__(default=default,
encoding=encoding,
unicode_errors=unicode_errors,
use_single_float=use_single_float,
autoreset=autoreset,
use_bin_type=use_bin_type)
class Unpacker(_Unpacker):
def __init__(self, file_like=None, read_size=0, use_list=None,
object_hook=None,
object_pairs_hook=None, list_hook=None, encoding=None,
unicode_errors='strict', max_buffer_size=0,
ext_hook=msgpack.ExtType):
object_hook = functools.partial(decode, chain=object_hook)
super(Unpacker, self).__init__(file_like=file_like,
read_size=read_size,
use_list=use_list,
object_hook=object_hook,
object_pairs_hook=object_pairs_hook,
list_hook=list_hook,
encoding=encoding,
unicode_errors=unicode_errors,
max_buffer_size=max_buffer_size,
ext_hook=ext_hook)
def pack(o, stream, **kwargs):
"""
Pack an object and write it to a stream.
"""
packer = Packer(**kwargs)
stream.write(packer.pack(o))
def packb(o, **kwargs):
"""
Pack an object and return the packed bytes.
"""
return Packer(**kwargs).pack(o)
def unpack(stream, **kwargs):
"""
Unpack a packed object from a stream.
"""
object_hook = kwargs.get('object_hook')
kwargs['object_hook'] = functools.partial(decode, chain=object_hook)
return _unpack(stream, **kwargs)
def unpackb(packed, **kwargs):
"""
Unpack a packed object.
"""
object_hook = kwargs.get('object_hook')
kwargs['object_hook'] = functools.partial(decode, chain=object_hook)
return _unpackb(packed, **kwargs)
load = unpack
loads = unpackb
dump = pack
dumps = packb
def patch():
"""
Monkey patch msgpack module to enable support for serializing numpy types.
"""
setattr(msgpack, 'Packer', Packer)
setattr(msgpack, 'Unpacker', Unpacker)
setattr(msgpack, 'load', unpack)
setattr(msgpack, 'loads', unpackb)
setattr(msgpack, 'dump', pack)
setattr(msgpack, 'dumps', packb)
setattr(msgpack, 'pack', pack)
setattr(msgpack, 'packb', packb)
setattr(msgpack, 'unpack', unpack)
setattr(msgpack, 'unpackb', unpackb)
if __name__ == '__main__':
try:
range = xrange # Python 2
except NameError:
pass # Python 3
from unittest import main, TestCase, TestSuite
from numpy.testing import assert_equal, assert_array_equal
class ThirdParty(object):
def __init__(self, foo=b'bar'):
self.foo = foo
def __eq__(self, other):
return isinstance(other, ThirdParty) and self.foo == other.foo
class test_numpy_msgpack(TestCase):
def setUp(self):
patch()
def encode_decode(self, x, use_bin_type=False, encoding=None):
x_enc = msgpack.packb(x, use_bin_type=use_bin_type)
return msgpack.unpackb(x_enc, encoding=encoding)
def encode_thirdparty(self, obj):
return dict(__thirdparty__=True, foo=obj.foo)
def decode_thirdparty(self, obj):
if b'__thirdparty__' in obj:
return ThirdParty(foo=obj[b'foo'])
return obj
def encode_decode_thirdparty(self, x, use_bin_type=False, encoding=None):
x_enc = msgpack.packb(x, default=self.encode_thirdparty,
use_bin_type=use_bin_type)
return msgpack.unpackb(x_enc, object_hook=self.decode_thirdparty,
encoding=encoding)
def test_bin(self):
# Since bytes == str in Python 2.7, the following
# should pass on both 2.7 and 3.*
assert_equal(type(self.encode_decode(b'foo')), bytes)
def test_str(self):
assert_equal(type(self.encode_decode('foo')), bytes)
if sys.version_info.major == 2:
assert_equal(type(self.encode_decode(u'foo')), str)
# Test non-default string encoding/decoding:
assert_equal(type(self.encode_decode(u'foo', True, 'utf=8')), unicode)
def test_numpy_scalar_bool(self):
x = np.bool_(True)
x_rec = self.encode_decode(x)
assert_equal(x, x_rec)
assert_equal(type(x), type(x_rec))
x = np.bool_(False)
x_rec = self.encode_decode(x)
assert_equal(x, x_rec)
assert_equal(type(x), type(x_rec))
def test_numpy_scalar_float(self):
x = np.float32(np.random.rand())
x_rec = self.encode_decode(x)
assert_equal(x, x_rec)
assert_equal(type(x), type(x_rec))
def test_numpy_scalar_complex(self):
x = np.complex64(np.random.rand()+1j*np.random.rand())
x_rec = self.encode_decode(x)
assert_equal(x, x_rec)
assert_equal(type(x), type(x_rec))
def test_scalar_float(self):
x = np.random.rand()
x_rec = self.encode_decode(x)
assert_equal(x, x_rec)
assert_equal(type(x), type(x_rec))
def test_scalar_complex(self):
x = np.random.rand()+1j*np.random.rand()
x_rec = self.encode_decode(x)
assert_equal(x, x_rec)
assert_equal(type(x), type(x_rec))
def test_list_numpy_float(self):
x = [np.float32(np.random.rand()) for i in range(5)]
x_rec = self.encode_decode(x)
assert_array_equal(x, x_rec)
assert_array_equal([type(e) for e in x],
[type(e) for e in x_rec])
def test_list_numpy_float_complex(self):
x = [np.float32(np.random.rand()) for i in range(5)] + \
[np.complex128(np.random.rand()+1j*np.random.rand()) for i in range(5)]
x_rec = self.encode_decode(x)
assert_array_equal(x, x_rec)
assert_array_equal([type(e) for e in x],
[type(e) for e in x_rec])
def test_list_float(self):
x = [np.random.rand() for i in range(5)]
x_rec = self.encode_decode(x)
assert_array_equal(x, x_rec)
assert_array_equal([type(e) for e in x],
[type(e) for e in x_rec])
def test_list_float_complex(self):
x = [(np.random.rand()+1j*np.random.rand()) for i in range(5)]
x_rec = self.encode_decode(x)
assert_array_equal(x, x_rec)
assert_array_equal([type(e) for e in x],
[type(e) for e in x_rec])
def test_list_str(self):
x = [b'x'*i for i in range(5)]
x_rec = self.encode_decode(x)
assert_array_equal(x, x_rec)
assert_array_equal([type(e) for e in x_rec], [bytes]*5)
def test_dict_float(self):
x = {b'foo': 1.0, b'bar': 2.0}
x_rec = self.encode_decode(x)
assert_array_equal(sorted(x.values()), sorted(x_rec.values()))
assert_array_equal([type(e) for e in sorted(x.values())],
[type(e) for e in sorted(x_rec.values())])
assert_array_equal(sorted(x.keys()), sorted(x_rec.keys()))
assert_array_equal([type(e) for e in sorted(x.keys())],
[type(e) for e in sorted(x_rec.keys())])
def test_dict_complex(self):
x = {b'foo': 1.0+1.0j, b'bar': 2.0+2.0j}
x_rec = self.encode_decode(x)
assert_array_equal(sorted(x.values(), key=np.linalg.norm),
sorted(x_rec.values(), key=np.linalg.norm))
assert_array_equal([type(e) for e in sorted(x.values(), key=np.linalg.norm)],
[type(e) for e in sorted(x_rec.values(), key=np.linalg.norm)])
assert_array_equal(sorted(x.keys()), sorted(x_rec.keys()))
assert_array_equal([type(e) for e in sorted(x.keys())],
[type(e) for e in sorted(x_rec.keys())])
def test_dict_str(self):
x = {b'foo': b'xxx', b'bar': b'yyyy'}
x_rec = self.encode_decode(x)
assert_array_equal(sorted(x.values()), sorted(x_rec.values()))
assert_array_equal([type(e) for e in sorted(x.values())],
[type(e) for e in sorted(x_rec.values())])
assert_array_equal(sorted(x.keys()), sorted(x_rec.keys()))
assert_array_equal([type(e) for e in sorted(x.keys())],
[type(e) for e in sorted(x_rec.keys())])
def test_dict_numpy_float(self):
x = {b'foo': np.float32(1.0), b'bar': np.float32(2.0)}
x_rec = self.encode_decode(x)
assert_array_equal(sorted(x.values()), sorted(x_rec.values()))
assert_array_equal([type(e) for e in sorted(x.values())],
[type(e) for e in sorted(x_rec.values())])
assert_array_equal(sorted(x.keys()), sorted(x_rec.keys()))
assert_array_equal([type(e) for e in sorted(x.keys())],
[type(e) for e in sorted(x_rec.keys())])
def test_dict_numpy_complex(self):
x = {b'foo': np.complex128(1.0+1.0j), b'bar': np.complex128(2.0+2.0j)}
x_rec = self.encode_decode(x)
assert_array_equal(sorted(x.values(), key=np.linalg.norm),
sorted(x_rec.values(), key=np.linalg.norm))
assert_array_equal([type(e) for e in sorted(x.values(), key=np.linalg.norm)],
[type(e) for e in sorted(x_rec.values(), key=np.linalg.norm)])
assert_array_equal(sorted(x.keys()), sorted(x_rec.keys()))
assert_array_equal([type(e) for e in sorted(x.keys())],
[type(e) for e in sorted(x_rec.keys())])
def test_numpy_array_float(self):
x = np.random.rand(5).astype(np.float32)
x_rec = self.encode_decode(x)
assert_array_equal(x, x_rec)
assert_equal(x.dtype, x_rec.dtype)
def test_numpy_array_complex(self):
x = (np.random.rand(5)+1j*np.random.rand(5)).astype(np.complex128)
x_rec = self.encode_decode(x)
assert_array_equal(x, x_rec)
assert_equal(x.dtype, x_rec.dtype)
def test_numpy_array_float_2d(self):
x = np.random.rand(5,5).astype(np.float32)
x_rec = self.encode_decode(x)
assert_array_equal(x, x_rec)
assert_equal(x.dtype, x_rec.dtype)
def test_numpy_array_str(self):
x = np.array([b'aaa', b'bbbb', b'ccccc'])
x_rec = self.encode_decode(x)
assert_array_equal(x, x_rec)
assert_equal(x.dtype, x_rec.dtype)
def test_numpy_array_mixed(self):
x = np.array([(1, 2, b'a', [1.0, 2.0])],
np.dtype([('arg0', np.uint32),
('arg1', np.uint32),
('arg2', 'S1'),
('arg3', np.float32, (2,))]))
x_rec = self.encode_decode(x)
assert_array_equal(x, x_rec)
assert_equal(x.dtype, x_rec.dtype)
def test_list_mixed(self):
x = [1.0, np.float32(3.5), np.complex128(4.25), b'foo']
x_rec = self.encode_decode(x)
assert_array_equal(x, x_rec)
assert_array_equal([type(e) for e in x],
[type(e) for e in x_rec])
def test_chain(self):
x = ThirdParty(foo=b'test marshal/unmarshal')
x_rec = self.encode_decode_thirdparty(x)
self.assertEqual(x, x_rec)
main()