2023-10-23
1872 words ~7 minute read
If you’re training small to medium size vision models and are loading data from an HDD (or worse, over a network), dataloading is probably a bottleneck for you.
Oftentimes, accelerating dataloading requires significant engineering effort. Tricks like storing data in contiguous raw blocks or writing custom collate functions can work really nicely, but often require you to tailor the tricks for each dataset you work with. It would be great if we could get the benefits of hand-engineered dataloading pipelines without actually having to do the engineering!
This post walks you through a trick that I call stochastic
caching. I’ll introduce the concept and build you up from a
naive solution to a fully-fledged library. Importantly,
implementing stochastic caching with this library requires
minimal code changes and works on any dataset or machine. I’ve
released the library on PyPI, under the name
stocaching
. You can find it on GitHub,
or install it with pip:
pip install stocaching
Stocaching is dead-simple (only 1 file!) – if you don’t want to install it with pip, you can just copy-paste it into your projects.
What is stochastic caching?
We all know that (for multi-epoch training) we should cache small datasets in RAM to speed up dataloading. However, once we get to larger datasets, we often give up on caching completely. This seems like a waste!
If we are training on large, shuffled datasets, we can just cache a subset of the data that fits in RAM. The cached data will then be randomly distributed throughout each batch (as shown in the ASCII figure below). This reduces the I/O workload on the disk and, if we cache a large enough proportion of our data, should move the bottleneck off the dataloading pipeline.
┌─────────────────┐┌─────────────────┐┌─────────────────┐
│ Batch 1 ││ Batch 2 ││ Batch 3 │
│ ┌──┬──┬──┬──┐ ││ ┌──┬──┬──┬──┐ ││ ┌──┬──┬──┬──┐ │
│ │┼┼│ │ │┼┼│ ││ │ │┼┼│ │┼┼│ ││ │ │┼┼│┼┼│┼┼│ │ ...
│ │┼┼│ │ │┼┼│ ││ │ │┼┼│ │┼┼│ ││ │ │┼┼│┼┼│┼┼│ │
│ └──┴──┴──┴──┘ ││ └──┴──┴──┴──┘ ││ └──┴──┴──┴──┘ │
└─────────────────┘└─────────────────┘└─────────────────┘
# = Sample cached in RAM
I call this trick stochastic caching. It’s conceptually simple and easy to implement. Given this simplicity, it’s surprising how effective it can be. Stochastic caching can:
- Provide dataloading speedups for datasets too large to fit in memory.
- Be lazy (i.e. only cache samples the first time they’re needed).
- Give speedups that scale linearly with the % of the dataset being cached.
The naive solution (doesn’t work)
What we want is to lazily cache samples in RAM the first time
they are read in __getitem__()
. If we can do this
right, the first epoch will still be bottlenecked by
dataloading, but future epochs will be faster, as some of the
data can be retrieved from RAM instead of disk.
A naive implementation would look something like this:
class NaiveDataset(Dataset):
def __init__(self) -> None:
super().__init__()
# other dataset logic
...
= N # number of samples you want to cache
cache_length = (C, H, W) # shape of data (not including batch)
data_dims
self.cache = torch.empty(
*data_dims), dtype=torch.float32
(cache_length,
)self.cached_idxs = set() # keeps track of what has been cached
def __getitem__(self, idx: int) -> tuple[Tensor, Tensor]:
if idx >= cache_length:
= ... # read image to float32 tensor
x elif idx not in self.cached_idxs:
= ... # read image to float32 tensor
x self.cache[idx] = x
self.cached_idxs.add(idx)
else:
= self.cache[idx] # get float32 image tensor from cache
x
= transforms(x)
x = ... # load label
label return x, label
The key problem with any method that implements caching like
this is that the NaiveDataset
object gets copied
for each worker in your dataloader. This means that you’ll have
to keep several whole copies of your dataset in RAM. Worse
still, since each worker sees only a subset of the dataset each
epoch, it will take many epochs for each worker to have filled
its own cache. Outside of the special case where
num_workers=0
, this method is inefficient in memory
and provides only a marginal speed up.
What we need is a method for sharing the cache between all worker processes…
Caching in shared memory
Fortunately, it is possible to share data between multiple
processes in python using multiprocessing
. In fact,
I started developing this project after finding an
implementation for shared memory datasets that @ptrblck
shared in 2018.
Adapting this implementation gets us to something like this:
import ctypes
import multiprocessing as mp
import numpy as np
import torch
from torch.utils.data import Dataset
class CacheDataset(Dataset):
def __init__(self) -> None:
super().__init__()
# other dataset logic
...
= N # number of samples you want to cache
cache_length = (C, H, W) # shape of data (not including batch)
data_dims
= mp.Array(
shared_array_base *np.prod(data_dims)
ctypes.c_float, num_samples
)= np.ctypeslib.as_array(shared_array_base.get_obj())
shared_array = shared_array.reshape(cache_length, *data_dims)
shared_array self.cache = torch.from_numpy(shared_array)
self.cache *=0
def __getitem__(self, idx: int) -> tuple[Tensor, Tensor]:
if idx >= cache_length:
= ... # read image to float32 tensor
x # hack to see if cache slot has changed since initialisation
elif not torch.all(self.cache == 0):
= ... # read image to float32 tensor
x self.cache[idx] = x
else:
= self.cache[idx] # get float32 image tensor from cache
x
= transforms(x)
x = ... # load label
label return img, label
This works pretty great! It gets around the data copy problem from the naive solution and will form the backbone of what we will do going forwards. Now, let’s extract this idea into a simple library that can be dropped into any codebase for a free speedup.
The stocaching
library
So far, we have a cool trick that requires quite a bit of
fiddly code and is not massively portable across datasets. If we
refactor the shared memory cache into a standalone module, add
some convenience features, and handle the edge-cases, we end up
with the stocaching
library, available now on GitHub
and PyPI.
The main contribution of stocaching
is the
SharedCache
class. Using it is super simple, and is
best demonstrated through a minimal example:
import torch
from stocaching import SharedCache
from torch.utils.data import Dataset
class MyDataset(Dataset):
def __init__(self):
super().__init__()
# set up dataset
...
= N # number of samples in the full dataset
dataset_len = (C, H, W) # data dims (not including batch)
data_dims
# initialize the cache
self.cache = SharedCache(
=32,
size_limit_gib=dataset_len,
dataset_len=data_dims,
data_dims=torch.uint8,
dtype
)
def __getitem__(self, idx):
# retrieve data from cache if it's there
= self.cache.get_slot(idx)
x # x will be None if the cache slot was empty or OOB
if x is None:
= ... # load data to uint8 tensor from disk
x self.cache.set_slot(idx, x) # try to cache x
return x
When initialising a SharedCache
object, we tell
it about the size of our dataset and the maximum amount of
memory we want the cache to take up. SharedCache
then calculates how many samples can fit in the cache and
allocates space accordingly.
You can interact with SharedCache
through
__getitem__
and __setitem__
methods,
or even by accessing the underlying PyTorch tensor. For
convenience, we expose the get_slot
and
set_slot
methods, which can gracefully handle cases
where they are passed an out-of-bounds index. This design helps
to reduce the amount of code necessary for common use cases.
Importantly, when using SharedCache
, you
don’t have to worry about whether the dataset is being fully or
partially cached. Just by changing the
size_limit_gib
parameter, you can run the same code
on any machine, getting benefits depending on how much RAM you
have. Even caching only 10% of your dataset can still give
noticeable speeups when you are very data-bottlenecked (such as
when training small models).
Benchmarks
I’ve run some basic benchmarks to test the method under single GPU image classification workloads. More details can be found in the repo, but the headline results are here:
Local HDD | Remote data |
---|---|
On both local HDD and network drive workloads, I found that we get speedups which scale linearly with the % of the dataset being cached. While there is a small overhead in the first epoch from filling the cache, this is quickly compensated for by the speedup in the second epoch. Even caching very small percentages of the data (5-10 %) seems beneficial over multiple epochs.
Tips and tricks
1. How do I decide how much memory to allocate to the cache?
As much as you like! The speedup from caching scales linearly
with the % of your dataset being cached. When the
SharedCache
object is created, it will print out
the calculated size of your dataset and how many samples it
decided to cache.
The shared memory is stored in /dev/shm
(tmpfs),
so this is likely the limiting factor for you. We provide a
convenience function get_shm_size()
to check how
large it is. Alternatively, check with df -h
.
Most Unix-like systems have /dev/shm
pre-set to
50% of your RAM. You can temporarily resize it (e.g. to 128 GiB)
by running:
mount -o remount,size=128G /dev/shm
.
2. How should I organise my transforms/augmentations?
Generally, you don’t want to do any random augmentations before caching because the cache will kill the randomness. It’s also a good idea to cache data in the smallest dtype you can to save space. For example, if you are reading images from jpg/png, you can usually cache them in uint8 format instead of float32.
Splitting your transforms/augmentation pipeline into two phases is a good idea. The first phase converts your data to a (possibly resized) uint8 tensor. The output of this phase gets cached. The second phase should do random augmentations, convert to float32, and normalise. This phase happens ‘on-line’ and the output goes straight into your model.
For a working example of how to do this properly, see the benchmark implementation here.
What about multi GPU (DDP) training?
The current implementation can work for DDP training, however, it may not be as efficient. Each GPU process would get its own cache and the caches would overlap. I’m still working on the library and haven’t focused too much yet on the DDP use case.
It shouldn’t be too difficult to extend the library to cover
the DDP case. My current thinking is to make a second class,
SharedCacheDDP
, which will use the
multiprocessing.shared_memory.SharedMemory
feature.
This feature would allow the cache to be more easily shared
across the GPU processes, however, would require a bit more
attention from the user.
Conclusions
GPUs are expensive. If we have a bottleneck in our dataloading pipeline, we are leaving performance and money on the table.
For dataloading-bottlenecked image classification workloads, my testing has shown this trick to be pretty successful. I’m honestly not too sure why I haven’t seen anything similar in deep learning libraries and codebases, but I’m now starting to use it in my research.
If you have any comments or suggestions (or even want to contribute to the codebase!), feel free to open an issue on the GitHub repo.