laywerrobot/lib/python3.6/site-packages/smart_open/tests/test_s3.py

510 lines
18 KiB
Python
Raw Normal View History

2020-08-27 21:55:39 +02:00
# -*- coding: utf-8 -*-
import logging
import gzip
import io
import os
import uuid
import unittest
import boto3
import botocore.client
import boto.s3.bucket
import mock
import moto
import smart_open
import smart_open.s3
BUCKET_NAME = 'test-smartopen-{}'.format(uuid.uuid4().hex) # generate random bucket (avoid race-condition in CI)
KEY_NAME = 'test-key'
WRITE_KEY_NAME = 'test-write-key'
logger = logging.getLogger(__name__)
def maybe_mock_s3(func):
if os.environ.get('SO_DISABLE_MOCKS') == "1":
return func
else:
return moto.mock_s3(func)
def cleanup_bucket(s3, delete_bucket=False):
for bucket in s3.buckets.all():
if bucket.name == BUCKET_NAME:
for key in bucket.objects.all():
key.delete()
if delete_bucket:
bucket.delete()
return False
return True
return False
def create_bucket_and_key(bucket_name=BUCKET_NAME, key_name=KEY_NAME, contents=None):
# fake (or not) connection, bucket and key
logger.debug('%r', locals())
s3 = boto3.resource('s3')
bucket_exist = cleanup_bucket(s3)
if not bucket_exist:
mybucket = s3.create_bucket(Bucket=bucket_name)
mybucket = s3.Bucket(bucket_name)
mykey = s3.Object(bucket_name, key_name)
if contents is not None:
mykey.put(Body=contents)
return mybucket, mykey
@maybe_mock_s3
class SeekableBufferedInputBaseTest(unittest.TestCase):
def setUp(self):
# lower the multipart upload size, to speed up these tests
self.old_min_part_size = smart_open.s3.DEFAULT_MIN_PART_SIZE
smart_open.s3.DEFAULT_MIN_PART_SIZE = 5 * 1024**2
def tearDown(self):
smart_open.s3.DEFAULT_MIN_PART_SIZE = self.old_min_part_size
s3 = boto3.resource('s3')
cleanup_bucket(s3, delete_bucket=True)
def test_iter(self):
"""Are S3 files iterated over correctly?"""
# a list of strings to test with
expected = u"hello wořld\nhow are you?".encode('utf8')
create_bucket_and_key(contents=expected)
# connect to fake s3 and read from the fake key we filled above
fin = smart_open.s3.SeekableBufferedInputBase(BUCKET_NAME, KEY_NAME)
output = [line.rstrip(b'\n') for line in fin]
self.assertEqual(output, expected.split(b'\n'))
def test_iter_context_manager(self):
# same thing but using a context manager
expected = u"hello wořld\nhow are you?".encode('utf8')
create_bucket_and_key(contents=expected)
with smart_open.s3.SeekableBufferedInputBase(BUCKET_NAME, KEY_NAME) as fin:
output = [line.rstrip(b'\n') for line in fin]
self.assertEqual(output, expected.split(b'\n'))
def test_read(self):
"""Are S3 files read correctly?"""
content = u"hello wořld\nhow are you?".encode('utf8')
create_bucket_and_key(contents=content)
logger.debug('content: %r len: %r', content, len(content))
fin = smart_open.s3.SeekableBufferedInputBase(BUCKET_NAME, KEY_NAME)
self.assertEqual(content[:6], fin.read(6))
self.assertEqual(content[6:14], fin.read(8)) # ř is 2 bytes
self.assertEqual(content[14:], fin.read()) # read the rest
def test_seek_beginning(self):
"""Does seeking to the beginning of S3 files work correctly?"""
content = u"hello wořld\nhow are you?".encode('utf8')
create_bucket_and_key(contents=content)
fin = smart_open.s3.SeekableBufferedInputBase(BUCKET_NAME, KEY_NAME)
self.assertEqual(content[:6], fin.read(6))
self.assertEqual(content[6:14], fin.read(8)) # ř is 2 bytes
fin.seek(0)
self.assertEqual(content, fin.read()) # no size given => read whole file
fin.seek(0)
self.assertEqual(content, fin.read(-1)) # same thing
def test_seek_start(self):
"""Does seeking from the start of S3 files work correctly?"""
content = u"hello wořld\nhow are you?".encode('utf8')
create_bucket_and_key(contents=content)
fin = smart_open.s3.SeekableBufferedInputBase(BUCKET_NAME, KEY_NAME)
seek = fin.seek(6)
self.assertEqual(seek, 6)
self.assertEqual(fin.tell(), 6)
self.assertEqual(fin.read(6), u'wořld'.encode('utf-8'))
def test_seek_current(self):
"""Does seeking from the middle of S3 files work correctly?"""
content = u"hello wořld\nhow are you?".encode('utf8')
create_bucket_and_key(contents=content)
fin = smart_open.s3.SeekableBufferedInputBase(BUCKET_NAME, KEY_NAME)
self.assertEqual(fin.read(5), b'hello')
seek = fin.seek(1, whence=smart_open.s3.CURRENT)
self.assertEqual(seek, 6)
self.assertEqual(fin.read(6), u'wořld'.encode('utf-8'))
def test_seek_end(self):
"""Does seeking from the end of S3 files work correctly?"""
content = u"hello wořld\nhow are you?".encode('utf8')
create_bucket_and_key(contents=content)
fin = smart_open.s3.SeekableBufferedInputBase(BUCKET_NAME, KEY_NAME)
seek = fin.seek(-4, whence=smart_open.s3.END)
self.assertEqual(seek, len(content) - 4)
self.assertEqual(fin.read(), b'you?')
def test_detect_eof(self):
content = u"hello wořld\nhow are you?".encode('utf8')
create_bucket_and_key(contents=content)
fin = smart_open.s3.SeekableBufferedInputBase(BUCKET_NAME, KEY_NAME)
fin.read()
eof = fin.tell()
self.assertEqual(eof, len(content))
fin.seek(0, whence=smart_open.s3.END)
self.assertEqual(eof, fin.tell())
def test_read_gzip(self):
expected = u'раcцветали яблони и груши, поплыли туманы над рекой...'.encode('utf-8')
buf = io.BytesIO()
buf.close = lambda: None # keep buffer open so that we can .getvalue()
with gzip.GzipFile(fileobj=buf, mode='w') as zipfile:
zipfile.write(expected)
create_bucket_and_key(contents=buf.getvalue())
#
# Make sure we're reading things correctly.
#
with smart_open.s3.SeekableBufferedInputBase(BUCKET_NAME, KEY_NAME) as fin:
self.assertEqual(fin.read(), buf.getvalue())
#
# Make sure the buffer we wrote is legitimate gzip.
#
sanity_buf = io.BytesIO(buf.getvalue())
with gzip.GzipFile(fileobj=sanity_buf) as zipfile:
self.assertEqual(zipfile.read(), expected)
logger.debug('starting actual test')
with smart_open.s3.SeekableBufferedInputBase(BUCKET_NAME, KEY_NAME) as fin:
with gzip.GzipFile(fileobj=fin) as zipfile:
actual = zipfile.read()
self.assertEqual(expected, actual)
def test_readline(self):
content = b'englishman\nin\nnew\nyork\n'
create_bucket_and_key(contents=content)
with smart_open.s3.SeekableBufferedInputBase(BUCKET_NAME, KEY_NAME) as fin:
fin.readline()
self.assertEqual(fin.tell(), content.index(b'\n')+1)
fin.seek(0)
actual = list(fin)
self.assertEqual(fin.tell(), len(content))
expected = [b'englishman\n', b'in\n', b'new\n', b'york\n']
self.assertEqual(expected, actual)
def test_readline_tiny_buffer(self):
content = b'englishman\nin\nnew\nyork\n'
create_bucket_and_key(contents=content)
with smart_open.s3.BufferedInputBase(BUCKET_NAME, KEY_NAME, buffer_size=8) as fin:
actual = list(fin)
expected = [b'englishman\n', b'in\n', b'new\n', b'york\n']
self.assertEqual(expected, actual)
def test_read0_does_not_return_data(self):
content = b'englishman\nin\nnew\nyork\n'
create_bucket_and_key(contents=content)
with smart_open.s3.BufferedInputBase(BUCKET_NAME, KEY_NAME) as fin:
data = fin.read(0)
self.assertEqual(data, b'')
@maybe_mock_s3
class BufferedOutputBaseTest(unittest.TestCase):
"""
Test writing into s3 files.
"""
def tearDown(self):
s3 = boto3.resource('s3')
cleanup_bucket(s3, delete_bucket=True)
def test_write_01(self):
"""Does writing into s3 work correctly?"""
create_bucket_and_key()
test_string = u"žluťoučký koníček".encode('utf8')
# write into key
with smart_open.s3.BufferedOutputBase(BUCKET_NAME, WRITE_KEY_NAME) as fout:
fout.write(test_string)
# read key and test content
output = list(smart_open.smart_open("s3://{}/{}".format(BUCKET_NAME, WRITE_KEY_NAME), "rb"))
self.assertEqual(output, [test_string])
def test_write_01a(self):
"""Does s3 write fail on incorrect input?"""
create_bucket_and_key()
try:
with smart_open.s3.BufferedOutputBase(BUCKET_NAME, WRITE_KEY_NAME) as fin:
fin.write(None)
except TypeError:
pass
else:
self.fail()
def test_write_02(self):
"""Does s3 write unicode-utf8 conversion work?"""
create_bucket_and_key()
smart_open_write = smart_open.s3.BufferedOutputBase(BUCKET_NAME, WRITE_KEY_NAME)
smart_open_write.tell()
logger.info("smart_open_write: %r", smart_open_write)
with smart_open_write as fout:
fout.write(u"testžížáč".encode("utf-8"))
self.assertEqual(fout.tell(), 14)
def test_write_03(self):
"""Does s3 multipart chunking work correctly?"""
create_bucket_and_key()
# write
smart_open_write = smart_open.s3.BufferedOutputBase(
BUCKET_NAME, WRITE_KEY_NAME, min_part_size=10
)
with smart_open_write as fout:
fout.write(b"test")
self.assertEqual(fout._buf.tell(), 4)
fout.write(b"test\n")
self.assertEqual(fout._buf.tell(), 9)
self.assertEqual(fout._total_parts, 0)
fout.write(b"test")
self.assertEqual(fout._buf.tell(), 0)
self.assertEqual(fout._total_parts, 1)
# read back the same key and check its content
output = list(smart_open.smart_open("s3://{}/{}".format(BUCKET_NAME, WRITE_KEY_NAME)))
self.assertEqual(output, [b"testtest\n", b"test"])
def test_write_04(self):
"""Does writing no data cause key with an empty value to be created?"""
_ = create_bucket_and_key()
smart_open_write = smart_open.s3.BufferedOutputBase(BUCKET_NAME, WRITE_KEY_NAME)
with smart_open_write as fout: # noqa
pass
# read back the same key and check its content
output = list(smart_open.smart_open("s3://{}/{}".format(BUCKET_NAME, WRITE_KEY_NAME)))
self.assertEqual(output, [])
def test_gzip(self):
create_bucket_and_key()
expected = u'а не спеть ли мне песню... о любви'.encode('utf-8')
with smart_open.s3.BufferedOutputBase(BUCKET_NAME, WRITE_KEY_NAME) as fout:
with gzip.GzipFile(fileobj=fout, mode='w') as zipfile:
zipfile.write(expected)
with smart_open.s3.SeekableBufferedInputBase(BUCKET_NAME, WRITE_KEY_NAME) as fin:
with gzip.GzipFile(fileobj=fin) as zipfile:
actual = zipfile.read()
self.assertEqual(expected, actual)
def test_binary_iterator(self):
expected = u"выйду ночью в поле с конём".encode('utf-8').split(b' ')
create_bucket_and_key(contents=b"\n".join(expected))
with smart_open.s3.open(BUCKET_NAME, KEY_NAME, 'rb') as fin:
actual = [line.rstrip() for line in fin]
self.assertEqual(expected, actual)
def test_nonexisting_bucket(self):
expected = u"выйду ночью в поле с конём".encode('utf-8')
with self.assertRaises(ValueError):
with smart_open.s3.open('thisbucketdoesntexist', 'mykey', 'wb') as fout:
fout.write(expected)
class ClampTest(unittest.TestCase):
def test(self):
self.assertEqual(smart_open.s3._clamp(5, 0, 10), 5)
self.assertEqual(smart_open.s3._clamp(11, 0, 10), 10)
self.assertEqual(smart_open.s3._clamp(-1, 0, 10), 0)
ARBITRARY_CLIENT_ERROR = botocore.client.ClientError(error_response={}, operation_name='bar')
@maybe_mock_s3
class IterBucketTest(unittest.TestCase):
def test_iter_bucket(self):
populate_bucket()
results = list(smart_open.s3.iter_bucket(BUCKET_NAME))
self.assertEqual(len(results), 10)
def test_accepts_boto3_bucket(self):
populate_bucket()
s3 = boto3.resource('s3')
bucket = s3.Bucket(BUCKET_NAME)
results = list(smart_open.s3.iter_bucket(bucket))
self.assertEqual(len(results), 10)
def test_accepts_boto_bucket(self):
populate_bucket()
bucket = boto.s3.bucket.Bucket(name=BUCKET_NAME)
results = list(smart_open.s3.iter_bucket(bucket))
self.assertEqual(len(results), 10)
def test_list_bucket(self):
num_keys = 10
populate_bucket()
keys = list(smart_open.s3._list_bucket(BUCKET_NAME))
self.assertEqual(len(keys), num_keys)
expected = ['key_%d' % x for x in range(num_keys)]
self.assertEqual(sorted(keys), sorted(expected))
@unittest.skip('this test takes too long for some unknown reason')
def test_list_bucket_long(self):
num_keys = 1010
populate_bucket(num_keys=num_keys)
keys = list(smart_open.s3._list_bucket(BUCKET_NAME))
self.assertEqual(len(keys), num_keys)
expected = ['key_%d' % x for x in range(num_keys)]
self.assertEqual(sorted(keys), sorted(expected))
def test_old(self):
"""Does s3_iter_bucket work correctly?"""
create_bucket_and_key()
#
# Use an old-school boto Bucket class for historical reasons.
#
mybucket = boto.s3.bucket.Bucket(name=BUCKET_NAME)
# first, create some keys in the bucket
expected = {}
for key_no in range(200):
key_name = "mykey%s" % key_no
with smart_open.smart_open("s3://%s/%s" % (BUCKET_NAME, key_name), 'wb') as fout:
content = '\n'.join("line%i%i" % (key_no, line_no) for line_no in range(10)).encode('utf8')
fout.write(content)
expected[key_name] = content
# read all keys + their content back, in parallel, using s3_iter_bucket
result = {}
for k, c in smart_open.s3.iter_bucket(mybucket):
result[k] = c
self.assertEqual(expected, result)
# read some of the keys back, in parallel, using s3_iter_bucket
result = {}
for k, c in smart_open.s3.iter_bucket(mybucket, accept_key=lambda fname: fname.endswith('4')):
result[k] = c
self.assertEqual(result, dict((k, c) for k, c in expected.items() if k.endswith('4')))
# read some of the keys back, in parallel, using s3_iter_bucket
result = dict(smart_open.s3.iter_bucket(mybucket, key_limit=10))
self.assertEqual(len(result), min(len(expected), 10))
for workers in [1, 4, 8, 16, 64]:
result = {}
for k, c in smart_open.s3.iter_bucket(mybucket):
result[k] = c
self.assertEqual(result, expected)
@maybe_mock_s3
class DownloadKeyTest(unittest.TestCase):
def test_happy(self):
contents = b'hello'
create_bucket_and_key(contents=contents)
expected = (KEY_NAME, contents)
actual = smart_open.s3._download_key(KEY_NAME, bucket_name=BUCKET_NAME)
self.assertEqual(expected, actual)
def test_intermittent_error(self):
contents = b'hello'
create_bucket_and_key(contents=contents)
expected = (KEY_NAME, contents)
side_effect = [ARBITRARY_CLIENT_ERROR, ARBITRARY_CLIENT_ERROR, contents]
with mock.patch('smart_open.s3._download_fileobj', side_effect=side_effect):
actual = smart_open.s3._download_key(KEY_NAME, bucket_name=BUCKET_NAME)
self.assertEqual(expected, actual)
def test_persistent_error(self):
contents = b'hello'
create_bucket_and_key(contents=contents)
side_effect = [ARBITRARY_CLIENT_ERROR, ARBITRARY_CLIENT_ERROR,
ARBITRARY_CLIENT_ERROR, ARBITRARY_CLIENT_ERROR]
with mock.patch('smart_open.s3._download_fileobj', side_effect=side_effect):
self.assertRaises(botocore.client.ClientError, smart_open.s3._download_key,
KEY_NAME, bucket_name=BUCKET_NAME)
def test_intermittent_error_retries(self):
contents = b'hello'
create_bucket_and_key(contents=contents)
expected = (KEY_NAME, contents)
side_effect = [ARBITRARY_CLIENT_ERROR, ARBITRARY_CLIENT_ERROR,
ARBITRARY_CLIENT_ERROR, ARBITRARY_CLIENT_ERROR, contents]
with mock.patch('smart_open.s3._download_fileobj', side_effect=side_effect):
actual = smart_open.s3._download_key(KEY_NAME, bucket_name=BUCKET_NAME, retries=4)
self.assertEqual(expected, actual)
def test_propagates_other_exception(self):
contents = b'hello'
create_bucket_and_key(contents=contents)
with mock.patch('smart_open.s3._download_fileobj', side_effect=ValueError):
self.assertRaises(ValueError, smart_open.s3._download_key,
KEY_NAME, bucket_name=BUCKET_NAME)
@maybe_mock_s3
class OpenTest(unittest.TestCase):
def test_read_never_returns_none(self):
"""read should never return None."""
s3 = boto3.resource('s3')
s3.create_bucket(Bucket=BUCKET_NAME)
test_string = u"ветер по морю гуляет..."
with smart_open.s3.open(BUCKET_NAME, KEY_NAME, "wb") as fout:
fout.write(test_string.encode('utf8'))
r = smart_open.s3.open(BUCKET_NAME, KEY_NAME, "rb")
self.assertEqual(r.read(), test_string.encode("utf-8"))
self.assertEqual(r.read(), b"")
self.assertEqual(r.read(), b"")
def populate_bucket(bucket_name=BUCKET_NAME, num_keys=10):
# fake (or not) connection, bucket and key
logger.debug('%r', locals())
s3 = boto3.resource('s3')
bucket_exist = cleanup_bucket(s3)
if not bucket_exist:
mybucket = s3.create_bucket(Bucket=bucket_name)
mybucket = s3.Bucket(bucket_name)
for key_number in range(num_keys):
key_name = 'key_%d' % key_number
s3.Object(bucket_name, key_name).put(Body=str(key_number))
if __name__ == '__main__':
logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.INFO)
unittest.main()