switch to pytorch DistributedDataParallel

This commit is contained in:
thomwolf 2019-06-18 12:03:13 +02:00
parent 9ce37af99b
commit 2ef5e0de87
1 changed files with 5 additions and 5 deletions

View File

@ -902,12 +902,12 @@ def main():
model.half()
model.to(device)
if args.local_rank != -1:
try:
from apex.parallel import DistributedDataParallel as DDP
except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
# try:
# from apex.parallel import DistributedDataParallel as DDP
# except ImportError:
# raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
model = DDP(model)
model = torch.nn.parallel.DistributedDataParallel(model)
elif n_gpu > 1:
model = torch.nn.DataParallel(model)