Updated quick-start example with `BertForMaskedLM`
As `convert_ids_to_tokens` returns a list, the code in the README currently throws an `AssertionError`, so I propose I quick fix.
This commit is contained in:
parent
21f0196412
commit
ec2c339b53
|
@ -142,7 +142,7 @@ predictions = model(tokens_tensor, segments_tensors)
|
|||
|
||||
# confirm we were able to predict 'henson'
|
||||
predicted_index = torch.argmax(predictions[0, masked_index]).item()
|
||||
predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])
|
||||
predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0]
|
||||
assert predicted_token == 'henson'
|
||||
```
|
||||
|
||||
|
|
Loading…
Reference in New Issue