diff --git a/src/transformers/models/bert_generation/modeling_bert_generation.py b/src/transformers/models/bert_generation/modeling_bert_generation.py index b242181b87..7336fd005c 100755 --- a/src/transformers/models/bert_generation/modeling_bert_generation.py +++ b/src/transformers/models/bert_generation/modeling_bert_generation.py @@ -130,7 +130,7 @@ def load_tf_weights_in_bert_generation( class BertGenerationEmbeddings(nn.Module): - """Construct the embeddings from word, position and token_type embeddings.""" + """Construct the embeddings from word and position embeddings.""" def __init__(self, config): super().__init__() @@ -468,7 +468,7 @@ class BertGenerationDecoder(BertGenerationPreTrainedModel): >>> config.is_decoder = True >>> model = BertGenerationDecoder.from_pretrained('google/bert_for_seq_generation_L-24_bbc_encoder', config=config) - >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> inputs = tokenizer("Hello, my dog is cute", return_token_type_ids=False, return_tensors="pt") >>> outputs = model(**inputs) >>> prediction_logits = outputs.logits