320 lines
12 KiB
Python
320 lines
12 KiB
Python
import argparse
|
|
import logging
|
|
import os
|
|
import sys
|
|
import tempfile
|
|
import unittest
|
|
from pathlib import Path
|
|
from unittest.mock import patch
|
|
|
|
import pytest
|
|
import torch
|
|
from pytest import param
|
|
from torch.utils.data import DataLoader
|
|
|
|
from transformers import AutoTokenizer, MBartTokenizer
|
|
from transformers.testing_utils import require_multigpu
|
|
|
|
from .distillation import distill_main, evaluate_checkpoint
|
|
from .finetune import main
|
|
from .pack_dataset import pack_data_dir
|
|
from .run_eval import generate_summaries_or_translations, run_generate
|
|
from .utils import MBartDataset, Seq2SeqDataset, lmap, load_json
|
|
|
|
|
|
logging.basicConfig(level=logging.DEBUG)
|
|
|
|
logger = logging.getLogger()
|
|
CUDA_AVAILABLE = torch.cuda.is_available()
|
|
CHEAP_ARGS = {
|
|
"label_smoothing_eps": 0.2,
|
|
"logger_name": "default",
|
|
"length_penalty": 0.5,
|
|
"cache_dir": "",
|
|
"task": "summarization",
|
|
"num_workers": 2,
|
|
"alpha_hid": 0,
|
|
"freeze_embeds": True,
|
|
"enc_only": False,
|
|
"tgt_suffix": "",
|
|
"resume_from_checkpoint": None,
|
|
"sortish_sampler": True,
|
|
"student_decoder_layers": 1,
|
|
"val_check_interval": 1.0,
|
|
"output_dir": "",
|
|
"fp16": CUDA_AVAILABLE,
|
|
"no_teacher": False,
|
|
"fp16_opt_level": "O1",
|
|
"gpus": 1 if CUDA_AVAILABLE else 0,
|
|
"n_tpu_cores": 0,
|
|
"max_grad_norm": 1.0,
|
|
"do_train": True,
|
|
"do_predict": True,
|
|
"accumulate_grad_batches": 1,
|
|
"server_ip": "",
|
|
"server_port": "",
|
|
"seed": 42,
|
|
"model_name_or_path": "sshleifer/bart-tiny-random",
|
|
"config_name": "",
|
|
"tokenizer_name": "facebook/bart-large",
|
|
"do_lower_case": False,
|
|
"learning_rate": 0.3,
|
|
"weight_decay": 0.0,
|
|
"adam_epsilon": 1e-08,
|
|
"warmup_steps": 0,
|
|
"max_epochs": 1,
|
|
"train_batch_size": 2,
|
|
"eval_batch_size": 2,
|
|
"max_source_length": 12,
|
|
"max_target_length": 12,
|
|
"val_max_target_length": 12,
|
|
"test_max_target_length": 12,
|
|
"fast_dev_run": False,
|
|
"no_cache": False,
|
|
"n_train": -1,
|
|
"n_val": -1,
|
|
"n_test": -1,
|
|
"student_encoder_layers": 1,
|
|
"alpha_loss_encoder": 0.0,
|
|
"freeze_encoder": False,
|
|
"auto_scale_batch_size": False,
|
|
}
|
|
|
|
|
|
def _dump_articles(path: Path, articles: list):
|
|
content = "\n".join(articles)
|
|
Path(path).open("w").writelines(content)
|
|
|
|
|
|
ARTICLES = [" Sam ate lunch today.", "Sams lunch ingredients."]
|
|
SUMMARIES = ["A very interesting story about what I ate for lunch.", "Avocado, celery, turkey, coffee"]
|
|
T5_TINY = "patrickvonplaten/t5-tiny-random"
|
|
BART_TINY = "sshleifer/bart-tiny-random"
|
|
MBART_TINY = "sshleifer/tiny-mbart"
|
|
MARIAN_TINY = "sshleifer/tiny-marian-en-de"
|
|
stream_handler = logging.StreamHandler(sys.stdout)
|
|
logger.addHandler(stream_handler)
|
|
logging.disable(logging.CRITICAL) # remove noisy download output from tracebacks
|
|
|
|
|
|
def make_test_data_dir(**kwargs):
|
|
tmp_dir = Path(tempfile.mkdtemp(**kwargs))
|
|
for split in ["train", "val", "test"]:
|
|
_dump_articles((tmp_dir / f"{split}.source"), ARTICLES)
|
|
_dump_articles((tmp_dir / f"{split}.target"), SUMMARIES)
|
|
return tmp_dir
|
|
|
|
|
|
class TestSummarizationDistiller(unittest.TestCase):
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
logging.disable(logging.CRITICAL) # remove noisy download output from tracebacks
|
|
return cls
|
|
|
|
@require_multigpu
|
|
def test_multigpu(self):
|
|
updates = dict(no_teacher=True, freeze_encoder=True, gpus=2, sortish_sampler=False,)
|
|
self._test_distiller_cli(updates)
|
|
|
|
def test_distill_no_teacher(self):
|
|
updates = dict(student_encoder_layers=2, student_decoder_layers=1, no_teacher=True)
|
|
self._test_distiller_cli(updates)
|
|
|
|
def test_distill_checkpointing_with_teacher(self):
|
|
updates = dict(
|
|
student_encoder_layers=2,
|
|
student_decoder_layers=1,
|
|
max_epochs=4,
|
|
val_check_interval=0.25,
|
|
alpha_hid=2.0,
|
|
model_name_or_path="IGNORE_THIS_IT_DOESNT_GET_USED",
|
|
)
|
|
model = self._test_distiller_cli(updates, check_contents=False)
|
|
|
|
ckpts = list(Path(model.output_dir).glob("*.ckpt"))
|
|
self.assertEqual(1, len(ckpts))
|
|
transformer_ckpts = list(Path(model.output_dir).glob("**/*.bin"))
|
|
self.assertEqual(len(transformer_ckpts), 2)
|
|
examples = lmap(str.strip, model.hparams.data_dir.joinpath("test.source").open().readlines())
|
|
out_path = tempfile.mktemp()
|
|
generate_summaries_or_translations(examples, out_path, str(model.output_dir / "best_tfmr"))
|
|
self.assertTrue(Path(out_path).exists())
|
|
|
|
evaluate_checkpoint(ckpts[0], dest_dir=Path(tempfile.mkdtemp()))
|
|
|
|
@unittest.skip("T5 distillation is broken at the moment")
|
|
def test_distill_t5(self):
|
|
updates = dict(
|
|
student_encoder_layers=1,
|
|
student_decoder_layers=1,
|
|
alpha_hid=2.0,
|
|
teacher=T5_TINY,
|
|
model_name_or_path=T5_TINY,
|
|
tokenizer_name=T5_TINY,
|
|
)
|
|
self._test_distiller_cli(updates)
|
|
|
|
def _test_distiller_cli(self, updates, check_contents=True):
|
|
default_updates = dict(
|
|
train_batch_size=1,
|
|
eval_batch_size=2,
|
|
max_epochs=2,
|
|
alpha_mlm=0.2,
|
|
alpha_ce=0.8,
|
|
do_predict=True,
|
|
model_name_or_path="sshleifer/tinier_bart",
|
|
teacher=CHEAP_ARGS["model_name_or_path"],
|
|
val_check_interval=0.5,
|
|
alpha_encoder_loss=0.4,
|
|
)
|
|
default_updates.update(updates)
|
|
args_d: dict = CHEAP_ARGS.copy()
|
|
tmp_dir = make_test_data_dir()
|
|
output_dir = tempfile.mkdtemp(prefix="output_")
|
|
|
|
args_d.update(data_dir=tmp_dir, output_dir=output_dir, **default_updates)
|
|
model = distill_main(argparse.Namespace(**args_d))
|
|
if not check_contents:
|
|
return model
|
|
contents = os.listdir(output_dir)
|
|
ckpt_name = "val_avg_rouge2=0.0000-step_count=2.ckpt" # "val_avg_rouge2=0.0000-epoch=1.ckpt" # "epoch=1-val_avg_rouge2=0.0000.ckpt"
|
|
contents = {os.path.basename(p) for p in contents}
|
|
self.assertIn(ckpt_name, contents)
|
|
|
|
self.assertIn("test_generations.txt", contents)
|
|
self.assertIn("test_results.txt", contents)
|
|
|
|
metrics = load_json(model.metrics_save_path)
|
|
last_step_stats = metrics["val"][-1]
|
|
self.assertGreaterEqual(last_step_stats["val_avg_gen_time"], 0.01)
|
|
self.assertGreaterEqual(1.0, last_step_stats["val_avg_gen_time"])
|
|
self.assertIsInstance(last_step_stats[f"val_avg_{model.val_metric}"], float)
|
|
desired_n_evals = int(args_d["max_epochs"] * (1 / args_d["val_check_interval"]) + 1)
|
|
self.assertEqual(len(metrics["val"]), desired_n_evals)
|
|
self.assertEqual(len(metrics["test"]), 1)
|
|
return model
|
|
|
|
|
|
@pytest.mark.parametrize(["model"], [pytest.param(T5_TINY), pytest.param(BART_TINY), pytest.param(MBART_TINY)])
|
|
def test_run_eval_bart(model):
|
|
input_file_name = Path(tempfile.mkdtemp()) / "utest_input.source"
|
|
output_file_name = input_file_name.parent / "utest_output.txt"
|
|
assert not output_file_name.exists()
|
|
articles = [" New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County."]
|
|
_dump_articles(input_file_name, articles)
|
|
testargs = ["run_eval.py", model, str(input_file_name), str(output_file_name)] # TODO: test score_path
|
|
with patch.object(sys, "argv", testargs):
|
|
run_generate()
|
|
assert Path(output_file_name).exists()
|
|
os.remove(Path(output_file_name))
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
["model"], [pytest.param(T5_TINY), pytest.param(BART_TINY), pytest.param(MBART_TINY), pytest.param(MARIAN_TINY)],
|
|
)
|
|
def test_finetune(model):
|
|
args_d: dict = CHEAP_ARGS.copy()
|
|
task = "translation" if model in [MBART_TINY, MARIAN_TINY] else "summarization"
|
|
tmp_dir = make_test_data_dir()
|
|
output_dir = tempfile.mkdtemp(prefix="output_")
|
|
args_d.update(
|
|
data_dir=tmp_dir,
|
|
model_name_or_path=model,
|
|
tokenizer_name=None,
|
|
train_batch_size=2,
|
|
eval_batch_size=2,
|
|
output_dir=output_dir,
|
|
do_predict=True,
|
|
task=task,
|
|
src_lang="en_XX",
|
|
tgt_lang="ro_RO",
|
|
freeze_encoder=True,
|
|
freeze_embeds=True,
|
|
)
|
|
assert "n_train" in args_d
|
|
args = argparse.Namespace(**args_d)
|
|
module = main(args)
|
|
|
|
input_embeds = module.model.get_input_embeddings()
|
|
assert not input_embeds.weight.requires_grad
|
|
if model == T5_TINY:
|
|
lm_head = module.model.lm_head
|
|
assert not lm_head.weight.requires_grad
|
|
assert (lm_head.weight == input_embeds.weight).all().item()
|
|
|
|
else:
|
|
bart = module.model.model
|
|
embed_pos = bart.decoder.embed_positions
|
|
assert not embed_pos.weight.requires_grad
|
|
assert not bart.shared.weight.requires_grad
|
|
# check that embeds are the same
|
|
assert bart.decoder.embed_tokens == bart.encoder.embed_tokens
|
|
assert bart.decoder.embed_tokens == bart.shared
|
|
|
|
|
|
def test_pack_dataset():
|
|
tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-cc25")
|
|
tmp_dir = Path(make_test_data_dir())
|
|
save_dir = Path(tempfile.mkdtemp(prefix="packed_"))
|
|
pack_data_dir(tokenizer, tmp_dir, 128, save_dir)
|
|
orig_paths = {x.name for x in tmp_dir.iterdir()}
|
|
new_paths = {x.name for x in save_dir.iterdir()}
|
|
assert orig_paths == new_paths
|
|
|
|
|
|
def test_mbart_dataset_truncation():
|
|
tokenizer = MBartTokenizer.from_pretrained(MBART_TINY)
|
|
tmp_dir = make_test_data_dir()
|
|
max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES)
|
|
max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES)
|
|
trunc = 4
|
|
src_lang, tgt_lang = "ro_RO", "de_DE" # NOT WHAT IT WAS TRAINED ON
|
|
train_dataset = MBartDataset(
|
|
tokenizer,
|
|
data_dir=tmp_dir,
|
|
type_path="train",
|
|
max_source_length=trunc,
|
|
max_target_length=1000, # ignored
|
|
src_lang=src_lang,
|
|
tgt_lang=tgt_lang,
|
|
)
|
|
dataloader = DataLoader(train_dataset, batch_size=2, collate_fn=train_dataset.collate_fn)
|
|
for batch in dataloader:
|
|
assert isinstance(batch, dict)
|
|
assert batch["attention_mask"].shape == batch["input_ids"].shape
|
|
# show that articles were trimmed.
|
|
assert batch["input_ids"].shape[1] == trunc
|
|
# show that targets are the same len
|
|
assert batch["decoder_input_ids"].shape[1] == trunc
|
|
# check language codes in correct place
|
|
assert batch["decoder_input_ids"][0, 0].item() == tokenizer.lang_code_to_id[tgt_lang]
|
|
assert batch["decoder_input_ids"][0, -1].item() == tokenizer.eos_token_id
|
|
assert batch["input_ids"][0, -2].item() == tokenizer.eos_token_id
|
|
assert batch["input_ids"][0, -1].item() == tokenizer.lang_code_to_id[src_lang]
|
|
|
|
assert max_len_target > trunc # Truncated
|
|
assert max_len_source > trunc
|
|
break # No need to test every batch
|
|
|
|
|
|
@pytest.mark.parametrize(["tok"], [pytest.param(T5_TINY), pytest.param(BART_TINY), param(MARIAN_TINY)])
|
|
def test_summarization_dataset_truncation(tok):
|
|
tokenizer = AutoTokenizer.from_pretrained(tok)
|
|
tmp_dir = make_test_data_dir()
|
|
max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES)
|
|
max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES)
|
|
trunc_target = 4
|
|
train_dataset = Seq2SeqDataset(
|
|
tokenizer, data_dir=tmp_dir, type_path="train", max_source_length=20, max_target_length=trunc_target,
|
|
)
|
|
dataloader = DataLoader(train_dataset, batch_size=2, collate_fn=train_dataset.collate_fn)
|
|
for batch in dataloader:
|
|
assert batch["attention_mask"].shape == batch["input_ids"].shape
|
|
# show that articles were trimmed.
|
|
assert batch["input_ids"].shape[1] == max_len_source
|
|
assert 20 >= batch["input_ids"].shape[1] # trimmed significantly
|
|
# show that targets were truncated
|
|
assert batch["decoder_input_ids"].shape[1] == trunc_target # Truncated
|
|
assert max_len_target > trunc_target # Truncated
|
|
break # No need to test every batch
|