Petastorm is an open source data access library developed at Uber ATG. This library enables single machine or distributed training and evaluation of deep learning models directly from datasets in Apache Parquet format. Petastorm supports popular Python-based machine learning (ML) frameworks such as Tensorflow, PyTorch, and PySpark. It can also be used from pure Python code.
Documentation web site: https://petastorm.readthedocs.io
pip install petastorm
There are several extra dependencies that are defined by the petastorm
package that are not installed automatically.
The extras are: tf
, tf_gpu
, torch
, opencv
, docs
, test
.
For example to trigger installation of GPU version of tensorflow and opencv, use the following pip command:
pip install petastorm[opencv,tf_gpu]
A dataset created using Petastorm is stored in Apache Parquet format. On top of a Parquet schema, petastorm also stores higher-level schema information that makes multidimensional arrays into a native part of a petastorm dataset.
Petastorm supports extensible data codecs. These enable a user to use one of the standard data compressions (jpeg, png) or implement her own.
Generating a dataset is done using PySpark. PySpark natively supports Parquet format, making it easy to run on a single machine or on a Spark compute cluster. Here is a minimalistic example writing out a table with some random data.
HelloWorldSchema = Unischema('HelloWorldSchema', [
UnischemaField('id', np.int32, (), ScalarCodec(IntegerType()), False),
UnischemaField('image1', np.uint8, (128, 256, 3), CompressedImageCodec('png'), False),
UnischemaField('other_data', np.uint8, (None, 128, 30, None), NdarrayCodec(), False),
])
def row_generator(x):
"""Returns a single entry in the generated dataset. Return a bunch of random values as an example."""
return {'id': x,
'image1': np.random.randint(0, 255, dtype=np.uint8, size=(128, 256, 3)),
'other_data': np.random.randint(0, 255, dtype=np.uint8, size=(4, 128, 30, 3))}
def generate_hello_world_dataset(output_url='file:///tmp/hello_world_dataset'):
rows_count = 10
rowgroup_size_mb = 256
spark = SparkSession.builder.config('spark.driver.memory', '2g').master('local[2]').getOrCreate()
sc = spark.sparkContext
# Wrap dataset materialization portion. Will take care of setting up spark environment variables as
# well as save petastorm specific metadata
with materialize_dataset(spark, output_url, HelloWorldSchema, rowgroup_size_mb):
rows_rdd = sc.parallelize(range(rows_count))\
.map(row_generator)\
.map(lambda x: dict_to_spark_row(HelloWorldSchema, x))
spark.createDataFrame(rows_rdd, HelloWorldSchema.as_spark_schema()) \
.coalesce(10) \
.write \
.mode('overwrite') \
.parquet(output_url)
HelloWorldSchema
is an instance of aUnischema
object.Unischema
is capable of rendering types of its fields into different framework specific formats, such as: SparkStructType
, Tensorflowtf.DType
and numpynumpy.dtype
.- To define a dataset field, you need to specify a
type
,shape
, acodec
instance and whether the field is nullable for each field of theUnischema
. - We use PySpark for writing output Parquet files. In this example, we launch
PySpark on a local box (
.master('local[2]')
). Of course for a larger scale dataset generation we would need a real compute cluster. - We wrap spark dataset generation code with the
materialize_dataset
context manager. The context manager is responsible for configuring row group size at the beginning and write out petastorm specific metadata at the end. - The row generating code is expected to return a Python dictionary indexed by
a field name. We use
row_generator
function for that. dict_to_spark_row
converts the dictionary into apyspark.Row
object while ensuring schemaHelloWorldSchema
compliance (shape, type and is-nullable condition are tested).- Once we have a
pyspark.DataFrame
we write it out to a parquet storage. The parquet schema is automatically derived fromHelloWorldSchema
.
The petastorm.reader.Reader
class is the main entry point for user
code that accesses the data from an ML framework such as Tensorflow or Pytorch.
The reader has multiple features such as:
- Selective column readout
- Multiple parallelism strategies: thread, process, single-threaded (for debug)
- N-grams readout support
- Row filtering (row predicates)
- Shuffling
- Partitioning for multi-GPU training
- Local caching
Reading a dataset is simple using the petastorm.reader.Reader
class:
with Reader('hdfs://myhadoop/some_dataset') as reader:
for row in reader:
print(row)
hdfs://...
and file://...
are supported URL protocols.
Once a Reader
is instantiated, you can use it as an iterator.
To hookup the reader into a tensorflow graph, you can use the tf_tensors
function:
with Reader('file:///some/localpath/a_dataset') as reader:
row_tensors = tf_tensors(reader)
with tf.Session() as session:
for _ in range(3):
print(session.run(row_tensors))
Alternatively, you can use new tf.data.Dataset
API;
with Reader('file:///some/localpath/a_dataset') as reader:
dataset = make_petastorm_dataset(reader)
iterator = dataset.make_one_shot_iterator()
tensor = iterator.get_next()
with tf.Session() as sess:
sample = sess.run(tensor)
print(sample.id)
As illustrated in
pytorch_example.py,
reading a petastorm dataset from pytorch
can be done via the adapter class petastorm.pytorch.DataLoader
,
which allows custom pytorch collating function and transforms to be supplied.
Be sure you have torch
and torchvision
installed:
pip install torchvision
The minimalist example below assumes the definition of a Net
class and
train
and test
functions, included in pytorch_example
:
import torch
from petastorm.pytorch import DataLoader
torch.manual_seed(1)
device = torch.device('cpu')
model = Net().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
def _transform_row(mnist_row):
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
return (transform(mnist_row['image']), mnist_row['digit'])
with DataLoader(Reader('file:///localpath/mnist/train', num_epochs=10),
batch_size=64, transform=_transform_row) as train_loader:
train(model, device, train_loader, 10, optimizer, 1)
with DataLoader(Reader('file:///localpath/mnist/test', num_epochs=10),
batch_size=1000, transform=_transform_row) as test_loader:
test(model, device, test_loader)
Using the Parquet data format, which is natively supported by Spark, makes it possible to use a wide range of Spark tools to analyze and manipulate the dataset. The example below shows how to read a Petastorm dataset as a Spark RDD object:
# Create a dataframe object from a parquet file
dataframe = spark.read.parquet(dataset_url)
# Show a schema
dataframe.printSchema()
# Count all
dataframe.count()
# Show a single column
dataframe.select('id').show()
SQL can be used to query a Petastorm dataset:
spark.sql(
'SELECT count(id) '
'from parquet.`file:///tmp/hello_world_dataset`').collect()
You can find a full code sample here: pyspark_hello_world.py,
See the Troubleshooting page and please submit a ticket if you can't find an answer.
See the Development page for instructions on how to develop Petastorm and run tests.
- Gruener, R., Cheng, O., and Litvin, Y. (2018) Introducing Petastorm: Uber ATG's Data Access Library for Deep Learning. URL: https://eng.uber.com/petastorm/