layout | title | permalink | redirect_from | ||
---|---|---|---|---|---|
post |
PYTORCH |
/docs/pytorch |
|
In AIStore, PyTorch integration is a growing set of datasets (both iterable and map-style), samplers, and dataloaders. This readme illustrates taxonomy of the associated abstractions and provides API reference documentation.
For usage examples, please see:
- base_map_dataset
- base_iter_dataset
- map_dataset
- iter_dataset
- shard_reader
- worker_request_client
- multishard_dataset
- aisio
Base class for AIS Map Style Datasets
Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
class AISBaseMapDataset(ABC, Dataset)
A base class for creating map-style AIS Datasets. Should not be instantiated directly. Subclasses
should implement :meth:__getitem__
which fetches a samples given a key from the dataset and can optionally
override other methods from torch Dataset such as :meth:__len__
and :meth:__getitems__
.
Arguments:
ais_source_list
Union[AISSource, List[AISSource]] - Single or list of AISSource objects to load data prefix_map (Dict(AISSource, List[str]), optional): Map of AISSource objects to list of prefixes that only allows objects with the specified prefixes to be used from each source
Base class for AIS Iterable Style Datasets
Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
class AISBaseIterDataset(ABC, IterableDataset)
A base class for creating AIS Iterable Datasets. Should not be instantiated directly. Subclasses
should implement :meth:__iter__
which returns the samples from the dataset and can optionally
override other methods from torch IterableDataset such as :meth:__len__
.
Arguments:
ais_source_list
Union[AISSource, List[AISSource]] - Single or list of AISSource objects to load data prefix_map (Dict(AISSource, List[str]), optional): Map of AISSource objects to list of prefixes that only allows objects with the specified prefixes to be used from each source
@abstractmethod
def __iter__() -> Iterator
Return iterator with samples in this dataset.
Returns:
Iterator
- Iterator of samples
def __len__()
Returns the length of the dataset. Note that calling this will iterate through the dataset, taking O(N) time.
NOTE: If you want the length of the dataset after iterating through
it, use for i, data in enumerate(dataset)
instead.
PyTorch Dataset for AIS.
Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
class AISMapDataset(AISBaseMapDataset)
A map-style dataset for objects in AIS.
If etl_name
is provided, that ETL must already exist on the AIStore cluster.
Arguments:
-
ais_source_list
Union[AISSource, List[AISSource]] - Single or list of AISSource objects to load data prefix_map (Dict(AISSource, List[str]), optional): Map of AISSource objects to list of prefixes that only allows objects with the specified prefixes to be used from each source -
etl_name
str, optional - Optional ETL on the AIS cluster to apply to each objectNOTE: Each object is represented as a tuple of object_name (str) and object_content (bytes)
Iterable Dataset for AIS
Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
class AISIterDataset(AISBaseIterDataset)
An iterable-style dataset that iterates over objects in AIS and yields
samples represented as a tuple of object_name (str) and object_content (bytes).
If etl_name
is provided, that ETL must already exist on the AIStore cluster.
Arguments:
ais_source_list
Union[AISSource, List[AISSource]] - Single or list of AISSource objects to load data prefix_map (Dict(AISSource, Union[str, List[str]]), optional): Map of AISSource objects to list of prefixes that only allows objects with the specified prefixes to be used from each sourceetl_name
str, optional - Optional ETL on the AIS cluster to apply to each objectshow_progress
bool, optional - Enables console dataset reading progress indicator
Yields:
Tuple[str, bytes]: Each item is a tuple where the first element is the name of the object and the second element is the byte representation of the object data.
AIS Shard Reader for PyTorch
PyTorch Dataset and DataLoader for AIS.
Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
class AISShardReader(AISBaseIterDataset)
An iterable-style dataset that iterates over objects stored as Webdataset shards and yields samples represented as a tuple of basename (str) and contents (dictionary).
Arguments:
bucket_list
Union[Bucket, List[Bucket]] - Single or list of Bucket objects to load data prefix_map (Dict(AISSource, Union[str, List[str]]), optional): Map of Bucket objects to list of prefixes that only allows objects with the specified prefixes to be used from each sourceetl_name
str, optional - Optional ETL on the AIS cluster to apply to each objectshow_progress
bool, optional - Enables console shard reading progress indicator
Yields:
Tuple[str, Dict(str, bytes)]: Each item is a tuple where the first element is the basename of the shard and the second element is a dictionary mapping strings of file extensions to bytes.
def __len__()
Returns the length of the dataset. Note that calling this will iterate through the dataset, taking O(N) time.
NOTE: If you want the length of the dataset after iterating through
it, use for i, data in enumerate(dataset)
instead.
class ZeroDict(dict)
When collate_fn
is called while using ShardReader with a dataloader,
the content dictionaries for each sample are merged into a single dictionary
with file extensions as keys and lists of contents as values. This means,
however, that each sample must have a value for that file extension in the batch
at iteration time or else collation will fail. To avoid forcing the user to
pass in a custom collation function, we workaround the default implementation
of collation.
As such, we define a dictionary that has a default value of b""
(zero bytes)
for every key that we have seen so far. We cannot use None as collation
does not accept None. Initially, when we open a shard tar, we collect every file type
(pre-processing pass) from its members and cache those. Then, we read the shard files.
Lastly, before yielding the sample, we wrap its content dictionary with this custom dictionary
to insert any keys that it does not contain, hence ensuring consistent keys across
samples.
NOTE: For our use case, defaultdict
does not work due to needing
a lambda
which cannot be pickled in multithreaded contexts.
Worker Supported Request Client for PyTorch
This client allows PyTorch workers to have separate request sessions per thread which is needed in order to use workers in a DataLoader as the default implementation of RequestClient and requests is not thread-safe.
Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
class WorkerRequestClient(RequestClient)
Extension that supports PyTorch and multiple workers of internal client for buckets, objects, jobs, etc. to use for making requests to an AIS cluster.
Arguments:
client
RequestClient - Existing RequestClient to replace
@property
def session()
Returns: Active request session acquired for a specific PyTorch dataloader worker
Multishard Stream Dataset for AIS.
Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
class AISMultiShardStream(IterableDataset)
An iterable-style dataset that iterates over multiple shard streams and yields combined samples.
Arguments:
data_sources
List[DataShard] - List of DataShard objects
Returns:
Iterable
- Iterable over the combined samples, where each sample is a tuple of one object bytes from each shard stream
AIS IO Datapipe Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
@functional_datapipe("ais_list_files")
class AISFileListerIterDataPipe(IterDataPipe[str])
Iterable Datapipe that lists files from the AIStore backends with the given URL prefixes.
(functional name: list_files_by_ais
).
Acceptable prefixes include but not limited to - ais://bucket-name
, ais://bucket-name/
Notes:
- This function also supports files from multiple backends (
aws://..
,gcp://..
, etc.) - Input must be a list and direct URLs are not supported.
- length is -1 by default, all calls to len() are invalid as not all items are iterated at the start.
- This internally uses AIStore Python SDK.
Arguments:
source_datapipe(IterDataPipe[str])
- a DataPipe that contains URLs/URL prefixes to objects on AISlength(int)
- length of the datapipeurl(str)
- AIStore endpoint
Example:
from torchdata.datapipes.iter import IterableWrapper, AISFileLister ais_prefixes = IterableWrapper(['gcp://bucket-name/folder/', 'aws:bucket-name/folder/', 'ais://bucket-name/folder/', ...]) dp_ais_urls = AISFileLister(url='localhost:8080', source_datapipe=ais_prefixes) for url in dp_ais_urls: ... pass
dp_ais_urls = ais_prefixes.list_files_by_ais(url='localhost:8080') for url in dp_ais_urls: ... pass
Notes:
http://localhost:8080
address (above and elsewhere) is used for purely demonstration purposes and must be understood as a placeholder for an arbitrary AIStore endpoint (AIS_ENDPOINT
).
@functional_datapipe("ais_load_files")
class AISFileLoaderIterDataPipe(IterDataPipe[Tuple[str, StreamWrapper]])
Iterable DataPipe that loads files from AIStore with the given URLs (functional name: load_files_by_ais
).
Iterates all files in BytesIO format and returns a tuple (url, BytesIO).
Notes:
- This function also supports files from multiple backends (
aws://..
,gcp://..
, etc) - Input must be a list and direct URLs are not supported.
- This internally uses AIStore Python SDK.
- An
etl_name
can be provided to run an existing ETL on the AIS cluster. See https://github.com/NVIDIA/aistore/blob/main/docs/etl.md for more info on AIStore ETL.
Arguments:
source_datapipe(IterDataPipe[str])
- a DataPipe that contains URLs/URL prefixes to objectslength(int)
- length of the datapipeurl(str)
- AIStore endpointetl_name
str, optional - Optional etl on the AIS cluster to apply to each object
Example:
from torchdata.datapipes.iter import IterableWrapper, AISFileLister,AISFileLoader ais_prefixes = IterableWrapper(['gcp://bucket-name/folder/', 'aws:bucket-name/folder/', 'ais://bucket-name/folder/', ...]) dp_ais_urls = AISFileLister(url='localhost:8080', source_datapipe=ais_prefixes) dp_cloud_files = AISFileLoader(url='localhost:8080', source_datapipe=dp_ais_urls) for url, file in dp_cloud_files: ... pass
dp_cloud_files = dp_ais_urls.load_files_by_ais(url='localhost:8080') for url, file in dp_cloud_files: ... pass
@functional_datapipe("ais_list_sources")
class AISSourceLister(IterDataPipe[str])
def __init__(ais_sources: List[AISSource], prefix="", etl_name=None)
Iterable DataPipe over the full URLs for each of the provided AIS source object types
Arguments:
ais_sources
List[AISSource] - List of types implementing the AISSource interface: Bucket, ObjectGroup, Object, etc.prefix
str, optional - Filter results to only include objects with names starting with this prefixetl_name
str, optional - Pre-existing ETL on AIS to apply to all selected objects on the cluster side