GLUE on TPU
This commit is contained in:
parent
6596e3d566
commit
d4e7934ac3
|
@ -160,7 +160,7 @@ def train(args, train_dataset, model, tokenizer):
|
||||||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
|
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
|
||||||
|
|
||||||
tr_loss += loss.item()
|
tr_loss += loss.item()
|
||||||
if (step + 1) % args.gradient_accumulation_steps == 0:
|
if (step + 1) % args.gradient_accumulation_steps == 0 and not args.tpu:
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
scheduler.step() # Update learning rate schedule
|
scheduler.step() # Update learning rate schedule
|
||||||
model.zero_grad()
|
model.zero_grad()
|
||||||
|
@ -186,6 +186,11 @@ def train(args, train_dataset, model, tokenizer):
|
||||||
torch.save(args, os.path.join(output_dir, 'training_args.bin'))
|
torch.save(args, os.path.join(output_dir, 'training_args.bin'))
|
||||||
logger.info("Saving model checkpoint to %s", output_dir)
|
logger.info("Saving model checkpoint to %s", output_dir)
|
||||||
|
|
||||||
|
if args.tpu:
|
||||||
|
args.xla_model.optimizer_step(optimizer, barrier=True)
|
||||||
|
model.zero_grad()
|
||||||
|
global_step += 1
|
||||||
|
|
||||||
if args.max_steps > 0 and global_step > args.max_steps:
|
if args.max_steps > 0 and global_step > args.max_steps:
|
||||||
epoch_iterator.close()
|
epoch_iterator.close()
|
||||||
break
|
break
|
||||||
|
@ -385,6 +390,15 @@ def main():
|
||||||
parser.add_argument('--seed', type=int, default=42,
|
parser.add_argument('--seed', type=int, default=42,
|
||||||
help="random seed for initialization")
|
help="random seed for initialization")
|
||||||
|
|
||||||
|
parser.add_argument('--tpu', action='store_true',
|
||||||
|
help="Whether to run on the TPU defined in the environment variables")
|
||||||
|
parser.add_argument('--tpu_ip_address', type=str, default='',
|
||||||
|
help="TPU IP address if none are set in the environment variables")
|
||||||
|
parser.add_argument('--tpu_name', type=str, default='',
|
||||||
|
help="TPU name if none are set in the environment variables")
|
||||||
|
parser.add_argument('--xrt_tpu_config', type=str, default='',
|
||||||
|
help="XRT TPU config if none are set in the environment variables")
|
||||||
|
|
||||||
parser.add_argument('--fp16', action='store_true',
|
parser.add_argument('--fp16', action='store_true',
|
||||||
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
|
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
|
||||||
parser.add_argument('--fp16_opt_level', type=str, default='O1',
|
parser.add_argument('--fp16_opt_level', type=str, default='O1',
|
||||||
|
@ -418,6 +432,23 @@ def main():
|
||||||
args.n_gpu = 1
|
args.n_gpu = 1
|
||||||
args.device = device
|
args.device = device
|
||||||
|
|
||||||
|
if args.tpu:
|
||||||
|
if args.tpu_ip_address:
|
||||||
|
os.environ["TPU_IP_ADDRESS"] = args.tpu_ip_address
|
||||||
|
if args.tpu_name:
|
||||||
|
os.environ["TPU_NAME"] = args.tpu_name
|
||||||
|
if args.xrt_tpu_config:
|
||||||
|
os.environ["XRT_TPU_CONFIG"] = args.xrt_tpu_config
|
||||||
|
|
||||||
|
assert "TPU_IP_ADDRESS" in os.environ
|
||||||
|
assert "TPU_NAME" in os.environ
|
||||||
|
assert "XRT_TPU_CONFIG" in os.environ
|
||||||
|
|
||||||
|
import torch_xla
|
||||||
|
import torch_xla.core.xla_model as xm
|
||||||
|
args.device = xm.xla_device()
|
||||||
|
args.xla_model = xm
|
||||||
|
|
||||||
# Setup logging
|
# Setup logging
|
||||||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||||
datefmt = '%m/%d/%Y %H:%M:%S',
|
datefmt = '%m/%d/%Y %H:%M:%S',
|
||||||
|
|
Loading…
Reference in New Issue