transformers/examples/research_projects/seq2seq-distillation/make_student.py

187 lines
8.0 KiB
Python

import warnings
from pathlib import Path
from typing import List, Tuple, Union
import fire
from torch import nn
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, PreTrainedModel
from transformers.utils import logging
logger = logging.get_logger(__name__)
def copy_layers(src_layers: nn.ModuleList, dest_layers: nn.ModuleList, layers_to_copy: List[int]) -> None:
layers_to_copy = nn.ModuleList([src_layers[i] for i in layers_to_copy])
assert len(dest_layers) == len(layers_to_copy), f"{len(dest_layers)} != {len(layers_to_copy)}"
dest_layers.load_state_dict(layers_to_copy.state_dict())
LAYERS_TO_COPY = {
# maps num layers in teacher -> num_layers in student -> which teacher layers to copy.
# 12: bart, 16: pegasus, 6: marian/Helsinki-NLP
12: {
1: [0], # This says that if the teacher has 12 layers and the student has 1, copy layer 0 of the teacher
2: [0, 6],
3: [0, 6, 11],
4: [0, 4, 8, 11],
6: [0, 2, 4, 7, 9, 11],
9: [0, 1, 2, 4, 5, 7, 9, 10, 11],
12: list(range(12)),
},
16: { # maps num layers in student -> which teacher layers to copy
1: [0],
2: [0, 15],
3: [0, 8, 15],
4: [0, 5, 10, 15],
6: [0, 3, 6, 9, 12, 15],
8: [0, 2, 4, 6, 8, 10, 12, 15],
9: [0, 1, 3, 5, 7, 9, 11, 13, 15],
12: [0, 1, 2, 3, 4, 5, 6, 7, 9, 11, 13, 15],
16: list(range(16)),
},
6: {1: [0], 2: [0, 5], 3: [0, 2, 5], 4: [0, 1, 3, 5], 6: list(range(6))},
}
LAYERS_TO_SUPERVISE = {
# maps num layers in student -> which teacher layers to copy.
6: {1: [5], 2: [3, 5], 3: [1, 4, 5], 4: [1, 2, 4, 5]},
12: {1: [11], 2: [5, 11], 3: [3, 7, 11], 6: [1, 3, 5, 8, 10, 11]},
16: {1: [15], 4: [4, 9, 12, 15], 8: [1, 3, 5, 7, 9, 11, 13, 15]},
}
def pick_layers_to_copy(n_student, n_teacher):
try:
val = LAYERS_TO_COPY[n_teacher][n_student]
return val
except KeyError:
if n_student != n_teacher:
warnings.warn(
f"no hardcoded layers to copy for teacher {n_teacher} -> student {n_student}, defaulting to first"
f" {n_student}"
)
return list(range(n_student))
def get_layers_to_supervise(n_student, n_teacher) -> List[int]:
"""Used or the --supervise_forward kwarg"""
if n_student > n_teacher:
raise ValueError(f"Cannot perform intermediate supervision for student {n_student} > teacher {n_teacher}")
elif n_teacher == n_student:
return list(range(n_teacher))
elif n_student == 1:
return [n_teacher - 1]
else:
return LAYERS_TO_SUPERVISE[n_teacher][n_student]
def create_student_by_copying_alternating_layers(
teacher: Union[str, PreTrainedModel],
save_path: Union[str, Path] = "student",
e: Union[int, None] = None,
d: Union[int, None] = None,
copy_first_teacher_layers=False,
e_layers_to_copy=None,
d_layers_to_copy=None,
**extra_config_kwargs,
) -> Tuple[PreTrainedModel, List[int], List[int]]:
"""Make a student by copying alternating layers from a teacher, save it to save_path.
Args:
teacher: str or PreTrainedModel if str, this will call AutoModelForSeq2SeqLM.from_pretrained(teacher) before
copying layers
save_path: where to save the student, defaults to student directory.
e: how many Encoder layers should the student have, default is fully copy of teacher
d: how many Decoder layers should the student have, default is fully copy of teacher
copy_first_teacher_layers: [bool] dont copy alternating layers, just the first e/d.
**extra_config_kwargs: extra kwargs to pass to the student, by default the teacher config is used.
Returns:
student: new, smaller model. (Also saves it to save_path)
e_layers_to_copy: list of which teacher encoder layers were used
d_layers_to_copy: list of which teacher decoder layers were used
"""
_msg = "encoder_layers and decoder_layers cannot be both None-- you would just have an identical teacher."
assert (e is not None) or (d is not None), _msg
if isinstance(teacher, str):
AutoTokenizer.from_pretrained(teacher).save_pretrained(save_path) # purely for convenience
teacher = AutoModelForSeq2SeqLM.from_pretrained(teacher).eval()
else:
assert isinstance(teacher, PreTrainedModel), f"teacher must be a model or string got type {type(teacher)}"
init_kwargs = teacher.config.to_diff_dict()
try:
teacher_e, teacher_d = teacher.config.encoder_layers, teacher.config.decoder_layers
if e is None:
e = teacher_e
if d is None:
d = teacher_d
init_kwargs.update({"encoder_layers": e, "decoder_layers": d})
except AttributeError: # T5
if hasattr(teacher.config, "num_encoder_layers"):
teacher_e, teacher_d = teacher.config.num_encoder_layers, teacher.config.num_decoder_layers
else:
teacher_e, teacher_d = teacher.config.num_layers, teacher.config.num_decoder_layers
if e is None:
e = teacher_e
if d is None:
d = teacher_d
if hasattr(teacher.config, "num_encoder_layers"):
init_kwargs.update({"num_encoder_layers": e, "num_decoder_layers": d})
else:
init_kwargs.update({"num_layers": e, "num_decoder_layers": d})
# Kwargs to instantiate student: teacher kwargs with updated layer numbers + **extra_config_kwargs
init_kwargs.update(extra_config_kwargs)
# Copy weights
student_cfg = teacher.config_class(**init_kwargs)
student = AutoModelForSeq2SeqLM.from_config(student_cfg)
# Start by copying the full teacher state dict this will copy the first N teacher layers to the student.
info = student.load_state_dict(teacher.state_dict(), strict=False)
assert info.missing_keys == [], info.missing_keys # every student key should have a teacher keys.
if copy_first_teacher_layers: # Our copying is done. We just log and save
e_layers_to_copy, d_layers_to_copy = list(range(e)), list(range(d))
logger.info(
f"Copied encoder layers {e_layers_to_copy} and decoder layers {d_layers_to_copy}. Saving them to"
f" {save_path}"
)
student.save_pretrained(save_path)
return student, e_layers_to_copy, d_layers_to_copy
# Decide which layers of the teacher to copy. Not exactly alternating -- we try to keep first and last layer.
if e_layers_to_copy is None:
e_layers_to_copy: List[int] = pick_layers_to_copy(e, teacher_e)
if d_layers_to_copy is None:
d_layers_to_copy: List[int] = pick_layers_to_copy(d, teacher_d)
try:
if hasattr(
teacher, "prophetnet"
): # For ProphetNet, student.model.encoder.layers is called student.prophetnet.encoder.layers
copy_layers(teacher.prophetnet.encoder.layers, student.prophetnet.encoder.layers, e_layers_to_copy)
copy_layers(teacher.prophetnet.decoder.layers, student.prophetnet.decoder.layers, d_layers_to_copy)
else:
copy_layers(teacher.model.encoder.layers, student.model.encoder.layers, e_layers_to_copy)
copy_layers(teacher.model.decoder.layers, student.model.decoder.layers, d_layers_to_copy)
except AttributeError: # For t5, student.model.encoder.layers is called student.encoder.block
copy_layers(teacher.encoder.block, student.encoder.block, e_layers_to_copy)
copy_layers(teacher.decoder.block, student.decoder.block, d_layers_to_copy)
logger.info(
f"Copied encoder layers {e_layers_to_copy} and decoder layers {d_layers_to_copy}. Saving them to {save_path}"
)
student.config.init_metadata = {
"teacher_type": teacher.config.model_type,
"copied_encoder_layers": e_layers_to_copy,
"copied_decoder_layers": d_layers_to_copy,
}
student.save_pretrained(save_path)
# Save information about copying for easier reproducibility
return student, e_layers_to_copy, d_layers_to_copy
if __name__ == "__main__":
fire.Fire(create_student_by_copying_alternating_layers)