Correct index in script

This commit is contained in:
LysandreJik 2019-12-10 18:33:17 -05:00
parent ec6fb25c21
commit b72f9d340e
1 changed files with 1 additions and 1 deletions

View File

@ -150,7 +150,7 @@ def mask_tokens(inputs, tokenizer, args):
special_tokens_mask = [tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()] special_tokens_mask = [tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()]
probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0) probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0)
masked_indices = torch.bernoulli(probability_matrix).bool() masked_indices = torch.bernoulli(probability_matrix).bool()
labels[~masked_indices] = -1 # We only compute loss on masked tokens labels[~masked_indices] = -100 # We only compute loss on masked tokens
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK]) # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices