1149 lines
47 KiB
Python
1149 lines
47 KiB
Python
# coding=utf-8
|
|
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
|
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
""" Fine-pruning Masked BERT for question-answering on SQuAD."""
|
|
|
|
|
|
import argparse
|
|
import glob
|
|
import logging
|
|
import os
|
|
import random
|
|
import timeit
|
|
|
|
import numpy as np
|
|
import torch
|
|
from emmental import MaskedBertConfig, MaskedBertForQuestionAnswering
|
|
from torch import nn
|
|
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
|
|
from torch.utils.data.distributed import DistributedSampler
|
|
from tqdm import tqdm, trange
|
|
|
|
from transformers import (
|
|
WEIGHTS_NAME,
|
|
AdamW,
|
|
BertConfig,
|
|
BertForQuestionAnswering,
|
|
BertTokenizer,
|
|
get_linear_schedule_with_warmup,
|
|
squad_convert_examples_to_features,
|
|
)
|
|
from transformers.data.metrics.squad_metrics import (
|
|
compute_predictions_log_probs,
|
|
compute_predictions_logits,
|
|
squad_evaluate,
|
|
)
|
|
from transformers.data.processors.squad import SquadResult, SquadV1Processor, SquadV2Processor
|
|
|
|
|
|
try:
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
except ImportError:
|
|
from tensorboardX import SummaryWriter
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
MODEL_CLASSES = {
|
|
"bert": (BertConfig, BertForQuestionAnswering, BertTokenizer),
|
|
"masked_bert": (MaskedBertConfig, MaskedBertForQuestionAnswering, BertTokenizer),
|
|
}
|
|
|
|
|
|
def set_seed(args):
|
|
random.seed(args.seed)
|
|
np.random.seed(args.seed)
|
|
torch.manual_seed(args.seed)
|
|
if args.n_gpu > 0:
|
|
torch.cuda.manual_seed_all(args.seed)
|
|
|
|
|
|
def schedule_threshold(
|
|
step: int,
|
|
total_step: int,
|
|
warmup_steps: int,
|
|
initial_threshold: float,
|
|
final_threshold: float,
|
|
initial_warmup: int,
|
|
final_warmup: int,
|
|
final_lambda: float,
|
|
):
|
|
if step <= initial_warmup * warmup_steps:
|
|
threshold = initial_threshold
|
|
elif step > (total_step - final_warmup * warmup_steps):
|
|
threshold = final_threshold
|
|
else:
|
|
spars_warmup_steps = initial_warmup * warmup_steps
|
|
spars_schedu_steps = (final_warmup + initial_warmup) * warmup_steps
|
|
mul_coeff = 1 - (step - spars_warmup_steps) / (total_step - spars_schedu_steps)
|
|
threshold = final_threshold + (initial_threshold - final_threshold) * (mul_coeff**3)
|
|
regu_lambda = final_lambda * threshold / final_threshold
|
|
return threshold, regu_lambda
|
|
|
|
|
|
def regularization(model: nn.Module, mode: str):
|
|
regu, counter = 0, 0
|
|
for name, param in model.named_parameters():
|
|
if "mask_scores" in name:
|
|
if mode == "l1":
|
|
regu += torch.norm(torch.sigmoid(param), p=1) / param.numel()
|
|
elif mode == "l0":
|
|
regu += torch.sigmoid(param - 2 / 3 * np.log(0.1 / 1.1)).sum() / param.numel()
|
|
else:
|
|
ValueError("Don't know this mode.")
|
|
counter += 1
|
|
return regu / counter
|
|
|
|
|
|
def to_list(tensor):
|
|
return tensor.detach().cpu().tolist()
|
|
|
|
|
|
def train(args, train_dataset, model, tokenizer, teacher=None):
|
|
"""Train the model"""
|
|
if args.local_rank in [-1, 0]:
|
|
tb_writer = SummaryWriter(log_dir=args.output_dir)
|
|
|
|
args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
|
|
train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
|
|
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
|
|
|
|
if args.max_steps > 0:
|
|
t_total = args.max_steps
|
|
args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
|
|
else:
|
|
t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
|
|
|
|
# Prepare optimizer and schedule (linear warmup and decay)
|
|
no_decay = ["bias", "LayerNorm.weight"]
|
|
optimizer_grouped_parameters = [
|
|
{
|
|
"params": [p for n, p in model.named_parameters() if "mask_score" in n and p.requires_grad],
|
|
"lr": args.mask_scores_learning_rate,
|
|
},
|
|
{
|
|
"params": [
|
|
p
|
|
for n, p in model.named_parameters()
|
|
if "mask_score" not in n and p.requires_grad and not any(nd in n for nd in no_decay)
|
|
],
|
|
"lr": args.learning_rate,
|
|
"weight_decay": args.weight_decay,
|
|
},
|
|
{
|
|
"params": [
|
|
p
|
|
for n, p in model.named_parameters()
|
|
if "mask_score" not in n and p.requires_grad and any(nd in n for nd in no_decay)
|
|
],
|
|
"lr": args.learning_rate,
|
|
"weight_decay": 0.0,
|
|
},
|
|
]
|
|
|
|
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
|
|
scheduler = get_linear_schedule_with_warmup(
|
|
optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
|
|
)
|
|
|
|
# Check if saved optimizer or scheduler states exist
|
|
if os.path.isfile(os.path.join(args.model_name_or_path, "optimizer.pt")) and os.path.isfile(
|
|
os.path.join(args.model_name_or_path, "scheduler.pt")
|
|
):
|
|
# Load in optimizer and scheduler states
|
|
optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
|
|
scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))
|
|
|
|
if args.fp16:
|
|
try:
|
|
from apex import amp
|
|
except ImportError:
|
|
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
|
|
model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
|
|
|
|
# multi-gpu training (should be after apex fp16 initialization)
|
|
if args.n_gpu > 1:
|
|
model = nn.DataParallel(model)
|
|
|
|
# Distributed training (should be after apex fp16 initialization)
|
|
if args.local_rank != -1:
|
|
model = nn.parallel.DistributedDataParallel(
|
|
model,
|
|
device_ids=[args.local_rank],
|
|
output_device=args.local_rank,
|
|
find_unused_parameters=True,
|
|
)
|
|
|
|
# Train!
|
|
logger.info("***** Running training *****")
|
|
logger.info(" Num examples = %d", len(train_dataset))
|
|
logger.info(" Num Epochs = %d", args.num_train_epochs)
|
|
logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
|
|
logger.info(
|
|
" Total train batch size (w. parallel, distributed & accumulation) = %d",
|
|
args.train_batch_size
|
|
* args.gradient_accumulation_steps
|
|
* (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
|
|
)
|
|
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
|
|
logger.info(" Total optimization steps = %d", t_total)
|
|
# Distillation
|
|
if teacher is not None:
|
|
logger.info(" Training with distillation")
|
|
|
|
global_step = 1
|
|
# Global TopK
|
|
if args.global_topk:
|
|
threshold_mem = None
|
|
epochs_trained = 0
|
|
steps_trained_in_current_epoch = 0
|
|
# Check if continuing training from a checkpoint
|
|
if os.path.exists(args.model_name_or_path):
|
|
# set global_step to global_step of last saved checkpoint from model path
|
|
try:
|
|
checkpoint_suffix = args.model_name_or_path.split("-")[-1].split("/")[0]
|
|
global_step = int(checkpoint_suffix)
|
|
epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps)
|
|
steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps)
|
|
|
|
logger.info(" Continuing training from checkpoint, will skip to saved global_step")
|
|
logger.info(" Continuing training from epoch %d", epochs_trained)
|
|
logger.info(" Continuing training from global step %d", global_step)
|
|
logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
|
|
except ValueError:
|
|
logger.info(" Starting fine-tuning.")
|
|
|
|
tr_loss, logging_loss = 0.0, 0.0
|
|
model.zero_grad()
|
|
train_iterator = trange(
|
|
epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]
|
|
)
|
|
# Added here for reproducibility
|
|
set_seed(args)
|
|
|
|
for _ in train_iterator:
|
|
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
|
|
for step, batch in enumerate(epoch_iterator):
|
|
# Skip past any already trained steps if resuming training
|
|
if steps_trained_in_current_epoch > 0:
|
|
steps_trained_in_current_epoch -= 1
|
|
continue
|
|
|
|
model.train()
|
|
batch = tuple(t.to(args.device) for t in batch)
|
|
threshold, regu_lambda = schedule_threshold(
|
|
step=global_step,
|
|
total_step=t_total,
|
|
warmup_steps=args.warmup_steps,
|
|
final_threshold=args.final_threshold,
|
|
initial_threshold=args.initial_threshold,
|
|
final_warmup=args.final_warmup,
|
|
initial_warmup=args.initial_warmup,
|
|
final_lambda=args.final_lambda,
|
|
)
|
|
# Global TopK
|
|
if args.global_topk:
|
|
if threshold == 1.0:
|
|
threshold = -1e2 # Or an indefinitely low quantity
|
|
else:
|
|
if (threshold_mem is None) or (global_step % args.global_topk_frequency_compute == 0):
|
|
# Sort all the values to get the global topK
|
|
concat = torch.cat(
|
|
[param.view(-1) for name, param in model.named_parameters() if "mask_scores" in name]
|
|
)
|
|
n = concat.numel()
|
|
kth = max(n - (int(n * threshold) + 1), 1)
|
|
threshold_mem = concat.kthvalue(kth).values.item()
|
|
threshold = threshold_mem
|
|
else:
|
|
threshold = threshold_mem
|
|
inputs = {
|
|
"input_ids": batch[0],
|
|
"attention_mask": batch[1],
|
|
"token_type_ids": batch[2],
|
|
"start_positions": batch[3],
|
|
"end_positions": batch[4],
|
|
}
|
|
|
|
if args.model_type in ["xlm", "roberta", "distilbert", "camembert"]:
|
|
del inputs["token_type_ids"]
|
|
|
|
if args.model_type in ["xlnet", "xlm"]:
|
|
inputs.update({"cls_index": batch[5], "p_mask": batch[6]})
|
|
if args.version_2_with_negative:
|
|
inputs.update({"is_impossible": batch[7]})
|
|
if hasattr(model, "config") and hasattr(model.config, "lang2id"):
|
|
inputs.update(
|
|
{"langs": (torch.ones(batch[0].shape, dtype=torch.int64) * args.lang_id).to(args.device)}
|
|
)
|
|
|
|
if "masked" in args.model_type:
|
|
inputs["threshold"] = threshold
|
|
|
|
outputs = model(**inputs)
|
|
# model outputs are always tuple in transformers (see doc)
|
|
loss, start_logits_stu, end_logits_stu = outputs
|
|
|
|
# Distillation loss
|
|
if teacher is not None:
|
|
with torch.no_grad():
|
|
start_logits_tea, end_logits_tea = teacher(
|
|
input_ids=inputs["input_ids"],
|
|
token_type_ids=inputs["token_type_ids"],
|
|
attention_mask=inputs["attention_mask"],
|
|
)
|
|
|
|
loss_start = nn.functional.kl_div(
|
|
input=nn.functional.log_softmax(start_logits_stu / args.temperature, dim=-1),
|
|
target=nn.functional.softmax(start_logits_tea / args.temperature, dim=-1),
|
|
reduction="batchmean",
|
|
) * (args.temperature**2)
|
|
loss_end = nn.functional.kl_div(
|
|
input=nn.functional.log_softmax(end_logits_stu / args.temperature, dim=-1),
|
|
target=nn.functional.softmax(end_logits_tea / args.temperature, dim=-1),
|
|
reduction="batchmean",
|
|
) * (args.temperature**2)
|
|
loss_logits = (loss_start + loss_end) / 2.0
|
|
|
|
loss = args.alpha_distil * loss_logits + args.alpha_ce * loss
|
|
|
|
# Regularization
|
|
if args.regularization is not None:
|
|
regu_ = regularization(model=model, mode=args.regularization)
|
|
loss = loss + regu_lambda * regu_
|
|
|
|
if args.n_gpu > 1:
|
|
loss = loss.mean() # mean() to average on multi-gpu parallel training
|
|
if args.gradient_accumulation_steps > 1:
|
|
loss = loss / args.gradient_accumulation_steps
|
|
|
|
if args.fp16:
|
|
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
|
scaled_loss.backward()
|
|
else:
|
|
loss.backward()
|
|
|
|
tr_loss += loss.item()
|
|
if (step + 1) % args.gradient_accumulation_steps == 0:
|
|
if args.fp16:
|
|
nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
|
|
else:
|
|
nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
|
|
|
|
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
|
|
tb_writer.add_scalar("threshold", threshold, global_step)
|
|
for name, param in model.named_parameters():
|
|
if not param.requires_grad:
|
|
continue
|
|
tb_writer.add_scalar("parameter_mean/" + name, param.data.mean(), global_step)
|
|
tb_writer.add_scalar("parameter_std/" + name, param.data.std(), global_step)
|
|
tb_writer.add_scalar("parameter_min/" + name, param.data.min(), global_step)
|
|
tb_writer.add_scalar("parameter_max/" + name, param.data.max(), global_step)
|
|
if "pooler" in name:
|
|
continue
|
|
tb_writer.add_scalar("grad_mean/" + name, param.grad.data.mean(), global_step)
|
|
tb_writer.add_scalar("grad_std/" + name, param.grad.data.std(), global_step)
|
|
if args.regularization is not None and "mask_scores" in name:
|
|
if args.regularization == "l1":
|
|
perc = (torch.sigmoid(param) > threshold).sum().item() / param.numel()
|
|
elif args.regularization == "l0":
|
|
perc = (torch.sigmoid(param - 2 / 3 * np.log(0.1 / 1.1))).sum().item() / param.numel()
|
|
tb_writer.add_scalar("retained_weights_perc/" + name, perc, global_step)
|
|
|
|
optimizer.step()
|
|
scheduler.step() # Update learning rate schedule
|
|
model.zero_grad()
|
|
global_step += 1
|
|
|
|
# Log metrics
|
|
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
|
|
# Only evaluate when single GPU otherwise metrics may not average well
|
|
if args.local_rank == -1 and args.evaluate_during_training:
|
|
results = evaluate(args, model, tokenizer)
|
|
for key, value in results.items():
|
|
tb_writer.add_scalar("eval_{}".format(key), value, global_step)
|
|
learning_rate_scalar = scheduler.get_lr()
|
|
tb_writer.add_scalar("lr", learning_rate_scalar[0], global_step)
|
|
if len(learning_rate_scalar) > 1:
|
|
for idx, lr in enumerate(learning_rate_scalar[1:]):
|
|
tb_writer.add_scalar(f"lr/{idx+1}", lr, global_step)
|
|
tb_writer.add_scalar("loss", (tr_loss - logging_loss) / args.logging_steps, global_step)
|
|
if teacher is not None:
|
|
tb_writer.add_scalar("loss/distil", loss_logits.item(), global_step)
|
|
if args.regularization is not None:
|
|
tb_writer.add_scalar("loss/regularization", regu_.item(), global_step)
|
|
if (teacher is not None) or (args.regularization is not None):
|
|
if (teacher is not None) and (args.regularization is not None):
|
|
tb_writer.add_scalar(
|
|
"loss/instant_ce",
|
|
(loss.item() - regu_lambda * regu_.item() - args.alpha_distil * loss_logits.item())
|
|
/ args.alpha_ce,
|
|
global_step,
|
|
)
|
|
elif teacher is not None:
|
|
tb_writer.add_scalar(
|
|
"loss/instant_ce",
|
|
(loss.item() - args.alpha_distil * loss_logits.item()) / args.alpha_ce,
|
|
global_step,
|
|
)
|
|
else:
|
|
tb_writer.add_scalar(
|
|
"loss/instant_ce", loss.item() - regu_lambda * regu_.item(), global_step
|
|
)
|
|
logging_loss = tr_loss
|
|
|
|
# Save model checkpoint
|
|
if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
|
|
output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step))
|
|
if not os.path.exists(output_dir):
|
|
os.makedirs(output_dir)
|
|
# Take care of distributed/parallel training
|
|
model_to_save = model.module if hasattr(model, "module") else model
|
|
model_to_save.save_pretrained(output_dir)
|
|
tokenizer.save_pretrained(output_dir)
|
|
|
|
torch.save(args, os.path.join(output_dir, "training_args.bin"))
|
|
logger.info("Saving model checkpoint to %s", output_dir)
|
|
|
|
torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
|
|
torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
|
|
logger.info("Saving optimizer and scheduler states to %s", output_dir)
|
|
|
|
if args.max_steps > 0 and global_step > args.max_steps:
|
|
epoch_iterator.close()
|
|
break
|
|
if args.max_steps > 0 and global_step > args.max_steps:
|
|
train_iterator.close()
|
|
break
|
|
|
|
if args.local_rank in [-1, 0]:
|
|
tb_writer.close()
|
|
|
|
return global_step, tr_loss / global_step
|
|
|
|
|
|
def evaluate(args, model, tokenizer, prefix=""):
|
|
dataset, examples, features = load_and_cache_examples(args, tokenizer, evaluate=True, output_examples=True)
|
|
|
|
if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
|
|
os.makedirs(args.output_dir)
|
|
|
|
args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
|
|
# Note that DistributedSampler samples randomly
|
|
eval_sampler = SequentialSampler(dataset)
|
|
eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
|
|
|
|
# multi-gpu eval
|
|
if args.n_gpu > 1 and not isinstance(model, nn.DataParallel):
|
|
model = nn.DataParallel(model)
|
|
|
|
# Eval!
|
|
logger.info("***** Running evaluation {} *****".format(prefix))
|
|
logger.info(" Num examples = %d", len(dataset))
|
|
logger.info(" Batch size = %d", args.eval_batch_size)
|
|
|
|
all_results = []
|
|
start_time = timeit.default_timer()
|
|
# Global TopK
|
|
if args.global_topk:
|
|
threshold_mem = None
|
|
|
|
for batch in tqdm(eval_dataloader, desc="Evaluating"):
|
|
model.eval()
|
|
batch = tuple(t.to(args.device) for t in batch)
|
|
|
|
with torch.no_grad():
|
|
inputs = {
|
|
"input_ids": batch[0],
|
|
"attention_mask": batch[1],
|
|
"token_type_ids": batch[2],
|
|
}
|
|
|
|
if args.model_type in ["xlm", "roberta", "distilbert", "camembert"]:
|
|
del inputs["token_type_ids"]
|
|
|
|
example_indices = batch[3]
|
|
|
|
# XLNet and XLM use more arguments for their predictions
|
|
if args.model_type in ["xlnet", "xlm"]:
|
|
inputs.update({"cls_index": batch[4], "p_mask": batch[5]})
|
|
# for lang_id-sensitive xlm models
|
|
if hasattr(model, "config") and hasattr(model.config, "lang2id"):
|
|
inputs.update(
|
|
{"langs": (torch.ones(batch[0].shape, dtype=torch.int64) * args.lang_id).to(args.device)}
|
|
)
|
|
if "masked" in args.model_type:
|
|
inputs["threshold"] = args.final_threshold
|
|
if args.global_topk:
|
|
if threshold_mem is None:
|
|
concat = torch.cat(
|
|
[param.view(-1) for name, param in model.named_parameters() if "mask_scores" in name]
|
|
)
|
|
n = concat.numel()
|
|
kth = max(n - (int(n * args.final_threshold) + 1), 1)
|
|
threshold_mem = concat.kthvalue(kth).values.item()
|
|
inputs["threshold"] = threshold_mem
|
|
outputs = model(**inputs)
|
|
|
|
for i, example_index in enumerate(example_indices):
|
|
eval_feature = features[example_index.item()]
|
|
unique_id = int(eval_feature.unique_id)
|
|
|
|
output = [to_list(output[i]) for output in outputs]
|
|
|
|
# Some models (XLNet, XLM) use 5 arguments for their predictions, while the other "simpler"
|
|
# models only use two.
|
|
if len(output) >= 5:
|
|
start_logits = output[0]
|
|
start_top_index = output[1]
|
|
end_logits = output[2]
|
|
end_top_index = output[3]
|
|
cls_logits = output[4]
|
|
|
|
result = SquadResult(
|
|
unique_id,
|
|
start_logits,
|
|
end_logits,
|
|
start_top_index=start_top_index,
|
|
end_top_index=end_top_index,
|
|
cls_logits=cls_logits,
|
|
)
|
|
|
|
else:
|
|
start_logits, end_logits = output
|
|
result = SquadResult(unique_id, start_logits, end_logits)
|
|
|
|
all_results.append(result)
|
|
|
|
evalTime = timeit.default_timer() - start_time
|
|
logger.info(" Evaluation done in total %f secs (%f sec per example)", evalTime, evalTime / len(dataset))
|
|
|
|
# Compute predictions
|
|
output_prediction_file = os.path.join(args.output_dir, "predictions_{}.json".format(prefix))
|
|
output_nbest_file = os.path.join(args.output_dir, "nbest_predictions_{}.json".format(prefix))
|
|
|
|
if args.version_2_with_negative:
|
|
output_null_log_odds_file = os.path.join(args.output_dir, "null_odds_{}.json".format(prefix))
|
|
else:
|
|
output_null_log_odds_file = None
|
|
|
|
# XLNet and XLM use a more complex post-processing procedure
|
|
if args.model_type in ["xlnet", "xlm"]:
|
|
start_n_top = model.config.start_n_top if hasattr(model, "config") else model.module.config.start_n_top
|
|
end_n_top = model.config.end_n_top if hasattr(model, "config") else model.module.config.end_n_top
|
|
|
|
predictions = compute_predictions_log_probs(
|
|
examples,
|
|
features,
|
|
all_results,
|
|
args.n_best_size,
|
|
args.max_answer_length,
|
|
output_prediction_file,
|
|
output_nbest_file,
|
|
output_null_log_odds_file,
|
|
start_n_top,
|
|
end_n_top,
|
|
args.version_2_with_negative,
|
|
tokenizer,
|
|
args.verbose_logging,
|
|
)
|
|
else:
|
|
predictions = compute_predictions_logits(
|
|
examples,
|
|
features,
|
|
all_results,
|
|
args.n_best_size,
|
|
args.max_answer_length,
|
|
args.do_lower_case,
|
|
output_prediction_file,
|
|
output_nbest_file,
|
|
output_null_log_odds_file,
|
|
args.verbose_logging,
|
|
args.version_2_with_negative,
|
|
args.null_score_diff_threshold,
|
|
tokenizer,
|
|
)
|
|
|
|
# Compute the F1 and exact scores.
|
|
results = squad_evaluate(examples, predictions)
|
|
return results
|
|
|
|
|
|
def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False):
|
|
if args.local_rank not in [-1, 0] and not evaluate:
|
|
# Make sure only the first process in distributed training process the dataset, and the others will use the cache
|
|
torch.distributed.barrier()
|
|
|
|
# Load data features from cache or dataset file
|
|
input_dir = args.data_dir if args.data_dir else "."
|
|
cached_features_file = os.path.join(
|
|
input_dir,
|
|
"cached_{}_{}_{}_{}".format(
|
|
"dev" if evaluate else "train",
|
|
args.tokenizer_name
|
|
if args.tokenizer_name
|
|
else list(filter(None, args.model_name_or_path.split("/"))).pop(),
|
|
str(args.max_seq_length),
|
|
list(filter(None, args.predict_file.split("/"))).pop()
|
|
if evaluate
|
|
else list(filter(None, args.train_file.split("/"))).pop(),
|
|
),
|
|
)
|
|
|
|
# Init features and dataset from cache if it exists
|
|
if os.path.exists(cached_features_file) and not args.overwrite_cache:
|
|
logger.info("Loading features from cached file %s", cached_features_file)
|
|
features_and_dataset = torch.load(cached_features_file)
|
|
features, dataset, examples = (
|
|
features_and_dataset["features"],
|
|
features_and_dataset["dataset"],
|
|
features_and_dataset["examples"],
|
|
)
|
|
else:
|
|
logger.info("Creating features from dataset file at %s", input_dir)
|
|
|
|
if not args.data_dir and ((evaluate and not args.predict_file) or (not evaluate and not args.train_file)):
|
|
try:
|
|
import tensorflow_datasets as tfds
|
|
except ImportError:
|
|
raise ImportError("If not data_dir is specified, tensorflow_datasets needs to be installed.")
|
|
|
|
if args.version_2_with_negative:
|
|
logger.warning("tensorflow_datasets does not handle version 2 of SQuAD.")
|
|
|
|
tfds_examples = tfds.load("squad")
|
|
examples = SquadV1Processor().get_examples_from_dataset(tfds_examples, evaluate=evaluate)
|
|
else:
|
|
processor = SquadV2Processor() if args.version_2_with_negative else SquadV1Processor()
|
|
if evaluate:
|
|
examples = processor.get_dev_examples(args.data_dir, filename=args.predict_file)
|
|
else:
|
|
examples = processor.get_train_examples(args.data_dir, filename=args.train_file)
|
|
|
|
features, dataset = squad_convert_examples_to_features(
|
|
examples=examples,
|
|
tokenizer=tokenizer,
|
|
max_seq_length=args.max_seq_length,
|
|
doc_stride=args.doc_stride,
|
|
max_query_length=args.max_query_length,
|
|
is_training=not evaluate,
|
|
return_dataset="pt",
|
|
threads=args.threads,
|
|
)
|
|
|
|
if args.local_rank in [-1, 0]:
|
|
logger.info("Saving features into cached file %s", cached_features_file)
|
|
torch.save({"features": features, "dataset": dataset, "examples": examples}, cached_features_file)
|
|
|
|
if args.local_rank == 0 and not evaluate:
|
|
# Make sure only the first process in distributed training process the dataset, and the others will use the cache
|
|
torch.distributed.barrier()
|
|
|
|
if output_examples:
|
|
return dataset, examples, features
|
|
return dataset
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser()
|
|
|
|
# Required parameters
|
|
parser.add_argument(
|
|
"--model_type",
|
|
default=None,
|
|
type=str,
|
|
required=True,
|
|
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
|
|
)
|
|
parser.add_argument(
|
|
"--model_name_or_path",
|
|
default=None,
|
|
type=str,
|
|
required=True,
|
|
help="Path to pretrained model or model identifier from huggingface.co/models",
|
|
)
|
|
parser.add_argument(
|
|
"--output_dir",
|
|
default=None,
|
|
type=str,
|
|
required=True,
|
|
help="The output directory where the model checkpoints and predictions will be written.",
|
|
)
|
|
|
|
# Other parameters
|
|
parser.add_argument(
|
|
"--data_dir",
|
|
default=None,
|
|
type=str,
|
|
help="The input data dir. Should contain the .json files for the task."
|
|
+ "If no data dir or train/predict files are specified, will run with tensorflow_datasets.",
|
|
)
|
|
parser.add_argument(
|
|
"--train_file",
|
|
default=None,
|
|
type=str,
|
|
help="The input training file. If a data dir is specified, will look for the file there"
|
|
+ "If no data dir or train/predict files are specified, will run with tensorflow_datasets.",
|
|
)
|
|
parser.add_argument(
|
|
"--predict_file",
|
|
default=None,
|
|
type=str,
|
|
help="The input evaluation file. If a data dir is specified, will look for the file there"
|
|
+ "If no data dir or train/predict files are specified, will run with tensorflow_datasets.",
|
|
)
|
|
parser.add_argument(
|
|
"--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
|
|
)
|
|
parser.add_argument(
|
|
"--tokenizer_name",
|
|
default="",
|
|
type=str,
|
|
help="Pretrained tokenizer name or path if not the same as model_name",
|
|
)
|
|
parser.add_argument(
|
|
"--cache_dir",
|
|
default="",
|
|
type=str,
|
|
help="Where do you want to store the pre-trained models downloaded from huggingface.co",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--version_2_with_negative",
|
|
action="store_true",
|
|
help="If true, the SQuAD examples contain some that do not have an answer.",
|
|
)
|
|
parser.add_argument(
|
|
"--null_score_diff_threshold",
|
|
type=float,
|
|
default=0.0,
|
|
help="If null_score - best_non_null is greater than the threshold predict null.",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--max_seq_length",
|
|
default=384,
|
|
type=int,
|
|
help=(
|
|
"The maximum total input sequence length after WordPiece tokenization. Sequences "
|
|
"longer than this will be truncated, and sequences shorter than this will be padded."
|
|
),
|
|
)
|
|
parser.add_argument(
|
|
"--doc_stride",
|
|
default=128,
|
|
type=int,
|
|
help="When splitting up a long document into chunks, how much stride to take between chunks.",
|
|
)
|
|
parser.add_argument(
|
|
"--max_query_length",
|
|
default=64,
|
|
type=int,
|
|
help=(
|
|
"The maximum number of tokens for the question. Questions longer than this will "
|
|
"be truncated to this length."
|
|
),
|
|
)
|
|
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
|
|
parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
|
|
parser.add_argument(
|
|
"--evaluate_during_training", action="store_true", help="Run evaluation during training at each logging step."
|
|
)
|
|
parser.add_argument(
|
|
"--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model."
|
|
)
|
|
|
|
parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, help="Batch size per GPU/CPU for training.")
|
|
parser.add_argument(
|
|
"--per_gpu_eval_batch_size", default=8, type=int, help="Batch size per GPU/CPU for evaluation."
|
|
)
|
|
parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
|
|
|
|
# Pruning parameters
|
|
parser.add_argument(
|
|
"--mask_scores_learning_rate",
|
|
default=1e-2,
|
|
type=float,
|
|
help="The Adam initial learning rate of the mask scores.",
|
|
)
|
|
parser.add_argument(
|
|
"--initial_threshold", default=1.0, type=float, help="Initial value of the threshold (for scheduling)."
|
|
)
|
|
parser.add_argument(
|
|
"--final_threshold", default=0.7, type=float, help="Final value of the threshold (for scheduling)."
|
|
)
|
|
parser.add_argument(
|
|
"--initial_warmup",
|
|
default=1,
|
|
type=int,
|
|
help=(
|
|
"Run `initial_warmup` * `warmup_steps` steps of threshold warmup during which threshold stays "
|
|
"at its `initial_threshold` value (sparsity schedule)."
|
|
),
|
|
)
|
|
parser.add_argument(
|
|
"--final_warmup",
|
|
default=2,
|
|
type=int,
|
|
help=(
|
|
"Run `final_warmup` * `warmup_steps` steps of threshold cool-down during which threshold stays "
|
|
"at its final_threshold value (sparsity schedule)."
|
|
),
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--pruning_method",
|
|
default="topK",
|
|
type=str,
|
|
help=(
|
|
"Pruning Method (l0 = L0 regularization, magnitude = Magnitude pruning, topK = Movement pruning,"
|
|
" sigmoied_threshold = Soft movement pruning)."
|
|
),
|
|
)
|
|
parser.add_argument(
|
|
"--mask_init",
|
|
default="constant",
|
|
type=str,
|
|
help="Initialization method for the mask scores. Choices: constant, uniform, kaiming.",
|
|
)
|
|
parser.add_argument(
|
|
"--mask_scale", default=0.0, type=float, help="Initialization parameter for the chosen initialization method."
|
|
)
|
|
|
|
parser.add_argument("--regularization", default=None, help="Add L0 or L1 regularization to the mask scores.")
|
|
parser.add_argument(
|
|
"--final_lambda",
|
|
default=0.0,
|
|
type=float,
|
|
help="Regularization intensity (used in conjunction with `regularization`.",
|
|
)
|
|
|
|
parser.add_argument("--global_topk", action="store_true", help="Global TopK on the Scores.")
|
|
parser.add_argument(
|
|
"--global_topk_frequency_compute",
|
|
default=25,
|
|
type=int,
|
|
help="Frequency at which we compute the TopK global threshold.",
|
|
)
|
|
|
|
# Distillation parameters (optional)
|
|
parser.add_argument(
|
|
"--teacher_type",
|
|
default=None,
|
|
type=str,
|
|
help=(
|
|
"Teacher type. Teacher tokenizer and student (model) tokenizer must output the same tokenization. Only for"
|
|
" distillation."
|
|
),
|
|
)
|
|
parser.add_argument(
|
|
"--teacher_name_or_path",
|
|
default=None,
|
|
type=str,
|
|
help="Path to the already SQuAD fine-tuned teacher model. Only for distillation.",
|
|
)
|
|
parser.add_argument(
|
|
"--alpha_ce", default=0.5, type=float, help="Cross entropy loss linear weight. Only for distillation."
|
|
)
|
|
parser.add_argument(
|
|
"--alpha_distil", default=0.5, type=float, help="Distillation loss linear weight. Only for distillation."
|
|
)
|
|
parser.add_argument(
|
|
"--temperature", default=2.0, type=float, help="Distillation temperature. Only for distillation."
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--gradient_accumulation_steps",
|
|
type=int,
|
|
default=1,
|
|
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
|
)
|
|
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
|
|
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
|
|
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
|
parser.add_argument(
|
|
"--num_train_epochs",
|
|
default=3.0,
|
|
type=float,
|
|
help="Total number of training epochs to perform.",
|
|
)
|
|
parser.add_argument(
|
|
"--max_steps",
|
|
default=-1,
|
|
type=int,
|
|
help="If > 0: set total number of training steps to perform. Override num_train_epochs.",
|
|
)
|
|
parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
|
|
parser.add_argument(
|
|
"--n_best_size",
|
|
default=20,
|
|
type=int,
|
|
help="The total number of n-best predictions to generate in the nbest_predictions.json output file.",
|
|
)
|
|
parser.add_argument(
|
|
"--max_answer_length",
|
|
default=30,
|
|
type=int,
|
|
help=(
|
|
"The maximum length of an answer that can be generated. This is needed because the start "
|
|
"and end predictions are not conditioned on one another."
|
|
),
|
|
)
|
|
parser.add_argument(
|
|
"--verbose_logging",
|
|
action="store_true",
|
|
help=(
|
|
"If true, all of the warnings related to data processing will be printed. "
|
|
"A number of warnings are expected for a normal SQuAD evaluation."
|
|
),
|
|
)
|
|
parser.add_argument(
|
|
"--lang_id",
|
|
default=0,
|
|
type=int,
|
|
help=(
|
|
"language id of input for language-specific xlm models (see"
|
|
" tokenization_xlm.PRETRAINED_INIT_CONFIGURATION)"
|
|
),
|
|
)
|
|
|
|
parser.add_argument("--logging_steps", type=int, default=500, help="Log every X updates steps.")
|
|
parser.add_argument("--save_steps", type=int, default=500, help="Save checkpoint every X updates steps.")
|
|
parser.add_argument(
|
|
"--eval_all_checkpoints",
|
|
action="store_true",
|
|
help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number",
|
|
)
|
|
parser.add_argument("--no_cuda", action="store_true", help="Whether not to use CUDA when available")
|
|
parser.add_argument(
|
|
"--overwrite_output_dir", action="store_true", help="Overwrite the content of the output directory"
|
|
)
|
|
parser.add_argument(
|
|
"--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
|
|
)
|
|
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
|
|
|
|
parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus")
|
|
parser.add_argument(
|
|
"--fp16",
|
|
action="store_true",
|
|
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
|
|
)
|
|
parser.add_argument(
|
|
"--fp16_opt_level",
|
|
type=str,
|
|
default="O1",
|
|
help=(
|
|
"For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']. "
|
|
"See details at https://nvidia.github.io/apex/amp.html"
|
|
),
|
|
)
|
|
parser.add_argument("--server_ip", type=str, default="", help="Can be used for distant debugging.")
|
|
parser.add_argument("--server_port", type=str, default="", help="Can be used for distant debugging.")
|
|
|
|
parser.add_argument("--threads", type=int, default=1, help="multiple threads for converting example to features")
|
|
args = parser.parse_args()
|
|
|
|
# Regularization
|
|
if args.regularization == "null":
|
|
args.regularization = None
|
|
|
|
if args.doc_stride >= args.max_seq_length - args.max_query_length:
|
|
logger.warning(
|
|
"WARNING - You've set a doc stride which may be superior to the document length in some "
|
|
"examples. This could result in errors when building features from the examples. Please reduce the doc "
|
|
"stride or increase the maximum length to ensure the features are correctly built."
|
|
)
|
|
|
|
if (
|
|
os.path.exists(args.output_dir)
|
|
and os.listdir(args.output_dir)
|
|
and args.do_train
|
|
and not args.overwrite_output_dir
|
|
):
|
|
raise ValueError(
|
|
"Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
|
|
args.output_dir
|
|
)
|
|
)
|
|
|
|
# Setup distant debugging if needed
|
|
if args.server_ip and args.server_port:
|
|
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
|
|
import ptvsd
|
|
|
|
print("Waiting for debugger attach")
|
|
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
|
|
ptvsd.wait_for_attach()
|
|
|
|
# Setup CUDA, GPU & distributed training
|
|
if args.local_rank == -1 or args.no_cuda:
|
|
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
|
|
args.n_gpu = 0 if args.no_cuda else torch.cuda.device_count()
|
|
else: # Initializes the distributed backend which will take care of synchronizing nodes/GPUs
|
|
torch.cuda.set_device(args.local_rank)
|
|
device = torch.device("cuda", args.local_rank)
|
|
torch.distributed.init_process_group(backend="nccl")
|
|
args.n_gpu = 1
|
|
args.device = device
|
|
|
|
# Setup logging
|
|
logging.basicConfig(
|
|
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
|
datefmt="%m/%d/%Y %H:%M:%S",
|
|
level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN,
|
|
)
|
|
logger.warning(
|
|
"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
|
|
args.local_rank,
|
|
device,
|
|
args.n_gpu,
|
|
bool(args.local_rank != -1),
|
|
args.fp16,
|
|
)
|
|
|
|
# Set seed
|
|
set_seed(args)
|
|
|
|
# Load pretrained model and tokenizer
|
|
if args.local_rank not in [-1, 0]:
|
|
# Make sure only the first process in distributed training will download model & vocab
|
|
torch.distributed.barrier()
|
|
|
|
args.model_type = args.model_type.lower()
|
|
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
|
config = config_class.from_pretrained(
|
|
args.config_name if args.config_name else args.model_name_or_path,
|
|
cache_dir=args.cache_dir if args.cache_dir else None,
|
|
pruning_method=args.pruning_method,
|
|
mask_init=args.mask_init,
|
|
mask_scale=args.mask_scale,
|
|
)
|
|
tokenizer = tokenizer_class.from_pretrained(
|
|
args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
|
|
do_lower_case=args.do_lower_case,
|
|
cache_dir=args.cache_dir if args.cache_dir else None,
|
|
)
|
|
model = model_class.from_pretrained(
|
|
args.model_name_or_path,
|
|
from_tf=bool(".ckpt" in args.model_name_or_path),
|
|
config=config,
|
|
cache_dir=args.cache_dir if args.cache_dir else None,
|
|
)
|
|
|
|
if args.teacher_type is not None:
|
|
assert args.teacher_name_or_path is not None
|
|
assert args.alpha_distil > 0.0
|
|
assert args.alpha_distil + args.alpha_ce > 0.0
|
|
teacher_config_class, teacher_model_class, _ = MODEL_CLASSES[args.teacher_type]
|
|
teacher_config = teacher_config_class.from_pretrained(args.teacher_name_or_path)
|
|
teacher = teacher_model_class.from_pretrained(
|
|
args.teacher_name_or_path,
|
|
from_tf=False,
|
|
config=teacher_config,
|
|
cache_dir=args.cache_dir if args.cache_dir else None,
|
|
)
|
|
teacher.to(args.device)
|
|
else:
|
|
teacher = None
|
|
|
|
if args.local_rank == 0:
|
|
# Make sure only the first process in distributed training will download model & vocab
|
|
torch.distributed.barrier()
|
|
|
|
model.to(args.device)
|
|
|
|
logger.info("Training/evaluation parameters %s", args)
|
|
|
|
# Before we do anything with models, we want to ensure that we get fp16 execution of torch.einsum if args.fp16 is set.
|
|
# Otherwise it'll default to "promote" mode, and we'll get fp32 operations. Note that running `--fp16_opt_level="O2"` will
|
|
# remove the need for this code, but it is still valid.
|
|
if args.fp16:
|
|
try:
|
|
import apex
|
|
|
|
apex.amp.register_half_function(torch, "einsum")
|
|
except ImportError:
|
|
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
|
|
|
|
# Training
|
|
if args.do_train:
|
|
train_dataset = load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False)
|
|
global_step, tr_loss = train(args, train_dataset, model, tokenizer, teacher=teacher)
|
|
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
|
|
|
|
# Save the trained model and the tokenizer
|
|
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
|
logger.info("Saving model checkpoint to %s", args.output_dir)
|
|
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
|
|
# They can then be reloaded using `from_pretrained()`
|
|
# Take care of distributed/parallel training
|
|
model_to_save = model.module if hasattr(model, "module") else model
|
|
model_to_save.save_pretrained(args.output_dir)
|
|
tokenizer.save_pretrained(args.output_dir)
|
|
|
|
# Good practice: save your training arguments together with the trained model
|
|
torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
|
|
|
|
# Load a trained model and vocabulary that you have fine-tuned
|
|
model = model_class.from_pretrained(args.output_dir) # , force_download=True)
|
|
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
|
|
model.to(args.device)
|
|
|
|
# Evaluation - we can ask to evaluate all the checkpoints (sub-directories) in a directory
|
|
results = {}
|
|
if args.do_eval and args.local_rank in [-1, 0]:
|
|
if args.do_train:
|
|
logger.info("Loading checkpoints saved during training for evaluation")
|
|
checkpoints = [args.output_dir]
|
|
if args.eval_all_checkpoints:
|
|
checkpoints = [
|
|
os.path.dirname(c)
|
|
for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True))
|
|
]
|
|
|
|
else:
|
|
logger.info("Loading checkpoint %s for evaluation", args.model_name_or_path)
|
|
checkpoints = [args.model_name_or_path]
|
|
|
|
logger.info("Evaluate the following checkpoints: %s", checkpoints)
|
|
|
|
for checkpoint in checkpoints:
|
|
# Reload the model
|
|
global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
|
|
model = model_class.from_pretrained(checkpoint) # , force_download=True)
|
|
model.to(args.device)
|
|
|
|
# Evaluate
|
|
result = evaluate(args, model, tokenizer, prefix=global_step)
|
|
|
|
result = {k + ("_{}".format(global_step) if global_step else ""): v for k, v in result.items()}
|
|
results.update(result)
|
|
|
|
logger.info("Results: {}".format(results))
|
|
predict_file = list(filter(None, args.predict_file.split("/"))).pop()
|
|
if not os.path.exists(os.path.join(args.output_dir, predict_file)):
|
|
os.makedirs(os.path.join(args.output_dir, predict_file))
|
|
output_eval_file = os.path.join(args.output_dir, predict_file, "eval_results.txt")
|
|
with open(output_eval_file, "w") as writer:
|
|
for key in sorted(results.keys()):
|
|
writer.write("%s = %s\n" % (key, str(results[key])))
|
|
|
|
return results
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|