switch to pytorch DistributedDataParallel
This commit is contained in:
parent
9ce37af99b
commit
2ef5e0de87
|
@ -902,12 +902,12 @@ def main():
|
|||
model.half()
|
||||
model.to(device)
|
||||
if args.local_rank != -1:
|
||||
try:
|
||||
from apex.parallel import DistributedDataParallel as DDP
|
||||
except ImportError:
|
||||
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
|
||||
# try:
|
||||
# from apex.parallel import DistributedDataParallel as DDP
|
||||
# except ImportError:
|
||||
# raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
|
||||
|
||||
model = DDP(model)
|
||||
model = torch.nn.parallel.DistributedDataParallel(model)
|
||||
elif n_gpu > 1:
|
||||
model = torch.nn.DataParallel(model)
|
||||
|
||||
|
|
Loading…
Reference in New Issue