diff --git a/petastorm/py_dict_reader_worker.py b/petastorm/py_dict_reader_worker.py index cee5f62f2..278102599 100644 --- a/petastorm/py_dict_reader_worker.py +++ b/petastorm/py_dict_reader_worker.py @@ -14,6 +14,7 @@ from __future__ import division import hashlib +import threading import numpy as np from pyarrow import parquet as pq @@ -44,6 +45,7 @@ def _select_cols(a_dict, keys): class PyDictReaderWorkerResultsQueueReader(object): def __init__(self): + self._result_buffer_lock = threading.Lock() self._result_buffer = [] @property @@ -54,23 +56,24 @@ def read_next(self, workers_pool, schema, ngram): try: # We are receiving decoded rows from the worker in chunks. We store the list internally # and return a single item upon each consequent call to __next__ - if not self._result_buffer: - # Reverse order, so we can pop from the end of the list in O(1) while maintaining - # order the items are returned from the worker - rows_as_dict = list(reversed(workers_pool.get_results())) - - if ngram: - for ngram_row in rows_as_dict: - for timestamp in ngram_row.keys(): - row = ngram_row[timestamp] - schema_at_timestamp = ngram.get_schema_at_timestep(schema, timestamp) - - ngram_row[timestamp] = schema_at_timestamp.make_namedtuple(**row) - self._result_buffer = rows_as_dict - else: - self._result_buffer = [schema.make_namedtuple(**row) for row in rows_as_dict] - - return self._result_buffer.pop() + with self._result_buffer_lock: + if not self._result_buffer: + # Reverse order, so we can pop from the end of the list in O(1) while maintaining + # order the items are returned from the worker + list_of_rows = list(reversed(workers_pool.get_results())) + + if ngram: + for ngram_row in list_of_rows: + for timestamp in ngram_row.keys(): + row = ngram_row[timestamp] + schema_at_timestamp = ngram.get_schema_at_timestep(schema, timestamp) + + ngram_row[timestamp] = schema_at_timestamp.make_namedtuple(**row) + self._result_buffer = list_of_rows + else: + self._result_buffer = [schema.make_namedtuple(**row) for row in list_of_rows] + + return self._result_buffer.pop() except EmptyResultError: raise StopIteration diff --git a/petastorm/tests/test_end_to_end.py b/petastorm/tests/test_end_to_end.py index 88cd38f55..0de4f3c6a 100644 --- a/petastorm/tests/test_end_to_end.py +++ b/petastorm/tests/test_end_to_end.py @@ -18,6 +18,7 @@ import numpy as np import pyarrow.hdfs import pytest +from concurrent.futures import ThreadPoolExecutor from pyspark.sql import SparkSession from pyspark.sql.types import LongType, ShortType, StringType @@ -612,3 +613,15 @@ def test_dataset_path_is_a_unicode(synthetic_dataset, reader_factory): def test_make_reader_fails_loading_non_petastrom_dataset(scalar_dataset): with pytest.raises(RuntimeError, match='use make_batch_reader'): make_reader(scalar_dataset.url) + + +def test_multithreaded_reads(synthetic_dataset): + with make_reader(synthetic_dataset.url, workers_count=5, num_epochs=1) as reader: + with ThreadPoolExecutor(max_workers=10) as executor: + def read_one_row(): + return next(reader) + + futures = [executor.submit(read_one_row) for _ in range(100)] + results = [f.result() for f in futures] + assert len(results) == len(synthetic_dataset.data) + assert set(r.id for r in results) == set(d['id'] for d in synthetic_dataset.data)