Using torchrun
for Distributed Training
torchrun
is a utility provided by PyTorch to simplify launching distributed training jobs. It manages process spawning, inter-process communication, and resource allocation across multiple GPUs and nodes.
Here’s a detailed guide on how to use torchrun
for distributed training:
1. Understand Distributed Training Concepts
- Distributed Data Parallel (DDP):
- PyTorch’s
torch.nn.parallel.DistributedDataParallel
(DDP) is the backbone for distributed training. - It splits data across GPUs and synchronizes gradients during training.
- PyTorch’s
- Backend Options:
- NCCL: Recommended for GPU-based training (supports CUDA).
- Gloo: Works for CPU-based training or smaller setups.
- MPI: For large-scale multi-node clusters (requires MPI setup).
- Process Groups:
torchrun
launches a group of processes that communicate with each other.- Each GPU typically corresponds to one process.
2. Install Necessary Dependencies
Ensure your PyTorch version supports torchrun
:
pip install torch torchvision torchaudio
For multi-node distributed training:
- NCCL is automatically installed with PyTorch.
- For MPI, install the required libraries:
sudo apt-get install libopenmpi-dev
3. Prepare Your Training Script
Modify your PyTorch training script to use DistributedDataParallel
.
Key Changes in train.py
:
- Initialize Distributed Process Group:
import torch import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP def setup(rank, world_size): dist.init_process_group("nccl", rank=rank, world_size=world_size) def cleanup(): dist.destroy_process_group()
- Wrap Your Model with DDP:
def main(rank, world_size): setup(rank, world_size) model = MyModel().to(rank) ddp_model = DDP(model, device_ids=[rank]) # Training loop optimizer = torch.optim.Adam(ddp_model.parameters()) for epoch in range(epochs): optimizer.zero_grad() outputs = ddp_model(inputs.to(rank)) loss = criterion(outputs, targets.to(rank)) loss.backward() optimizer.step() cleanup()
- Spawn Processes:
Use
torch.multiprocessing.spawn
for process management:if __name__ == "__main__": world_size = torch.cuda.device_count() torch.multiprocessing.spawn(main, args=(world_size,), nprocs=world_size, join=True)
4. Launch Training with torchrun
Use torchrun
to manage distributed training processes.
Single Node, Multi-GPU Training
For a single node with 4 GPUs:
torchrun --nproc_per_node=4 train.py
--nproc_per_node
: Number of processes to launch (e.g., number of GPUs).
Multi-Node Distributed Training
For multi-node setups:
torchrun --nnodes=2 --nproc_per_node=4 --rdzv_backend=c10d \
--rdzv_endpoint="master_ip:29500" train.py
--nnodes
: Number of nodes participating in training.--nproc_per_node
: Number of processes per node (typically number of GPUs per node).--rdzv_backend
: Rendezvous backend (c10d
is the default).--rdzv_endpoint
: IP and port of the master node for communication.
5. Check System Configuration
Ensure the environment is configured correctly:
- NCCL Settings:
- For multi-node training:
export NCCL_DEBUG=INFO export NCCL_IB_DISABLE=0 export NCCL_SOCKET_IFNAME=eth0 # or your network interface
- For multi-node training:
- CUDA and GPU Settings:
- Confirm GPU visibility:
nvidia-smi
- Set
CUDA_VISIBLE_DEVICES
if needed:export CUDA_VISIBLE_DEVICES=0,1,2,3
- Confirm GPU visibility:
6. Monitor and Debug
- Use verbose logging for debugging:
export TORCH_DISTRIBUTED_DEBUG=DETAIL torchrun --nproc_per_node=4 train.py
- Check GPU utilization:
watch -n 1 nvidia-smi
- Debug communication issues (e.g., NCCL or Gloo):
- Check network connectivity between nodes.
- Use
dmesg
or log files for hardware errors.
Example Scenarios
Single Node, 4 GPUs
torchrun --nproc_per_node=4 train.py
Two Nodes, 8 GPUs Total (4 GPUs Per Node)
- Start on the master node:
torchrun --nnodes=2 --nproc_per_node=4 --rdzv_endpoint="master_ip:29500" train.py
- Start on the worker node:
torchrun --nnodes=2 --nproc_per_node=4 --rdzv_endpoint="master_ip:29500" train.py
7. Best Practices
- Use batch normalization carefully in DDP to avoid synchronization overhead.
- Optimize network communication with InfiniBand if available.
- Profile your training script to identify bottlenecks using PyTorch’s profiler.
Comments