From 4bafc43b0ebf65dc1e9df70c4fe1a81dfa2475cf Mon Sep 17 00:00:00 2001 From: Xu Song Date: Wed, 23 Dec 2020 18:34:57 +0800 Subject: [PATCH] Fix param error (#9273) TypeError: forward() got an unexpected keyword argument 'token_type_ids' --- .../models/bert_generation/modeling_bert_generation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/bert_generation/modeling_bert_generation.py b/src/transformers/models/bert_generation/modeling_bert_generation.py index b242181b87..7336fd005c 100755 --- a/src/transformers/models/bert_generation/modeling_bert_generation.py +++ b/src/transformers/models/bert_generation/modeling_bert_generation.py @@ -130,7 +130,7 @@ def load_tf_weights_in_bert_generation( class BertGenerationEmbeddings(nn.Module): - """Construct the embeddings from word, position and token_type embeddings.""" + """Construct the embeddings from word and position embeddings.""" def __init__(self, config): super().__init__() @@ -468,7 +468,7 @@ class BertGenerationDecoder(BertGenerationPreTrainedModel): >>> config.is_decoder = True >>> model = BertGenerationDecoder.from_pretrained('google/bert_for_seq_generation_L-24_bbc_encoder', config=config) - >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> inputs = tokenizer("Hello, my dog is cute", return_token_type_ids=False, return_tensors="pt") >>> outputs = model(**inputs) >>> prediction_logits = outputs.logits