3 minute read


Background

How can you speed up your training? What should you do when your model is too large to fit into a single GPU’s memory? How can you efficiently utilize multiple GPUs?

Distributed training is designed to address these challenges. In PyTorch, two common approaches for distributed training are DataParallel and Distributed Data Parallel (DDP).


DataParallel

The DataParallel module splits a batch of data into smaller mini-batches, each assigned to a different GPU. Every GPU holds a copy of the model. After the forward pass, gradients from all GPUs are sent to a master GPU, which performs the back-propagation and updates the model parameters. The updated parameters are then broadcasted back to all GPUs.

However, there are key limitations with DataParallel:

  1. Communication Overhead: Gradients and updated model parameters must be transmitted between GPUs, causing significant communication overhead.
  2. Memory Bottleneck: The memory usage is constrained by the master GPU, as it handles all back-propagation. This prevents the full utilization of other GPUs’ memory.
  3. Slower Training: Relying on a single GPU for back-propagation slows down the training process.

Distributed Data Parallel (DDP)

Distributed Data Parallel (DDP) is a more efficient solution that addresses the drawbacks of DataParallel. DDP attaches autograd hooks to each parameter, triggering gradient synchronization across GPUs using the AllReduce operation. This allows all GPUs to perform back-propagation independently after the forward pass.

Key Advantages:

  • Reduced Communication Overhead: Only gradients are synchronized, reducing data transfer costs.
  • Balanced Memory Usage: Each GPU handles its own back-propagation, resulting in similar memory usage across GPUs.
  • Scalability: DDP supports multi-node setups and peer-to-peer communication between GPUs.
  • Improved Performance: Multiple CPU processes are used, alleviating the limitations of Python’s Global Interpreter Lock (GIL).

For more details, see PyTorch Distributed Overview.

This guide focuses on implementing DDP for single-machine, multi-GPU setups.


Getting Started with DDP

Running DDP

The torch.distributed.launch utility spawns multiple processes for you. Set nproc_per_node to the number of GPUs on your machine so that each process corresponds to one GPU.

CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 main.py $args

Preparing Data

Supervised Learning

Use DistributedSampler to split the dataset among processes:

train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=..., sampler=train_sampler)

Reinforcement Learning

In reinforcement learning, run the environment in each rank process with different seeds to ensure diversity.


DDP Initialization with NVIDIA NCCL Backend

import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--local_rank", default=-1)
local_rank = parser.parse_args().local_rank

# Initialize DDP
dist.init_process_group(backend='nccl', init_method='env://')
rank = dist.get_rank()
world_size = dist.get_world_size()
print(f"My rank={rank}, local_rank={local_rank}")
torch.cuda.set_device(local_rank)

Wrapping the Model

Wrap your model with DistributedDataParallel:

model = model.to(device)
model = DDP(model, device_ids=[local_rank], output_device=local_rank)

Training

Synchronize the sampler for each epoch and perform training as usual:

for epoch in range(num_epochs):
    train_loader.sampler.set_epoch(epoch)
    for data, label in train_loader:
        prediction = model(data)
        loss = loss_fn(prediction, label)
        loss.backward()
        optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
        optimizer.step()

Logging Data

Use torch.distributed.reduce to aggregate data across ranks. For example, summing the loss across GPUs and calculating the mean:

loss = loss.clone().detach()
dist.reduce(loss, dst=0)
if dist.get_rank() == 0:
    loss_mean = loss / dist.get_world_size()
    print(f"Epoch: {epoch}, Loss: {loss_mean}")

Saving and Loading Checkpoints

Saving Checkpoints

Only save checkpoints on rank 0:

if dist.get_rank() == 0:
    checkpoint_state = {
        'iter_no': iter_no,
        'model': model.state_dict(),
        'optimizer': optimizer.state_dict(),
    }
    torch.save(checkpoint_state, checkpoint_path)

Loading Checkpoints

Map the checkpoint to the current rank’s device:

def load_checkpoint(model, optimizer, rank, checkpoint_path):
    map_location = {'cuda:%d' % 0: 'cuda:%d' % rank}
    checkpoint_state = torch.load(checkpoint_path, map_location=map_location)
    model.load_state_dict(checkpoint_state['model'])
    optimizer.load_state_dict(checkpoint_state['optimizer'])
    return checkpoint_state['iter_no'] + 1

Handling BatchNorm

To synchronize BatchNorm across GPUs, convert the model to use SyncBatchNorm before wrapping it with DDP:

model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
model = DDP(model, device_ids=[local_rank], output_device=local_rank)

Common Issues and Troubleshooting

  1. Program Hangs: Ensure all ranks participate in collective operations like reduce.
  2. NCCL Errors in Docker: Check for appropriate NCCL configurations or Docker flags.
  3. Unused Parameters: Avoid having unused parameters, as they may cause synchronization issues.

These issues will be covered in more detail in a future post.


Tags:

Updated:

Comments