diff --git a/tests/test_multipart_encoder.py b/tests/test_multipart_encoder.py index 575f54c..501079a 100644 --- a/tests/test_multipart_encoder.py +++ b/tests/test_multipart_encoder.py @@ -1,4 +1,5 @@ # -*- coding: utf-8 -*- +import contextlib import unittest import io @@ -233,24 +234,34 @@ def test_regresion_1(self): "test": "t" * 100 } - for x in range(30): - fields['f%d' % x] = ( - 'test', open('tests/test_multipart_encoder.py', 'rb') - ) - - m = MultipartEncoder(fields=fields) - total_size = m.len - - blocksize = 8192 - read_so_far = 0 - - while True: - data = m.read(blocksize) - if not data: - break - read_so_far += len(data) - - assert read_so_far == total_size + @contextlib.contextmanager + def open_many(): + files = [] + try: + for x in range(30): + fp = open(__file__, 'rb') + files.append(fp) + fields['f%d' % x] = ('test', fp) + yield + finally: + while files: + fp = files.pop() + fp.close() + + with open_many(): + m = MultipartEncoder(fields=fields) + total_size = m.len + + blocksize = 8192 + read_so_far = 0 + + while True: + data = m.read(blocksize) + if not data: + break + read_so_far += len(data) + + assert read_so_far == total_size def test_regression_2(self): """Ensure issue #31 doesn't ever happen again."""