Updated quick-start example with `BertForMaskedLM`

As `convert_ids_to_tokens` returns a list, the code in the README currently throws an `AssertionError`, so I propose I quick fix.
This commit is contained in:
Davide Fiocco 2018-11-28 14:53:46 +01:00 committed by GitHub
parent 21f0196412
commit ec2c339b53
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 1 additions and 1 deletions

View File

@ -142,7 +142,7 @@ predictions = model(tokens_tensor, segments_tensors)
# confirm we were able to predict 'henson'
predicted_index = torch.argmax(predictions[0, masked_index]).item()
predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])
predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0]
assert predicted_token == 'henson'
```