various fix and clean up on run_lm_finetuning
This commit is contained in:
parent
f94f1c6016
commit
a690edab17
|
@ -127,4 +127,7 @@ proc_data
|
||||||
|
|
||||||
# examples
|
# examples
|
||||||
runs
|
runs
|
||||||
examples/runs
|
examples/runs
|
||||||
|
|
||||||
|
# data
|
||||||
|
data
|
|
@ -25,33 +25,75 @@ import argparse
|
||||||
import glob
|
import glob
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import pickle
|
||||||
import random
|
import random
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import (DataLoader, SequentialSampler,)
|
from torch.utils.data import DataLoader, Dataset, SequentialSampler
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
from tensorboardX import SummaryWriter
|
from tensorboardX import SummaryWriter
|
||||||
from tqdm import tqdm, trange
|
from tqdm import tqdm, trange
|
||||||
|
|
||||||
from pytorch_transformers import (WEIGHTS_NAME, GPT2Config, GPT2LMHeadModel, GPT2Tokenizer, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP,
|
from pytorch_transformers import (WEIGHTS_NAME, AdamW, WarmupLinearSchedule,
|
||||||
OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
BertConfig, BertForMaskedLM, BertTokenizer,
|
||||||
BertConfig, BertForMaskedLM, BertTokenizer, BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
GPT2Config, GPT2LMHeadModel, GPT2Tokenizer,
|
||||||
RobertaConfig, RobertaForMaskedLM, RobertaTokenizer, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP)
|
OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer,
|
||||||
from pytorch_transformers import AdamW, WarmupLinearSchedule
|
RobertaConfig, RobertaForMaskedLM, RobertaTokenizer)
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
from utils_lm import WikiTextDataset
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
MODEL_CLASSES = {
|
MODEL_CLASSES = {
|
||||||
'gpt2': (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer),
|
'gpt2': (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer),
|
||||||
'openai-gpt': (OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),
|
'openai-gpt': (OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),
|
||||||
"bert": (BertConfig, BertForMaskedLM, BertTokenizer),
|
'bert': (BertConfig, BertForMaskedLM, BertTokenizer),
|
||||||
"roberta": (RobertaConfig, RobertaForMaskedLM, RobertaTokenizer)
|
'roberta': (RobertaConfig, RobertaForMaskedLM, RobertaTokenizer)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TextDataset(Dataset):
|
||||||
|
def __init__(self, tokenizer, file_path='train', block_size=512):
|
||||||
|
assert os.path.isfile(file_path)
|
||||||
|
directory, filename = os.path.split(file_path)
|
||||||
|
cached_features_file = os.path.join(directory, f'cached_lm_{block_size}_{filename}')
|
||||||
|
|
||||||
|
if os.path.exists(cached_features_file):
|
||||||
|
logger.info("Loading features from cached file %s", cached_features_file)
|
||||||
|
with open(cached_features_file, 'rb') as handle:
|
||||||
|
self.examples = pickle.load(handle)
|
||||||
|
else:
|
||||||
|
logger.info("Creating features from dataset file at %s", directory)
|
||||||
|
|
||||||
|
self.examples = []
|
||||||
|
with open(file_path, encoding="utf-8") as f:
|
||||||
|
text = f.read()
|
||||||
|
|
||||||
|
tokenized_text = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text))
|
||||||
|
while len(tokenized_text) >= block_size: # Truncate in block of block_size
|
||||||
|
self.examples.append(tokenized_text[:block_size])
|
||||||
|
tokenized_text = tokenized_text[block_size:]
|
||||||
|
# Note that we are loosing the last truncated example here for the sake of simplicity (no padding)
|
||||||
|
# If your dataset is small, first you should loook for a bigger one :-) and second you
|
||||||
|
# can change this behavior by adding (model specific) padding.
|
||||||
|
|
||||||
|
logger.info("Saving features into cached file %s", cached_features_file)
|
||||||
|
with open(cached_features_file, 'wb') as handle:
|
||||||
|
pickle.dump(self.examples, handle, protocol=pickle.HIGHEST_PROTOCOL)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.examples)
|
||||||
|
|
||||||
|
def __getitem__(self, item):
|
||||||
|
return torch.tensor(self.examples[item])
|
||||||
|
|
||||||
|
|
||||||
|
def load_and_cache_examples(args, tokenizer, evaluate=False):
|
||||||
|
dataset = TextDataset(tokenizer, file_path=args.eval_data_file if evaluate else args.train_data_file, block_size=args.block_size)
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
def set_seed(args):
|
def set_seed(args):
|
||||||
random.seed(args.seed)
|
random.seed(args.seed)
|
||||||
np.random.seed(args.seed)
|
np.random.seed(args.seed)
|
||||||
|
@ -59,20 +101,27 @@ def set_seed(args):
|
||||||
if args.n_gpu > 0:
|
if args.n_gpu > 0:
|
||||||
torch.cuda.manual_seed_all(args.seed)
|
torch.cuda.manual_seed_all(args.seed)
|
||||||
|
|
||||||
# Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original
|
|
||||||
def mask_tokens(inputs, tokenizer, args):
|
|
||||||
labels = inputs.clone()
|
|
||||||
masked_indices = torch.bernoulli(torch.full(labels.shape, args.mlm_probability)).byte()
|
|
||||||
labels[~masked_indices.bool()] = -1 # We only compute loss on masked tokens
|
|
||||||
indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).byte() & masked_indices
|
|
||||||
|
|
||||||
inputs[indices_replaced.bool()] = tokenizer.convert_tokens_to_ids(tokenizer.mask_token) # 80% of the time, replace masked input tokens with [MASK]
|
def mask_tokens(inputs, tokenizer, args):
|
||||||
indices_random = (torch.bernoulli(torch.full(labels.shape, 0.5)).byte() & masked_indices & ~indices_replaced).bool()
|
""" Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. """
|
||||||
random_words = torch.randint(args.num_embeddings, labels.shape, dtype=torch.long)
|
labels = inputs.clone()
|
||||||
inputs[indices_random] = random_words[
|
# We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
|
||||||
indices_random] # 10% of the time, replace masked input tokens with random word
|
masked_indices = torch.bernoulli(torch.full(labels.shape, args.mlm_probability)).byte()
|
||||||
|
labels[~masked_indices] = -1 # We only compute loss on masked tokens
|
||||||
|
|
||||||
|
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
|
||||||
|
indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).byte() & masked_indices
|
||||||
|
inputs[indices_replaced] = tokenizer.convert_tokens_to_ids(tokenizer.mask_token)
|
||||||
|
|
||||||
|
# 10% of the time, we replace masked input tokens with random word
|
||||||
|
indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).byte() & masked_indices & ~indices_replaced
|
||||||
|
random_words = torch.randint(len(tokenizer), labels.shape, dtype=torch.long)
|
||||||
|
inputs[indices_random] = random_words[indices_random]
|
||||||
|
|
||||||
|
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
|
||||||
return inputs, labels
|
return inputs, labels
|
||||||
|
|
||||||
|
|
||||||
def train(args, train_dataset, model, tokenizer):
|
def train(args, train_dataset, model, tokenizer):
|
||||||
""" Train the model """
|
""" Train the model """
|
||||||
if args.local_rank in [-1, 0]:
|
if args.local_rank in [-1, 0]:
|
||||||
|
@ -146,13 +195,15 @@ def train(args, train_dataset, model, tokenizer):
|
||||||
if args.fp16:
|
if args.fp16:
|
||||||
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
||||||
scaled_loss.backward()
|
scaled_loss.backward()
|
||||||
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
|
|
||||||
else:
|
else:
|
||||||
loss.backward()
|
loss.backward()
|
||||||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
|
|
||||||
|
|
||||||
tr_loss += loss.item()
|
tr_loss += loss.item()
|
||||||
if (step + 1) % args.gradient_accumulation_steps == 0:
|
if (step + 1) % args.gradient_accumulation_steps == 0:
|
||||||
|
if args.fp16:
|
||||||
|
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
|
||||||
|
else:
|
||||||
|
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
scheduler.step() # Update learning rate schedule
|
scheduler.step() # Update learning rate schedule
|
||||||
model.zero_grad()
|
model.zero_grad()
|
||||||
|
@ -240,24 +291,22 @@ def evaluate(args, model, tokenizer, prefix=""):
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
def load_and_cache_examples(args, tokenizer, evaluate=False):
|
|
||||||
dataset = WikiTextDataset(args, tokenizer, file="test" if evaluate else "train", directory=args.data_dir)
|
|
||||||
return dataset
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
## Required parameters
|
## Required parameters
|
||||||
parser.add_argument("--data_dir", default=None, type=str, required=True,
|
parser.add_argument("--train_data_file", default=None, type=str, required=True,
|
||||||
help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
|
help="The input training data file (a text file).")
|
||||||
parser.add_argument("--output_dir", default=None, type=str, required=True,
|
parser.add_argument("--output_dir", default=None, type=str, required=True,
|
||||||
help="The output directory where the model predictions and checkpoints will be written.")
|
help="The output directory where the model predictions and checkpoints will be written.")
|
||||||
|
|
||||||
## Other parameters
|
## Other parameters
|
||||||
parser.add_argument("--model_name", default="bert", type=str,
|
parser.add_argument("--eval_data_file", default=None, type=str,
|
||||||
|
help="An optional input evaluation data file to evaluate the perplexity on (a text file).")
|
||||||
|
|
||||||
|
parser.add_argument("--model_type", default="bert", type=str,
|
||||||
help="The model architecture to be fine-tuned.")
|
help="The model architecture to be fine-tuned.")
|
||||||
parser.add_argument("--model_checkpoint", default="bert-base-cased", type=str,
|
parser.add_argument("--model_name_or_path", default="bert-base-cased", type=str,
|
||||||
help="The model checkpoint for weights initialization.")
|
help="The model checkpoint for weights initialization.")
|
||||||
|
|
||||||
parser.add_argument("--mlm", action='store_true',
|
parser.add_argument("--mlm", action='store_true',
|
||||||
|
@ -266,20 +315,21 @@ def main():
|
||||||
help="Ratio of tokens to mask for masked language modeling loss")
|
help="Ratio of tokens to mask for masked language modeling loss")
|
||||||
|
|
||||||
parser.add_argument("--config_name", default="", type=str,
|
parser.add_argument("--config_name", default="", type=str,
|
||||||
help="Pretrained config name or path if not the same as model_name")
|
help="Optional pretrained config name or path if not the same as model_name_or_path")
|
||||||
parser.add_argument("--tokenizer_name", default="", type=str,
|
parser.add_argument("--tokenizer_name", default="", type=str,
|
||||||
help="Pretrained tokenizer name or path if not the same as model_name")
|
help="Optional pretrained tokenizer name or path if not the same as model_name_or_path")
|
||||||
parser.add_argument("--cache_dir", default="", type=str,
|
parser.add_argument("--cache_dir", default="", type=str,
|
||||||
help="Where do you want to store the pre-trained models downloaded from s3")
|
help="Optional directory to store the pre-trained models downloaded from s3 (instread of the default one)")
|
||||||
parser.add_argument("--max_seq_length", default=128, type=int,
|
parser.add_argument("--block_size", default=-1, type=int,
|
||||||
help="The maximum total input sequence length after tokenization. Sequences longer "
|
help="Optional input sequence length after tokenization."
|
||||||
"than this will be truncated, sequences shorter will be padded.")
|
"The training dataset will be truncated in block of this size for training."
|
||||||
|
"Default to the model max input length.")
|
||||||
parser.add_argument("--do_train", action='store_true',
|
parser.add_argument("--do_train", action='store_true',
|
||||||
help="Whether to run training.")
|
help="Whether to run training.")
|
||||||
parser.add_argument("--do_eval", action='store_true',
|
parser.add_argument("--do_eval", action='store_true',
|
||||||
help="Whether to run eval on the dev set.")
|
help="Whether to run eval on the dev set.")
|
||||||
parser.add_argument("--evaluate_during_training", action='store_true',
|
parser.add_argument("--evaluate_during_training", action='store_true',
|
||||||
help="Rul evaluation during training at each logging step.")
|
help="Run evaluation during training at each logging step.")
|
||||||
parser.add_argument("--do_lower_case", action='store_true',
|
parser.add_argument("--do_lower_case", action='store_true',
|
||||||
help="Set this flag if you are using an uncased model.")
|
help="Set this flag if you are using an uncased model.")
|
||||||
|
|
||||||
|
@ -309,7 +359,7 @@ def main():
|
||||||
parser.add_argument('--save_steps', type=int, default=50,
|
parser.add_argument('--save_steps', type=int, default=50,
|
||||||
help="Save checkpoint every X updates steps.")
|
help="Save checkpoint every X updates steps.")
|
||||||
parser.add_argument("--eval_all_checkpoints", action='store_true',
|
parser.add_argument("--eval_all_checkpoints", action='store_true',
|
||||||
help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number")
|
help="Evaluate all checkpoints starting with the same prefix as model_name_or_path ending and ending with step number")
|
||||||
parser.add_argument("--no_cuda", action='store_true',
|
parser.add_argument("--no_cuda", action='store_true',
|
||||||
help="Avoid using CUDA when available")
|
help="Avoid using CUDA when available")
|
||||||
parser.add_argument('--overwrite_output_dir', action='store_true',
|
parser.add_argument('--overwrite_output_dir', action='store_true',
|
||||||
|
@ -330,9 +380,12 @@ def main():
|
||||||
parser.add_argument('--server_port', type=str, default='', help="For distant debugging.")
|
parser.add_argument('--server_port', type=str, default='', help="For distant debugging.")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if args.model_name in ["bert", "roberta"] and not args.mlm:
|
if args.model_type in ["bert", "roberta"] and not args.mlm:
|
||||||
raise ValueError("BERT and RoBERTa do not have LM heads but masked LM heads. They must be run using the --mlm "
|
raise ValueError("BERT and RoBERTa do not have LM heads but masked LM heads. They must be run using the --mlm "
|
||||||
"flag (masked language modeling).")
|
"flag (masked language modeling).")
|
||||||
|
if args.eval_data_file is None and args.do_eval:
|
||||||
|
raise ValueError("Cannot do evaluation without an evaluation data file. Either supply a file to --eval_data_file "
|
||||||
|
"or remove the --do_eval argument.")
|
||||||
|
|
||||||
if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir:
|
if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir:
|
||||||
raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir))
|
raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir))
|
||||||
|
@ -368,30 +421,36 @@ def main():
|
||||||
|
|
||||||
# Load pretrained model and tokenizer
|
# Load pretrained model and tokenizer
|
||||||
if args.local_rank not in [-1, 0]:
|
if args.local_rank not in [-1, 0]:
|
||||||
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
|
torch.distributed.barrier() # Barrier to make sure only the first process in distributed training download model & vocab
|
||||||
|
|
||||||
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_name]
|
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
||||||
config = config_class.from_pretrained(args.config_name if args.config_name else args.model_checkpoint)
|
config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path)
|
||||||
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_checkpoint, do_lower_case=args.do_lower_case)
|
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, do_lower_case=args.do_lower_case)
|
||||||
model = model_class.from_pretrained(args.model_checkpoint, from_tf=bool('.ckpt' in args.model_checkpoint), config=config)
|
if args.block_size <= 0:
|
||||||
args.num_embeddings = config.vocab_size # We need this to create the model at next line (number of embeddings to use)
|
args.block_size = tokenizer.max_len # Our input block size will be the max possible for the model
|
||||||
|
model = model_class.from_pretrained(args.model_name_or_path, from_tf=bool('.ckpt' in args.model_name_or_path), config=config)
|
||||||
|
model.to(args.device)
|
||||||
|
|
||||||
if args.local_rank == 0:
|
if args.local_rank == 0:
|
||||||
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
|
torch.distributed.barrier() # End of barrier to make sure only the first process in distributed training download model & vocab
|
||||||
|
|
||||||
model.to(args.device)
|
|
||||||
|
|
||||||
logger.info("Training/evaluation parameters %s", args)
|
logger.info("Training/evaluation parameters %s", args)
|
||||||
|
|
||||||
|
|
||||||
# Training
|
# Training
|
||||||
if args.do_train:
|
if args.do_train:
|
||||||
|
if args.local_rank not in [-1, 0]:
|
||||||
|
torch.distributed.barrier() # Barrier to make sure only the first process in distributed training process the dataset, and the others will use the cache
|
||||||
|
|
||||||
train_dataset = load_and_cache_examples(args, tokenizer, evaluate=False)
|
train_dataset = load_and_cache_examples(args, tokenizer, evaluate=False)
|
||||||
|
|
||||||
|
if args.local_rank == 0:
|
||||||
|
torch.distributed.barrier()
|
||||||
|
|
||||||
global_step, tr_loss = train(args, train_dataset, model, tokenizer)
|
global_step, tr_loss = train(args, train_dataset, model, tokenizer)
|
||||||
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
|
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
|
||||||
|
|
||||||
|
|
||||||
# Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained()
|
# Saving best-practices: if you use save_pretrained for the model and tokenizer, you can reload them using from_pretrained()
|
||||||
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
||||||
# Create output directory if needed
|
# Create output directory if needed
|
||||||
if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
|
if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
|
||||||
|
@ -409,7 +468,7 @@ def main():
|
||||||
|
|
||||||
# Load a trained model and vocabulary that you have fine-tuned
|
# Load a trained model and vocabulary that you have fine-tuned
|
||||||
model = model_class.from_pretrained(args.output_dir)
|
model = model_class.from_pretrained(args.output_dir)
|
||||||
tokenizer = tokenizer_class.from_pretrained(args.output_dir)
|
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
|
||||||
model.to(args.device)
|
model.to(args.device)
|
||||||
|
|
||||||
|
|
|
@ -1,51 +0,0 @@
|
||||||
from torch.utils.data import Dataset, DataLoader
|
|
||||||
import os
|
|
||||||
import random
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
import logging
|
|
||||||
import pickle
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class WikiTextDataset(Dataset):
|
|
||||||
def __init__(self, args, tokenizer, file='train', directory='wikitext', max_context_length=512, cache=None):
|
|
||||||
if args.local_rank not in [-1, 0]:
|
|
||||||
torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache
|
|
||||||
|
|
||||||
|
|
||||||
cached_features_file = os.path.join(args.data_dir, f'cached_lm_{file}_{args.max_seq_length}')
|
|
||||||
|
|
||||||
if os.path.exists(cached_features_file):
|
|
||||||
logger.info("Loading features from cached file %s", cached_features_file)
|
|
||||||
with open(cached_features_file, 'rb') as handle:
|
|
||||||
self.examples = pickle.load(handle)
|
|
||||||
else:
|
|
||||||
logger.info("Creating features from dataset file at %s", args.data_dir)
|
|
||||||
|
|
||||||
self.max_context_length = max_context_length
|
|
||||||
|
|
||||||
self.examples = []
|
|
||||||
|
|
||||||
with open(os.path.join(directory, f"wiki.{file}.raw"), encoding="utf-8") as f:
|
|
||||||
text = f.read()
|
|
||||||
tokenized_text = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text))
|
|
||||||
|
|
||||||
while len(tokenized_text) > max_context_length:
|
|
||||||
self.examples.append(tokenized_text[:max_context_length])
|
|
||||||
tokenized_text = tokenized_text[max_context_length:]
|
|
||||||
|
|
||||||
if args.local_rank in [-1, 0]:
|
|
||||||
logger.info("Saving features into cached file %s", cached_features_file)
|
|
||||||
with open(cached_features_file, 'wb') as handle:
|
|
||||||
pickle.dump(self.examples, handle, protocol=pickle.HIGHEST_PROTOCOL)
|
|
||||||
|
|
||||||
if args.local_rank == 0:
|
|
||||||
torch.distributed.barrier()
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.examples)
|
|
||||||
|
|
||||||
def __getitem__(self, item):
|
|
||||||
return torch.tensor(self.examples[item])
|
|
Loading…
Reference in New Issue