Correct index in script
This commit is contained in:
parent
ec6fb25c21
commit
b72f9d340e
|
@ -150,7 +150,7 @@ def mask_tokens(inputs, tokenizer, args):
|
||||||
special_tokens_mask = [tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()]
|
special_tokens_mask = [tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()]
|
||||||
probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0)
|
probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0)
|
||||||
masked_indices = torch.bernoulli(probability_matrix).bool()
|
masked_indices = torch.bernoulli(probability_matrix).bool()
|
||||||
labels[~masked_indices] = -1 # We only compute loss on masked tokens
|
labels[~masked_indices] = -100 # We only compute loss on masked tokens
|
||||||
|
|
||||||
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
|
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
|
||||||
indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
|
indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
|
||||||
|
|
Loading…
Reference in New Issue