Regression test for pegasus bugfix (#6606)

This commit is contained in:
Sam Shleifer 2020-08-20 15:34:43 -04:00 committed by GitHub
parent 86c07e634f
commit 5bf4465e6c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 73 additions and 63 deletions

View File

@ -22,6 +22,7 @@ from .file_utils import add_start_docstrings_to_callable
logger = logging.getLogger(__name__)
# These config values do not vary between checkpoints
DEFAULTS = dict(
vocab_size=96103,
max_position_embeddings=512,
@ -46,6 +47,47 @@ DEFAULTS = dict(
num_beams=8,
activation_function="relu",
)
# Config values that vary between checkpoints: for testing and conversion
max_gen_length = {
# See appendix C of paper
"xsum": 64,
"cnn_dailymail": 128,
"newsroom": 128,
"wikihow": 256,
"multi_news": 256,
"reddit_tifu": 128,
"big_patent": 256,
"arxiv": 256,
"pubmed": 256,
"gigaword": 32,
"aeslc": 32,
"billsum": 256,
"large": 256, # @sshleifer chose arbitrarily
}
max_model_length = {
"xsum": 512,
"cnn_dailymail": 1024,
"newsroom": 512,
"wikihow": 512,
"multi_news": 1024,
"reddit_tifu": 512,
"big_patent": 1024,
"arxiv": 1024,
"pubmed": 1024,
"gigaword": 128,
"aeslc": 512,
"billsum": 1024,
"large": 1024,
}
expected_alpha = {
"multinews": 0.9,
"wikihow": 0.6,
"reddit_tifu": 0.6,
"big_patent": 0.7,
"gigaword": 0.6,
"aeslc": 0.6,
"billsum": 0.6,
} # otherwise 0.8
@add_start_docstrings_to_callable(BART_CONFIG_ARGS_DOC)
@ -56,7 +98,3 @@ class PegasusConfig(BartConfig):
"""
model_type = "pegasus"
# The implementation of the config object is in BartConfig
@property
def default_config_parameters(self):
return DEFAULTS

View File

@ -22,7 +22,7 @@ import torch
from tqdm import tqdm
from transformers import PegasusConfig, PegasusForConditionalGeneration, PegasusTokenizer
from transformers.configuration_pegasus import DEFAULTS
from transformers.configuration_pegasus import DEFAULTS, expected_alpha, max_gen_length, max_model_length
PATTERNS = [
@ -52,47 +52,7 @@ def rename_state_dict_key(k):
# See appendix C of paper for all hyperparams
max_gen_length = {
# See appendix C of paper
"xsum": 64,
"cnn_dailymail": 128,
"newsroom": 128,
"wikihow": 256,
"multi_news": 256,
"reddit_tifu": 128,
"big_patent": 256,
"arxiv": 256,
"pubmed": 256,
"gigaword": 32,
"aeslc": 32,
"billsum": 256,
"large": 256, # @sshleifer chose arbitrarily
}
max_model_length = {
"xsum": 512,
"cnn_dailymail": 1024,
"newsroom": 512,
"wikihow": 512,
"multi_news": 1024,
"reddit_tifu": 512,
"big_patent": 1024,
"arxiv": 1024,
"pubmed": 1024,
"gigaword": 128,
"aeslc": 512,
"billsum": 1024,
"large": 1024,
}
expected_alpha = {
"multinews": 0.9,
"wikihow": 0.6,
"reddit_tifu": 0.6,
"big_patent": 0.7,
"gigaword": 0.6,
"aeslc": 0.6,
"billsum": 0.6,
} # otherwise 0.8
# TODO(SS): one constant
@ -151,7 +111,11 @@ def convert_pegasus_ckpt_to_pytorch(ckpt_path, save_dir):
# convert model
tf_weights = get_tf_weights_as_numpy(ckpt_path)
cfg_updates = dict(max_length=max_gen_length[dataset], length_penalty=expected_alpha.get(dataset, 0.8))
cfg_updates = dict(
max_length=max_gen_length[dataset],
length_penalty=expected_alpha.get(dataset, 0.8),
max_position_embeddings=desired_max_model_length,
)
torch_model = convert_pegasus_to_bart(tf_weights, cfg_updates)
torch_model.save_pretrained(save_dir)

View File

@ -23,6 +23,13 @@ from .modeling_bart import BART_START_DOCSTRING, BartForConditionalGeneration
@add_start_docstrings("The Pegasus Model for summarization ", BART_START_DOCSTRING)
class PegasusForConditionalGeneration(BartForConditionalGeneration):
config_class = PegasusConfig
authorized_missing_keys = [
r"final_logits_bias",
r"encoder\.version",
r"decoder\.version",
r"model.encoder.embed_positions",
"model.decoder.embed_positions",
]
r"""
Pytorch version of google's pegasus model for summarization.
Model API is identical to BartForConditionalGeneration.

View File

@ -1,6 +1,7 @@
import unittest
from transformers import AutoConfig, is_torch_available
from transformers import AutoConfig, AutoTokenizer, is_torch_available
from transformers.configuration_pegasus import max_gen_length, max_model_length
from transformers.file_utils import cached_property
from transformers.testing_utils import require_torch, slow, torch_device
@ -50,28 +51,28 @@ class PegasusXSUMIntegrationTest(AbstractSeq2SeqIntegrationTest):
class PegasusConfigTests(unittest.TestCase):
def test_all_config_max_lengths(self):
expected_max_length = {
# See appendix C of paper
"xsum": 64,
"cnn_dailymail": 128,
"newsroom": 128,
"wikihow": 256,
"multi_news": 256,
"reddit_tifu": 128,
"big_patent": 256,
"arxiv": 256,
"pubmed": 256,
"gigaword": 32,
"aeslc": 32,
"billsum": 256,
}
failures = []
pegasus_prefix = "google/pegasus"
for dataset, max_len in expected_max_length.items():
for dataset, max_len in max_gen_length.items():
mname = f"{pegasus_prefix}-{dataset}"
cfg = AutoConfig.from_pretrained(mname)
if cfg.max_length != max_len:
failures.append(f"config for {mname} had max_length: {cfg.max_length}, expected {max_len}")
if cfg.max_position_embeddings < max_model_length[dataset]:
# otherwise you get IndexError for e.g. position 513
# see https://github.com/huggingface/transformers/issues/6599
failures.append(
f"config for {mname} had max_position_embeddings: {cfg.max_position_embeddings}, expected {max_model_length[dataset]}"
)
tokenizer = AutoTokenizer.from_pretrained(mname)
if max_model_length[dataset] != tokenizer.model_max_length:
failures.append(
f"tokenizer.model_max_length {tokenizer.model_max_length} expected {max_model_length[dataset]}"
)
if failures == []:
return
# error