新聞中心

EEPW首頁 > 智能計(jì)算 > 設(shè)計(jì)應(yīng)用 > GPU如何訓(xùn)練大批量模型?方法在這里

GPU如何訓(xùn)練大批量模型?方法在這里

作者: 時(shí)間:2018-10-22 來源:網(wǎng)絡(luò) 收藏

  分布式訓(xùn)練:在多臺(tái)機(jī)器上訓(xùn)練

本文引用地址:http://butianyuan.cn/article/201810/393173.htm

  在更大的批量上訓(xùn)練時(shí),我們要如何控制多個(gè)服務(wù)器的算力呢?

  最簡單的選擇是使用 PyTorch 的 DistributedDataParallel,它幾乎可以說是以上討論的 DataParallel 的直接替代元件。

  但要注意:盡管代碼看起來很相似,但在分布式設(shè)定中訓(xùn)練模型要改變工作流程,因?yàn)槟惚仨氃诿總€(gè)節(jié)點(diǎn)上啟動(dòng)一個(gè)獨(dú)立的 訓(xùn)練腳本。正如我們將看到的,一旦啟動(dòng),這些訓(xùn)練腳本可以通過使用 PyTorch 分布式后端一起同步化。

  在實(shí)踐中,這意味著每個(gè)訓(xùn)練腳本將擁有:

  它自己的優(yōu)化器,并在每次迭代中執(zhí)行一個(gè)完整的優(yōu)化步驟,不需要進(jìn)行參數(shù)傳播(DataParallel 中的步驟 2);

  一個(gè)獨(dú)立的 解釋器:這也將避免 GIL-freeze,這是在單個(gè) 解釋器上驅(qū)動(dòng)多個(gè)并行執(zhí)行線程時(shí)會(huì)出現(xiàn)的問題。

  當(dāng)多個(gè)并行前向調(diào)用由單個(gè)解釋器驅(qū)動(dòng)時(shí),在前向傳播中大量使用 Python 循環(huán)/調(diào)用的模型可能會(huì)被 Python 解釋器的 GIL 放慢速度。通過這種設(shè)置,DistributedDataParallel 甚至在單臺(tái)機(jī)器設(shè)置中也能很方便地替代 DataParallel。

  現(xiàn)在我們直接討論代碼和用途。

  DistributedDataParallel 是建立在 torch.distributed 包之上的,這個(gè)包可以為同步分布式運(yùn)算提供低級(jí)原語,并能以不同的性能使用多種后端(tcp、gloo、mpi、nccl)。在這篇文章中,我將選擇一種簡單的開箱即用的方式來使用它,但你應(yīng)該閱讀文檔和 Séb Arnold 寫的教程來深入理解這個(gè)模塊。

  文檔:https://pytorch.org/docs/stable/distributed.html

  教程:https://pytorch.org/tutorials/intermediate/dist_tuto.html

  我們將考慮使用具有兩個(gè) 4 - 服務(wù)器(節(jié)點(diǎn))的簡單但通用的設(shè)置:



  主服務(wù)器(服務(wù)器 1)擁有一個(gè)可訪問的 IP 地址和一個(gè)用于通信的開放端口。

  改寫 Python 訓(xùn)練腳本以適應(yīng)分布式訓(xùn)練

  首先我們需要改寫腳本,從而令其可以在每臺(tái)機(jī)器(節(jié)點(diǎn))上獨(dú)立運(yùn)行。我們將實(shí)現(xiàn)完全的分布式訓(xùn)練,并在每個(gè)節(jié)點(diǎn)的每塊 上運(yùn)行一個(gè)獨(dú)立的進(jìn)程,因此總共需要 8 個(gè)進(jìn)程。

  我們的訓(xùn)練腳本有點(diǎn)長,因?yàn)樾枰獮橥交跏蓟植际胶蠖?,封裝模型并準(zhǔn)備數(shù)據(jù),以在數(shù)據(jù)的一個(gè)子集上來訓(xùn)練每個(gè)進(jìn)程(每個(gè)進(jìn)程都是獨(dú)立的,因此我們需要自行處理)。以下是更新后的代碼:

  from torch.utils.data.distributed import DistributedSampler

  from torch.utils.data import DataLoader

  # Each process runs on 1 device specified by the local_rank argument.

  parser = argparse.ArgumentParser()

  parser.add_argument("--local_rank", type=int)

  args = parser.parse_args()

  # Initializes the distributed backend which will take care of sychronizing nodes/GPUs

  torch.distributed.init_process_group(backend='nccl')

  # Encapsulate the model on the GPU assigned to the current process

  device = torch.device('cuda', arg.local_rank)

  model = model.to(device)

  distrib_model = torch.nn.parallel.DistributedDataParallel(model,

  device_ids=[args.local_rank],

  output_device=args.local_rank)

  # Restricts data loading to a subset of the dataset exclusive to the current process

  sampler = DistributedSampler(dataset)

  dataloader = DataLoader(dataset, sampler=sampler)

  for inputs, labels in dataloader:

  predictions = distrib_model(inputs.to(device)) # Forward pass

  loss = loss_function(predictions, labels.to(device)) # Compute loss function

  loss.backward() # Backward pass

  optimizer.step() # Optimizer step

  啟動(dòng) Python 訓(xùn)練腳本的多個(gè)實(shí)例

  我們就快完成了,只需要在每個(gè)服務(wù)器上啟動(dòng)訓(xùn)練腳本的一個(gè)實(shí)例。

  為了運(yùn)行腳本,我們將使用 PyTorch 的 torch.distributed.launch 工具。它將用來設(shè)置環(huán)境變量,并用正確的 local_rank 參數(shù)調(diào)用每個(gè)腳本。

  第一臺(tái)機(jī)器是最主要的,它應(yīng)該對(duì)于所有其它機(jī)器都是可訪問的,因此擁有一個(gè)可訪問的 IP 地址(我們的案例中是 192.168.1.1)以及一個(gè)開放端口(在我們的案例中是 1234)。在第一臺(tái)機(jī)器上,我們使用 torch.distributed.launch 來運(yùn)行訓(xùn)練腳本:

  python -m torch.distributed.launch --nproc_per_node=4 --nnodes=2 --node_rank=0 --master_addr="192.168.1.1" --master_port=1234 OUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3 and all other arguments of our training script) # Optimizer step

  在第二臺(tái)機(jī)器上,我們類似地啟動(dòng)腳本:

  python -m torch.distributed.launch --nproc_per_node=4 --nnodes=2 --node_rank=1 --master_addr="192.168.1.1" --master_port=1234 OUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3 and all other arguments of our training script)

  這兩個(gè)命令是相同的,除了—node_rank 參數(shù),其在第一臺(tái)機(jī)器上被設(shè)為 0,在第二臺(tái)機(jī)器上被設(shè)為 1(如果再加一臺(tái)機(jī)器,則設(shè)為 2,以此類推…)。


上一頁 1 2 3 下一頁

關(guān)鍵詞: GPU Python

評(píng)論


相關(guān)推薦

技術(shù)專區(qū)

關(guān)閉