Split LMBert model in two (#4874)
* Split LMBert model in two * Fix example * Remove lm_labels * Adapt tests, refactor prepare_for_generation * Fix merge * Hide BeartLMHeadModel
This commit is contained in:
parent
f6da8b2200
commit
1e2631d6f8
|
@ -873,7 +873,116 @@ class BertForPreTraining(BertPreTrainedModel):
|
|||
return outputs # (loss), prediction_scores, seq_relationship_score, (hidden_states), (attentions)
|
||||
|
||||
|
||||
# TODO: Split with a different BertWithLMHead to get rid of `lm_labels` here and in encoder_decoder.
|
||||
@add_start_docstrings(
|
||||
"""Bert Model with a `language modeling` head on top for CLM fine-tuning. """, BERT_START_DOCSTRING
|
||||
)
|
||||
class BertLMHeadModel(BertPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
assert config.is_decoder, "If you want to use `BertLMHeadModel` as a standalone, add `is_decoder=True`."
|
||||
|
||||
self.bert = BertModel(config)
|
||||
self.cls = BertOnlyMLMHead(config)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.cls.predictions.decoder
|
||||
|
||||
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
labels=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
output_attentions=None,
|
||||
**kwargs
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||
Labels for computing the left-to-right language modeling loss (next word prediction).
|
||||
Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
|
||||
Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
|
||||
in ``[0, ..., config.vocab_size]``
|
||||
kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`):
|
||||
Used to hide legacy arguments that have been deprecated.
|
||||
|
||||
Returns:
|
||||
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
|
||||
ltr_lm_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
|
||||
Next token prediction loss.
|
||||
prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`)
|
||||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
||||
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
|
||||
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
|
||||
Examples::
|
||||
|
||||
from transformers import BertTokenizer, BertLMHeadModel
|
||||
import torch
|
||||
|
||||
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
||||
model = BertLMHeadModel.from_pretrained('bert-base-uncased', is_decoder=True)
|
||||
|
||||
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1
|
||||
outputs = model(input_ids, labels=input_ids)
|
||||
|
||||
loss, prediction_scores = outputs[:2]
|
||||
|
||||
"""
|
||||
|
||||
outputs = self.bert(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
prediction_scores = self.cls(sequence_output)
|
||||
|
||||
outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here
|
||||
|
||||
if labels is not None:
|
||||
# we are doing next-token prediction; shift prediction scores and input ids by one
|
||||
prediction_scores = prediction_scores[:, :-1, :].contiguous()
|
||||
labels = labels[:, 1:].contiguous()
|
||||
loss_fct = CrossEntropyLoss()
|
||||
ltr_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
|
||||
outputs = (ltr_lm_loss,) + outputs
|
||||
|
||||
return outputs # (ltr_lm_loss), prediction_scores, (hidden_states), (attentions)
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
|
||||
input_shape = input_ids.shape
|
||||
|
||||
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
||||
if attention_mask is None:
|
||||
attention_mask = input_ids.new_ones(input_shape)
|
||||
|
||||
return {"input_ids": input_ids, "attention_mask": attention_mask}
|
||||
|
||||
|
||||
@add_start_docstrings("""Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING)
|
||||
class BertForMaskedLM(BertPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
|
@ -899,7 +1008,6 @@ class BertForMaskedLM(BertPreTrainedModel):
|
|||
labels=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
lm_labels=None,
|
||||
output_attentions=None,
|
||||
**kwargs
|
||||
):
|
||||
|
@ -909,11 +1017,6 @@ class BertForMaskedLM(BertPreTrainedModel):
|
|||
Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
|
||||
Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
|
||||
in ``[0, ..., config.vocab_size]``
|
||||
lm_labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||
Labels for computing the left-to-right language modeling loss (next word prediction).
|
||||
Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
|
||||
Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
|
||||
in ``[0, ..., config.vocab_size]``
|
||||
kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`):
|
||||
Used to hide legacy arguments that have been deprecated.
|
||||
|
||||
|
@ -921,8 +1024,6 @@ class BertForMaskedLM(BertPreTrainedModel):
|
|||
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
|
||||
masked_lm_loss (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
||||
Masked language modeling loss.
|
||||
ltr_lm_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`lm_labels` is provided):
|
||||
Next token prediction loss.
|
||||
prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`)
|
||||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
|
||||
|
@ -957,6 +1058,7 @@ class BertForMaskedLM(BertPreTrainedModel):
|
|||
DeprecationWarning,
|
||||
)
|
||||
labels = kwargs.pop("masked_lm_labels")
|
||||
assert "lm_labels" not in kwargs, "Use `BertWithLMHead` for autoregressive language modeling task."
|
||||
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
|
||||
|
||||
outputs = self.bert(
|
||||
|
@ -976,46 +1078,24 @@ class BertForMaskedLM(BertPreTrainedModel):
|
|||
|
||||
outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here
|
||||
|
||||
# Although this may seem awkward, BertForMaskedLM supports two scenarios:
|
||||
# 1. If a tensor that contains the indices of masked labels is provided,
|
||||
# the cross-entropy is the MLM cross-entropy that measures the likelihood
|
||||
# of predictions for masked words.
|
||||
# 2. If `lm_labels` is provided we are in a causal scenario where we
|
||||
# try to predict the next token for each input in the decoder.
|
||||
if labels is not None:
|
||||
loss_fct = CrossEntropyLoss() # -100 index = padding token
|
||||
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
|
||||
outputs = (masked_lm_loss,) + outputs
|
||||
|
||||
if lm_labels is not None:
|
||||
# we are doing next-token prediction; shift prediction scores and input ids by one
|
||||
prediction_scores = prediction_scores[:, :-1, :].contiguous()
|
||||
lm_labels = lm_labels[:, 1:].contiguous()
|
||||
loss_fct = CrossEntropyLoss()
|
||||
ltr_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), lm_labels.view(-1))
|
||||
outputs = (ltr_lm_loss,) + outputs
|
||||
|
||||
return outputs # (ltr_lm_loss), (masked_lm_loss), prediction_scores, (hidden_states), (attentions)
|
||||
return outputs # (masked_lm_loss), prediction_scores, (hidden_states), (attentions)
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
|
||||
input_shape = input_ids.shape
|
||||
effective_batch_size = input_shape[0]
|
||||
|
||||
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
||||
if attention_mask is None:
|
||||
attention_mask = input_ids.new_ones(input_shape)
|
||||
|
||||
# if model is does not use a causal mask then add a dummy token
|
||||
if self.config.is_decoder is False:
|
||||
assert self.config.pad_token_id is not None, "The PAD token should be defined for generation"
|
||||
attention_mask = torch.cat(
|
||||
[attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1
|
||||
)
|
||||
|
||||
dummy_token = torch.full(
|
||||
(effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device
|
||||
)
|
||||
input_ids = torch.cat([input_ids, dummy_token], dim=1)
|
||||
# add a dummy token
|
||||
assert self.config.pad_token_id is not None, "The PAD token should be defined for generation"
|
||||
attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1)
|
||||
dummy_token = torch.full(
|
||||
(effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device
|
||||
)
|
||||
input_ids = torch.cat([input_ids, dummy_token], dim=1)
|
||||
|
||||
return {"input_ids": input_ids, "attention_mask": attention_mask}
|
||||
|
||||
|
|
|
@ -192,7 +192,6 @@ class EncoderDecoderModel(PreTrainedModel):
|
|||
decoder_head_mask=None,
|
||||
decoder_inputs_embeds=None,
|
||||
labels=None,
|
||||
lm_labels=None,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
|
@ -239,11 +238,6 @@ class EncoderDecoderModel(PreTrainedModel):
|
|||
Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
|
||||
Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
|
||||
in ``[0, ..., config.vocab_size]``
|
||||
lm_labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||
Labels for computing the left-to-right language modeling loss (next word prediction) for the decoder.
|
||||
Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
|
||||
Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
|
||||
in ``[0, ..., config.vocab_size]``
|
||||
kwargs: (`optional`) Remaining dictionary of keyword arguments. Keyword arguments come in two flavors:
|
||||
- Without a prefix which will be input as `**encoder_kwargs` for the encoder forward function.
|
||||
- With a `decoder_` prefix which will be input as `**decoder_kwargs` for the decoder forward function.
|
||||
|
@ -293,7 +287,6 @@ class EncoderDecoderModel(PreTrainedModel):
|
|||
encoder_hidden_states=hidden_states,
|
||||
encoder_attention_mask=attention_mask,
|
||||
head_mask=decoder_head_mask,
|
||||
lm_labels=lm_labels,
|
||||
labels=labels,
|
||||
**kwargs_decoder,
|
||||
)
|
||||
|
|
|
@ -35,7 +35,7 @@ if is_torch_available():
|
|||
BertForTokenClassification,
|
||||
BertForMultipleChoice,
|
||||
)
|
||||
from transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||
from transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_LIST, BertLMHeadModel
|
||||
|
||||
|
||||
class BertModelTester:
|
||||
|
@ -211,6 +211,33 @@ class BertModelTester:
|
|||
)
|
||||
self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size])
|
||||
|
||||
def create_and_check_bert_for_causal_lm(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
):
|
||||
model = BertLMHeadModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
loss, prediction_scores = model(
|
||||
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels
|
||||
)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"prediction_scores": prediction_scores,
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["prediction_scores"].size()), [self.batch_size, self.seq_length, self.vocab_size]
|
||||
)
|
||||
self.check_loss_output(result)
|
||||
|
||||
def create_and_check_bert_for_masked_lm(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
|
@ -229,7 +256,7 @@ class BertModelTester:
|
|||
)
|
||||
self.check_loss_output(result)
|
||||
|
||||
def create_and_check_bert_model_for_masked_lm_as_decoder(
|
||||
def create_and_check_bert_model_for_causal_lm_as_decoder(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
|
@ -241,7 +268,7 @@ class BertModelTester:
|
|||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
):
|
||||
model = BertForMaskedLM(config=config)
|
||||
model = BertLMHeadModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
loss, prediction_scores = model(
|
||||
|
@ -461,13 +488,17 @@ class BertModelTest(ModelTesterMixin, unittest.TestCase):
|
|||
encoder_attention_mask,
|
||||
)
|
||||
|
||||
def test_for_causal_lm(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
|
||||
self.model_tester.create_and_check_bert_for_causal_lm(*config_and_inputs)
|
||||
|
||||
def test_for_masked_lm(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_bert_for_masked_lm(*config_and_inputs)
|
||||
|
||||
def test_for_masked_lm_decoder(self):
|
||||
def test_for_causal_lm_decoder(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
|
||||
self.model_tester.create_and_check_bert_model_for_masked_lm_as_decoder(*config_and_inputs)
|
||||
self.model_tester.create_and_check_bert_model_for_causal_lm_as_decoder(*config_and_inputs)
|
||||
|
||||
def test_for_multiple_choice(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
|
|
|
@ -27,7 +27,8 @@ from .utils import require_torch, slow, torch_device
|
|||
|
||||
|
||||
if is_torch_available():
|
||||
from transformers import BertModel, BertForMaskedLM, EncoderDecoderModel, EncoderDecoderConfig
|
||||
from transformers import BertModel, EncoderDecoderModel, EncoderDecoderConfig
|
||||
from transformers.modeling_bert import BertLMHeadModel
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
@ -70,7 +71,6 @@ class EncoderDecoderModelTest(unittest.TestCase):
|
|||
"decoder_token_labels": decoder_token_labels,
|
||||
"decoder_choice_labels": decoder_choice_labels,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"lm_labels": decoder_token_labels,
|
||||
"labels": decoder_token_labels,
|
||||
}
|
||||
|
||||
|
@ -116,7 +116,7 @@ class EncoderDecoderModelTest(unittest.TestCase):
|
|||
**kwargs
|
||||
):
|
||||
encoder_model = BertModel(config)
|
||||
decoder_model = BertForMaskedLM(decoder_config)
|
||||
decoder_model = BertLMHeadModel(decoder_config)
|
||||
enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||
self.assertTrue(enc_dec_model.config.decoder.is_decoder)
|
||||
self.assertTrue(enc_dec_model.config.is_encoder_decoder)
|
||||
|
@ -153,7 +153,7 @@ class EncoderDecoderModelTest(unittest.TestCase):
|
|||
**kwargs
|
||||
):
|
||||
encoder_model = BertModel(config)
|
||||
decoder_model = BertForMaskedLM(decoder_config)
|
||||
decoder_model = BertLMHeadModel(decoder_config)
|
||||
kwargs = {"encoder_model": encoder_model, "decoder_model": decoder_model}
|
||||
enc_dec_model = EncoderDecoderModel.from_encoder_decoder_pretrained(**kwargs)
|
||||
enc_dec_model.to(torch_device)
|
||||
|
@ -179,7 +179,7 @@ class EncoderDecoderModelTest(unittest.TestCase):
|
|||
**kwargs
|
||||
):
|
||||
encoder_model = BertModel(config)
|
||||
decoder_model = BertForMaskedLM(decoder_config)
|
||||
decoder_model = BertLMHeadModel(decoder_config)
|
||||
enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||
enc_dec_model.to(torch_device)
|
||||
enc_dec_model.eval()
|
||||
|
@ -220,7 +220,7 @@ class EncoderDecoderModelTest(unittest.TestCase):
|
|||
**kwargs
|
||||
):
|
||||
encoder_model = BertModel(config)
|
||||
decoder_model = BertForMaskedLM(decoder_config)
|
||||
decoder_model = BertLMHeadModel(decoder_config)
|
||||
enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||
enc_dec_model.to(torch_device)
|
||||
enc_dec_model.eval()
|
||||
|
@ -269,7 +269,7 @@ class EncoderDecoderModelTest(unittest.TestCase):
|
|||
**kwargs
|
||||
):
|
||||
encoder_model = BertModel(config)
|
||||
decoder_model = BertForMaskedLM(decoder_config)
|
||||
decoder_model = BertLMHeadModel(decoder_config)
|
||||
enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||
enc_dec_model.to(torch_device)
|
||||
outputs_encoder_decoder = enc_dec_model(
|
||||
|
@ -288,41 +288,9 @@ class EncoderDecoderModelTest(unittest.TestCase):
|
|||
self.assertEqual(outputs_encoder_decoder[1].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,)))
|
||||
self.assertEqual(outputs_encoder_decoder[2].shape, (input_ids.shape + (config.hidden_size,)))
|
||||
|
||||
def create_and_check_bert_encoder_decoder_model_lm_labels(
|
||||
self,
|
||||
config,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
encoder_hidden_states,
|
||||
decoder_config,
|
||||
decoder_input_ids,
|
||||
decoder_attention_mask,
|
||||
lm_labels,
|
||||
**kwargs
|
||||
):
|
||||
encoder_model = BertModel(config)
|
||||
decoder_model = BertForMaskedLM(decoder_config)
|
||||
enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||
enc_dec_model.to(torch_device)
|
||||
outputs_encoder_decoder = enc_dec_model(
|
||||
input_ids=input_ids,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
lm_labels=lm_labels,
|
||||
)
|
||||
|
||||
lm_loss = outputs_encoder_decoder[0]
|
||||
self.check_loss_output(lm_loss)
|
||||
# check that backprop works
|
||||
lm_loss.backward()
|
||||
|
||||
self.assertEqual(outputs_encoder_decoder[1].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,)))
|
||||
self.assertEqual(outputs_encoder_decoder[2].shape, (input_ids.shape + (config.hidden_size,)))
|
||||
|
||||
def create_and_check_bert_encoder_decoder_model_generate(self, input_ids, config, decoder_config, **kwargs):
|
||||
encoder_model = BertModel(config)
|
||||
decoder_model = BertForMaskedLM(decoder_config)
|
||||
decoder_model = BertLMHeadModel(decoder_config)
|
||||
enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||
enc_dec_model.to(torch_device)
|
||||
|
||||
|
@ -356,10 +324,6 @@ class EncoderDecoderModelTest(unittest.TestCase):
|
|||
input_ids_dict = self.prepare_config_and_inputs_bert()
|
||||
self.create_and_check_bert_encoder_decoder_model_labels(**input_ids_dict)
|
||||
|
||||
def test_bert_encoder_decoder_model_lm_labels(self):
|
||||
input_ids_dict = self.prepare_config_and_inputs_bert()
|
||||
self.create_and_check_bert_encoder_decoder_model_lm_labels(**input_ids_dict)
|
||||
|
||||
def test_bert_encoder_decoder_model_generate(self):
|
||||
input_ids_dict = self.prepare_config_and_inputs_bert()
|
||||
self.create_and_check_bert_encoder_decoder_model_generate(**input_ids_dict)
|
||||
|
|
Loading…
Reference in New Issue