s2s distillation uses AutoModelForSeqToSeqLM (#6761)

This commit is contained in:
Sam Shleifer 2020-08-26 23:25:11 -04:00 committed by GitHub
parent 05e7150a53
commit 4bd7be9a42
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 6 additions and 6 deletions

View File

@ -10,7 +10,7 @@ from torch import nn
from torch.nn import functional as F
from lightning_base import generic_train
from transformers import BartConfig, BartForConditionalGeneration, MBartTokenizer, T5Config, T5ForConditionalGeneration
from transformers import AutoModelForSeq2SeqLM, MBartTokenizer, T5Config, T5ForConditionalGeneration
try:
@ -74,22 +74,22 @@ class BartSummarizationDistiller(SummarizationModule):
def pre_init(self, hparams):
self.output_dir = Path(hparams.output_dir)
self.output_dir.mkdir(exist_ok=True)
teacher = BartForConditionalGeneration.from_pretrained(hparams.teacher).eval()
teacher = AutoModelForSeq2SeqLM.from_pretrained(hparams.teacher).eval()
student_updates = {
"decoder_layers": hparams.student_decoder_layers,
"encoder_layers": hparams.student_encoder_layers,
}
if hparams.length_penalty != -1:
student_updates["length_penalty"] = hparams.length_penalty
d_layers_to_copy = get_layers_to_copy(student_updates["decoder_layers"], teacher.config.decoder_layers)
d_layers_to_copy: List = get_layers_to_copy(student_updates["decoder_layers"], teacher.config.decoder_layers)
e_layers_to_copy: List = get_layers_to_copy(student_updates["encoder_layers"], teacher.config.encoder_layers)
hparams.d_layer_to_copy = d_layers_to_copy
hparams.e_layer_to_copy = e_layers_to_copy
kw = teacher.config.to_diff_dict()
kw.update(student_updates)
# Copy weights
student_cfg = BartConfig(**kw)
student = BartForConditionalGeneration(student_cfg)
student_cfg = teacher.config_class(**kw)
student = type(teacher)(student_cfg)
student, _ = init_student(student, teacher)
save_dir = self.output_dir.joinpath("student")
self.copy_to_student(d_layers_to_copy, e_layers_to_copy, hparams, student, teacher)
@ -252,7 +252,6 @@ class BartTranslationDistiller(BartSummarizationDistiller):
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
assert isinstance(self.tokenizer, MBartTokenizer)
assert hparams.src_lang is not None
assert hparams.tgt_lang is not None
self.dataset_kwargs["src_lang"] = hparams.src_lang

View File

@ -186,6 +186,7 @@ class TestSummarizationDistiller(unittest.TestCase):
tgt_lang="ro_RO",
)
model = self._test_distiller_cli(updates, check_contents=False)
assert model.model.config.model_type == "mbart"
ckpts = list(Path(model.output_dir).glob("*.ckpt"))
self.assertEqual(1, len(ckpts))