GLUE on TPU

This commit is contained in:
Lysandre 2019-10-10 19:03:06 +00:00
parent 6596e3d566
commit d4e7934ac3
1 changed files with 32 additions and 1 deletions

View File

@ -160,7 +160,7 @@ def train(args, train_dataset, model, tokenizer):
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
tr_loss += loss.item() tr_loss += loss.item()
if (step + 1) % args.gradient_accumulation_steps == 0: if (step + 1) % args.gradient_accumulation_steps == 0 and not args.tpu:
optimizer.step() optimizer.step()
scheduler.step() # Update learning rate schedule scheduler.step() # Update learning rate schedule
model.zero_grad() model.zero_grad()
@ -186,6 +186,11 @@ def train(args, train_dataset, model, tokenizer):
torch.save(args, os.path.join(output_dir, 'training_args.bin')) torch.save(args, os.path.join(output_dir, 'training_args.bin'))
logger.info("Saving model checkpoint to %s", output_dir) logger.info("Saving model checkpoint to %s", output_dir)
if args.tpu:
args.xla_model.optimizer_step(optimizer, barrier=True)
model.zero_grad()
global_step += 1
if args.max_steps > 0 and global_step > args.max_steps: if args.max_steps > 0 and global_step > args.max_steps:
epoch_iterator.close() epoch_iterator.close()
break break
@ -385,6 +390,15 @@ def main():
parser.add_argument('--seed', type=int, default=42, parser.add_argument('--seed', type=int, default=42,
help="random seed for initialization") help="random seed for initialization")
parser.add_argument('--tpu', action='store_true',
help="Whether to run on the TPU defined in the environment variables")
parser.add_argument('--tpu_ip_address', type=str, default='',
help="TPU IP address if none are set in the environment variables")
parser.add_argument('--tpu_name', type=str, default='',
help="TPU name if none are set in the environment variables")
parser.add_argument('--xrt_tpu_config', type=str, default='',
help="XRT TPU config if none are set in the environment variables")
parser.add_argument('--fp16', action='store_true', parser.add_argument('--fp16', action='store_true',
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit") help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
parser.add_argument('--fp16_opt_level', type=str, default='O1', parser.add_argument('--fp16_opt_level', type=str, default='O1',
@ -418,6 +432,23 @@ def main():
args.n_gpu = 1 args.n_gpu = 1
args.device = device args.device = device
if args.tpu:
if args.tpu_ip_address:
os.environ["TPU_IP_ADDRESS"] = args.tpu_ip_address
if args.tpu_name:
os.environ["TPU_NAME"] = args.tpu_name
if args.xrt_tpu_config:
os.environ["XRT_TPU_CONFIG"] = args.xrt_tpu_config
assert "TPU_IP_ADDRESS" in os.environ
assert "TPU_NAME" in os.environ
assert "XRT_TPU_CONFIG" in os.environ
import torch_xla
import torch_xla.core.xla_model as xm
args.device = xm.xla_device()
args.xla_model = xm
# Setup logging # Setup logging
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt = '%m/%d/%Y %H:%M:%S', datefmt = '%m/%d/%Y %H:%M:%S',