Output hidden states (#4978)

* Configure all models to use output_hidden_states as argument passed to foward()

* Pass all tests

* Remove cast_bool_to_primitive in TF Flaubert model

* correct tf xlnet

* add pytorch test

* add tf test

* Fix broken tests

* Configure all models to use output_hidden_states as argument passed to foward()

* Pass all tests

* Remove cast_bool_to_primitive in TF Flaubert model

* correct tf xlnet

* add pytorch test

* add tf test

* Fix broken tests

* Refactor output_hidden_states for mobilebert

* Reset and remerge to master

Co-authored-by: Joseph Liu <joseph.liu@coinflex.com>
Co-authored-by: patrickvonplaten <patrick.v.platen@gmail.com>
This commit is contained in:
Joseph Liu 2020-06-22 22:10:45 +08:00 committed by GitHub
parent 866a8ccabb
commit f4e1f02210
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
34 changed files with 814 additions and 349 deletions

View File

@ -269,7 +269,9 @@ class AlbertLayer(nn.Module):
self.ffn_output = nn.Linear(config.intermediate_size, config.hidden_size)
self.activation = ACT2FN[config.hidden_act]
def forward(self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False):
def forward(
self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False, output_hidden_states=False
):
attention_output = self.attention(hidden_states, attention_mask, head_mask, output_attentions)
ffn_output = self.ffn(attention_output[0])
ffn_output = self.activation(ffn_output)
@ -283,10 +285,11 @@ class AlbertLayerGroup(nn.Module):
def __init__(self, config):
super().__init__()
self.output_hidden_states = config.output_hidden_states
self.albert_layers = nn.ModuleList([AlbertLayer(config) for _ in range(config.inner_group_num)])
def forward(self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False):
def forward(
self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False, output_hidden_states=False
):
layer_hidden_states = ()
layer_attentions = ()
@ -297,11 +300,11 @@ class AlbertLayerGroup(nn.Module):
if output_attentions:
layer_attentions = layer_attentions + (layer_output[1],)
if self.output_hidden_states:
if output_hidden_states:
layer_hidden_states = layer_hidden_states + (hidden_states,)
outputs = (hidden_states,)
if self.output_hidden_states:
if output_hidden_states:
outputs = outputs + (layer_hidden_states,)
if output_attentions:
outputs = outputs + (layer_attentions,)
@ -313,16 +316,17 @@ class AlbertTransformer(nn.Module):
super().__init__()
self.config = config
self.output_hidden_states = config.output_hidden_states
self.embedding_hidden_mapping_in = nn.Linear(config.embedding_size, config.hidden_size)
self.albert_layer_groups = nn.ModuleList([AlbertLayerGroup(config) for _ in range(config.num_hidden_groups)])
def forward(self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False):
def forward(
self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False, output_hidden_states=False
):
hidden_states = self.embedding_hidden_mapping_in(hidden_states)
all_attentions = ()
if self.output_hidden_states:
if output_hidden_states:
all_hidden_states = (hidden_states,)
for i in range(self.config.num_hidden_layers):
@ -337,17 +341,18 @@ class AlbertTransformer(nn.Module):
attention_mask,
head_mask[group_idx * layers_per_group : (group_idx + 1) * layers_per_group],
output_attentions,
output_hidden_states,
)
hidden_states = layer_group_output[0]
if output_attentions:
all_attentions = all_attentions + layer_group_output[-1]
if self.output_hidden_states:
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
outputs = (hidden_states,)
if self.output_hidden_states:
if output_hidden_states:
outputs = outputs + (all_hidden_states,)
if output_attentions:
outputs = outputs + (all_attentions,)
@ -489,6 +494,7 @@ class AlbertModel(AlbertPreTrainedModel):
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
):
r"""
Return:
@ -504,7 +510,9 @@ class AlbertModel(AlbertPreTrainedModel):
This output is usually *not* a good summary
of the semantic content of the input, you're often better with averaging or pooling
the sequence of hidden-states for the whole input sequence.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned
when ``output_hidden_states=True`` is passed or when
``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -530,6 +538,9 @@ class AlbertModel(AlbertPreTrainedModel):
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
@ -556,7 +567,11 @@ class AlbertModel(AlbertPreTrainedModel):
input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
)
encoder_outputs = self.encoder(
embedding_output, extended_attention_mask, head_mask=head_mask, output_attentions=output_attentions,
embedding_output,
extended_attention_mask,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
sequence_output = encoder_outputs[0]
@ -603,6 +618,7 @@ class AlbertForPreTraining(AlbertPreTrainedModel):
labels=None,
sentence_order_label=None,
output_attentions=None,
output_hidden_states=None,
**kwargs,
):
r"""
@ -628,7 +644,9 @@ class AlbertForPreTraining(AlbertPreTrainedModel):
sop_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2)`):
Prediction scores of the next sequence prediction (classification) head (scores of True/False
continuation before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned
when ``output_hidden_states=True`` is passed or when
``config.output_hidden_states``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -672,6 +690,7 @@ class AlbertForPreTraining(AlbertPreTrainedModel):
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
sequence_output, pooled_output = outputs[:2]
@ -758,6 +777,7 @@ class AlbertForMaskedLM(AlbertPreTrainedModel):
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
**kwargs
):
r"""
@ -775,7 +795,9 @@ class AlbertForMaskedLM(AlbertPreTrainedModel):
Masked language modeling loss.
prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`)
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned
when ``output_hidden_states=True`` is passed or when
``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -815,6 +837,7 @@ class AlbertForMaskedLM(AlbertPreTrainedModel):
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
sequence_outputs = outputs[0]
@ -856,6 +879,7 @@ class AlbertForSequenceClassification(AlbertPreTrainedModel):
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
@ -870,7 +894,9 @@ class AlbertForSequenceClassification(AlbertPreTrainedModel):
Classification (or regression if config.num_labels==1) loss.
logits ``torch.FloatTensor`` of shape ``(batch_size, config.num_labels)``
Classification (or regression if config.num_labels==1) scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned
when ``output_hidden_states=True`` is passed or when
``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -904,6 +930,7 @@ class AlbertForSequenceClassification(AlbertPreTrainedModel):
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
pooled_output = outputs[1]
@ -953,6 +980,7 @@ class AlbertForTokenClassification(AlbertPreTrainedModel):
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
@ -965,7 +993,9 @@ class AlbertForTokenClassification(AlbertPreTrainedModel):
Classification loss.
scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.num_labels)`)
Classification scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned
when ``output_hidden_states=True`` is passed or when
``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -1001,6 +1031,7 @@ class AlbertForTokenClassification(AlbertPreTrainedModel):
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
sequence_output = outputs[0]
@ -1052,6 +1083,7 @@ class AlbertForQuestionAnswering(AlbertPreTrainedModel):
start_positions=None,
end_positions=None,
output_attentions=None,
output_hidden_states=None,
):
r"""
start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
@ -1071,7 +1103,9 @@ class AlbertForQuestionAnswering(AlbertPreTrainedModel):
Span-start scores (before SoftMax).
end_scores: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)``
Span-end scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned
when ``output_hidden_states=True`` is passed or when
``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -1107,6 +1141,7 @@ class AlbertForQuestionAnswering(AlbertPreTrainedModel):
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
sequence_output = outputs[0]
@ -1163,6 +1198,7 @@ class AlbertForMultipleChoice(AlbertPreTrainedModel):
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
@ -1178,7 +1214,9 @@ class AlbertForMultipleChoice(AlbertPreTrainedModel):
`num_choices` is the second dimension of the input tensors. (see `input_ids` above).
Classification scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned
when ``output_hidden_states=True`` is passed or when
``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -1228,6 +1266,7 @@ class AlbertForMultipleChoice(AlbertPreTrainedModel):
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
pooled_output = outputs[1]

View File

@ -248,7 +248,6 @@ class BartEncoder(nn.Module):
self.dropout = config.dropout
self.layerdrop = config.encoder_layerdrop
self.output_hidden_states = config.output_hidden_states
embed_dim = embed_tokens.embedding_dim
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
@ -269,7 +268,7 @@ class BartEncoder(nn.Module):
# mbart has one extra layer_norm
self.layer_norm = LayerNorm(config.d_model) if config.normalize_before else None
def forward(self, input_ids, attention_mask=None, output_attentions=False):
def forward(self, input_ids, attention_mask=None, output_attentions=False, output_hidden_states=False):
"""
Args:
input_ids (LongTensor): tokens in the source language of shape
@ -281,7 +280,7 @@ class BartEncoder(nn.Module):
shape `(src_len, batch, embed_dim)`
- **encoder_states** (List[Tensor]): all intermediate
hidden states of shape `(src_len, batch, embed_dim)`.
Only populated if *self.output_hidden_states:* is True.
Only populated if *output_hidden_states:* is True.
- **all_attentions** (List[Tensor]): Attention weights for each layer.
During training might not be of length n_layers because of layer dropout.
"""
@ -300,7 +299,7 @@ class BartEncoder(nn.Module):
encoder_states, all_attentions = [], []
for encoder_layer in self.layers:
if self.output_hidden_states:
if output_hidden_states:
encoder_states.append(x)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
dropout_probability = random.uniform(0, 1)
@ -314,7 +313,7 @@ class BartEncoder(nn.Module):
if self.layer_norm:
x = self.layer_norm(x)
if self.output_hidden_states:
if output_hidden_states:
encoder_states.append(x)
# T x B x C -> B x T x C
@ -424,7 +423,6 @@ class BartDecoder(nn.Module):
def __init__(self, config: BartConfig, embed_tokens: nn.Embedding):
super().__init__()
self.output_hidden_states = config.output_hidden_states
self.dropout = config.dropout
self.layerdrop = config.decoder_layerdrop
self.padding_idx = embed_tokens.padding_idx
@ -455,6 +453,7 @@ class BartDecoder(nn.Module):
decoder_cached_states=None,
use_cache=False,
output_attentions=False,
output_hidden_states=False,
**unused,
):
"""
@ -502,7 +501,7 @@ class BartDecoder(nn.Module):
next_decoder_cache = []
for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if self.output_hidden_states:
if output_hidden_states:
all_hidden_states += (x,)
dropout_probability = random.uniform(0, 1)
if self.training and (dropout_probability < self.layerdrop):
@ -797,7 +796,6 @@ def _get_shape(t):
class BartModel(PretrainedBartModel):
def __init__(self, config: BartConfig):
super().__init__(config)
self.output_hidden_states = config.output_hidden_states
padding_idx, vocab_size = config.pad_token_id, config.vocab_size
self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx)
@ -818,8 +816,12 @@ class BartModel(PretrainedBartModel):
decoder_cached_states=None,
use_cache=False,
output_attentions=None,
output_hidden_states=None,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
# make masks if user doesn't supply
if not use_cache:
@ -837,7 +839,10 @@ class BartModel(PretrainedBartModel):
if encoder_outputs is None:
encoder_outputs = self.encoder(
input_ids=input_ids, attention_mask=attention_mask, output_attentions=output_attentions,
input_ids=input_ids,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
assert isinstance(encoder_outputs, tuple)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
@ -849,6 +854,7 @@ class BartModel(PretrainedBartModel):
decoder_causal_mask=causal_mask,
decoder_cached_states=decoder_cached_states,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
use_cache=use_cache,
)
@ -910,6 +916,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
labels=None,
use_cache=False,
output_attentions=None,
output_hidden_states=None,
**unused,
):
r"""
@ -926,7 +933,9 @@ class BartForConditionalGeneration(PretrainedBartModel):
Masked language modeling loss.
prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`)
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned
when ``output_hidden_states=True`` is passed or when
``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -969,6 +978,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
decoder_cached_states=decoder_cached_states,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
lm_logits = F.linear(outputs[0], self.model.shared.weight, bias=self.final_logits_bias)
outputs = (lm_logits,) + outputs[1:] # Add cache, hidden states and attention if they are here
@ -1060,6 +1070,7 @@ class BartForSequenceClassification(PretrainedBartModel):
decoder_attention_mask=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
@ -1073,7 +1084,9 @@ class BartForSequenceClassification(PretrainedBartModel):
Classification loss (cross entropy)
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`):
Classification (or regression if config.num_labels==1) scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned
when ``output_hidden_states=True`` is passed or when
``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
@ -1104,6 +1117,7 @@ class BartForSequenceClassification(PretrainedBartModel):
decoder_attention_mask=decoder_attention_mask,
encoder_outputs=encoder_outputs,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
x = outputs[0] # last hidden state
eos_mask = input_ids.eq(self.config.eos_token_id)
@ -1148,6 +1162,7 @@ class BartForQuestionAnswering(PretrainedBartModel):
start_positions=None,
end_positions=None,
output_attentions=None,
output_hidden_states=None,
):
r"""
start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
@ -1167,7 +1182,9 @@ class BartForQuestionAnswering(PretrainedBartModel):
Span-start scores (before SoftMax).
end_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`):
Span-end scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned
when ``output_hidden_states=True`` is passed or when
``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -1206,6 +1223,7 @@ class BartForQuestionAnswering(PretrainedBartModel):
decoder_attention_mask=decoder_attention_mask,
encoder_outputs=encoder_outputs,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
sequence_output = outputs[0]

View File

@ -391,7 +391,6 @@ class BertLayer(nn.Module):
class BertEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.output_hidden_states = config.output_hidden_states
self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
def forward(
@ -402,11 +401,12 @@ class BertEncoder(nn.Module):
encoder_hidden_states=None,
encoder_attention_mask=None,
output_attentions=False,
output_hidden_states=False,
):
all_hidden_states = ()
all_attentions = ()
for i, layer_module in enumerate(self.layer):
if self.output_hidden_states:
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_outputs = layer_module(
@ -423,11 +423,11 @@ class BertEncoder(nn.Module):
all_attentions = all_attentions + (layer_outputs[1],)
# Add last layer
if self.output_hidden_states:
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
outputs = (hidden_states,)
if self.output_hidden_states:
if output_hidden_states:
outputs = outputs + (all_hidden_states,)
if output_attentions:
outputs = outputs + (all_attentions,)
@ -656,6 +656,7 @@ class BertModel(BertPreTrainedModel):
encoder_hidden_states=None,
encoder_attention_mask=None,
output_attentions=None,
output_hidden_states=None,
):
r"""
Return:
@ -671,7 +672,7 @@ class BertModel(BertPreTrainedModel):
This output is usually *not* a good summary
of the semantic content of the input, you're often better with averaging or pooling
the sequence of hidden-states for the whole input sequence.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -698,6 +699,9 @@ class BertModel(BertPreTrainedModel):
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
@ -747,6 +751,7 @@ class BertModel(BertPreTrainedModel):
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output)
@ -786,6 +791,7 @@ class BertForPreTraining(BertPreTrainedModel):
labels=None,
next_sentence_label=None,
output_attentions=None,
output_hidden_states=None,
**kwargs
):
r"""
@ -811,7 +817,7 @@ class BertForPreTraining(BertPreTrainedModel):
seq_relationship_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2)`):
Prediction scores of the next sequence prediction (classification) head (scores of True/False
continuation before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -854,6 +860,7 @@ class BertForPreTraining(BertPreTrainedModel):
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
sequence_output, pooled_output = outputs[:2]
@ -902,6 +909,7 @@ class BertLMHeadModel(BertPreTrainedModel):
encoder_hidden_states=None,
encoder_attention_mask=None,
output_attentions=None,
output_hidden_states=None,
**kwargs
):
r"""
@ -919,7 +927,7 @@ class BertLMHeadModel(BertPreTrainedModel):
Next token prediction loss.
prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`)
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -956,6 +964,7 @@ class BertLMHeadModel(BertPreTrainedModel):
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
sequence_output = outputs[0]
@ -1012,6 +1021,7 @@ class BertForMaskedLM(BertPreTrainedModel):
encoder_hidden_states=None,
encoder_attention_mask=None,
output_attentions=None,
output_hidden_states=None,
**kwargs
):
r"""
@ -1029,7 +1039,7 @@ class BertForMaskedLM(BertPreTrainedModel):
Masked language modeling loss.
prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`)
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -1074,6 +1084,7 @@ class BertForMaskedLM(BertPreTrainedModel):
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
sequence_output = outputs[0]
@ -1126,6 +1137,7 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
inputs_embeds=None,
next_sentence_label=None,
output_attentions=None,
output_hidden_states=None,
):
r"""
next_sentence_label (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
@ -1140,7 +1152,7 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
Next sequence prediction (classification) loss.
seq_relationship_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2)`):
Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -1176,6 +1188,7 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
pooled_output = outputs[1]
@ -1218,6 +1231,7 @@ class BertForSequenceClassification(BertPreTrainedModel):
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
@ -1232,7 +1246,7 @@ class BertForSequenceClassification(BertPreTrainedModel):
Classification (or regression if config.num_labels==1) loss.
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`):
Classification (or regression if config.num_labels==1) scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or ``config.output_hidden_states``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -1268,6 +1282,7 @@ class BertForSequenceClassification(BertPreTrainedModel):
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
pooled_output = outputs[1]
@ -1316,6 +1331,7 @@ class BertForMultipleChoice(BertPreTrainedModel):
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
@ -1331,7 +1347,7 @@ class BertForMultipleChoice(BertPreTrainedModel):
`num_choices` is the second dimension of the input tensors. (see `input_ids` above).
Classification scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or ``config.output_hidden_states``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -1382,6 +1398,7 @@ class BertForMultipleChoice(BertPreTrainedModel):
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
pooled_output = outputs[1]
@ -1427,6 +1444,7 @@ class BertForTokenClassification(BertPreTrainedModel):
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
@ -1439,7 +1457,7 @@ class BertForTokenClassification(BertPreTrainedModel):
Classification loss.
scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.num_labels)`)
Classification scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or ``config.output_hidden_states``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -1475,6 +1493,7 @@ class BertForTokenClassification(BertPreTrainedModel):
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
sequence_output = outputs[0]
@ -1527,6 +1546,7 @@ class BertForQuestionAnswering(BertPreTrainedModel):
start_positions=None,
end_positions=None,
output_attentions=None,
output_hidden_states=None,
):
r"""
start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
@ -1546,7 +1566,7 @@ class BertForQuestionAnswering(BertPreTrainedModel):
Span-start scores (before SoftMax).
end_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`):
Span-end scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or ``config.output_hidden_states``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -1586,6 +1606,7 @@ class BertForQuestionAnswering(BertPreTrainedModel):
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
sequence_output = outputs[0]

View File

@ -296,7 +296,6 @@ CTRL_INPUTS_DOCSTRING = r"""
class CTRLModel(CTRLPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.output_hidden_states = config.output_hidden_states
self.d_model_size = config.n_embd
self.num_layers = config.n_layer
@ -338,6 +337,7 @@ class CTRLModel(CTRLPreTrainedModel):
inputs_embeds=None,
use_cache=True,
output_attentions=None,
output_hidden_states=None,
):
r"""
Return:
@ -347,7 +347,8 @@ class CTRLModel(CTRLPreTrainedModel):
past (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers` with each tensor of shape :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`):
Contains pre-computed hidden-states (key and values in the attention blocks).
Can be used (see `past` input) to speed up sequential decoding.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned
when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -374,6 +375,9 @@ class CTRLModel(CTRLPreTrainedModel):
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
@ -446,7 +450,7 @@ class CTRLModel(CTRLPreTrainedModel):
all_hidden_states = ()
all_attentions = []
for i, (h, layer_past) in enumerate(zip(self.h, past)):
if self.output_hidden_states:
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),)
outputs = h(
hidden_states,
@ -466,13 +470,13 @@ class CTRLModel(CTRLPreTrainedModel):
hidden_states = self.layernorm(hidden_states)
hidden_states = hidden_states.view(*output_shape)
if self.output_hidden_states:
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
outputs = (hidden_states,)
if use_cache is True:
outputs = outputs + (presents,)
if self.output_hidden_states:
if output_hidden_states:
outputs = outputs + (all_hidden_states,)
if output_attentions:
# let the number of heads free (-1) so we can extract attention even after head pruning
@ -518,6 +522,7 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
labels=None,
use_cache=True,
output_attentions=None,
output_hidden_states=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
@ -536,7 +541,8 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
past (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers` with each tensor of shape :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`):
Contains pre-computed hidden-states (key and values in the attention blocks).
Can be used (see `past` input) to speed up sequential decoding.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned
when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -571,6 +577,7 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
hidden_states = transformer_outputs[0]

View File

@ -259,12 +259,11 @@ class Transformer(nn.Module):
def __init__(self, config):
super().__init__()
self.n_layers = config.n_layers
self.output_hidden_states = config.output_hidden_states
layer = TransformerBlock(config)
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.n_layers)])
def forward(self, x, attn_mask=None, head_mask=None, output_attentions=False):
def forward(self, x, attn_mask=None, head_mask=None, output_attentions=False, output_hidden_states=False):
"""
Parameters
----------
@ -289,7 +288,7 @@ class Transformer(nn.Module):
hidden_state = x
for i, layer_module in enumerate(self.layer):
if self.output_hidden_states:
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_state,)
layer_outputs = layer_module(
@ -305,11 +304,11 @@ class Transformer(nn.Module):
assert len(layer_outputs) == 1
# Add last layer
if self.output_hidden_states:
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_state,)
outputs = (hidden_state,)
if self.output_hidden_states:
if output_hidden_states:
outputs = outputs + (all_hidden_states,)
if output_attentions:
outputs = outputs + (all_attentions,)
@ -411,14 +410,20 @@ class DistilBertModel(DistilBertPreTrainedModel):
@add_start_docstrings_to_callable(DISTILBERT_INPUTS_DOCSTRING)
def forward(
self, input_ids=None, attention_mask=None, head_mask=None, inputs_embeds=None, output_attentions=None,
self,
input_ids=None,
attention_mask=None,
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
):
r"""
Return:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.DistilBertConfig`) and inputs:
last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -445,6 +450,9 @@ class DistilBertModel(DistilBertPreTrainedModel):
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
@ -466,7 +474,11 @@ class DistilBertModel(DistilBertPreTrainedModel):
if inputs_embeds is None:
inputs_embeds = self.embeddings(input_ids) # (bs, seq_length, dim)
tfmr_output = self.transformer(
x=inputs_embeds, attn_mask=attention_mask, head_mask=head_mask, output_attentions=output_attentions,
x=inputs_embeds,
attn_mask=attention_mask,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
hidden_state = tfmr_output[0]
output = (hidden_state,) + tfmr_output[1:]
@ -480,7 +492,6 @@ class DistilBertModel(DistilBertPreTrainedModel):
class DistilBertForMaskedLM(DistilBertPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.output_hidden_states = config.output_hidden_states
self.distilbert = DistilBertModel(config)
self.vocab_transform = nn.Linear(config.dim, config.dim)
@ -503,6 +514,7 @@ class DistilBertForMaskedLM(DistilBertPreTrainedModel):
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
**kwargs
):
r"""
@ -520,7 +532,7 @@ class DistilBertForMaskedLM(DistilBertPreTrainedModel):
Masked language modeling loss.
prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`)
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -558,6 +570,7 @@ class DistilBertForMaskedLM(DistilBertPreTrainedModel):
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
hidden_states = dlbrt_output[0] # (bs, seq_length, dim)
prediction_logits = self.vocab_transform(hidden_states) # (bs, seq_length, dim)
@ -599,6 +612,7 @@ class DistilBertForSequenceClassification(DistilBertPreTrainedModel):
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
@ -613,7 +627,7 @@ class DistilBertForSequenceClassification(DistilBertPreTrainedModel):
Classification (or regression if config.num_labels==1) loss.
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`):
Classification (or regression if config.num_labels==1) scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -644,6 +658,7 @@ class DistilBertForSequenceClassification(DistilBertPreTrainedModel):
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
hidden_state = distilbert_output[0] # (bs, seq_len, dim)
pooled_output = hidden_state[:, 0] # (bs, dim)
@ -691,6 +706,7 @@ class DistilBertForQuestionAnswering(DistilBertPreTrainedModel):
start_positions=None,
end_positions=None,
output_attentions=None,
output_hidden_states=None,
):
r"""
start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
@ -710,7 +726,7 @@ class DistilBertForQuestionAnswering(DistilBertPreTrainedModel):
Span-start scores (before SoftMax).
end_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`):
Span-end scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -742,6 +758,7 @@ class DistilBertForQuestionAnswering(DistilBertPreTrainedModel):
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
hidden_states = distilbert_output[0] # (bs, max_query_len, dim)
@ -797,6 +814,7 @@ class DistilBertForTokenClassification(DistilBertPreTrainedModel):
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
@ -809,7 +827,7 @@ class DistilBertForTokenClassification(DistilBertPreTrainedModel):
Classification loss.
scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.num_labels)`)
Classification scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -841,6 +859,7 @@ class DistilBertForTokenClassification(DistilBertPreTrainedModel):
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
sequence_output = outputs[0]
@ -891,6 +910,7 @@ class DistilBertForMultipleChoice(DistilBertPreTrainedModel):
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
@ -906,7 +926,7 @@ class DistilBertForMultipleChoice(DistilBertPreTrainedModel):
`num_choices` is the second dimension of the input tensors. (see `input_ids` above).
Classification scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or ``config.output_hidden_states``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -953,6 +973,7 @@ class DistilBertForMultipleChoice(DistilBertPreTrainedModel):
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
hidden_state = outputs[0] # (bs * num_choices, seq_len, dim)

View File

@ -273,13 +273,14 @@ class ElectraModel(ElectraPreTrainedModel):
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
):
r"""
Return:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.ElectraConfig`) and inputs:
last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -307,6 +308,9 @@ class ElectraModel(ElectraPreTrainedModel):
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
@ -339,6 +343,7 @@ class ElectraModel(ElectraPreTrainedModel):
attention_mask=extended_attention_mask,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
return hidden_states
@ -388,6 +393,7 @@ class ElectraForSequenceClassification(ElectraPreTrainedModel):
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
@ -402,7 +408,7 @@ class ElectraForSequenceClassification(ElectraPreTrainedModel):
Classification (or regression if config.num_labels==1) loss.
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`):
Classification (or regression if config.num_labels==1) scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -430,7 +436,14 @@ class ElectraForSequenceClassification(ElectraPreTrainedModel):
"""
discriminator_hidden_states = self.electra(
input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, output_attentions
input_ids,
attention_mask,
token_type_ids,
position_ids,
head_mask,
inputs_embeds,
output_attentions,
output_hidden_states,
)
sequence_output = discriminator_hidden_states[0]
@ -478,6 +491,7 @@ class ElectraForPreTraining(ElectraPreTrainedModel):
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
):
r"""
labels (``torch.LongTensor`` of shape ``(batch_size, sequence_length)``, `optional`, defaults to :obj:`None`):
@ -492,7 +506,8 @@ class ElectraForPreTraining(ElectraPreTrainedModel):
Total loss of the ELECTRA objective.
scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`)
Prediction scores of the head (scores for each token before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned
when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -521,7 +536,14 @@ class ElectraForPreTraining(ElectraPreTrainedModel):
"""
discriminator_hidden_states = self.electra(
input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, output_attentions,
input_ids,
attention_mask,
token_type_ids,
position_ids,
head_mask,
inputs_embeds,
output_attentions,
output_hidden_states,
)
discriminator_sequence_output = discriminator_hidden_states[0]
@ -578,6 +600,7 @@ class ElectraForMaskedLM(ElectraPreTrainedModel):
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
**kwargs
):
r"""
@ -595,7 +618,7 @@ class ElectraForMaskedLM(ElectraPreTrainedModel):
Masked language modeling loss.
prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`)
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -630,7 +653,14 @@ class ElectraForMaskedLM(ElectraPreTrainedModel):
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
generator_hidden_states = self.electra(
input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, output_attentions,
input_ids,
attention_mask,
token_type_ids,
position_ids,
head_mask,
inputs_embeds,
output_attentions,
output_hidden_states,
)
generator_sequence_output = generator_hidden_states[0]
@ -677,6 +707,7 @@ class ElectraForTokenClassification(ElectraPreTrainedModel):
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
@ -689,7 +720,7 @@ class ElectraForTokenClassification(ElectraPreTrainedModel):
Classification loss.
scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.num_labels)`)
Classification scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -718,7 +749,14 @@ class ElectraForTokenClassification(ElectraPreTrainedModel):
"""
discriminator_hidden_states = self.electra(
input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, output_attentions,
input_ids,
attention_mask,
token_type_ids,
position_ids,
head_mask,
inputs_embeds,
output_attentions,
output_hidden_states,
)
discriminator_sequence_output = discriminator_hidden_states[0]
@ -776,6 +814,7 @@ class ElectraForQuestionAnswering(ElectraPreTrainedModel):
start_positions=None,
end_positions=None,
output_attentions=None,
output_hidden_states=None,
):
r"""
start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
@ -795,7 +834,7 @@ class ElectraForQuestionAnswering(ElectraPreTrainedModel):
Span-start scores (before SoftMax).
end_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`):
Span-end scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -833,6 +872,7 @@ class ElectraForQuestionAnswering(ElectraPreTrainedModel):
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
sequence_output = discriminator_hidden_states[0]

View File

@ -131,13 +131,14 @@ class FlaubertModel(XLMModel):
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
):
r"""
Return:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.XLMConfig`) and inputs:
last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -162,6 +163,9 @@ class FlaubertModel(XLMModel):
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
# removed: src_enc=None, src_len=None
if input_ids is not None:
@ -240,7 +244,7 @@ class FlaubertModel(XLMModel):
if self.training and (dropout_probability < self.layerdrop):
continue
if self.output_hidden_states:
if output_hidden_states:
hidden_states = hidden_states + (tensor,)
# self attention
@ -281,7 +285,7 @@ class FlaubertModel(XLMModel):
tensor *= mask.unsqueeze(-1).to(tensor.dtype)
# Add last hidden state
if self.output_hidden_states:
if output_hidden_states:
hidden_states = hidden_states + (tensor,)
# update cache length
@ -292,7 +296,7 @@ class FlaubertModel(XLMModel):
# tensor = tensor.transpose(0, 1)
outputs = (tensor,)
if self.output_hidden_states:
if output_hidden_states:
outputs = outputs + (hidden_states,)
if output_attentions:
outputs = outputs + (attentions,)

View File

@ -347,7 +347,6 @@ GPT2_INPUTS_DOCSTRING = r"""
class GPT2Model(GPT2PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.output_hidden_states = config.output_hidden_states
self.wte = nn.Embedding(config.vocab_size, config.n_embd)
self.wpe = nn.Embedding(config.n_positions, config.n_embd)
@ -382,6 +381,7 @@ class GPT2Model(GPT2PreTrainedModel):
inputs_embeds=None,
use_cache=True,
output_attentions=None,
output_hidden_states=None,
):
r"""
Return:
@ -392,7 +392,7 @@ class GPT2Model(GPT2PreTrainedModel):
past (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers` with each tensor of shape :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`):
Contains pre-computed hidden-states (key and values in the attention blocks).
Can be used (see `past` input) to speed up sequential decoding.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True``) is passed or when ``config.output_hidden_states=True``:
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -417,6 +417,9 @@ class GPT2Model(GPT2PreTrainedModel):
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
@ -486,7 +489,7 @@ class GPT2Model(GPT2PreTrainedModel):
all_attentions = []
all_hidden_states = ()
for i, (block, layer_past) in enumerate(zip(self.h, past)):
if self.output_hidden_states:
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),)
outputs = block(
@ -509,13 +512,13 @@ class GPT2Model(GPT2PreTrainedModel):
hidden_states = hidden_states.view(*output_shape)
# Add last hidden state
if self.output_hidden_states:
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
outputs = (hidden_states,)
if use_cache is True:
outputs = outputs + (presents,)
if self.output_hidden_states:
if output_hidden_states:
outputs = outputs + (all_hidden_states,)
if output_attentions:
# let the number of heads free (-1) so we can extract attention even after head pruning
@ -561,6 +564,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
labels=None,
use_cache=True,
output_attentions=None,
output_hidden_states=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
@ -579,7 +583,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
past (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers` with each tensor of shape :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`):
Contains pre-computed hidden-states (key and values in the attention blocks).
Can be used (see `past` input) to speed up sequential decoding.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -614,6 +618,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
hidden_states = transformer_outputs[0]
@ -668,6 +673,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
mc_labels=None,
use_cache=True,
output_attentions=None,
output_hidden_states=None,
**kwargs
):
r"""
@ -700,7 +706,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
past (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers` with each tensor of shape :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`):
Contains pre-computed hidden-states (key and values in the attention blocks).
Can be used (see `past` input) to speed up sequential decoding.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -754,6 +760,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
hidden_states = transformer_outputs[0]

View File

@ -587,6 +587,7 @@ class LongformerModel(RobertaModel):
position_ids=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
):
r"""
@ -594,7 +595,7 @@ class LongformerModel(RobertaModel):
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.RobertaConfig`) and inputs:
prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`)
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -627,6 +628,9 @@ class LongformerModel(RobertaModel):
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
# padding
attention_window = (
@ -668,6 +672,7 @@ class LongformerModel(RobertaModel):
encoder_hidden_states=None,
encoder_attention_mask=None,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
# undo padding
@ -706,6 +711,7 @@ class LongformerForMaskedLM(BertPreTrainedModel):
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
**kwargs
):
r"""
@ -723,7 +729,7 @@ class LongformerForMaskedLM(BertPreTrainedModel):
Masked language modeling loss.
prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`)
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -767,6 +773,7 @@ class LongformerForMaskedLM(BertPreTrainedModel):
position_ids=position_ids,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
sequence_output = outputs[0]
prediction_scores = self.lm_head(sequence_output)
@ -810,6 +817,7 @@ class LongformerForSequenceClassification(BertPreTrainedModel):
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
@ -824,7 +832,7 @@ class LongformerForSequenceClassification(BertPreTrainedModel):
Classification (or regression if config.num_labels==1) loss.
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`):
Classification (or regression if config.num_labels==1) scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -864,6 +872,7 @@ class LongformerForSequenceClassification(BertPreTrainedModel):
position_ids=position_ids,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
sequence_output = outputs[0]
logits = self.classifier(sequence_output)
@ -931,6 +940,7 @@ class LongformerForQuestionAnswering(BertPreTrainedModel):
start_positions=None,
end_positions=None,
output_attentions=None,
output_hidden_states=None,
):
r"""
start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
@ -949,7 +959,7 @@ class LongformerForQuestionAnswering(BertPreTrainedModel):
Span-start scores (before SoftMax).
end_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`):
Span-end scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
@ -997,6 +1007,7 @@ class LongformerForQuestionAnswering(BertPreTrainedModel):
position_ids=position_ids,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
sequence_output = outputs[0]
@ -1057,6 +1068,7 @@ class LongformerForTokenClassification(BertPreTrainedModel):
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
@ -1069,7 +1081,7 @@ class LongformerForTokenClassification(BertPreTrainedModel):
Classification loss.
scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.num_labels)`)
Classification scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -1103,6 +1115,7 @@ class LongformerForTokenClassification(BertPreTrainedModel):
position_ids=position_ids,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
sequence_output = outputs[0]
@ -1158,6 +1171,7 @@ class LongformerForMultipleChoice(BertPreTrainedModel):
position_ids=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
@ -1173,7 +1187,7 @@ class LongformerForMultipleChoice(BertPreTrainedModel):
`num_choices` is the second dimension of the input tensors. (see `input_ids` above).
Classification scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -1239,6 +1253,7 @@ class LongformerForMultipleChoice(BertPreTrainedModel):
global_attention_mask=flat_global_attention_mask,
inputs_embeds=flat_inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
pooled_output = outputs[1]

View File

@ -163,7 +163,7 @@ class MMBTModel(nn.Module, ModuleUtilsMixin):
objective during Bert pretraining. This output is usually *not* a good summary
of the semantic content of the input, you're often better with averaging or pooling
the sequence of hidden-states for the whole input sequence.
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
**hidden_states**: (`optional`, returned when ``output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
@ -200,6 +200,7 @@ class MMBTModel(nn.Module, ModuleUtilsMixin):
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
output_hidden_states=None,
):
if input_ids is not None and inputs_embeds is not None:
@ -257,6 +258,7 @@ class MMBTModel(nn.Module, ModuleUtilsMixin):
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask,
output_hidden_states=output_hidden_states,
)
sequence_output = encoder_outputs[0]
@ -293,7 +295,7 @@ class MMBTForClassification(nn.Module):
Classification (or regression if config.num_labels==1) loss.
**logits**: ``torch.FloatTensor`` of shape ``(batch_size, config.num_labels)``
Classification (or regression if config.num_labels==1) scores (before SoftMax).
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
**hidden_states**: (`optional`, returned when ``output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.

View File

@ -514,7 +514,6 @@ class MobileBertLayer(nn.Module):
class MobileBertEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.output_hidden_states = config.output_hidden_states
self.layer = nn.ModuleList([MobileBertLayer(config) for _ in range(config.num_hidden_layers)])
def forward(
@ -525,11 +524,12 @@ class MobileBertEncoder(nn.Module):
encoder_hidden_states=None,
encoder_attention_mask=None,
output_attentions=False,
output_hidden_states=False,
):
all_hidden_states = ()
all_attentions = ()
for i, layer_module in enumerate(self.layer):
if self.output_hidden_states:
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_outputs = layer_module(
@ -546,11 +546,11 @@ class MobileBertEncoder(nn.Module):
all_attentions = all_attentions + (layer_outputs[1],)
# Add last layer
if self.output_hidden_states:
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
outputs = (hidden_states,)
if self.output_hidden_states:
if output_hidden_states:
outputs = outputs + (all_hidden_states,)
if output_attentions:
outputs = outputs + (all_attentions,)
@ -757,6 +757,7 @@ class MobileBertModel(MobileBertPreTrainedModel):
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
output_hidden_states=None,
output_attentions=None,
):
r"""
@ -773,7 +774,7 @@ class MobileBertModel(MobileBertPreTrainedModel):
This output is usually *not* a good summary
of the semantic content of the input, you're often better with averaging or pooling
the sequence of hidden-states for the whole input sequence.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -801,6 +802,9 @@ class MobileBertModel(MobileBertPreTrainedModel):
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
@ -852,6 +856,7 @@ class MobileBertModel(MobileBertPreTrainedModel):
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output)
@ -911,6 +916,7 @@ class MobileBertForPreTraining(MobileBertPreTrainedModel):
labels=None,
next_sentence_label=None,
output_attentions=None,
output_hidden_states=None,
):
r"""
labels (``torch.LongTensor`` of shape ``(batch_size, sequence_length)``, `optional`, defaults to :obj:`None`):
@ -932,7 +938,7 @@ class MobileBertForPreTraining(MobileBertPreTrainedModel):
seq_relationship_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2)`):
Prediction scores of the next sequence prediction (classification) head (scores of True/False
continuation before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -962,6 +968,7 @@ class MobileBertForPreTraining(MobileBertPreTrainedModel):
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
sequence_output, pooled_output = outputs[:2]
prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
@ -1027,6 +1034,7 @@ class MobileBertForMaskedLM(MobileBertPreTrainedModel):
encoder_hidden_states=None,
encoder_attention_mask=None,
output_attentions=None,
output_hidden_states=None,
**kwargs
):
r"""
@ -1044,7 +1052,7 @@ class MobileBertForMaskedLM(MobileBertPreTrainedModel):
Masked language modeling loss.
prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`)
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -1087,6 +1095,7 @@ class MobileBertForMaskedLM(MobileBertPreTrainedModel):
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
sequence_output = outputs[0]
@ -1136,6 +1145,7 @@ class MobileBertForNextSentencePrediction(MobileBertPreTrainedModel):
inputs_embeds=None,
next_sentence_label=None,
output_attentions=None,
output_hidden_states=None,
):
r"""
next_sentence_label (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
@ -1150,7 +1160,7 @@ class MobileBertForNextSentencePrediction(MobileBertPreTrainedModel):
Next sequence prediction (classification) loss.
seq_relationship_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2)`):
Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -1186,6 +1196,7 @@ class MobileBertForNextSentencePrediction(MobileBertPreTrainedModel):
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
pooled_output = outputs[1]
@ -1227,6 +1238,7 @@ class MobileBertForSequenceClassification(MobileBertPreTrainedModel):
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
@ -1240,7 +1252,7 @@ class MobileBertForSequenceClassification(MobileBertPreTrainedModel):
Classification (or regression if config.num_labels==1) loss.
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`):
Classification (or regression if config.num_labels==1) scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
@ -1273,6 +1285,7 @@ class MobileBertForSequenceClassification(MobileBertPreTrainedModel):
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output)
@ -1317,6 +1330,7 @@ class MobileBertForQuestionAnswering(MobileBertPreTrainedModel):
start_positions=None,
end_positions=None,
output_attentions=None,
output_hidden_states=None,
):
r"""
start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
@ -1336,7 +1350,7 @@ class MobileBertForQuestionAnswering(MobileBertPreTrainedModel):
Span-start scores (before SoftMax).
end_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`):
Span-end scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -1376,6 +1390,7 @@ class MobileBertForQuestionAnswering(MobileBertPreTrainedModel):
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
sequence_output = outputs[0]
@ -1432,6 +1447,7 @@ class MobileBertForMultipleChoice(MobileBertPreTrainedModel):
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
@ -1447,7 +1463,7 @@ class MobileBertForMultipleChoice(MobileBertPreTrainedModel):
`num_choices` is the second dimension of the input tensors. (see `input_ids` above).
Classification scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -1498,6 +1514,7 @@ class MobileBertForMultipleChoice(MobileBertPreTrainedModel):
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
pooled_output = outputs[1]
@ -1543,6 +1560,7 @@ class MobileBertForTokenClassification(MobileBertPreTrainedModel):
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
@ -1555,7 +1573,7 @@ class MobileBertForTokenClassification(MobileBertPreTrainedModel):
Classification loss.
scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.num_labels)`)
Classification scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -1591,6 +1609,7 @@ class MobileBertForTokenClassification(MobileBertPreTrainedModel):
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
sequence_output = outputs[0]

View File

@ -334,7 +334,6 @@ OPENAI_GPT_INPUTS_DOCSTRING = r"""
class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.output_hidden_states = config.output_hidden_states
self.tokens_embed = nn.Embedding(config.vocab_size, config.n_embd)
self.positions_embed = nn.Embedding(config.n_positions, config.n_embd)
@ -366,13 +365,14 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
):
r"""
Return:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.OpenAIGPTConfig`) and inputs:
last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the last layer of the model.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -397,6 +397,9 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
@ -450,7 +453,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
all_attentions = ()
all_hidden_states = ()
for i, block in enumerate(self.h):
if self.output_hidden_states:
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),)
outputs = block(hidden_states, attention_mask, head_mask[i], output_attentions=output_attentions)
@ -459,11 +462,11 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
all_attentions = all_attentions + (outputs[1],)
# Add last layer
if self.output_hidden_states:
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),)
outputs = (hidden_states.view(*output_shape),)
if self.output_hidden_states:
if output_hidden_states:
outputs = outputs + (all_hidden_states,)
if output_attentions:
outputs = outputs + (all_attentions,)
@ -497,6 +500,7 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
@ -516,7 +520,7 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
Contains pre-computed hidden-states (key and values in the attention blocks).
Can be used (see `past` input) to speed up sequential decoding. The token ids which have their past given to this model
should not be passed as input ids as they have already been computed.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -548,6 +552,7 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states)
@ -600,6 +605,7 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
labels=None,
mc_labels=None,
output_attentions=None,
output_hidden_states=None,
**kwargs
):
r"""
@ -633,7 +639,7 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
Contains pre-computed hidden-states (key and values in the attention blocks).
Can be used (see `past` input) to speed up sequential decoding. The token ids which have their past given to this model
should not be passed as input ids as they have already been computed.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -679,6 +685,7 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
hidden_states = transformer_outputs[0]

View File

@ -1256,7 +1256,7 @@ class _ReversibleFunction(Function):
num_hashes,
all_hidden_states,
all_attentions,
do_output_hidden_states,
output_hidden_states,
output_attentions,
):
all_buckets = ()
@ -1265,7 +1265,7 @@ class _ReversibleFunction(Function):
hidden_states, attn_output = torch.chunk(hidden_states, 2, dim=-1)
for layer, layer_head_mask in zip(layers, head_mask):
if do_output_hidden_states is True:
if output_hidden_states is True:
all_hidden_states.append(hidden_states)
layer_outputs = layer(
@ -1284,7 +1284,7 @@ class _ReversibleFunction(Function):
all_attentions.append(layer_outputs.attention_probs)
# Add last layer
if do_output_hidden_states is True:
if output_hidden_states is True:
all_hidden_states.append(hidden_states)
# attach params to ctx for backward
@ -1360,7 +1360,7 @@ class ReformerEncoder(nn.Module):
attention_mask=None,
head_mask=None,
num_hashes=None,
do_output_hidden_states=False,
output_hidden_states=False,
output_attentions=False,
):
# hidden_states and attention lists to be filled if wished
@ -1377,7 +1377,7 @@ class ReformerEncoder(nn.Module):
num_hashes,
all_hidden_states,
all_attentions,
do_output_hidden_states,
output_hidden_states,
output_attentions,
)
@ -1546,7 +1546,7 @@ class ReformerModel(ReformerPreTrainedModel):
head_mask=None,
inputs_embeds=None,
num_hashes=None,
do_output_hidden_states=False,
output_hidden_states=None,
output_attentions=None,
):
r"""
@ -1554,7 +1554,7 @@ class ReformerModel(ReformerPreTrainedModel):
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
all_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
all_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -1581,7 +1581,9 @@ class ReformerModel(ReformerPreTrainedModel):
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
do_output_hidden_states = self.config.output_hidden_states
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
@ -1639,7 +1641,7 @@ class ReformerModel(ReformerPreTrainedModel):
head_mask=head_mask,
attention_mask=attention_mask,
num_hashes=num_hashes,
do_output_hidden_states=do_output_hidden_states,
output_hidden_states=output_hidden_states,
output_attentions=output_attentions,
)
sequence_output = encoder_outputs.hidden_states
@ -1650,7 +1652,7 @@ class ReformerModel(ReformerPreTrainedModel):
outputs = (sequence_output,)
# TODO(PVP): Replace by named tuple after namedtuples are introduced in the library.
if do_output_hidden_states is True:
if output_hidden_states is True:
outputs = outputs + (encoder_outputs.all_hidden_states,)
if output_attentions is True:
outputs = outputs + (encoder_outputs.all_attentions,)
@ -1740,7 +1742,7 @@ class ReformerModelWithLMHead(ReformerPreTrainedModel):
inputs_embeds=None,
num_hashes=None,
labels=None,
do_output_hidden_states=False,
output_hidden_states=None,
output_attentions=None,
):
r"""
@ -1756,7 +1758,7 @@ class ReformerModelWithLMHead(ReformerPreTrainedModel):
Classification loss (cross entropy).
prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`)
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
all_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
all_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -1789,7 +1791,7 @@ class ReformerModelWithLMHead(ReformerPreTrainedModel):
head_mask=head_mask,
inputs_embeds=inputs_embeds,
num_hashes=num_hashes,
do_output_hidden_states=do_output_hidden_states,
output_hidden_states=output_hidden_states,
output_attentions=output_attentions,
)

View File

@ -188,6 +188,7 @@ class RobertaForMaskedLM(BertPreTrainedModel):
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
**kwargs
):
r"""
@ -205,7 +206,7 @@ class RobertaForMaskedLM(BertPreTrainedModel):
Masked language modeling loss.
prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`)
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -245,6 +246,7 @@ class RobertaForMaskedLM(BertPreTrainedModel):
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
sequence_output = outputs[0]
prediction_scores = self.lm_head(sequence_output)
@ -313,6 +315,7 @@ class RobertaForSequenceClassification(BertPreTrainedModel):
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
@ -327,7 +330,7 @@ class RobertaForSequenceClassification(BertPreTrainedModel):
Classification (or regression if config.num_labels==1) loss.
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`):
Classification (or regression if config.num_labels==1) scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -360,6 +363,7 @@ class RobertaForSequenceClassification(BertPreTrainedModel):
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
sequence_output = outputs[0]
logits = self.classifier(sequence_output)
@ -407,6 +411,7 @@ class RobertaForMultipleChoice(BertPreTrainedModel):
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
@ -422,7 +427,7 @@ class RobertaForMultipleChoice(BertPreTrainedModel):
`num_choices` is the second dimension of the input tensors. (see `input_ids` above).
Classification scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -468,6 +473,7 @@ class RobertaForMultipleChoice(BertPreTrainedModel):
head_mask=head_mask,
inputs_embeds=flat_inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
pooled_output = outputs[1]
@ -515,6 +521,7 @@ class RobertaForTokenClassification(BertPreTrainedModel):
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
@ -527,7 +534,7 @@ class RobertaForTokenClassification(BertPreTrainedModel):
Classification loss.
scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.num_labels)`)
Classification scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -561,6 +568,7 @@ class RobertaForTokenClassification(BertPreTrainedModel):
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
sequence_output = outputs[0]
@ -636,6 +644,7 @@ class RobertaForQuestionAnswering(BertPreTrainedModel):
start_positions=None,
end_positions=None,
output_attentions=None,
output_hidden_states=None,
):
r"""
start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
@ -655,7 +664,7 @@ class RobertaForQuestionAnswering(BertPreTrainedModel):
Span-start scores (before SoftMax).
end_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`):
Span-end scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -695,6 +704,7 @@ class RobertaForQuestionAnswering(BertPreTrainedModel):
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
sequence_output = outputs[0]

View File

@ -629,7 +629,6 @@ class T5PreTrainedModel(PreTrainedModel):
class T5Stack(T5PreTrainedModel):
def __init__(self, config, embed_tokens=None):
super().__init__(config)
self.output_hidden_states = config.output_hidden_states
self.embed_tokens = embed_tokens
self.is_decoder = config.is_decoder
@ -662,9 +661,13 @@ class T5Stack(T5PreTrainedModel):
past_key_value_states=None,
use_cache=False,
output_attentions=None,
output_hidden_states=None,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
@ -726,7 +729,7 @@ class T5Stack(T5PreTrainedModel):
hidden_states = self.dropout(inputs_embeds)
for i, (layer_module, past_key_value_state) in enumerate(zip(self.block, past_key_value_states)):
if self.output_hidden_states:
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_outputs = layer_module(
@ -761,14 +764,14 @@ class T5Stack(T5PreTrainedModel):
hidden_states = self.dropout(hidden_states)
# Add last layer
if self.output_hidden_states:
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
outputs = (hidden_states,)
if use_cache is True:
assert self.is_decoder, "`use_cache` can only be set to `True` if {} is used as a decoder".format(self)
outputs = outputs + (present_key_value_states,)
if self.output_hidden_states:
if output_hidden_states:
outputs = outputs + (all_hidden_states,)
if output_attentions:
outputs = outputs + (all_attentions,)
@ -895,6 +898,7 @@ class T5Model(T5PreTrainedModel):
decoder_inputs_embeds=None,
head_mask=None,
output_attentions=None,
output_hidden_states=None,
):
r"""
Returns:
@ -906,7 +910,7 @@ class T5Model(T5PreTrainedModel):
Contains pre-computed key and value hidden-states of the attention blocks.
Can be used to speed up sequential decoding (see `decoder_past_key_value_states` input).
Note that when using `decoder_past_key_value_states`, the model only outputs the last `hidden-state` of the sequence of shape :obj:`(batch_size, 1, config.vocab_size)`.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -938,6 +942,7 @@ class T5Model(T5PreTrainedModel):
inputs_embeds=inputs_embeds,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
hidden_states = encoder_outputs[0]
@ -961,6 +966,7 @@ class T5Model(T5PreTrainedModel):
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
if use_cache is True:
@ -1021,6 +1027,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
decoder_inputs_embeds=None,
head_mask=None,
output_attentions=None,
output_hidden_states=None,
**kwargs
):
r"""
@ -1043,7 +1050,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
Contains pre-computed key and value hidden-states of the attention blocks.
Can be used to speed up sequential decoding (see `decoder_past_key_value_states` input).
Note that when using `decoder_past_key_value_states`, the model only outputs the last `prediction_score` of the sequence of shape :obj:`(batch_size, 1, config.vocab_size)`.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
@ -1085,6 +1092,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
inputs_embeds=inputs_embeds,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
hidden_states = encoder_outputs[0]
@ -1113,6 +1121,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
# insert decoder past at right place

View File

@ -361,13 +361,12 @@ class TFAlbertLayerGroup(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.output_hidden_states = config.output_hidden_states
self.albert_layers = [
TFAlbertLayer(config, name="albert_layers_._{}".format(i)) for i in range(config.inner_group_num)
]
def call(self, inputs, training=False):
hidden_states, attention_mask, head_mask, output_attentions = inputs
hidden_states, attention_mask, head_mask, output_attentions, output_hidden_states = inputs
layer_hidden_states = ()
layer_attentions = ()
@ -381,11 +380,11 @@ class TFAlbertLayerGroup(tf.keras.layers.Layer):
if cast_bool_to_primitive(output_attentions) is True:
layer_attentions = layer_attentions + (layer_output[1],)
if self.output_hidden_states:
if cast_bool_to_primitive(output_hidden_states) is True:
layer_hidden_states = layer_hidden_states + (hidden_states,)
outputs = (hidden_states,)
if self.output_hidden_states:
if cast_bool_to_primitive(output_hidden_states) is True:
outputs = outputs + (layer_hidden_states,)
if cast_bool_to_primitive(output_attentions) is True:
outputs = outputs + (layer_attentions,)
@ -398,7 +397,6 @@ class TFAlbertTransformer(tf.keras.layers.Layer):
super().__init__(**kwargs)
self.config = config
self.output_hidden_states = config.output_hidden_states
self.embedding_hidden_mapping_in = tf.keras.layers.Dense(
config.hidden_size,
kernel_initializer=get_initializer(config.initializer_range),
@ -410,12 +408,12 @@ class TFAlbertTransformer(tf.keras.layers.Layer):
]
def call(self, inputs, training=False):
hidden_states, attention_mask, head_mask, output_attentions = inputs
hidden_states, attention_mask, head_mask, output_attentions, output_hidden_states = inputs
hidden_states = self.embedding_hidden_mapping_in(hidden_states)
all_attentions = ()
if self.output_hidden_states:
if cast_bool_to_primitive(output_hidden_states) is True:
all_hidden_states = (hidden_states,)
for i in range(self.config.num_hidden_layers):
@ -431,6 +429,7 @@ class TFAlbertTransformer(tf.keras.layers.Layer):
attention_mask,
head_mask[group_idx * layers_per_group : (group_idx + 1) * layers_per_group],
output_attentions,
output_hidden_states,
],
training=training,
)
@ -439,11 +438,11 @@ class TFAlbertTransformer(tf.keras.layers.Layer):
if cast_bool_to_primitive(output_attentions) is True:
all_attentions = all_attentions + layer_group_output[-1]
if self.output_hidden_states:
if cast_bool_to_primitive(output_hidden_states) is True:
all_hidden_states = all_hidden_states + (hidden_states,)
outputs = (hidden_states,)
if self.output_hidden_states:
if cast_bool_to_primitive(output_hidden_states) is True:
outputs = outputs + (all_hidden_states,)
if cast_bool_to_primitive(output_attentions) is True:
outputs = outputs + (all_attentions,)
@ -503,6 +502,7 @@ class TFAlbertMainLayer(tf.keras.layers.Layer):
super().__init__(**kwargs)
self.num_hidden_layers = config.num_hidden_layers
self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states
self.embeddings = TFAlbertEmbeddings(config, name="embeddings")
self.encoder = TFAlbertTransformer(config, name="encoder")
@ -539,6 +539,7 @@ class TFAlbertMainLayer(tf.keras.layers.Layer):
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
training=False,
):
if isinstance(inputs, (tuple, list)):
@ -549,7 +550,8 @@ class TFAlbertMainLayer(tf.keras.layers.Layer):
head_mask = inputs[4] if len(inputs) > 4 else head_mask
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
output_attentions = inputs[6] if len(inputs) > 6 else output_attentions
assert len(inputs) <= 7, "Too many inputs."
output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states
assert len(inputs) <= 8, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
@ -558,11 +560,13 @@ class TFAlbertMainLayer(tf.keras.layers.Layer):
head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
output_attentions = inputs.get("output_attentions", output_attentions)
assert len(inputs) <= 7, "Too many inputs."
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
assert len(inputs) <= 8, "Too many inputs."
else:
input_ids = inputs
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
@ -607,7 +611,8 @@ class TFAlbertMainLayer(tf.keras.layers.Layer):
embedding_output = self.embeddings([input_ids, position_ids, token_type_ids, inputs_embeds], training=training)
encoder_outputs = self.encoder(
[embedding_output, extended_attention_mask, head_mask, output_attentions], training=training
[embedding_output, extended_attention_mask, head_mask, output_attentions, output_hidden_states],
training=training,
)
sequence_output = encoder_outputs[0]
@ -710,38 +715,39 @@ class TFAlbertModel(TFAlbertPreTrainedModel):
@add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
def call(self, inputs, **kwargs):
r"""
Returns:
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.AlbertConfig`) and inputs:
last_hidden_state (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
pooler_output (:obj:`tf.Tensor` of shape :obj:`(batch_size, hidden_size)`):
Last layer hidden-state of the first token of the sequence (classification token)
further processed by a Linear layer and a Tanh activation function. The Linear
layer weights are trained from the next sentence prediction (classification)
objective during Albert pretraining. This output is usually *not* a good summary
of the semantic content of the input, you're often better with averaging or pooling
the sequence of hidden-states for the whole input sequence.
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Returns:
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.AlbertConfig`) and inputs:
last_hidden_state (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
pooler_output (:obj:`tf.Tensor` of shape :obj:`(batch_size, hidden_size)`):
Last layer hidden-state of the first token of the sequence (classification token)
further processed by a Linear layer and a Tanh activation function. The Linear
layer weights are trained from the next sentence prediction (classification)
objective during Albert pretraining. This output is usually *not* a good summary
of the semantic content of the input, you're often better with averaging or pooling
the sequence of hidden-states for the whole input sequence.
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when
``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``):
tuple of :obj:`tf.Tensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``):
tuple of :obj:`tf.Tensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
Examples::
import tensorflow as tf
from transformers import AlbertTokenizer, TFAlbertModel
import tensorflow as tf
from transformers import AlbertTokenizer, TFAlbertModel
tokenizer = AlbertTokenizer.from_pretrained('albert-base-v2')
model = TFAlbertModel.from_pretrained('albert-base-v2')
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :] # Batch size 1
outputs = model(input_ids)
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
tokenizer = AlbertTokenizer.from_pretrained('albert-base-v2')
model = TFAlbertModel.from_pretrained('albert-base-v2')
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :] # Batch size 1
outputs = model(input_ids)
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
"""
outputs = self.albert(inputs, **kwargs)
@ -774,7 +780,8 @@ class TFAlbertForPreTraining(TFAlbertPreTrainedModel):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
sop_scores (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, 2)`):
Prediction scores of the sentence order prediction (classification) head (scores of True/False continuation before SoftMax).
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when
``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
@ -833,7 +840,8 @@ class TFAlbertForMaskedLM(TFAlbertPreTrainedModel):
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.AlbertConfig`) and inputs:
prediction_scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when
``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -894,6 +902,7 @@ class TFAlbertForSequenceClassification(TFAlbertPreTrainedModel, TFSequenceClass
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
training=False,
):
r"""
@ -907,7 +916,8 @@ class TFAlbertForSequenceClassification(TFAlbertPreTrainedModel, TFSequenceClass
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.AlbertConfig`) and inputs:
logits (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, config.num_labels)`)
Classification (or regression if config.num_labels==1) scores (before SoftMax).
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when
``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -940,6 +950,7 @@ class TFAlbertForSequenceClassification(TFAlbertPreTrainedModel, TFSequenceClass
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
training=training,
)
@ -984,6 +995,7 @@ class TFAlbertForTokenClassification(TFAlbertPreTrainedModel, TFTokenClassificat
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
training=False,
):
r"""
@ -995,7 +1007,8 @@ class TFAlbertForTokenClassification(TFAlbertPreTrainedModel, TFTokenClassificat
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.num_labels)`):
Classification scores (before SoftMax).
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when
``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -1027,6 +1040,7 @@ class TFAlbertForTokenClassification(TFAlbertPreTrainedModel, TFTokenClassificat
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
training=training,
)
@ -1073,6 +1087,7 @@ class TFAlbertForQuestionAnswering(TFAlbertPreTrainedModel, TFQuestionAnsweringL
p_mask=None,
is_impossible=None,
output_attentions=None,
output_hidden_states=None,
training=False,
):
r"""
@ -1091,7 +1106,8 @@ class TFAlbertForQuestionAnswering(TFAlbertPreTrainedModel, TFQuestionAnsweringL
Span-start scores (before SoftMax).
end_scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length,)`):
Span-end scores (before SoftMax).
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when
``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -1128,6 +1144,7 @@ class TFAlbertForQuestionAnswering(TFAlbertPreTrainedModel, TFQuestionAnsweringL
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
training=training,
)
@ -1184,6 +1201,7 @@ class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel, TFMultipleChoiceLoss):
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
training=False,
):
r"""
@ -1198,7 +1216,8 @@ class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel, TFMultipleChoiceLoss):
`num_choices` is the size of the second dimension of the input tensors. (see `input_ids` above).
Classification scores (before SoftMax).
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when
``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -1266,6 +1285,7 @@ class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel, TFMultipleChoiceLoss):
head_mask,
inputs_embeds,
output_attentions,
output_hidden_states,
]
outputs = self.albert(flat_inputs, training=training)

View File

@ -378,16 +378,15 @@ class TFBertLayer(tf.keras.layers.Layer):
class TFBertEncoder(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.output_hidden_states = config.output_hidden_states
self.layer = [TFBertLayer(config, name="layer_._{}".format(i)) for i in range(config.num_hidden_layers)]
def call(self, inputs, training=False):
hidden_states, attention_mask, head_mask, output_attentions = inputs
hidden_states, attention_mask, head_mask, output_attentions, output_hidden_states = inputs
all_hidden_states = ()
all_attentions = ()
for i, layer_module in enumerate(self.layer):
if self.output_hidden_states:
if cast_bool_to_primitive(output_hidden_states) is True:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_outputs = layer_module(
@ -399,11 +398,11 @@ class TFBertEncoder(tf.keras.layers.Layer):
all_attentions = all_attentions + (layer_outputs[1],)
# Add last layer
if self.output_hidden_states:
if cast_bool_to_primitive(output_hidden_states) is True:
all_hidden_states = all_hidden_states + (hidden_states,)
outputs = (hidden_states,)
if self.output_hidden_states:
if cast_bool_to_primitive(output_hidden_states) is True:
outputs = outputs + (all_hidden_states,)
if cast_bool_to_primitive(output_attentions) is True:
outputs = outputs + (all_attentions,)
@ -499,6 +498,7 @@ class TFBertMainLayer(tf.keras.layers.Layer):
self.num_hidden_layers = config.num_hidden_layers
self.initializer_range = config.initializer_range
self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states
self.embeddings = TFBertEmbeddings(config, name="embeddings")
self.encoder = TFBertEncoder(config, name="encoder")
@ -527,6 +527,7 @@ class TFBertMainLayer(tf.keras.layers.Layer):
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
training=False,
):
if isinstance(inputs, (tuple, list)):
@ -537,7 +538,8 @@ class TFBertMainLayer(tf.keras.layers.Layer):
head_mask = inputs[4] if len(inputs) > 4 else head_mask
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
output_attentions = inputs[6] if len(inputs) > 6 else output_attentions
assert len(inputs) <= 7, "Too many inputs."
output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states
assert len(inputs) <= 8, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
@ -546,11 +548,13 @@ class TFBertMainLayer(tf.keras.layers.Layer):
head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
output_attentions = inputs.get("output_attentions", output_attentions)
assert len(inputs) <= 7, "Too many inputs."
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
assert len(inputs) <= 8, "Too many inputs."
else:
input_ids = inputs
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
@ -595,7 +599,8 @@ class TFBertMainLayer(tf.keras.layers.Layer):
embedding_output = self.embeddings([input_ids, position_ids, token_type_ids, inputs_embeds], training=training)
encoder_outputs = self.encoder(
[embedding_output, extended_attention_mask, head_mask, output_attentions], training=training
[embedding_output, extended_attention_mask, head_mask, output_attentions, output_hidden_states],
training=training,
)
sequence_output = encoder_outputs[0]
@ -712,7 +717,8 @@ class TFBertModel(TFBertPreTrainedModel):
objective during Bert pretraining. This output is usually *not* a good summary
of the semantic content of the input, you're often better with averaging or pooling
the sequence of hidden-states for the whole input sequence.
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when
``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -764,7 +770,8 @@ class TFBertForPreTraining(TFBertPreTrainedModel):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
seq_relationship_scores (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, 2)`):
Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation before SoftMax).
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when
``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -818,7 +825,8 @@ class TFBertForMaskedLM(TFBertPreTrainedModel):
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
prediction_scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when
``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -868,7 +876,8 @@ class TFBertForNextSentencePrediction(TFBertPreTrainedModel):
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
seq_relationship_scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, 2)`)
Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation before SoftMax).
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when
``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -931,6 +940,7 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel, TFSequenceClassific
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
training=False,
):
r"""
@ -944,7 +954,8 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel, TFSequenceClassific
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
logits (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, config.num_labels)`):
Classification (or regression if config.num_labels==1) scores (before SoftMax).
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when
``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -977,6 +988,7 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel, TFSequenceClassific
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
training=training,
)
@ -1029,6 +1041,7 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss):
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
training=False,
):
r"""
@ -1043,7 +1056,8 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss):
`num_choices` is the size of the second dimension of the input tensors. (see `input_ids` above).
Classification scores (before SoftMax).
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when
``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -1116,6 +1130,7 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss):
head_mask,
flat_inputs_embeds,
output_attentions,
output_hidden_states,
]
outputs = self.bert(flat_inputs, training=training)
@ -1162,6 +1177,7 @@ class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationL
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
training=False,
):
r"""
@ -1173,7 +1189,8 @@ class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationL
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.num_labels)`):
Classification scores (before SoftMax).
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when
``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -1205,6 +1222,7 @@ class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationL
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
training=training,
)
@ -1252,6 +1270,7 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss)
p_mask=None,
is_impossible=None,
output_attentions=None,
output_hidden_states=None,
training=False,
):
r"""
@ -1270,7 +1289,8 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss)
Span-start scores (before SoftMax).
end_scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length,)`):
Span-end scores (before SoftMax).
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when
``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -1305,6 +1325,7 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss)
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
training=training,
)

View File

@ -237,6 +237,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
inputs_embeds=None,
use_cache=True,
output_attentions=None,
output_hidden_states=None,
training=False,
):
@ -250,7 +251,8 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
inputs_embeds = inputs[6] if len(inputs) > 6 else inputs_embeds
use_cache = inputs[7] if len(inputs) > 7 else use_cache
output_attentions = inputs[8] if len(inputs) > 8 else output_attentions
assert len(inputs) <= 9, "Too many inputs."
output_hidden_states = inputs[9] if len(inputs) > 9 else output_hidden_states
assert len(inputs) <= 10, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
past = inputs.get("past", past)
@ -261,11 +263,13 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
use_cache = inputs.get("use_cache", use_cache)
output_attentions = inputs.get("output_attentions", output_attentions)
assert len(inputs) <= 9, "Too many inputs."
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
assert len(inputs) <= 10, "Too many inputs."
else:
input_ids = inputs
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
# If using past key value states, only the last tokens
# should be given as an input
@ -351,7 +355,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
all_hidden_states = ()
all_attentions = []
for i, (h, layer_past) in enumerate(zip(self.h, past)):
if self.output_hidden_states:
if cast_bool_to_primitive(output_hidden_states) is True:
all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),)
outputs = h(
[hidden_states, mask, layer_past, attention_mask, head_mask[i], use_cache, output_attentions],
@ -367,13 +371,13 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
hidden_states = self.layernorm(hidden_states)
hidden_states = tf.reshape(hidden_states, output_shape)
if self.output_hidden_states:
if cast_bool_to_primitive(output_hidden_states) is True:
all_hidden_states = all_hidden_states + (hidden_states,)
outputs = (hidden_states,)
if use_cache is True:
outputs = outputs + (presents,)
if self.output_hidden_states:
if cast_bool_to_primitive(output_hidden_states) is True:
outputs = outputs + (all_hidden_states,)
if cast_bool_to_primitive(output_attentions) is True:
# let the number of heads free (-1) so we can extract attention even after head pruning
@ -493,7 +497,7 @@ class TFCTRLModel(TFCTRLPreTrainedModel):
Contains pre-computed hidden-states (key and values in the attention blocks).
Can be used (see `past` input) to speed up sequential decoding. The token ids which have their past given to this model
should not be passed as input ids as they have already been computed.
hidden_states (:obj:`tuple(tf.Tensor)` `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(tf.Tensor)` `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -573,7 +577,7 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel):
Contains pre-computed hidden-states (key and values in the attention blocks).
Can be used (see `past` input) to speed up sequential decoding. The token ids which have their past given to this model
should not be passed as input ids as they have already been computed.
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.

View File

@ -351,7 +351,6 @@ class TFTransformer(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.n_layers = config.n_layers
self.output_hidden_states = config.output_hidden_states
self.layer = [TFTransformerBlock(config, name="layer_._{}".format(i)) for i in range(config.n_layers)]
@ -375,14 +374,14 @@ class TFTransformer(tf.keras.layers.Layer):
Tuple of length n_layers with the attention weights from each layer
Optional: only if output_attentions=True
"""
x, attn_mask, head_mask, output_attentions = inputs
x, attn_mask, head_mask, output_attentions, output_hidden_states = inputs
all_hidden_states = ()
all_attentions = ()
hidden_state = x
for i, layer_module in enumerate(self.layer):
if self.output_hidden_states:
if cast_bool_to_primitive(output_hidden_states) is True:
all_hidden_states = all_hidden_states + (hidden_state,)
layer_outputs = layer_module([hidden_state, attn_mask, head_mask[i], output_attentions], training=training)
@ -396,11 +395,11 @@ class TFTransformer(tf.keras.layers.Layer):
assert len(layer_outputs) == 1
# Add last layer
if self.output_hidden_states:
if cast_bool_to_primitive(output_hidden_states) is True:
all_hidden_states = all_hidden_states + (hidden_state,)
outputs = (hidden_state,)
if self.output_hidden_states:
if cast_bool_to_primitive(output_hidden_states) is True:
outputs = outputs + (all_hidden_states,)
if cast_bool_to_primitive(output_attentions) is True:
outputs = outputs + (all_attentions,)
@ -415,6 +414,7 @@ class TFDistilBertMainLayer(tf.keras.layers.Layer):
super().__init__(**kwargs)
self.num_hidden_layers = config.num_hidden_layers
self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states
self.embeddings = TFEmbeddings(config, name="embeddings") # Embeddings
self.transformer = TFTransformer(config, name="transformer") # Encoder
@ -430,7 +430,14 @@ class TFDistilBertMainLayer(tf.keras.layers.Layer):
raise NotImplementedError
def call(
self, inputs, attention_mask=None, head_mask=None, inputs_embeds=None, output_attentions=None, training=False
self,
inputs,
attention_mask=None,
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
training=False,
):
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
@ -438,18 +445,21 @@ class TFDistilBertMainLayer(tf.keras.layers.Layer):
head_mask = inputs[2] if len(inputs) > 2 else head_mask
inputs_embeds = inputs[3] if len(inputs) > 3 else inputs_embeds
output_attentions = inputs[4] if len(inputs) > 4 else output_attentions
assert len(inputs) <= 5, "Too many inputs."
output_hidden_states = inputs[5] if len(inputs) > 5 else output_hidden_states
assert len(inputs) <= 6, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
output_attentions = inputs.get("output_attentions", output_attentions)
assert len(inputs) <= 5, "Too many inputs."
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
assert len(inputs) <= 6, "Too many inputs."
else:
input_ids = inputs
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
@ -476,7 +486,7 @@ class TFDistilBertMainLayer(tf.keras.layers.Layer):
embedding_output = self.embeddings(input_ids, inputs_embeds=inputs_embeds) # (bs, seq_length, dim)
tfmr_output = self.transformer(
[embedding_output, attention_mask, head_mask, output_attentions], training=training
[embedding_output, attention_mask, head_mask, output_attentions, output_hidden_states], training=training
)
return tfmr_output # last-layer hidden-state, (all hidden_states), (all attentions)
@ -571,7 +581,8 @@ class TFDistilBertModel(TFDistilBertPreTrainedModel):
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers,DistilBertConfig`) and inputs:
last_hidden_state (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when
``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -622,7 +633,6 @@ class TFDistilBertLMHead(tf.keras.layers.Layer):
class TFDistilBertForMaskedLM(TFDistilBertPreTrainedModel):
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.output_hidden_states = config.output_hidden_states
self.vocab_size = config.vocab_size
self.distilbert = TFDistilBertMainLayer(config, name="distilbert")
@ -644,7 +654,8 @@ class TFDistilBertForMaskedLM(TFDistilBertPreTrainedModel):
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers,DistilBertConfig`) and inputs:
prediction_scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when
``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -710,6 +721,7 @@ class TFDistilBertForSequenceClassification(TFDistilBertPreTrainedModel, TFSeque
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
training=False,
):
r"""
@ -723,7 +735,8 @@ class TFDistilBertForSequenceClassification(TFDistilBertPreTrainedModel, TFSeque
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers,DistilBertConfig`) and inputs:
logits (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, config.num_labels)`):
Classification (or regression if config.num_labels==1) scores (before SoftMax).
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when
``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -753,6 +766,7 @@ class TFDistilBertForSequenceClassification(TFDistilBertPreTrainedModel, TFSeque
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
training=training,
)
@ -796,6 +810,7 @@ class TFDistilBertForTokenClassification(TFDistilBertPreTrainedModel, TFTokenCla
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
training=False,
):
r"""
@ -807,7 +822,8 @@ class TFDistilBertForTokenClassification(TFDistilBertPreTrainedModel, TFTokenCla
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers,DistilBertConfig`) and inputs:
scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.num_labels)`):
Classification scores (before SoftMax).
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when
``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -837,6 +853,7 @@ class TFDistilBertForTokenClassification(TFDistilBertPreTrainedModel, TFTokenCla
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
training=training,
)
@ -893,6 +910,7 @@ class TFDistilBertForMultipleChoice(TFDistilBertPreTrainedModel, TFMultipleChoic
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
training=False,
):
r"""
@ -907,7 +925,8 @@ class TFDistilBertForMultipleChoice(TFDistilBertPreTrainedModel, TFMultipleChoic
`num_choices` is the size of the second dimension of the input tensors. (see `input_ids` above).
Classification scores (before SoftMax).
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when
``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -964,6 +983,8 @@ class TFDistilBertForMultipleChoice(TFDistilBertPreTrainedModel, TFMultipleChoic
flat_attention_mask,
head_mask,
inputs_embeds,
output_attentions,
output_hidden_states,
]
distilbert_output = self.distilbert(flat_inputs, training=training)
@ -1012,6 +1033,7 @@ class TFDistilBertForQuestionAnswering(TFDistilBertPreTrainedModel, TFQuestionAn
p_mask=None,
is_impossible=None,
output_attentions=None,
output_hidden_states=None,
training=False,
):
r"""
@ -1030,7 +1052,8 @@ class TFDistilBertForQuestionAnswering(TFDistilBertPreTrainedModel, TFQuestionAn
Span-start scores (before SoftMax).
end_scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length,)`):
Span-end scores (before SoftMax).
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when
``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -1061,6 +1084,8 @@ class TFDistilBertForQuestionAnswering(TFDistilBertPreTrainedModel, TFQuestionAn
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
training=training,
)

View File

@ -240,6 +240,7 @@ class TFElectraMainLayer(TFElectraPreTrainedModel):
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
training=False,
):
if isinstance(inputs, (tuple, list)):
@ -250,7 +251,8 @@ class TFElectraMainLayer(TFElectraPreTrainedModel):
head_mask = inputs[4] if len(inputs) > 4 else head_mask
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
output_attentions = inputs[6] if len(inputs) > 6 else output_attentions
assert len(inputs) <= 7, "Too many inputs."
output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states
assert len(inputs) <= 8, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
@ -259,11 +261,15 @@ class TFElectraMainLayer(TFElectraPreTrainedModel):
head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
output_attentions = inputs.get("output_attentions", output_attentions)
assert len(inputs) <= 7, "Too many inputs."
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
assert len(inputs) <= 8, "Too many inputs."
else:
input_ids = inputs
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
@ -288,7 +294,8 @@ class TFElectraMainLayer(TFElectraPreTrainedModel):
hidden_states = self.embeddings_project(hidden_states, training=training)
hidden_states = self.encoder(
[hidden_states, extended_attention_mask, head_mask, output_attentions], training=training
[hidden_states, extended_attention_mask, head_mask, output_attentions, output_hidden_states],
training=training,
)
return hidden_states
@ -382,7 +389,8 @@ class TFElectraModel(TFElectraPreTrainedModel):
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.ElectraConfig`) and inputs:
last_hidden_state (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when
``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -433,6 +441,7 @@ class TFElectraForPreTraining(TFElectraPreTrainedModel):
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
training=False,
):
r"""
@ -440,7 +449,8 @@ class TFElectraForPreTraining(TFElectraPreTrainedModel):
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.ElectraConfig`) and inputs:
scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.num_labels)`):
Prediction scores of the head (scores for each token before SoftMax).
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when
``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -471,6 +481,7 @@ class TFElectraForPreTraining(TFElectraPreTrainedModel):
head_mask,
inputs_embeds,
output_attentions,
output_hidden_states,
training=training,
)
discriminator_sequence_output = discriminator_hidden_states[0]
@ -530,6 +541,7 @@ class TFElectraForMaskedLM(TFElectraPreTrainedModel):
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
training=False,
):
r"""
@ -537,7 +549,8 @@ class TFElectraForMaskedLM(TFElectraPreTrainedModel):
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.ElectraConfig`) and inputs:
prediction_scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when
``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -569,6 +582,7 @@ class TFElectraForMaskedLM(TFElectraPreTrainedModel):
head_mask,
inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
training=training,
)
generator_sequence_output = generator_hidden_states[0]
@ -607,6 +621,7 @@ class TFElectraForTokenClassification(TFElectraPreTrainedModel, TFTokenClassific
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
training=False,
):
r"""
@ -618,7 +633,8 @@ class TFElectraForTokenClassification(TFElectraPreTrainedModel, TFTokenClassific
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.ElectraConfig`) and inputs:
scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.num_labels)`):
Classification scores (before SoftMax).
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when
``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -651,6 +667,7 @@ class TFElectraForTokenClassification(TFElectraPreTrainedModel, TFTokenClassific
head_mask,
inputs_embeds,
output_attentions,
output_hidden_states,
training=training,
)
discriminator_sequence_output = discriminator_hidden_states[0]
@ -696,6 +713,7 @@ class TFElectraForQuestionAnswering(TFElectraPreTrainedModel, TFQuestionAnswerin
p_mask=None,
is_impossible=None,
output_attentions=None,
output_hidden_states=None,
training=False,
):
r"""
@ -714,7 +732,8 @@ class TFElectraForQuestionAnswering(TFElectraPreTrainedModel, TFQuestionAnswerin
Span-start scores (before SoftMax).
end_scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length,)`):
Span-end scores (before SoftMax).
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when
``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -749,6 +768,7 @@ class TFElectraForQuestionAnswering(TFElectraPreTrainedModel, TFQuestionAnswerin
head_mask,
inputs_embeds,
output_attentions,
output_hidden_states,
training=training,
)
discriminator_sequence_output = discriminator_hidden_states[0]

View File

@ -137,6 +137,7 @@ class TFFlaubertMainLayer(TFXLMMainLayer):
inputs_embeds=None,
training=False,
output_attentions=False,
output_hidden_states=False,
):
# removed: src_enc=None, src_len=None
if isinstance(inputs, (tuple, list)):
@ -251,15 +252,14 @@ class TFFlaubertMainLayer(TFXLMMainLayer):
if training and (dropout_probability < self.layerdrop):
continue
if self.output_hidden_states:
if output_hidden_states:
hidden_states = hidden_states + (tensor,)
# self attention
if not self.pre_norm:
attn_outputs = self.attentions[i]([tensor, attn_mask, None, cache, head_mask[i]], training=training)
attn = attn_outputs[0]
if output_attentions:
attentions = attentions + (attn_outputs[1],)
attentions = attentions + (attn_outputs[1],)
attn = self.dropout(attn, training=training)
tensor = tensor + attn
tensor = self.layer_norm1[i](tensor)
@ -292,7 +292,7 @@ class TFFlaubertMainLayer(TFXLMMainLayer):
tensor = tensor * mask[..., tf.newaxis]
# Add last hidden state
if self.output_hidden_states:
if output_hidden_states:
hidden_states = hidden_states + (tensor,)
# update cache length
@ -303,7 +303,7 @@ class TFFlaubertMainLayer(TFXLMMainLayer):
# tensor = tensor.transpose(0, 1)
outputs = (tensor,)
if self.output_hidden_states:
if output_hidden_states:
outputs = outputs + (hidden_states,)
if output_attentions:
outputs = outputs + (attentions,)

View File

@ -257,6 +257,7 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
use_cache=True,
training=False,
output_attentions=None,
output_hidden_states=None,
):
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
@ -268,7 +269,8 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
inputs_embeds = inputs[6] if len(inputs) > 6 else inputs_embeds
use_cache = inputs[7] if len(inputs) > 7 else use_cache
output_attentions = inputs[8] if len(inputs) > 7 else output_attentions
assert len(inputs) <= 9, "Too many inputs."
output_hidden_states = inputs[9] if len(inputs) > 8 else output_hidden_states
assert len(inputs) <= 10, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
past = inputs.get("past", past)
@ -279,11 +281,13 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
use_cache = inputs.get("use_cache", use_cache)
output_attentions = inputs.get("output_attentions", output_attentions)
assert len(inputs) <= 9, "Too many inputs."
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
assert len(inputs) <= 10, "Too many inputs."
else:
input_ids = inputs
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
@ -352,7 +356,7 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
all_attentions = []
all_hidden_states = ()
for i, (block, layer_past) in enumerate(zip(self.h, past)):
if self.output_hidden_states:
if cast_bool_to_primitive(output_hidden_states) is True:
all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),)
outputs = block(
@ -370,14 +374,14 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
hidden_states = tf.reshape(hidden_states, output_shape)
# Add last hidden state
if self.output_hidden_states:
if cast_bool_to_primitive(output_hidden_states) is True:
all_hidden_states = all_hidden_states + (hidden_states,)
outputs = (hidden_states,)
if use_cache is True:
outputs = outputs + (presents,)
if self.output_hidden_states:
if cast_bool_to_primitive(output_hidden_states) is True:
outputs = outputs + (all_hidden_states,)
if cast_bool_to_primitive(output_attentions) is True:
# let the number of heads free (-1) so we can extract attention even after head pruning
@ -493,7 +497,7 @@ class TFGPT2Model(TFGPT2PreTrainedModel):
Contains pre-computed hidden-states (key and values in the attention blocks).
Can be used (see `past` input) to speed up sequential decoding. The token ids which have their past given to this model
should not be passed as input ids as they have already been computed.
hidden_states (:obj:`tuple(tf.Tensor)` `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(tf.Tensor)` `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -552,7 +556,7 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel):
Contains pre-computed hidden-states (key and values in the attention blocks).
Can be used (see `past` input) to speed up sequential decoding. The token ids which have their past given to this model
should not be passed as input ids as they have already been computed.
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -620,6 +624,7 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
mc_token_ids=None,
use_cache=True,
output_attentions=None,
output_hidden_states=None,
training=False,
):
r"""
@ -637,7 +642,7 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
Contains pre-computed hidden-states (key and values in the attention blocks).
Can be used (see `past` input) to speed up sequential decoding. The token ids which have their past given to this model
should not be passed as `input_ids` as they have already been computed.
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -726,6 +731,7 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
inputs_embeds,
use_cache,
output_attentions,
output_hidden_states,
]
transformer_outputs = self.transformer(flat_inputs, training=training)

View File

@ -508,16 +508,15 @@ class TFMobileBertLayer(tf.keras.layers.Layer):
class TFMobileBertEncoder(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.output_hidden_states = config.output_hidden_states
self.layer = [TFMobileBertLayer(config, name="layer_._{}".format(i)) for i in range(config.num_hidden_layers)]
def call(self, inputs, training=False):
hidden_states, attention_mask, head_mask, output_attentions = inputs
hidden_states, attention_mask, head_mask, output_attentions, output_hidden_states = inputs
all_hidden_states = ()
all_attentions = ()
for i, layer_module in enumerate(self.layer):
if self.output_hidden_states:
if cast_bool_to_primitive(output_hidden_states) is True:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_outputs = layer_module(
@ -529,11 +528,11 @@ class TFMobileBertEncoder(tf.keras.layers.Layer):
all_attentions = all_attentions + (layer_outputs[1],)
# Add last layer
if self.output_hidden_states:
if cast_bool_to_primitive(output_hidden_states) is True:
all_hidden_states = all_hidden_states + (hidden_states,)
outputs = (hidden_states,)
if self.output_hidden_states:
if cast_bool_to_primitive(output_hidden_states) is True:
outputs = outputs + (all_hidden_states,)
if cast_bool_to_primitive(output_attentions) is True:
outputs = outputs + (all_attentions,)
@ -643,6 +642,7 @@ class TFMobileBertMainLayer(tf.keras.layers.Layer):
super().__init__(**kwargs)
self.num_hidden_layers = config.num_hidden_layers
self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states
self.embeddings = TFMobileBertEmbeddings(config, name="embeddings")
self.encoder = TFMobileBertEncoder(config, name="encoder")
@ -670,6 +670,7 @@ class TFMobileBertMainLayer(tf.keras.layers.Layer):
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
training=False,
):
if isinstance(inputs, (tuple, list)):
@ -680,7 +681,8 @@ class TFMobileBertMainLayer(tf.keras.layers.Layer):
head_mask = inputs[4] if len(inputs) > 4 else head_mask
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
output_attentions = inputs[6] if len(inputs) > 6 else output_attentions
assert len(inputs) <= 7, "Too many inputs."
output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states
assert len(inputs) <= 8, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
@ -689,11 +691,13 @@ class TFMobileBertMainLayer(tf.keras.layers.Layer):
head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
output_attentions = inputs.get("output_attentions", output_attentions)
assert len(inputs) <= 7, "Too many inputs."
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
assert len(inputs) <= 8, "Too many inputs."
else:
input_ids = inputs
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
@ -738,7 +742,8 @@ class TFMobileBertMainLayer(tf.keras.layers.Layer):
embedding_output = self.embeddings([input_ids, position_ids, token_type_ids, inputs_embeds], training=training)
encoder_outputs = self.encoder(
[embedding_output, extended_attention_mask, head_mask, output_attentions], training=training
[embedding_output, extended_attention_mask, head_mask, output_attentions, output_hidden_states],
training=training,
)
sequence_output = encoder_outputs[0]
@ -1079,6 +1084,7 @@ class TFMobileBertForSequenceClassification(TFMobileBertPreTrainedModel, TFSeque
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
training=False,
):
r"""
@ -1092,7 +1098,7 @@ class TFMobileBertForSequenceClassification(TFMobileBertPreTrainedModel, TFSeque
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.MobileBertConfig`) and inputs:
logits (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, config.num_labels)`):
Classification (or regression if config.num_labels==1) scores (before SoftMax).
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True``):
tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -1125,6 +1131,7 @@ class TFMobileBertForSequenceClassification(TFMobileBertPreTrainedModel, TFSeque
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
training=training,
)
@ -1172,6 +1179,7 @@ class TFMobileBertForQuestionAnswering(TFMobileBertPreTrainedModel, TFQuestionAn
p_mask=None,
is_impossible=None,
output_attentions=None,
output_hidden_states=None,
training=False,
):
r"""
@ -1190,7 +1198,7 @@ class TFMobileBertForQuestionAnswering(TFMobileBertPreTrainedModel, TFQuestionAn
Span-start scores (before SoftMax).
end_scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length,)`):
Span-end scores (before SoftMax).
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True``):
tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -1225,6 +1233,7 @@ class TFMobileBertForQuestionAnswering(TFMobileBertPreTrainedModel, TFQuestionAn
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
training=training,
)
@ -1281,6 +1290,7 @@ class TFMobileBertForMultipleChoice(TFMobileBertPreTrainedModel, TFMultipleChoic
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
training=False,
):
r"""
@ -1295,7 +1305,7 @@ class TFMobileBertForMultipleChoice(TFMobileBertPreTrainedModel, TFMultipleChoic
`num_choices` is the size of the second dimension of the input tensors. (see `input_ids` above).
Classification scores (before SoftMax).
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True``):
tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -1330,7 +1340,8 @@ class TFMobileBertForMultipleChoice(TFMobileBertPreTrainedModel, TFMultipleChoic
head_mask = inputs[4] if len(inputs) > 4 else head_mask
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
output_attentions = inputs[6] if len(inputs) > 6 else output_attentions
assert len(inputs) <= 7, "Too many inputs."
output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states
assert len(inputs) <= 8, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
@ -1339,7 +1350,8 @@ class TFMobileBertForMultipleChoice(TFMobileBertPreTrainedModel, TFMultipleChoic
head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
output_attentions = inputs.get("output_attentions", output_attentions)
assert len(inputs) <= 7, "Too many inputs."
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
assert len(inputs) <= 8, "Too many inputs."
else:
input_ids = inputs
@ -1368,6 +1380,7 @@ class TFMobileBertForMultipleChoice(TFMobileBertPreTrainedModel, TFMultipleChoic
head_mask,
flat_inputs_embeds,
output_attentions,
output_hidden_states,
]
outputs = self.mobilebert(flat_inputs, training=training)
@ -1414,6 +1427,7 @@ class TFMobileBertForTokenClassification(TFMobileBertPreTrainedModel, TFTokenCla
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
training=False,
):
r"""
@ -1425,7 +1439,7 @@ class TFMobileBertForTokenClassification(TFMobileBertPreTrainedModel, TFTokenCla
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.MobileBertConfig`) and inputs:
scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.num_labels)`):
Classification scores (before SoftMax).
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` or ``config.output_hidden_states=True``):
tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -1457,6 +1471,7 @@ class TFMobileBertForTokenClassification(TFMobileBertPreTrainedModel, TFTokenCla
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
training=training,
)

View File

@ -246,6 +246,7 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
training=False,
):
if isinstance(inputs, (tuple, list)):
@ -256,7 +257,8 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
head_mask = inputs[4] if len(inputs) > 4 else head_mask
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
output_attentions = inputs[6] if len(inputs) > 6 else output_attentions
assert len(inputs) <= 7, "Too many inputs."
output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states
assert len(inputs) <= 8, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
@ -265,11 +267,13 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
output_attentions = inputs.get("output_attentions", output_attentions)
assert len(inputs) <= 7, "Too many inputs."
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
assert len(inputs) <= 8, "Too many inputs."
else:
input_ids = inputs
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
@ -332,7 +336,7 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
all_attentions = []
all_hidden_states = ()
for i, block in enumerate(self.h):
if self.output_hidden_states:
if cast_bool_to_primitive(output_hidden_states) is True:
all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),)
outputs = block([hidden_states, attention_mask, head_mask[i], output_attentions], training=training)
@ -342,11 +346,11 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
hidden_states = tf.reshape(hidden_states, output_shape)
# Add last hidden state
if self.output_hidden_states:
if cast_bool_to_primitive(output_hidden_states) is True:
all_hidden_states = all_hidden_states + (hidden_states,)
outputs = (hidden_states,)
if self.output_hidden_states:
if cast_bool_to_primitive(output_hidden_states) is True:
outputs = outputs + (all_hidden_states,)
if cast_bool_to_primitive(output_attentions) is True:
# let the number of heads free (-1) so we can extract attention even after head pruning
@ -451,7 +455,7 @@ class TFOpenAIGPTModel(TFOpenAIGPTPreTrainedModel):
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.OpenAIGPTConfig`) and inputs:
last_hidden_state (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the last layer of the model.
hidden_states (:obj:`tuple(tf.Tensor)` `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(tf.Tensor)` `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -499,7 +503,7 @@ class TFOpenAIGPTLMHeadModel(TFOpenAIGPTPreTrainedModel):
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.OpenAIGPTConfig`) and inputs:
prediction_scores (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -564,6 +568,7 @@ class TFOpenAIGPTDoubleHeadsModel(TFOpenAIGPTPreTrainedModel):
inputs_embeds=None,
mc_token_ids=None,
output_attentions=None,
output_hidden_states=None,
training=False,
):
r"""
@ -581,7 +586,7 @@ class TFOpenAIGPTDoubleHeadsModel(TFOpenAIGPTPreTrainedModel):
Contains pre-computed hidden-states (key and values in the attention blocks).
Can be used (see `past` input) to speed up sequential decoding. The token ids which have their past given to this model
should not be passed as input ids as they have already been computed.
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -661,6 +666,7 @@ class TFOpenAIGPTDoubleHeadsModel(TFOpenAIGPTPreTrainedModel):
head_mask,
inputs_embeds,
output_attentions,
output_hidden_states,
]
transformer_outputs = self.transformer(flat_inputs, training=training)

View File

@ -207,7 +207,8 @@ class TFRobertaModel(TFRobertaPreTrainedModel):
objective during Bert pretraining. This output is usually *not* a good summary
of the semantic content of the input, you're often better with averaging or pooling
the sequence of hidden-states for the whole input sequence.
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when
``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -283,7 +284,8 @@ class TFRobertaForMaskedLM(TFRobertaPreTrainedModel):
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.RobertaConfig`) and inputs:
prediction_scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when
``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -365,6 +367,7 @@ class TFRobertaForSequenceClassification(TFRobertaPreTrainedModel, TFSequenceCla
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
training=False,
):
r"""
@ -372,7 +375,8 @@ class TFRobertaForSequenceClassification(TFRobertaPreTrainedModel, TFSequenceCla
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.RobertaConfig`) and inputs:
logits (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, config.num_labels)`):
Classification (or regression if config.num_labels==1) scores (before SoftMax).
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when
``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -404,6 +408,7 @@ class TFRobertaForSequenceClassification(TFRobertaPreTrainedModel, TFSequenceCla
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
training=training,
)
@ -454,6 +459,7 @@ class TFRobertaForMultipleChoice(TFRobertaPreTrainedModel, TFMultipleChoiceLoss)
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
training=False,
):
r"""
@ -468,7 +474,8 @@ class TFRobertaForMultipleChoice(TFRobertaPreTrainedModel, TFMultipleChoiceLoss)
`num_choices` is the size of the second dimension of the input tensors. (see `input_ids` above).
Classification scores (before SoftMax).
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when
``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -533,6 +540,8 @@ class TFRobertaForMultipleChoice(TFRobertaPreTrainedModel, TFMultipleChoiceLoss)
flat_position_ids,
head_mask,
inputs_embeds,
output_attentions,
output_hidden_states,
]
outputs = self.roberta(flat_inputs, training=training)
@ -579,6 +588,7 @@ class TFRobertaForTokenClassification(TFRobertaPreTrainedModel, TFTokenClassific
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
training=False,
):
r"""
@ -590,7 +600,8 @@ class TFRobertaForTokenClassification(TFRobertaPreTrainedModel, TFTokenClassific
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.RobertaConfig`) and inputs:
scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.num_labels)`):
Classification scores (before SoftMax).
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when
``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -622,6 +633,7 @@ class TFRobertaForTokenClassification(TFRobertaPreTrainedModel, TFTokenClassific
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
training=training,
)
@ -668,6 +680,7 @@ class TFRobertaForQuestionAnswering(TFRobertaPreTrainedModel, TFQuestionAnswerin
p_mask=None,
is_impossible=None,
output_attentions=None,
output_hidden_states=None,
training=False,
):
r"""
@ -686,7 +699,8 @@ class TFRobertaForQuestionAnswering(TFRobertaPreTrainedModel, TFQuestionAnswerin
Span-start scores (before SoftMax).
end_scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length,)`):
Span-end scores (before SoftMax).
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when
``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -723,6 +737,7 @@ class TFRobertaForQuestionAnswering(TFRobertaPreTrainedModel, TFQuestionAnswerin
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
training=training,
)

View File

@ -558,6 +558,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
past_key_value_states=None,
use_cache=False,
output_attentions=None,
output_hidden_states=None,
training=False,
):
if isinstance(inputs, (tuple, list)):
@ -584,6 +585,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
input_ids = inputs
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both inputs and inputs_embeds at the same time")
@ -696,7 +698,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
hidden_states = self.dropout(inputs_embeds, training=training)
for i, (layer_module, past_key_value_state) in enumerate(zip(self.block, past_key_value_states)):
if self.output_hidden_states:
if cast_bool_to_primitive(output_hidden_states) is True:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_outputs = layer_module(
@ -731,14 +733,14 @@ class TFT5MainLayer(tf.keras.layers.Layer):
hidden_states = self.dropout(hidden_states, training=training)
# Add last layer
if self.output_hidden_states:
if cast_bool_to_primitive(output_hidden_states) is True:
all_hidden_states = all_hidden_states + (hidden_states,)
outputs = (hidden_states,)
if use_cache is True:
assert self.is_decoder, "`use_cache` can only be set to `True` if {} is used as a decoder".format(self)
outputs = outputs + (present_key_value_states,)
if self.output_hidden_states:
if cast_bool_to_primitive(output_hidden_states) is True:
outputs = outputs + (all_hidden_states,)
if cast_bool_to_primitive(output_attentions) is True:
outputs = outputs + (all_attentions,)
@ -912,7 +914,7 @@ class TFT5Model(TFT5PreTrainedModel):
Contains pre-computed key and value hidden-states of the attention blocks.
Can be used to speed up sequential decoding (see `decoder_past_key_value_states` input).
Note that when using `decoder_past_key_value_states`, the model only outputs the last `hidden-state` of the sequence of shape :obj:`(batch_size, 1, config.vocab_size)`.
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -953,6 +955,7 @@ class TFT5Model(TFT5PreTrainedModel):
use_cache = kwargs.get("use_cache", True)
head_mask = kwargs.get("head_mask", None)
output_attentions = kwargs.get("output_attentions", None)
output_hidden_states = kwargs.get("output_hidden_states", None)
# Encode if needed (training, first prediction pass)
if encoder_outputs is None:
@ -962,6 +965,7 @@ class TFT5Model(TFT5PreTrainedModel):
inputs_embeds=inputs_embeds,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
hidden_states = encoder_outputs[0]
@ -985,6 +989,7 @@ class TFT5Model(TFT5PreTrainedModel):
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
if use_cache is True:
@ -1049,7 +1054,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel):
Contains pre-computed key and value hidden-states of the attention blocks.
Can be used to speed up sequential decoding (see `decoder_past_key_value_states` input).
Note that when using `decoder_past_key_value_states`, the model only outputs the last `prediction_score` of the sequence of shape :obj:`(batch_size, 1, config.vocab_size)`.
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -1094,6 +1099,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel):
decoder_inputs_embeds = kwargs.get("decoder_inputs_embeds", None)
head_mask = kwargs.get("head_mask", None)
output_attentions = kwargs.get("output_attentions", None)
output_hidden_states = kwargs.get("output_hidden_states", None)
# Encode if needed (training, first prediction pass)
if encoder_outputs is None:
@ -1104,6 +1110,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel):
inputs_embeds=inputs_embeds,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
hidden_states = encoder_outputs[0]
@ -1127,6 +1134,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel):
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
# insert decoder past at right place

View File

@ -520,25 +520,37 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
return new_mems
def call(self, inputs, mems=None, head_mask=None, inputs_embeds=None, output_attentions=None, training=False):
def call(
self,
inputs,
mems=None,
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
training=False,
):
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
mems = inputs[1] if len(inputs) > 1 else mems
head_mask = inputs[2] if len(inputs) > 2 else head_mask
inputs_embeds = inputs[3] if len(inputs) > 3 else inputs_embeds
output_attentions = inputs[4] if len(inputs) > 4 else output_attentions
assert len(inputs) <= 5, "Too many inputs."
output_hidden_states = inputs[5] if len(inputs) > 4 else output_hidden_states
assert len(inputs) <= 6, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
mems = inputs.get("mems", mems)
head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
output_attentions = inputs.get("output_attentions", output_attentions)
assert len(inputs) <= 5, "Too many inputs."
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
assert len(inputs) <= 6, "Too many inputs."
else:
input_ids = inputs
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
# the original code for Transformer-XL used shapes [len, bsz] but we want a unified interface in the library
# so we transpose here from shape [bsz, len] to shape [len, bsz]
@ -625,7 +637,7 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
# We transpose back here to shape [bsz, len, hidden_dim]
outputs = [tf.transpose(core_out, perm=(1, 0, 2)), new_mems]
if self.output_hidden_states:
if cast_bool_to_primitive(output_hidden_states):
# Add last layer and transpose to library standard shape [bsz, len, hidden_dim]
hids.append(core_out)
hids = list(tf.transpose(t, perm=(1, 0, 2)) for t in hids)
@ -720,7 +732,7 @@ class TFTransfoXLModel(TFTransfoXLPreTrainedModel):
Contains pre-computed hidden-states (key and values in the attention blocks).
Can be used (see `mems` input) to speed up sequential decoding. The token ids which have their past given to this model
should not be passed as input ids as they have already been computed.
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -807,6 +819,7 @@ class TFTransfoXLLMHeadModel(TFTransfoXLPreTrainedModel):
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
training=False,
):
r"""
@ -818,7 +831,7 @@ class TFTransfoXLLMHeadModel(TFTransfoXLPreTrainedModel):
Contains pre-computed hidden-states (key and values in the attention blocks).
Can be used (see `past` input) to speed up sequential decoding. The token ids which have their past given to this model
should not be passed as input ids as they have already been computed.
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -867,7 +880,7 @@ class TFTransfoXLLMHeadModel(TFTransfoXLPreTrainedModel):
bsz, tgt_len = shape_list(inputs_embeds)[:2]
transformer_outputs = self.transformer(
[input_ids, mems, head_mask, inputs_embeds, output_attentions], training=training
[input_ids, mems, head_mask, inputs_embeds, output_attentions, output_hidden_states], training=training
)
last_hidden = transformer_outputs[0]

View File

@ -332,6 +332,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
training=False,
): # removed: src_enc=None, src_len=None
if isinstance(inputs, (tuple, list)):
@ -345,7 +346,8 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
head_mask = inputs[7] if len(inputs) > 7 else head_mask
inputs_embeds = inputs[8] if len(inputs) > 8 else inputs_embeds
output_attentions = inputs[9] if len(inputs) > 9 else output_attentions
assert len(inputs) <= 10, "Too many inputs."
output_hidden_states = inputs[10] if len(inputs) > 10 else output_hidden_states
assert len(inputs) <= 11, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
@ -357,11 +359,13 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
output_attentions = inputs.get("output_attentions", output_attentions)
assert len(inputs) <= 10, "Too many inputs."
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
assert len(inputs) <= 11, "Too many inputs."
else:
input_ids = inputs
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
@ -445,7 +449,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
hidden_states = ()
attentions = ()
for i in range(self.n_layers):
if self.output_hidden_states:
if cast_bool_to_primitive(output_hidden_states) is True:
hidden_states = hidden_states + (tensor,)
# self attention
@ -472,7 +476,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
tensor = tensor * mask[..., tf.newaxis]
# Add last hidden state
if self.output_hidden_states:
if cast_bool_to_primitive(output_hidden_states) is True:
hidden_states = hidden_states + (tensor,)
# update cache length
@ -483,7 +487,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
# tensor = tensor.transpose(0, 1)
outputs = (tensor,)
if self.output_hidden_states:
if cast_bool_to_primitive(output_hidden_states) is True:
outputs = outputs + (hidden_states,)
if cast_bool_to_primitive(output_attentions) is True:
outputs = outputs + (attentions,)
@ -610,7 +614,7 @@ class TFXLMModel(TFXLMPreTrainedModel):
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.XLMConfig`) and inputs:
last_hidden_state (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`tf.Tensor` or :obj:`Numpy array` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -706,7 +710,7 @@ class TFXLMWithLMHeadModel(TFXLMPreTrainedModel):
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.XLMConfig`) and inputs:
prediction_scores (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`tf.Tensor` or :obj:`Numpy array` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -766,6 +770,7 @@ class TFXLMForSequenceClassification(TFXLMPreTrainedModel, TFSequenceClassificat
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
training=False,
):
r"""
@ -779,7 +784,7 @@ class TFXLMForSequenceClassification(TFXLMPreTrainedModel, TFSequenceClassificat
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.XLMConfig`) and inputs:
logits (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size, config.num_labels)`):
Classification (or regression if config.num_labels==1) scores (before SoftMax).
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`tf.Tensor` or :obj:`Numpy array` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -815,6 +820,7 @@ class TFXLMForSequenceClassification(TFXLMPreTrainedModel, TFSequenceClassificat
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
training=training,
)
output = transformer_outputs[0]
@ -865,6 +871,7 @@ class TFXLMForMultipleChoice(TFXLMPreTrainedModel, TFMultipleChoiceLoss):
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
training=False,
):
r"""
@ -879,7 +886,8 @@ class TFXLMForMultipleChoice(TFXLMPreTrainedModel, TFMultipleChoiceLoss):
`num_choices` is the size of the second dimension of the input tensors. (see `input_ids` above).
Classification scores (before SoftMax).
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when
``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -956,6 +964,7 @@ class TFXLMForMultipleChoice(TFXLMPreTrainedModel, TFMultipleChoiceLoss):
head_mask,
inputs_embeds,
output_attentions,
output_hidden_states,
]
transformer_outputs = self.transformer(flat_inputs, training=training)
@ -1002,6 +1011,7 @@ class TFXLMForTokenClassification(TFXLMPreTrainedModel, TFTokenClassificationLos
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
training=False,
):
r"""
@ -1013,7 +1023,8 @@ class TFXLMForTokenClassification(TFXLMPreTrainedModel, TFTokenClassificationLos
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.num_labels)`):
Classification scores (before SoftMax).
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when
``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -1045,6 +1056,7 @@ class TFXLMForTokenClassification(TFXLMPreTrainedModel, TFTokenClassificationLos
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
training=training,
)
@ -1093,6 +1105,7 @@ class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel, TFQuestionAnsweringL
p_mask=None,
is_impossible=None,
output_attentions=None,
output_hidden_states=None,
training=False,
):
r"""
@ -1111,7 +1124,7 @@ class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel, TFQuestionAnsweringL
Span-start scores (before SoftMax).
end_scores (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size, sequence_length,)`):
Span-end scores (before SoftMax).
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`tf.Tensor` or :obj:`Numpy array` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -1150,6 +1163,7 @@ class TFXLMForQuestionAnsweringSimple(TFXLMPreTrainedModel, TFQuestionAnsweringL
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
training=training,
)

View File

@ -517,6 +517,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
inputs_embeds=None,
use_cache=True,
output_attentions=None,
output_hidden_states=None,
training=False,
):
if isinstance(inputs, (tuple, list)):
@ -530,8 +531,9 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
head_mask = inputs[7] if len(inputs) > 7 else head_mask
inputs_embeds = inputs[8] if len(inputs) > 8 else inputs_embeds
use_cache = inputs[9] if len(inputs) > 9 else use_cache
output_attentions = inputs[-9] if len(inputs) > 10 else output_attentions
assert len(inputs) <= 11, "Too many inputs."
output_attentions = inputs[10] if len(inputs) > 10 else output_attentions
output_hidden_states = inputs[11] if len(inputs) > 11 else output_hidden_states
assert len(inputs) <= 12, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
@ -544,11 +546,13 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
use_cache = inputs.get("use_cache", use_cache)
output_attentions = inputs.get("output_attentions", output_attentions)
assert len(inputs) <= 11, "Too many inputs."
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
assert len(inputs) <= 12, "Too many inputs."
else:
input_ids = inputs
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
# the original code for XLNet uses shapes [len, bsz] with the batch dimension at the end
# but we want a unified interface in the library with the batch size on the first dimension
@ -677,7 +681,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
# cache new mems
if self.mem_len is not None and self.mem_len > 0 and use_cache is True:
new_mems = new_mems + (self.cache_mem(output_h, mems[i]),)
if self.output_hidden_states:
if cast_bool_to_primitive(output_hidden_states) is True:
hidden_states.append((output_h, output_g) if output_g is not None else output_h)
outputs = layer_module(
@ -700,7 +704,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
attentions.append(outputs[2])
# Add last hidden state
if self.output_hidden_states:
if cast_bool_to_primitive(output_hidden_states) is True:
hidden_states.append((output_h, output_g) if output_g is not None else output_h)
output = self.dropout(output_g if output_g is not None else output_h, training=training)
@ -711,7 +715,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
if self.mem_len is not None and self.mem_len > 0 and use_cache is True:
outputs = outputs + (new_mems,)
if self.output_hidden_states:
if cast_bool_to_primitive(output_hidden_states) is True:
if output_g is not None:
hidden_states = tuple(tf.transpose(h, perm=(1, 0, 2)) for hs in hidden_states for h in hs)
else:
@ -838,7 +842,7 @@ class TFXLNetModel(TFXLNetPreTrainedModel):
Contains pre-computed hidden-states (key and values in the attention blocks).
Can be used (see `mems` input) to speed up sequential decoding. The token ids which have their past given to this model
should not be passed as input ids as they have already been computed.
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`tf.Tensor` or :obj:`Numpy array` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -922,7 +926,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel):
Contains pre-computed hidden-states (key and values in the attention blocks).
Can be used (see `past` input) to speed up sequential decoding. The token ids which have their past given to this model
should not be passed as input ids as they have already been computed.
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`tf.Tensor` or :obj:`Numpy array` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -996,6 +1000,7 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassif
use_cache=True,
labels=None,
output_attentions=None,
output_hidden_states=None,
training=False,
):
r"""
@ -1013,7 +1018,7 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassif
Contains pre-computed hidden-states (key and values in the attention blocks).
Can be used (see `past` input) to speed up sequential decoding. The token ids which have their past given to this model
should not be passed as input ids as they have already been computed.
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`tf.Tensor` or :obj:`Numpy array` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -1050,6 +1055,7 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassif
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
output = transformer_outputs[0]
@ -1106,6 +1112,7 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
use_cache=True,
labels=None,
output_attentions=None,
output_hidden_states=None,
training=False,
):
r"""
@ -1120,7 +1127,8 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
`num_choices` is the size of the second dimension of the input tensors. (see `input_ids` above).
Classification scores (before SoftMax).
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when
``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -1158,8 +1166,9 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
head_mask = inputs[7] if len(inputs) > 7 else head_mask
inputs_embeds = inputs[8] if len(inputs) > 8 else inputs_embeds
use_cache = inputs[9] if len(inputs) > 9 else use_cache
output_attentions = inputs[-9] if len(inputs) > 10 else output_attentions
assert len(inputs) <= 11, "Too many inputs."
output_attentions = inputs[10] if len(inputs) > 10 else output_attentions
output_hidden_states = inputs[11] if len(inputs) > 11 else output_hidden_states
assert len(inputs) <= 12, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
@ -1172,7 +1181,8 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
use_cache = inputs.get("use_cache", use_cache)
output_attentions = inputs.get("output_attentions", output_attentions)
assert len(inputs) <= 11, "Too many inputs."
output_hidden_states = inputs.get("output_hidden_states", output_attentions)
assert len(inputs) <= 12, "Too many inputs."
else:
input_ids = inputs
@ -1200,6 +1210,7 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
inputs_embeds,
use_cache,
output_attentions,
output_hidden_states,
]
transformer_outputs = self.transformer(flat_inputs, training=training)
@ -1246,6 +1257,7 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificatio
use_cache=True,
labels=None,
output_attentions=None,
output_hidden_states=None,
training=False,
):
r"""
@ -1261,7 +1273,7 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificatio
Contains pre-computed hidden-states (key and values in the attention blocks).
Can be used (see `past` input) to speed up sequential decoding. The token ids which have their past given to this model
should not be passed as input ids as they have already been computed.
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`tf.Tensor` or :obj:`Numpy array` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -1298,6 +1310,7 @@ class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificatio
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
training=training,
)
output = transformer_outputs[0]
@ -1345,6 +1358,7 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer
p_mask=None,
is_impossible=None,
output_attentions=None,
output_hidden_states=None,
training=False,
):
r"""
@ -1369,7 +1383,7 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer
Contains pre-computed hidden-states (key and values in the attention blocks).
Can be used (see `past` input) to speed up sequential decoding. The token ids which have their past given to this model
should not be passed as input ids as they have already been computed.
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`tf.Tensor` or :obj:`Numpy array` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -1408,6 +1422,7 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
training=training,
)
@ -1457,7 +1472,7 @@ class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnswer
# that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
# if config.mem_len > 0 else tuple of None. Can be used to speed up sequential decoding and attend to longer context.
# See details in the docstring of the `mems` input above.
# **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
# **hidden_states**: (`optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``)
# list of ``tf.Tensor`` (one for the output of each layer + the output of the embeddings)
# of shape ``(batch_size, sequence_length, hidden_size)``:
# Hidden-states of the model at the output of each layer plus the initial embedding outputs.

View File

@ -634,7 +634,6 @@ TRANSFO_XL_INPUTS_DOCSTRING = r"""
class TransfoXLModel(TransfoXLPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.output_hidden_states = config.output_hidden_states
self.n_token = config.vocab_size
@ -750,7 +749,15 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
return new_mems
@add_start_docstrings_to_callable(TRANSFO_XL_INPUTS_DOCSTRING)
def forward(self, input_ids=None, mems=None, head_mask=None, inputs_embeds=None, output_attentions=None):
def forward(
self,
input_ids=None,
mems=None,
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
):
r"""
Return:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.TransfoXLConfig`) and inputs:
@ -760,7 +767,7 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
Contains pre-computed hidden-states (key and values in the attention blocks).
Can be used (see `mems` input) to speed up sequential decoding. The token ids which have their past given to this model
should not be passed as input ids as they have already been computed.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -785,6 +792,9 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
# the original code for Transformer-XL used shapes [len, bsz] but we want a unified interface in the library
# so we transpose here from shape [bsz, len] to shape [len, bsz]
@ -873,7 +883,7 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
# We transpose back here to shape [bsz, len, hidden_dim]
outputs = [core_out.transpose(0, 1).contiguous(), new_mems]
if self.output_hidden_states:
if output_hidden_states:
# Add last layer and transpose to library standard shape [bsz, len, hidden_dim]
hids.append(core_out)
hids = list(t.transpose(0, 1).contiguous() for t in hids)
@ -936,7 +946,14 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
@add_start_docstrings_to_callable(TRANSFO_XL_INPUTS_DOCSTRING)
def forward(
self, input_ids=None, mems=None, head_mask=None, inputs_embeds=None, labels=None, output_attentions=None
self,
input_ids=None,
mems=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
@ -956,7 +973,7 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
Contains pre-computed hidden-states (key and values in the attention blocks).
Can be used (see `past` input) to speed up sequential decoding. The token ids which have their past given to this model
should not be passed as input ids as they have already been computed.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -988,7 +1005,12 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
raise ValueError("You have to specify either input_ids or inputs_embeds")
transformer_outputs = self.transformer(
input_ids, mems=mems, head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions
input_ids,
mems=mems,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
last_hidden = transformer_outputs[0]

View File

@ -314,7 +314,6 @@ XLM_INPUTS_DOCSTRING = r"""
class XLMModel(XLMPreTrainedModel):
def __init__(self, config): # , dico, is_encoder, with_output):
super().__init__(config)
self.output_hidden_states = config.output_hidden_states
# encoder / decoder, output layer
self.is_encoder = config.is_encoder
@ -408,13 +407,14 @@ class XLMModel(XLMPreTrainedModel):
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
):
r"""
Return:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.XLMConfig`) and inputs:
last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -439,6 +439,9 @@ class XLMModel(XLMPreTrainedModel):
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
if input_ids is not None:
bs, slen = input_ids.size()
@ -511,7 +514,7 @@ class XLMModel(XLMPreTrainedModel):
hidden_states = ()
attentions = ()
for i in range(self.n_layers):
if self.output_hidden_states:
if output_hidden_states:
hidden_states = hidden_states + (tensor,)
# self attention
@ -538,7 +541,7 @@ class XLMModel(XLMPreTrainedModel):
tensor *= mask.unsqueeze(-1).to(tensor.dtype)
# Add last hidden state
if self.output_hidden_states:
if output_hidden_states:
hidden_states = hidden_states + (tensor,)
# update cache length
@ -549,7 +552,7 @@ class XLMModel(XLMPreTrainedModel):
# tensor = tensor.transpose(0, 1)
outputs = (tensor,)
if self.output_hidden_states:
if output_hidden_states:
outputs = outputs + (hidden_states,)
if output_attentions:
outputs = outputs + (attentions,)
@ -642,6 +645,7 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
@ -657,7 +661,7 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
Language modeling loss.
prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -692,6 +696,7 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
output = transformer_outputs[0]
@ -730,6 +735,7 @@ class XLMForSequenceClassification(XLMPreTrainedModel):
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
@ -744,7 +750,7 @@ class XLMForSequenceClassification(XLMPreTrainedModel):
Classification (or regression if config.num_labels==1) loss.
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`):
Classification (or regression if config.num_labels==1) scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -780,6 +786,7 @@ class XLMForSequenceClassification(XLMPreTrainedModel):
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
output = transformer_outputs[0]
@ -829,6 +836,7 @@ class XLMForQuestionAnsweringSimple(XLMPreTrainedModel):
start_positions=None,
end_positions=None,
output_attentions=None,
output_hidden_states=None,
):
r"""
start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
@ -848,7 +856,7 @@ class XLMForQuestionAnsweringSimple(XLMPreTrainedModel):
Span-start scores (before SoftMax).
end_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`):
Span-end scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -885,6 +893,7 @@ class XLMForQuestionAnsweringSimple(XLMPreTrainedModel):
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
sequence_output = transformer_outputs[0]
@ -952,6 +961,7 @@ class XLMForQuestionAnswering(XLMPreTrainedModel):
cls_index=None,
p_mask=None,
output_attentions=None,
output_hidden_states=None,
):
r"""
start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
@ -984,7 +994,7 @@ class XLMForQuestionAnswering(XLMPreTrainedModel):
Indices for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search).
cls_logits (``torch.FloatTensor`` of shape ``(batch_size,)``, `optional`, returned if ``start_positions`` or ``end_positions`` is not provided):
Log probabilities for the ``is_impossible`` label of the answers.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -1021,6 +1031,7 @@ class XLMForQuestionAnswering(XLMPreTrainedModel):
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
output = transformer_outputs[0]
@ -1066,6 +1077,7 @@ class XLMForTokenClassification(XLMPreTrainedModel):
head_mask=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
@ -1078,7 +1090,7 @@ class XLMForTokenClassification(XLMPreTrainedModel):
Classification loss.
scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.num_labels)`)
Classification scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -1111,6 +1123,7 @@ class XLMForTokenClassification(XLMPreTrainedModel):
position_ids=position_ids,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
sequence_output = outputs[0]

View File

@ -630,7 +630,6 @@ XLNET_INPUTS_DOCSTRING = r"""
class XLNetModel(XLNetPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.output_hidden_states = config.output_hidden_states
self.mem_len = config.mem_len
self.reuse_len = config.reuse_len
@ -763,6 +762,7 @@ class XLNetModel(XLNetPreTrainedModel):
inputs_embeds=None,
use_cache=True,
output_attentions=None,
output_hidden_states=None,
):
r"""
Return:
@ -774,7 +774,7 @@ class XLNetModel(XLNetPreTrainedModel):
Contains pre-computed hidden-states (key and values in the attention blocks).
Can be used (see `mems` input) to speed up sequential decoding. The token ids which have their past given to this model
should not be passed as input ids as they have already been computed.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -801,6 +801,9 @@ class XLNetModel(XLNetPreTrainedModel):
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
# the original code for XLNet uses shapes [len, bsz] with the batch dimension at the end
# but we want a unified interface in the library with the batch size on the first dimension
@ -934,7 +937,7 @@ class XLNetModel(XLNetPreTrainedModel):
if self.mem_len is not None and self.mem_len > 0 and use_cache is True:
# cache new mems
new_mems = new_mems + (self.cache_mem(output_h, mems[i]),)
if self.output_hidden_states:
if output_hidden_states:
hidden_states.append((output_h, output_g) if output_g is not None else output_h)
outputs = layer_module(
@ -954,7 +957,7 @@ class XLNetModel(XLNetPreTrainedModel):
attentions.append(outputs[2])
# Add last hidden state
if self.output_hidden_states:
if output_hidden_states:
hidden_states.append((output_h, output_g) if output_g is not None else output_h)
output = self.dropout(output_g if output_g is not None else output_h)
@ -965,7 +968,7 @@ class XLNetModel(XLNetPreTrainedModel):
if self.mem_len is not None and self.mem_len > 0 and use_cache is True:
outputs = outputs + (new_mems,)
if self.output_hidden_states:
if output_hidden_states:
if output_g is not None:
hidden_states = tuple(h.permute(1, 0, 2).contiguous() for hs in hidden_states for h in hs)
else:
@ -1051,6 +1054,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
use_cache=True,
labels=None,
output_attentions=None,
output_hidden_states=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, num_predict)`, `optional`, defaults to :obj:`None`):
@ -1072,7 +1076,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
Contains pre-computed hidden-states (key and values in the attention blocks).
Can be used (see `past` input) to speed up sequential decoding. The token ids which have their past given to this model
should not be passed as input ids as they have already been computed.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -1127,6 +1131,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
logits = self.lm_loss(transformer_outputs[0])
@ -1173,6 +1178,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
use_cache=True,
labels=None,
output_attentions=None,
output_hidden_states=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`)
@ -1191,7 +1197,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
Contains pre-computed hidden-states (key and values in the attention blocks).
Can be used (see `past` input) to speed up sequential decoding. The token ids which have their past given to this model
should not be passed as input ids as they have already been computed.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -1229,6 +1235,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
output = transformer_outputs[0]
@ -1280,6 +1287,7 @@ class XLNetForTokenClassification(XLNetPreTrainedModel):
use_cache=True,
labels=None,
output_attentions=None,
output_hidden_states=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
@ -1297,7 +1305,7 @@ class XLNetForTokenClassification(XLNetPreTrainedModel):
Contains pre-computed hidden-states (key and values in the attention blocks).
Can be used (see `past` input) to speed up sequential decoding. The token ids which have their past given to this model
should not be passed as input ids as they have already been computed.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -1337,6 +1345,7 @@ class XLNetForTokenClassification(XLNetPreTrainedModel):
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
sequence_output = outputs[0]
@ -1391,6 +1400,7 @@ class XLNetForMultipleChoice(XLNetPreTrainedModel):
use_cache=True,
labels=None,
output_attentions=None,
output_hidden_states=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
@ -1410,7 +1420,7 @@ class XLNetForMultipleChoice(XLNetPreTrainedModel):
Contains pre-computed hidden-states (key and values in the attention blocks).
Can be used (see `past` input) to speed up sequential decoding. The token ids which have their past given to this model
should not be passed as input ids as they have already been computed.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -1462,6 +1472,7 @@ class XLNetForMultipleChoice(XLNetPreTrainedModel):
inputs_embeds=flat_inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
output = transformer_outputs[0]
@ -1512,6 +1523,7 @@ class XLNetForQuestionAnsweringSimple(XLNetPreTrainedModel):
start_positions=None,
end_positions=None,
output_attentions=None,
output_hidden_states=None,
):
r"""
start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
@ -1535,7 +1547,7 @@ class XLNetForQuestionAnsweringSimple(XLNetPreTrainedModel):
Contains pre-computed hidden-states (key and values in the attention blocks).
Can be used (see `past` input) to speed up sequential decoding. The token ids which have their past given to this model
should not be passed as input ids as they have already been computed.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -1576,6 +1588,7 @@ class XLNetForQuestionAnsweringSimple(XLNetPreTrainedModel):
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
sequence_output = outputs[0]
@ -1643,6 +1656,7 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
cls_index=None,
p_mask=None,
output_attentions=None,
output_hidden_states=None,
):
r"""
start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
@ -1679,7 +1693,7 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
Contains pre-computed hidden-states (key and values in the attention blocks).
Can be used (see `past` input) to speed up sequential decoding. The token ids which have their past given to this model
should not be passed as input ids as they have already been computed.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@ -1718,6 +1732,7 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
hidden_states = transformer_outputs[0]
start_logits = self.start_logits(hidden_states, p_mask=p_mask)

View File

@ -143,14 +143,13 @@ class ModelTesterMixin:
for model_class in self.all_model_classes:
inputs_dict["output_attentions"] = True
config.output_hidden_states = False
inputs_dict["output_hidden_states"] = False
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs[-1]
self.assertEqual(model.config.output_hidden_states, False)
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
# check that output_attentions also work using config
@ -162,7 +161,6 @@ class ModelTesterMixin:
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs[-1]
self.assertEqual(model.config.output_hidden_states, False)
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
if chunk_length is not None:
@ -201,14 +199,13 @@ class ModelTesterMixin:
# Check attention is always last and order is fine
inputs_dict["output_attentions"] = True
config.output_hidden_states = True
inputs_dict["output_hidden_states"] = True
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
self.assertEqual(out_len + (2 if self.is_encoder_decoder else 1), len(outputs))
self.assertEqual(model.config.output_hidden_states, True)
self_attentions = outputs[-1]
self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
@ -493,19 +490,16 @@ class ModelTesterMixin:
self.assertDictEqual(model.config.pruned_heads, {0: [0], 1: [1, 2], 2: [1, 2]})
def test_hidden_states_output(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
config.output_hidden_states = True
def check_hidden_states_output(inputs_dict, config, model_class):
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
hidden_states = outputs[-1]
self.assertEqual(model.config.output_hidden_states, True)
self.assertEqual(len(hidden_states), self.model_tester.num_hidden_layers + 1)
self.assertEqual(len(hidden_states), self.model_tester.num_hidden_layers + 1)
if hasattr(self.model_tester, "encoder_seq_length"):
seq_length = self.model_tester.encoder_seq_length
if hasattr(self.model_tester, "chunk_length") and self.model_tester.chunk_length > 1:
@ -517,6 +511,18 @@ class ModelTesterMixin:
list(hidden_states[0].shape[-2:]), [seq_length, self.model_tester.hidden_size],
)
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
inputs_dict["output_hidden_states"] = True
check_hidden_states_output(inputs_dict, config, model_class)
# check that output_hidden_states also work using config
del inputs_dict["output_hidden_states"]
config.output_hidden_states = True
check_hidden_states_output(inputs_dict, config, model_class)
def test_resize_tokens_embeddings(self):
(original_config, inputs_dict,) = self.model_tester.prepare_config_and_inputs_for_common()
if not self.test_resize_embeddings:

View File

@ -392,17 +392,23 @@ class TFModelTesterMixin:
def test_hidden_states_output(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
config.output_hidden_states = True
def check_hidden_states_output(config, inputs_dict, model_class):
model = model_class(config)
outputs = model(self._prepare_for_class(inputs_dict, model_class))
hidden_states = [t.numpy() for t in outputs[-1]]
self.assertEqual(model.config.output_hidden_states, True)
self.assertEqual(len(hidden_states), self.model_tester.num_hidden_layers + 1)
self.assertListEqual(
list(hidden_states[0].shape[-2:]), [self.model_tester.seq_length, self.model_tester.hidden_size],
)
for model_class in self.all_model_classes:
inputs_dict["output_hidden_states"] = True
check_hidden_states_output(config, inputs_dict, model_class)
del inputs_dict["output_hidden_states"]
config.output_hidden_states = True
check_hidden_states_output(config, inputs_dict, model_class)
def test_model_common_attributes(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()