use new API for save and load

This commit is contained in:
yzy5630 2019-07-18 15:45:23 +08:00
parent a7ba27b1b4
commit a1fe4ba9c9
2 changed files with 5 additions and 15 deletions

View File

@ -322,14 +322,8 @@ def main():
# Save a trained model
if n_gpu > 1 and torch.distributed.get_rank() == 0 or n_gpu <=1 :
logging.info("** ** * Saving fine-tuned model ** ** * ")
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
torch.save(model_to_save.state_dict(), output_model_file)
model_to_save.config.to_json_file(output_config_file)
tokenizer.save_vocabulary(args.output_dir)
model.save_pretrained(args.output_dir)
tokenizer.save_pretrained(args.output_dir)
if __name__ == '__main__':

View File

@ -32,7 +32,7 @@ from tqdm import tqdm, trange
from pytorch_transformers import WEIGHTS_NAME, CONFIG_NAME
from pytorch_transformers.modeling_bert import BertForPreTraining
from pytorch_transformers.tokenization_bert import BertTokenizer
from pytorch_transformers.optimization import BertAdam, WarmupLinearSchedule
from pytorch_transformers.optimization import AdamW, WarmupLinearSchedule
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt='%m/%d/%Y %H:%M:%S',
@ -610,12 +610,8 @@ def main():
# Save a trained model
if args.do_train and ( n_gpu > 1 and torch.distributed.get_rank() == 0 or n_gpu <=1):
logger.info("** ** * Saving fine - tuned model ** ** * ")
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
torch.save(model_to_save.state_dict(), output_model_file)
model_to_save.config.to_json_file(output_config_file)
tokenizer.save_vocabulary(args.output_dir)
model.save_pretrained(args.output_dir)
tokenizer.save_pretrained(args.output_dir)
def _truncate_seq_pair(tokens_a, tokens_b, max_length):