100 lines
3.4 KiB
Python
100 lines
3.4 KiB
Python
import logging
|
|
|
|
import torch
|
|
from accelerate import Accelerator
|
|
from arguments import EvaluationArguments
|
|
from datasets import load_dataset
|
|
from torch.utils.data import IterableDataset
|
|
from torch.utils.data.dataloader import DataLoader
|
|
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, set_seed
|
|
|
|
|
|
class ConstantLengthDataset(IterableDataset):
|
|
def __init__(self, tokenizer, dataset, seq_length=1024, num_of_sequences=1024, chars_per_token=3.6):
|
|
self.tokenizer = tokenizer
|
|
self.concat_token_id = tokenizer.bos_token_id
|
|
self.dataset = dataset
|
|
self.seq_length = seq_length
|
|
self.input_characters = seq_length * chars_per_token * num_of_sequences
|
|
|
|
def __iter__(self):
|
|
iterator = iter(self.dataset)
|
|
more_examples = True
|
|
while more_examples:
|
|
buffer, buffer_len = [], 0
|
|
while True:
|
|
if buffer_len >= self.input_characters:
|
|
break
|
|
try:
|
|
buffer.append(next(iterator)["content"])
|
|
buffer_len += len(buffer[-1])
|
|
except StopIteration:
|
|
more_examples = False
|
|
break
|
|
tokenized_inputs = tokenizer(buffer, truncation=False)["input_ids"]
|
|
all_token_ids = []
|
|
for tokenized_input in tokenized_inputs:
|
|
all_token_ids.extend(tokenized_input + [self.concat_token_id])
|
|
for i in range(0, len(all_token_ids), self.seq_length):
|
|
input_ids = all_token_ids[i : i + self.seq_length]
|
|
if len(input_ids) == self.seq_length:
|
|
yield torch.tensor(input_ids)
|
|
|
|
|
|
def create_dataloader(args):
|
|
ds_kwargs = {"streaming": True}
|
|
valid_data = load_dataset(args.dataset_name, split="train", **ds_kwargs)
|
|
valid_dataset = ConstantLengthDataset(tokenizer, valid_data, seq_length=args.seq_length)
|
|
eval_dataloader = DataLoader(valid_dataset, batch_size=args.batch_size)
|
|
return eval_dataloader
|
|
|
|
|
|
def evaluate(args):
|
|
model.eval()
|
|
losses = []
|
|
for step, batch in enumerate(eval_dataloader):
|
|
with torch.no_grad():
|
|
outputs = model(batch, labels=batch)
|
|
loss = outputs.loss.repeat(args.batch_size)
|
|
losses.append(accelerator.gather(loss))
|
|
|
|
if args.max_eval_steps > 0 and step >= args.max_eval_steps:
|
|
break
|
|
loss = torch.mean(torch.cat(losses))
|
|
try:
|
|
perplexity = torch.exp(loss)
|
|
except OverflowError:
|
|
perplexity = float("inf")
|
|
return loss.item(), perplexity.item()
|
|
|
|
|
|
# Setup Accelerator
|
|
accelerator = Accelerator()
|
|
|
|
# Parse configuration
|
|
parser = HfArgumentParser(EvaluationArguments)
|
|
args = parser.parse_args()
|
|
set_seed(args.seed)
|
|
|
|
# Logging
|
|
logger = logging.getLogger(__name__)
|
|
logging.basicConfig(
|
|
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO
|
|
)
|
|
|
|
# Load model and tokenizer
|
|
model = AutoModelForCausalLM.from_pretrained(args.model_ckpt)
|
|
tokenizer = AutoTokenizer.from_pretrained(args.model_ckpt)
|
|
|
|
# Load dataset and dataloader
|
|
eval_dataloader = create_dataloader(args)
|
|
|
|
# Prepare everything with our `accelerator`.
|
|
model, eval_dataloader = accelerator.prepare(model, eval_dataloader)
|
|
|
|
# Evaluate and save the last checkpoint
|
|
logger.info("Evaluating and saving model after training")
|
|
eval_loss, perplexity = evaluate(args)
|
|
logger.info(f"loss/eval: {eval_loss}, perplexity: {perplexity}")
|