import os
import tempfile
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
import torch.multiprocessing as mpfrom torch.nn.parallel import DistributedDataParallel as DDPdefsetup(rank, world_size):os.environ['MASTER_ADDR']='localhost'os.environ['MASTER_PORT']='12355'# initialize the process groupdist.init_process_group("gloo", rank=rank, world_size=world_size)defcleanup():dist.destroy_process_group()
下面构建一个简单的模型,用过DDP包裹,导入一些随机生成的数据。
classToyModel(nn.Module):def__init__(self):super(ToyModel, self).__init__()self.net1 = nn.Linear(10,10)self.relu = nn.ReLU()self.net2 = nn.Linear(10,5)defforward(self, x):return self.net2(self.relu(self.net1(x)))defdemo_basic(rank, world_size):print(f"Running basic DDP example on rank {rank}.")setup(rank, world_size)# create model and move it to GPU with id rankmodel = ToyModel().to(rank)ddp_model = DDP(model, device_ids=[rank])loss_fn = nn.MSELoss()optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)optimizer.zero_grad()outputs = ddp_model(torch.randn(20,10))labels = torch.randn(20,5).to(rank)loss_fn(outputs, labels).backward()optimizer.step()cleanup()defrun_demo(demo_fn, world_size):mp.spawn(demo_fn,args=(world_size,),nprocs=world_size,join=True)
defdemo_checkpoint(rank, world_size):print(f"Running DDP checkpoint example on rank {rank}.")setup(rank, world_size)model = ToyModel().to(rank)ddp_model = DDP(model, device_ids=[rank])loss_fn = nn.MSELoss()optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)CHECKPOINT_PATH = tempfile.gettempdir()+"/model.checkpoint"if rank ==0:# All processes should see same parameters as they all start from same# random parameters and gradients are synchronized in backward passes.# Therefore, saving it in one process is sufficient.torch.save(ddp_model.state_dict(), CHECKPOINT_PATH)# Use a barrier() to make sure that process 1 loads the model after process# 0 saves it.dist.barrier()# configure map_location properlymap_location ={'cuda:%d'%0:'cuda:%d'% rank}ddp_model.load_state_dict(torch.load(CHECKPOINT_PATH, map_location=map_location))optimizer.zero_grad()outputs = ddp_model(torch.randn(20,10))labels = torch.randn(20,5).to(rank)loss_fn = nn.MSELoss()loss_fn(outputs, labels).backward()optimizer.step()# Not necessary to use a dist.barrier() to guard the file deletion below# as the AllReduce ops in the backward pass of DDP already served as# a synchronization.if rank ==0:os.remove(CHECKPOINT_PATH)cleanup()
defdemo_model_parallel(rank, world_size):print(f"Running DDP with model parallel example on rank {rank}.")setup(rank, world_size)# setup mp_model and devices for this processdev0 = rank *2dev1 = rank *2+1mp_model = ToyMpModel(dev0, dev1)ddp_mp_model = DDP(mp_model)loss_fn = nn.MSELoss()optimizer = optim.SGD(ddp_mp_model.parameters(), lr=0.001)optimizer.zero_grad()# outputs will be on dev1outputs = ddp_mp_model(torch.randn(20,10))labels = torch.randn(20,5).to(dev1)loss_fn(outputs, labels).backward()optimizer.step()cleanup()if __name__ =="__main__":n_gpus = torch.cuda.device_count()if n_gpus <8:print(f"Requires at least 8 GPUs to run, but got {n_gpus}.")else:run_demo(demo_basic,8)run_demo(demo_checkpoint,8)run_demo(demo_model_parallel,4)