Pytorch distributed data parallel step by step
Background
How can you speed up your training? What should you do when your model is too large to fit into a single GPU’s memory? How can you efficiently utilize multiple GPUs?
Distributed training is designed to address these challenges. In PyTorch, two common approaches for distributed training are DataParallel and Distributed Data Parallel (DDP).
DataParallel
The DataParallel
module splits a batch of data into smaller mini-batches, each assigned to a different GPU. Every GPU holds a copy of the model. After the forward pass, gradients from all GPUs are sent to a master GPU, which performs the back-propagation and updates the model parameters. The updated parameters are then broadcasted back to all GPUs.
However, there are key limitations with DataParallel
:
- Communication Overhead: Gradients and updated model parameters must be transmitted between GPUs, causing significant communication overhead.
- Memory Bottleneck: The memory usage is constrained by the master GPU, as it handles all back-propagation. This prevents the full utilization of other GPUs’ memory.
- Slower Training: Relying on a single GPU for back-propagation slows down the training process.
Distributed Data Parallel (DDP)
Distributed Data Parallel (DDP) is a more efficient solution that addresses the drawbacks of DataParallel
. DDP attaches autograd hooks to each parameter, triggering gradient synchronization across GPUs using the AllReduce
operation. This allows all GPUs to perform back-propagation independently after the forward pass.
Key Advantages:
- Reduced Communication Overhead: Only gradients are synchronized, reducing data transfer costs.
- Balanced Memory Usage: Each GPU handles its own back-propagation, resulting in similar memory usage across GPUs.
- Scalability: DDP supports multi-node setups and peer-to-peer communication between GPUs.
- Improved Performance: Multiple CPU processes are used, alleviating the limitations of Python’s Global Interpreter Lock (GIL).
For more details, see PyTorch Distributed Overview.
This guide focuses on implementing DDP for single-machine, multi-GPU setups.
Getting Started with DDP
Running DDP
The torch.distributed.launch
utility spawns multiple processes for you. Set nproc_per_node
to the number of GPUs on your machine so that each process corresponds to one GPU.
CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 main.py $args
Preparing Data
Supervised Learning
Use DistributedSampler
to split the dataset among processes:
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=..., sampler=train_sampler)
Reinforcement Learning
In reinforcement learning, run the environment in each rank process with different seeds to ensure diversity.
DDP Initialization with NVIDIA NCCL Backend
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--local_rank", default=-1)
local_rank = parser.parse_args().local_rank
# Initialize DDP
dist.init_process_group(backend='nccl', init_method='env://')
rank = dist.get_rank()
world_size = dist.get_world_size()
print(f"My rank={rank}, local_rank={local_rank}")
torch.cuda.set_device(local_rank)
Wrapping the Model
Wrap your model with DistributedDataParallel
:
model = model.to(device)
model = DDP(model, device_ids=[local_rank], output_device=local_rank)
Training
Synchronize the sampler for each epoch and perform training as usual:
for epoch in range(num_epochs):
train_loader.sampler.set_epoch(epoch)
for data, label in train_loader:
prediction = model(data)
loss = loss_fn(prediction, label)
loss.backward()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
optimizer.step()
Logging Data
Use torch.distributed.reduce
to aggregate data across ranks. For example, summing the loss across GPUs and calculating the mean:
loss = loss.clone().detach()
dist.reduce(loss, dst=0)
if dist.get_rank() == 0:
loss_mean = loss / dist.get_world_size()
print(f"Epoch: {epoch}, Loss: {loss_mean}")
Saving and Loading Checkpoints
Saving Checkpoints
Only save checkpoints on rank 0:
if dist.get_rank() == 0:
checkpoint_state = {
'iter_no': iter_no,
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
}
torch.save(checkpoint_state, checkpoint_path)
Loading Checkpoints
Map the checkpoint to the current rank’s device:
def load_checkpoint(model, optimizer, rank, checkpoint_path):
map_location = {'cuda:%d' % 0: 'cuda:%d' % rank}
checkpoint_state = torch.load(checkpoint_path, map_location=map_location)
model.load_state_dict(checkpoint_state['model'])
optimizer.load_state_dict(checkpoint_state['optimizer'])
return checkpoint_state['iter_no'] + 1
Handling BatchNorm
To synchronize BatchNorm across GPUs, convert the model to use SyncBatchNorm
before wrapping it with DDP:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
model = DDP(model, device_ids=[local_rank], output_device=local_rank)
Common Issues and Troubleshooting
- Program Hangs: Ensure all ranks participate in collective operations like
reduce
. - NCCL Errors in Docker: Check for appropriate NCCL configurations or Docker flags.
- Unused Parameters: Avoid having unused parameters, as they may cause synchronization issues.
These issues will be covered in more detail in a future post.
Comments