diff --git a/examples/run_glue.py b/examples/run_glue.py index f757d46627..a2167ed807 100644 --- a/examples/run_glue.py +++ b/examples/run_glue.py @@ -160,7 +160,7 @@ def train(args, train_dataset, model, tokenizer): torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) tr_loss += loss.item() - if (step + 1) % args.gradient_accumulation_steps == 0: + if (step + 1) % args.gradient_accumulation_steps == 0 and not args.tpu: optimizer.step() scheduler.step() # Update learning rate schedule model.zero_grad() @@ -186,6 +186,11 @@ def train(args, train_dataset, model, tokenizer): torch.save(args, os.path.join(output_dir, 'training_args.bin')) logger.info("Saving model checkpoint to %s", output_dir) + if args.tpu: + args.xla_model.optimizer_step(optimizer, barrier=True) + model.zero_grad() + global_step += 1 + if args.max_steps > 0 and global_step > args.max_steps: epoch_iterator.close() break @@ -385,6 +390,15 @@ def main(): parser.add_argument('--seed', type=int, default=42, help="random seed for initialization") + parser.add_argument('--tpu', action='store_true', + help="Whether to run on the TPU defined in the environment variables") + parser.add_argument('--tpu_ip_address', type=str, default='', + help="TPU IP address if none are set in the environment variables") + parser.add_argument('--tpu_name', type=str, default='', + help="TPU name if none are set in the environment variables") + parser.add_argument('--xrt_tpu_config', type=str, default='', + help="XRT TPU config if none are set in the environment variables") + parser.add_argument('--fp16', action='store_true', help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit") parser.add_argument('--fp16_opt_level', type=str, default='O1', @@ -418,6 +432,23 @@ def main(): args.n_gpu = 1 args.device = device + if args.tpu: + if args.tpu_ip_address: + os.environ["TPU_IP_ADDRESS"] = args.tpu_ip_address + if args.tpu_name: + os.environ["TPU_NAME"] = args.tpu_name + if args.xrt_tpu_config: + os.environ["XRT_TPU_CONFIG"] = args.xrt_tpu_config + + assert "TPU_IP_ADDRESS" in os.environ + assert "TPU_NAME" in os.environ + assert "XRT_TPU_CONFIG" in os.environ + + import torch_xla + import torch_xla.core.xla_model as xm + args.device = xm.xla_device() + args.xla_model = xm + # Setup logging logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt = '%m/%d/%Y %H:%M:%S',