diff --git a/examples/seq2seq/distillation.py b/examples/seq2seq/distillation.py index 9cf6cbd818..67e695ef99 100644 --- a/examples/seq2seq/distillation.py +++ b/examples/seq2seq/distillation.py @@ -10,7 +10,7 @@ from torch import nn from torch.nn import functional as F from lightning_base import generic_train -from transformers import BartConfig, BartForConditionalGeneration, MBartTokenizer, T5Config, T5ForConditionalGeneration +from transformers import AutoModelForSeq2SeqLM, MBartTokenizer, T5Config, T5ForConditionalGeneration try: @@ -74,22 +74,22 @@ class BartSummarizationDistiller(SummarizationModule): def pre_init(self, hparams): self.output_dir = Path(hparams.output_dir) self.output_dir.mkdir(exist_ok=True) - teacher = BartForConditionalGeneration.from_pretrained(hparams.teacher).eval() + teacher = AutoModelForSeq2SeqLM.from_pretrained(hparams.teacher).eval() student_updates = { "decoder_layers": hparams.student_decoder_layers, "encoder_layers": hparams.student_encoder_layers, } if hparams.length_penalty != -1: student_updates["length_penalty"] = hparams.length_penalty - d_layers_to_copy = get_layers_to_copy(student_updates["decoder_layers"], teacher.config.decoder_layers) + d_layers_to_copy: List = get_layers_to_copy(student_updates["decoder_layers"], teacher.config.decoder_layers) e_layers_to_copy: List = get_layers_to_copy(student_updates["encoder_layers"], teacher.config.encoder_layers) hparams.d_layer_to_copy = d_layers_to_copy hparams.e_layer_to_copy = e_layers_to_copy kw = teacher.config.to_diff_dict() kw.update(student_updates) # Copy weights - student_cfg = BartConfig(**kw) - student = BartForConditionalGeneration(student_cfg) + student_cfg = teacher.config_class(**kw) + student = type(teacher)(student_cfg) student, _ = init_student(student, teacher) save_dir = self.output_dir.joinpath("student") self.copy_to_student(d_layers_to_copy, e_layers_to_copy, hparams, student, teacher) @@ -252,7 +252,6 @@ class BartTranslationDistiller(BartSummarizationDistiller): def __init__(self, hparams, **kwargs): super().__init__(hparams, **kwargs) - assert isinstance(self.tokenizer, MBartTokenizer) assert hparams.src_lang is not None assert hparams.tgt_lang is not None self.dataset_kwargs["src_lang"] = hparams.src_lang diff --git a/examples/seq2seq/test_seq2seq_examples.py b/examples/seq2seq/test_seq2seq_examples.py index 1f70cbd312..2f397c7adc 100644 --- a/examples/seq2seq/test_seq2seq_examples.py +++ b/examples/seq2seq/test_seq2seq_examples.py @@ -186,6 +186,7 @@ class TestSummarizationDistiller(unittest.TestCase): tgt_lang="ro_RO", ) model = self._test_distiller_cli(updates, check_contents=False) + assert model.model.config.model_type == "mbart" ckpts = list(Path(model.output_dir).glob("*.ckpt")) self.assertEqual(1, len(ckpts))