Fix param error (#9273)

TypeError: forward() got an unexpected keyword argument 'token_type_ids'
This commit is contained in:
Xu Song 2020-12-23 18:34:57 +08:00 committed by GitHub
parent 58e8a7611f
commit 4bafc43b0e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 2 deletions

View File

@ -130,7 +130,7 @@ def load_tf_weights_in_bert_generation(
class BertGenerationEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings."""
"""Construct the embeddings from word and position embeddings."""
def __init__(self, config):
super().__init__()
@ -468,7 +468,7 @@ class BertGenerationDecoder(BertGenerationPreTrainedModel):
>>> config.is_decoder = True
>>> model = BertGenerationDecoder.from_pretrained('google/bert_for_seq_generation_L-24_bbc_encoder', config=config)
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
>>> inputs = tokenizer("Hello, my dog is cute", return_token_type_ids=False, return_tensors="pt")
>>> outputs = model(**inputs)
>>> prediction_logits = outputs.logits