[T5] allow config.decoder_layers to control decoder size (#7409)

* Working assymmetrical T5

* rename decoder_layers -> num_decoder_layers

* Fix docstring

* Allow creation of asymmetric t5 students
This commit is contained in:
Sam Shleifer 2020-09-28 03:08:04 -04:00 committed by GitHub
parent 7296fea1d6
commit 748425d47d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 58 additions and 10 deletions

View File

@ -116,12 +116,14 @@ def create_student_by_copying_alternating_layers(
d = teacher_d
init_kwargs.update({"encoder_layers": e, "decoder_layers": d})
except AttributeError: # T5
teacher_e, teacher_d = teacher.config.num_layers, teacher.config.num_hidden_layers
assert e == d, "T5 Students must be symmetric"
init_kwargs["num_layers"] = e
# Kwargs to instantiate student = teacher kwargs with updated layer numbers + **extra_config_kwargs
teacher_e, teacher_d = teacher.config.num_layers, teacher.config.num_decoder_layers
if e is None:
e = teacher_e
if d is None:
d = teacher_d
init_kwargs.update({"num_layers": e, "num_decoder_layers": d})
# Kwargs to instantiate student: teacher kwargs with updated layer numbers + **extra_config_kwargs
init_kwargs.update(extra_config_kwargs)
# Copy weights

View File

@ -21,9 +21,7 @@ class MakeStudentTester(unittest.TestCase):
student, *_ = create_student_by_copying_alternating_layers(TINY_T5, tempfile.mkdtemp(), e=1, d=1)
self.assertEqual(student.config.num_hidden_layers, 1)
def test_invalid_t5(self):
# T5 students must have the same e==d because there is only one config property
with self.assertRaises(AssertionError):
def test_asymmetric_t5(self):
student, *_ = create_student_by_copying_alternating_layers(TINY_T5, tempfile.mkdtemp(), e=1, d=None)
def test_same_decoder_small_encoder(self):

View File

@ -57,6 +57,8 @@ class T5Config(PretrainedConfig):
Size of the intermediate feed forward layer in each :obj:`T5Block`.
num_layers (:obj:`int`, `optional`, defaults to 6):
Number of hidden layers in the Transformer encoder.
num_decoder_layers (:obj:`int`, `optional`):
Number of hidden layers in the Transformer decoder. Will use the same value as :obj:`num_layers` if not set.
num_heads (:obj:`int`, `optional`, defaults to 8):
Number of attention heads for each attention layer in
the Transformer encoder.
@ -80,6 +82,7 @@ class T5Config(PretrainedConfig):
d_kv=64,
d_ff=2048,
num_layers=6,
num_decoder_layers=None,
num_heads=8,
relative_attention_num_buckets=32,
dropout_rate=0.1,
@ -102,6 +105,9 @@ class T5Config(PretrainedConfig):
self.d_kv = d_kv
self.d_ff = d_ff
self.num_layers = num_layers
self.num_decoder_layers = (
num_decoder_layers if num_decoder_layers is not None else self.num_layers
) # default = symmetry
self.num_heads = num_heads
self.relative_attention_num_buckets = relative_attention_num_buckets
self.dropout_rate = dropout_rate

View File

@ -907,7 +907,7 @@ T5_INPUTS_DOCSTRING = r"""
T5_START_DOCSTRING,
)
class T5Model(T5PreTrainedModel):
def __init__(self, config):
def __init__(self, config: T5Config):
super().__init__(config)
self.shared = nn.Embedding(config.vocab_size, config.d_model)
@ -919,6 +919,7 @@ class T5Model(T5PreTrainedModel):
decoder_config = copy.deepcopy(config)
decoder_config.is_decoder = True
decoder_config.is_encoder_decoder = False
decoder_config.num_layers = config.num_decoder_layers
self.decoder = T5Stack(decoder_config, self.shared)
self.init_weights()
@ -1077,6 +1078,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
decoder_config = copy.deepcopy(config)
decoder_config.is_decoder = True
decoder_config.is_encoder_decoder = False
decoder_config.num_layers = config.num_decoder_layers
self.decoder = T5Stack(decoder_config, self.shared)
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)

View File

@ -59,6 +59,7 @@ class T5ModelTester:
pad_token_id=0,
decoder_start_token_id=0,
scope=None,
decoder_layers=None,
):
self.parent = parent
@ -83,6 +84,7 @@ class T5ModelTester:
self.pad_token_id = pad_token_id
self.decoder_start_token_id = decoder_start_token_id
self.scope = None
self.decoder_layers = decoder_layers
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.encoder_seq_length], self.vocab_size)
@ -105,6 +107,7 @@ class T5ModelTester:
d_ff=self.d_ff,
d_kv=self.hidden_size // self.num_attention_heads,
num_layers=self.num_hidden_layers,
num_decoder_layers=self.decoder_layers,
num_heads=self.num_attention_heads,
relative_attention_num_buckets=self.relative_attention_num_buckets,
dropout_rate=self.dropout_rate,
@ -623,3 +626,40 @@ class T5ModelIntegrationTests(unittest.TestCase):
output = model.generate(**inputs)
translation = tok.decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
self.assertEqual(translation, expected_translation)
@require_torch
class TestAsymmetricT5(unittest.TestCase):
def build_model_and_check_forward_pass(self, **kwargs):
tester = T5ModelTester(self, **kwargs)
config, *inputs = tester.prepare_config_and_inputs()
(
input_ids,
decoder_input_ids,
attention_mask,
decoder_attention_mask,
lm_labels,
) = inputs
model = T5ForConditionalGeneration(config=config).to(torch_device).eval()
outputs = model(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
labels=lm_labels,
)
# outputs = model(*inputs)
assert len(outputs) == 4
assert outputs["logits"].size() == (tester.batch_size, tester.decoder_seq_length, tester.vocab_size)
assert outputs["loss"].size() == ()
return model
def test_small_decoder(self):
# num_hidden_layers is passed to T5Config as num_layers
model = self.build_model_and_check_forward_pass(decoder_layers=1, num_hidden_layers=2)
assert len(model.encoder.block) == 2
assert len(model.decoder.block) == 1
def test_defaulting_to_symmetry(self):
# num_hidden_layers is passed to T5Config as num_layers
model = self.build_model_and_check_forward_pass(num_hidden_layers=2)
assert len(model.decoder.block) == len(model.encoder.block) == 2