Regression test for pegasus bugfix (#6606)
This commit is contained in:
parent
86c07e634f
commit
5bf4465e6c
|
@ -22,6 +22,7 @@ from .file_utils import add_start_docstrings_to_callable
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# These config values do not vary between checkpoints
|
||||
DEFAULTS = dict(
|
||||
vocab_size=96103,
|
||||
max_position_embeddings=512,
|
||||
|
@ -46,6 +47,47 @@ DEFAULTS = dict(
|
|||
num_beams=8,
|
||||
activation_function="relu",
|
||||
)
|
||||
# Config values that vary between checkpoints: for testing and conversion
|
||||
max_gen_length = {
|
||||
# See appendix C of paper
|
||||
"xsum": 64,
|
||||
"cnn_dailymail": 128,
|
||||
"newsroom": 128,
|
||||
"wikihow": 256,
|
||||
"multi_news": 256,
|
||||
"reddit_tifu": 128,
|
||||
"big_patent": 256,
|
||||
"arxiv": 256,
|
||||
"pubmed": 256,
|
||||
"gigaword": 32,
|
||||
"aeslc": 32,
|
||||
"billsum": 256,
|
||||
"large": 256, # @sshleifer chose arbitrarily
|
||||
}
|
||||
max_model_length = {
|
||||
"xsum": 512,
|
||||
"cnn_dailymail": 1024,
|
||||
"newsroom": 512,
|
||||
"wikihow": 512,
|
||||
"multi_news": 1024,
|
||||
"reddit_tifu": 512,
|
||||
"big_patent": 1024,
|
||||
"arxiv": 1024,
|
||||
"pubmed": 1024,
|
||||
"gigaword": 128,
|
||||
"aeslc": 512,
|
||||
"billsum": 1024,
|
||||
"large": 1024,
|
||||
}
|
||||
expected_alpha = {
|
||||
"multinews": 0.9,
|
||||
"wikihow": 0.6,
|
||||
"reddit_tifu": 0.6,
|
||||
"big_patent": 0.7,
|
||||
"gigaword": 0.6,
|
||||
"aeslc": 0.6,
|
||||
"billsum": 0.6,
|
||||
} # otherwise 0.8
|
||||
|
||||
|
||||
@add_start_docstrings_to_callable(BART_CONFIG_ARGS_DOC)
|
||||
|
@ -56,7 +98,3 @@ class PegasusConfig(BartConfig):
|
|||
"""
|
||||
model_type = "pegasus"
|
||||
# The implementation of the config object is in BartConfig
|
||||
|
||||
@property
|
||||
def default_config_parameters(self):
|
||||
return DEFAULTS
|
||||
|
|
|
@ -22,7 +22,7 @@ import torch
|
|||
from tqdm import tqdm
|
||||
|
||||
from transformers import PegasusConfig, PegasusForConditionalGeneration, PegasusTokenizer
|
||||
from transformers.configuration_pegasus import DEFAULTS
|
||||
from transformers.configuration_pegasus import DEFAULTS, expected_alpha, max_gen_length, max_model_length
|
||||
|
||||
|
||||
PATTERNS = [
|
||||
|
@ -52,47 +52,7 @@ def rename_state_dict_key(k):
|
|||
|
||||
|
||||
# See appendix C of paper for all hyperparams
|
||||
max_gen_length = {
|
||||
# See appendix C of paper
|
||||
"xsum": 64,
|
||||
"cnn_dailymail": 128,
|
||||
"newsroom": 128,
|
||||
"wikihow": 256,
|
||||
"multi_news": 256,
|
||||
"reddit_tifu": 128,
|
||||
"big_patent": 256,
|
||||
"arxiv": 256,
|
||||
"pubmed": 256,
|
||||
"gigaword": 32,
|
||||
"aeslc": 32,
|
||||
"billsum": 256,
|
||||
"large": 256, # @sshleifer chose arbitrarily
|
||||
}
|
||||
max_model_length = {
|
||||
"xsum": 512,
|
||||
"cnn_dailymail": 1024,
|
||||
"newsroom": 512,
|
||||
"wikihow": 512,
|
||||
"multi_news": 1024,
|
||||
"reddit_tifu": 512,
|
||||
"big_patent": 1024,
|
||||
"arxiv": 1024,
|
||||
"pubmed": 1024,
|
||||
"gigaword": 128,
|
||||
"aeslc": 512,
|
||||
"billsum": 1024,
|
||||
"large": 1024,
|
||||
}
|
||||
|
||||
expected_alpha = {
|
||||
"multinews": 0.9,
|
||||
"wikihow": 0.6,
|
||||
"reddit_tifu": 0.6,
|
||||
"big_patent": 0.7,
|
||||
"gigaword": 0.6,
|
||||
"aeslc": 0.6,
|
||||
"billsum": 0.6,
|
||||
} # otherwise 0.8
|
||||
# TODO(SS): one constant
|
||||
|
||||
|
||||
|
@ -151,7 +111,11 @@ def convert_pegasus_ckpt_to_pytorch(ckpt_path, save_dir):
|
|||
|
||||
# convert model
|
||||
tf_weights = get_tf_weights_as_numpy(ckpt_path)
|
||||
cfg_updates = dict(max_length=max_gen_length[dataset], length_penalty=expected_alpha.get(dataset, 0.8))
|
||||
cfg_updates = dict(
|
||||
max_length=max_gen_length[dataset],
|
||||
length_penalty=expected_alpha.get(dataset, 0.8),
|
||||
max_position_embeddings=desired_max_model_length,
|
||||
)
|
||||
torch_model = convert_pegasus_to_bart(tf_weights, cfg_updates)
|
||||
torch_model.save_pretrained(save_dir)
|
||||
|
||||
|
|
|
@ -23,6 +23,13 @@ from .modeling_bart import BART_START_DOCSTRING, BartForConditionalGeneration
|
|||
@add_start_docstrings("The Pegasus Model for summarization ", BART_START_DOCSTRING)
|
||||
class PegasusForConditionalGeneration(BartForConditionalGeneration):
|
||||
config_class = PegasusConfig
|
||||
authorized_missing_keys = [
|
||||
r"final_logits_bias",
|
||||
r"encoder\.version",
|
||||
r"decoder\.version",
|
||||
r"model.encoder.embed_positions",
|
||||
"model.decoder.embed_positions",
|
||||
]
|
||||
r"""
|
||||
Pytorch version of google's pegasus model for summarization.
|
||||
Model API is identical to BartForConditionalGeneration.
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import unittest
|
||||
|
||||
from transformers import AutoConfig, is_torch_available
|
||||
from transformers import AutoConfig, AutoTokenizer, is_torch_available
|
||||
from transformers.configuration_pegasus import max_gen_length, max_model_length
|
||||
from transformers.file_utils import cached_property
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
|
||||
|
@ -50,28 +51,28 @@ class PegasusXSUMIntegrationTest(AbstractSeq2SeqIntegrationTest):
|
|||
|
||||
class PegasusConfigTests(unittest.TestCase):
|
||||
def test_all_config_max_lengths(self):
|
||||
expected_max_length = {
|
||||
# See appendix C of paper
|
||||
"xsum": 64,
|
||||
"cnn_dailymail": 128,
|
||||
"newsroom": 128,
|
||||
"wikihow": 256,
|
||||
"multi_news": 256,
|
||||
"reddit_tifu": 128,
|
||||
"big_patent": 256,
|
||||
"arxiv": 256,
|
||||
"pubmed": 256,
|
||||
"gigaword": 32,
|
||||
"aeslc": 32,
|
||||
"billsum": 256,
|
||||
}
|
||||
failures = []
|
||||
pegasus_prefix = "google/pegasus"
|
||||
for dataset, max_len in expected_max_length.items():
|
||||
for dataset, max_len in max_gen_length.items():
|
||||
mname = f"{pegasus_prefix}-{dataset}"
|
||||
cfg = AutoConfig.from_pretrained(mname)
|
||||
|
||||
if cfg.max_length != max_len:
|
||||
failures.append(f"config for {mname} had max_length: {cfg.max_length}, expected {max_len}")
|
||||
|
||||
if cfg.max_position_embeddings < max_model_length[dataset]:
|
||||
# otherwise you get IndexError for e.g. position 513
|
||||
# see https://github.com/huggingface/transformers/issues/6599
|
||||
failures.append(
|
||||
f"config for {mname} had max_position_embeddings: {cfg.max_position_embeddings}, expected {max_model_length[dataset]}"
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(mname)
|
||||
if max_model_length[dataset] != tokenizer.model_max_length:
|
||||
failures.append(
|
||||
f"tokenizer.model_max_length {tokenizer.model_max_length} expected {max_model_length[dataset]}"
|
||||
)
|
||||
|
||||
if failures == []:
|
||||
return
|
||||
# error
|
||||
|
|
Loading…
Reference in New Issue