GCP – Introducing Dataflux Dataset for Cloud Storage to accelerate PyTorch AI training
Introduction
Machine learning (ML) models thrive on massive datasets, and fast data loading is key for cost-effective ML training. We recently launched a PyTorch Dataset abstraction, the Dataflux Dataset, for accelerating data loading from Google’s Cloud Storage. Dataflux provides up to 3.5x faster training times compared to fsspec, with small files.
Today’s launch builds upon Google’s commitment to open standards that spans over two decades of OSS contributions like TensorFlow, JAX, TFX, MLIR, KubeFlow, and Kubernetes, as well as sponsorship for critical OSS data science initiatives like Project Jupyter and NumFOCUS.
We also validated the Dataflux Dataset on Deep Learning IO (DLIO) benchmarks and realized similar performance gains, even with larger files. Due to this broad performance boost, we recommend using Dataflux Dataset over other libraries or direct Cloud Storage API calls for training workflows.
Key Dataflux Dataset features include:
Direct Cloud Storage integration: Eliminate the need to download data locally first.
Performance optimization: Achieve up to 3.5x faster training times, especially with small files.
PyTorch Dataset primitive: Work seamlessly with familiar PyTorch concepts.
Checkpointing support: Save and load model checkpoints directly to/from Cloud Storage.
Using Dataflux Datasets
Prerequisites: Python 3.8+
Installation: $ pip install gcs-torch-dataflux
Authentication: Use Google Cloud application-default authentication
Example: Loading images for training
There are only a few changes needed to enable the Dataflux Dataset. If you’re using PyTorch and have data in Cloud Storage, you most likely have written your own Dataset implementation. The below snippet shows how easy it is to create a Dataflux Dataset. For further details, checkout our GitHub page.
<ListValue: [StructValue([(‘code’, ‘import numpyrnimport iornfrom PIL import Imagernfrom dataflux_pytorch import dataflux_mapstyle_datasetrnrndef transform(img_in_bytes): rn return numpy.asarray(rnImage.open(io.BytesIO(img_in_bytes)))rnrndataset = dataflux_mapstyle_dataset.DatafluxMapStyleDataset(rn project_name=PROJECT_NAME,rn bucket_name=BUCKET_NAME,rn config=dataflux_mapstyle_dataset.Config(prefix=PREFIX),rn data_format_fn=transform,rn)rnrn# Use “dataset” as usual in your ML-Training loop in combination with PyTorch DataLoader.’), (‘language’, ”), (‘caption’, <wagtail.rich_text.RichText object at 0x3dfdd6a5d280>)])]>
Under the hood
To achieve such significant performance gains for Dataflux, we addressed the data-loading performance bottlenecks in ML training workflows. In a training run, data is loaded in batches from storage, and after some processing, is sent from CPU to GPU for ML-Training computations. If reading and constructing a batch takes longer than GPU computation, then the GPU is effectively stalled and underutilized, leading to longer training times.
When data is in a cloud-based object storage system (like Google’s Cloud Storage), it takes longer to fetch the data than from a local disk, especially if the data is in small objects. This is due to time-to-first-byte latency. Once an object is ‘opened’ though, the cloud storage platform provides high throughput. In Dataflux, we employ a Cloud Storage feature called Compose Objects that can dynamically combine many smaller objects into a larger object. Then, instead of fetching (say) 1024 small objects (batch size), we only fetch 30 larger objects and download those to memory. The larger objects are then decomposed back to their individual smaller objects and served back as the dataset-samples. Any temporary composed objects created in the process are also cleaned up.
Another optimization that Dataflux Datasets employs is high-throughput parallel-listing, speeding up the initial metadata needed for the dataset. Dataflux uses a sophisticated algorithm called work-stealing to significantly speed up listings; with it, even the first AI training run, or “epoch,” is faster compared to Dataflux Datasets without parallel-listing, even on datasets that have tens of millions of objects.
Together, fast-listing and dynamic-composition help ensure that ML-training with Dataflux leads to minimal GPU stalls, leading to greatly reduced training time and increased accelerator utilization.
Fast-listing and dynamic-composition are part of the Dataflux Client Libraries and available on GitHub. Dataflux Dataset uses these client libraries under the hood.
Dataflux is available now
Give the Dataflux Dataset for PyTorch (or the Dataflux Python client library if writing your own ML training dataset code) a try and let us know how it boosts your workflows!
You can learn more about this and our other storage AI related capabilities from our Google Cloud Next ‘24 recorded session “How to define a storage infrastructure for AI and analytical workloads”
Read More for the details.