Fix param error (#9273)
TypeError: forward() got an unexpected keyword argument 'token_type_ids'
This commit is contained in:
parent
58e8a7611f
commit
4bafc43b0e
|
@ -130,7 +130,7 @@ def load_tf_weights_in_bert_generation(
|
|||
|
||||
|
||||
class BertGenerationEmbeddings(nn.Module):
|
||||
"""Construct the embeddings from word, position and token_type embeddings."""
|
||||
"""Construct the embeddings from word and position embeddings."""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
|
@ -468,7 +468,7 @@ class BertGenerationDecoder(BertGenerationPreTrainedModel):
|
|||
>>> config.is_decoder = True
|
||||
>>> model = BertGenerationDecoder.from_pretrained('google/bert_for_seq_generation_L-24_bbc_encoder', config=config)
|
||||
|
||||
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
||||
>>> inputs = tokenizer("Hello, my dog is cute", return_token_type_ids=False, return_tensors="pt")
|
||||
>>> outputs = model(**inputs)
|
||||
|
||||
>>> prediction_logits = outputs.logits
|
||||
|
|
Loading…
Reference in New Issue