s2s distillation uses AutoModelForSeqToSeqLM (#6761)
This commit is contained in:
parent
05e7150a53
commit
4bd7be9a42
|
@ -10,7 +10,7 @@ from torch import nn
|
|||
from torch.nn import functional as F
|
||||
|
||||
from lightning_base import generic_train
|
||||
from transformers import BartConfig, BartForConditionalGeneration, MBartTokenizer, T5Config, T5ForConditionalGeneration
|
||||
from transformers import AutoModelForSeq2SeqLM, MBartTokenizer, T5Config, T5ForConditionalGeneration
|
||||
|
||||
|
||||
try:
|
||||
|
@ -74,22 +74,22 @@ class BartSummarizationDistiller(SummarizationModule):
|
|||
def pre_init(self, hparams):
|
||||
self.output_dir = Path(hparams.output_dir)
|
||||
self.output_dir.mkdir(exist_ok=True)
|
||||
teacher = BartForConditionalGeneration.from_pretrained(hparams.teacher).eval()
|
||||
teacher = AutoModelForSeq2SeqLM.from_pretrained(hparams.teacher).eval()
|
||||
student_updates = {
|
||||
"decoder_layers": hparams.student_decoder_layers,
|
||||
"encoder_layers": hparams.student_encoder_layers,
|
||||
}
|
||||
if hparams.length_penalty != -1:
|
||||
student_updates["length_penalty"] = hparams.length_penalty
|
||||
d_layers_to_copy = get_layers_to_copy(student_updates["decoder_layers"], teacher.config.decoder_layers)
|
||||
d_layers_to_copy: List = get_layers_to_copy(student_updates["decoder_layers"], teacher.config.decoder_layers)
|
||||
e_layers_to_copy: List = get_layers_to_copy(student_updates["encoder_layers"], teacher.config.encoder_layers)
|
||||
hparams.d_layer_to_copy = d_layers_to_copy
|
||||
hparams.e_layer_to_copy = e_layers_to_copy
|
||||
kw = teacher.config.to_diff_dict()
|
||||
kw.update(student_updates)
|
||||
# Copy weights
|
||||
student_cfg = BartConfig(**kw)
|
||||
student = BartForConditionalGeneration(student_cfg)
|
||||
student_cfg = teacher.config_class(**kw)
|
||||
student = type(teacher)(student_cfg)
|
||||
student, _ = init_student(student, teacher)
|
||||
save_dir = self.output_dir.joinpath("student")
|
||||
self.copy_to_student(d_layers_to_copy, e_layers_to_copy, hparams, student, teacher)
|
||||
|
@ -252,7 +252,6 @@ class BartTranslationDistiller(BartSummarizationDistiller):
|
|||
|
||||
def __init__(self, hparams, **kwargs):
|
||||
super().__init__(hparams, **kwargs)
|
||||
assert isinstance(self.tokenizer, MBartTokenizer)
|
||||
assert hparams.src_lang is not None
|
||||
assert hparams.tgt_lang is not None
|
||||
self.dataset_kwargs["src_lang"] = hparams.src_lang
|
||||
|
|
|
@ -186,6 +186,7 @@ class TestSummarizationDistiller(unittest.TestCase):
|
|||
tgt_lang="ro_RO",
|
||||
)
|
||||
model = self._test_distiller_cli(updates, check_contents=False)
|
||||
assert model.model.config.model_type == "mbart"
|
||||
|
||||
ckpts = list(Path(model.output_dir).glob("*.ckpt"))
|
||||
self.assertEqual(1, len(ckpts))
|
||||
|
|
Loading…
Reference in New Issue