Skip to content

Commit

Permalink
Synchronize reading new samples from Reader.
Browse files Browse the repository at this point in the history
Lack of synchronization in the `PyDictReaderWorkerResultsQueueReader` was resulting in some samples being missed or even crash.
  • Loading branch information
Yevgeni Litvin authored and selitvin committed Mar 22, 2019
1 parent b865cce commit 1a24452
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 17 deletions.
37 changes: 20 additions & 17 deletions petastorm/py_dict_reader_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from __future__ import division

import hashlib
import threading

import numpy as np
from pyarrow import parquet as pq
Expand Down Expand Up @@ -44,6 +45,7 @@ def _select_cols(a_dict, keys):

class PyDictReaderWorkerResultsQueueReader(object):
def __init__(self):
self._result_buffer_lock = threading.Lock()
self._result_buffer = []

@property
Expand All @@ -54,23 +56,24 @@ def read_next(self, workers_pool, schema, ngram):
try:
# We are receiving decoded rows from the worker in chunks. We store the list internally
# and return a single item upon each consequent call to __next__
if not self._result_buffer:
# Reverse order, so we can pop from the end of the list in O(1) while maintaining
# order the items are returned from the worker
rows_as_dict = list(reversed(workers_pool.get_results()))

if ngram:
for ngram_row in rows_as_dict:
for timestamp in ngram_row.keys():
row = ngram_row[timestamp]
schema_at_timestamp = ngram.get_schema_at_timestep(schema, timestamp)

ngram_row[timestamp] = schema_at_timestamp.make_namedtuple(**row)
self._result_buffer = rows_as_dict
else:
self._result_buffer = [schema.make_namedtuple(**row) for row in rows_as_dict]

return self._result_buffer.pop()
with self._result_buffer_lock:
if not self._result_buffer:
# Reverse order, so we can pop from the end of the list in O(1) while maintaining
# order the items are returned from the worker
list_of_rows = list(reversed(workers_pool.get_results()))

if ngram:
for ngram_row in list_of_rows:
for timestamp in ngram_row.keys():
row = ngram_row[timestamp]
schema_at_timestamp = ngram.get_schema_at_timestep(schema, timestamp)

ngram_row[timestamp] = schema_at_timestamp.make_namedtuple(**row)
self._result_buffer = list_of_rows
else:
self._result_buffer = [schema.make_namedtuple(**row) for row in list_of_rows]

return self._result_buffer.pop()

except EmptyResultError:
raise StopIteration
Expand Down
13 changes: 13 additions & 0 deletions petastorm/tests/test_end_to_end.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import numpy as np
import pyarrow.hdfs
import pytest
from concurrent.futures import ThreadPoolExecutor
from pyspark.sql import SparkSession
from pyspark.sql.types import LongType, ShortType, StringType

Expand Down Expand Up @@ -612,3 +613,15 @@ def test_dataset_path_is_a_unicode(synthetic_dataset, reader_factory):
def test_make_reader_fails_loading_non_petastrom_dataset(scalar_dataset):
with pytest.raises(RuntimeError, match='use make_batch_reader'):
make_reader(scalar_dataset.url)


def test_multithreaded_reads(synthetic_dataset):
with make_reader(synthetic_dataset.url, workers_count=5, num_epochs=1) as reader:
with ThreadPoolExecutor(max_workers=10) as executor:
def read_one_row():
return next(reader)

futures = [executor.submit(read_one_row) for _ in range(100)]
results = [f.result() for f in futures]
assert len(results) == len(synthetic_dataset.data)
assert set(r.id for r in results) == set(d['id'] for d in synthetic_dataset.data)

0 comments on commit 1a24452

Please sign in to comment.