Optimizing away the `fill-mask` pipeline. (#12113)

* Optimizing away the `fill-mask` pipeline.

- Don't send anything to the tokenizer unless needed. Vocab check is
much faster
- Keep BC by sending data to the tokenizer when needed. User handling warning messages will see performance benefits again
- Make `targets` and `top_k` work together better `top_k` cannot be
higher than `len(targets)` but can be smaller still.
- Actually simplify the `target_ids` in case of duplicate (it can happen
because we're parsing raw strings)
- Removed useless code to fail on empty strings. It works only if empty
string is in first position, moved to ignoring them instead.
- Changed the related tests as only the tests would fail correctly
(having incorrect value in first position)

* Make tests compatible for 2 different vocabs... (at the price of a
warning).

Co-authored-by: @EtaoinWu

* ValueError working globally

* Update src/transformers/pipelines/fill_mask.py

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

* `tokenizer.vocab` -> `tokenizer.get_vocab()` for more compatiblity +
fallback.

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
Nicolas Patry 2021-06-23 10:38:04 +02:00 committed by GitHub
parent 037e466b10
commit d4be498441
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 81 additions and 30 deletions

View File

@ -98,9 +98,9 @@ class FillMaskPipeline(Pipeline):
args (:obj:`str` or :obj:`List[str]`):
One or several texts (or one list of prompts) with masked tokens.
targets (:obj:`str` or :obj:`List[str]`, `optional`):
When passed, the model will return the scores for the passed token or tokens rather than the top k
predictions in the entire vocabulary. If the provided targets are not in the model vocab, they will be
tokenized and the first resulting token will be used (with a warning).
When passed, the model will limit the scores to the passed targets instead of looking up in the whole
vocab. If the provided targets are not in the model vocab, they will be tokenized and the first
resulting token will be used (with a warning, and that might be slower).
top_k (:obj:`int`, `optional`):
When passed, overrides the number of predictions to return.
@ -115,25 +115,56 @@ class FillMaskPipeline(Pipeline):
inputs = self._parse_and_tokenize(*args, **kwargs)
outputs = self._forward(inputs, return_tensors=True)
# top_k must be defined
if top_k is None:
top_k = self.top_k
results = []
batch_size = outputs.shape[0] if self.framework == "tf" else outputs.size(0)
if targets is not None:
if len(targets) == 0 or len(targets[0]) == 0:
raise ValueError("At least one target must be provided when passed.")
if isinstance(targets, str):
targets = [targets]
targets_proc = []
try:
vocab = self.tokenizer.get_vocab()
except Exception:
vocab = {}
target_ids = []
for target in targets:
target_enc = self.tokenizer.tokenize(target)
if len(target_enc) > 1 or target_enc[0] == self.tokenizer.unk_token:
id_ = vocab.get(target, None)
if id_ is None:
input_ids = self.tokenizer(
target,
add_special_tokens=False,
return_attention_mask=False,
return_token_type_ids=False,
max_length=1,
truncation=True,
)["input_ids"]
if len(input_ids) == 0:
logger.warning(
f"The specified target token `{target}` does not exist in the model vocabulary. "
f"We cannot replace it with anything meaningful, ignoring it"
)
continue
id_ = input_ids[0]
# XXX: If users encounter this pass
# it becomes pretty slow, so let's make sure
# The warning enables them to fix the input to
# get faster performance.
logger.warning(
f"The specified target token `{target}` does not exist in the model vocabulary. "
f"Replacing with `{target_enc[0]}`."
f"Replacing with `{self.tokenizer.convert_ids_to_tokens(id_)}`."
)
targets_proc.append(target_enc[0])
target_inds = np.array(self.tokenizer.convert_tokens_to_ids(targets_proc))
target_ids.append(id_)
target_ids = list(set(target_ids))
if len(target_ids) == 0:
raise ValueError("At least one target must be provided when passed.")
target_ids = np.array(target_ids)
# Cap top_k if there are targets
if top_k > target_ids.shape[0]:
top_k = target_ids.shape[0]
for i in range(batch_size):
input_ids = inputs["input_ids"][i]
@ -147,14 +178,11 @@ class FillMaskPipeline(Pipeline):
logits = outputs[i, masked_index.item(), :]
probs = tf.nn.softmax(logits)
if targets is None:
topk = tf.math.top_k(probs, k=top_k if top_k is not None else self.top_k)
values, predictions = topk.values.numpy(), topk.indices.numpy()
else:
values = tf.gather_nd(probs, tf.reshape(target_inds, (-1, 1)))
sort_inds = tf.reverse(tf.argsort(values), [0])
values = tf.gather_nd(values, tf.reshape(sort_inds, (-1, 1))).numpy()
predictions = target_inds[sort_inds.numpy()]
if targets is not None:
probs = tf.gather_nd(probs, tf.reshape(target_ids, (-1, 1)))
topk = tf.math.top_k(probs, k=top_k)
values, predictions = topk.values.numpy(), topk.indices.numpy()
else:
masked_index = torch.nonzero(input_ids == self.tokenizer.mask_token_id, as_tuple=False)
@ -163,13 +191,11 @@ class FillMaskPipeline(Pipeline):
logits = outputs[i, masked_index.item(), :]
probs = logits.softmax(dim=0)
if targets is None:
values, predictions = probs.topk(top_k if top_k is not None else self.top_k)
else:
values = probs[..., target_inds]
sort_inds = list(reversed(values.argsort(dim=-1)))
values = values[..., sort_inds]
predictions = target_inds[sort_inds]
if targets is not None:
probs = probs[..., target_ids]
values, predictions = probs.topk(top_k)
for v, p in zip(values.tolist(), predictions.tolist()):
tokens = input_ids.numpy()

View File

@ -78,7 +78,8 @@ class FillMaskPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
@require_torch
def test_torch_fill_mask_with_targets(self):
valid_inputs = ["My name is <mask>"]
valid_targets = [[" Teven", " Patrick", " Clara"], [" Sam"]]
# ' Sam' will yield a warning but work
valid_targets = [[" Teven", "ĠPatrick", "ĠClara"], ["ĠSam"], [" Sam"]]
invalid_targets = [[], [""], ""]
for model_name in self.small_models:
unmasker = pipeline(task="fill-mask", model=model_name, tokenizer=model_name, framework="pt")
@ -89,10 +90,34 @@ class FillMaskPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
for targets in invalid_targets:
self.assertRaises(ValueError, unmasker, valid_inputs, targets=targets)
@require_torch
def test_torch_fill_mask_with_targets_and_topk(self):
model_name = self.small_models[0]
unmasker = pipeline(task="fill-mask", model=model_name, tokenizer=model_name, framework="pt")
targets = [" Teven", "ĠPatrick", "ĠClara"]
top_k = 2
outputs = unmasker("My name is <mask>", targets=targets, top_k=top_k)
self.assertEqual(len(outputs), 2)
@require_torch
def test_torch_fill_mask_with_duplicate_targets_and_topk(self):
model_name = self.small_models[0]
unmasker = pipeline(task="fill-mask", model=model_name, tokenizer=model_name, framework="pt")
# String duplicates + id duplicates
targets = [" Teven", "ĠPatrick", "ĠClara", "ĠClara", " Clara"]
top_k = 10
outputs = unmasker("My name is <mask>", targets=targets, top_k=top_k)
# The target list contains duplicates, so we can't output more
# than them
self.assertEqual(len(outputs), 3)
@require_tf
def test_tf_fill_mask_with_targets(self):
valid_inputs = ["My name is <mask>"]
valid_targets = [[" Teven", " Patrick", " Clara"], [" Sam"]]
# ' Sam' will yield a warning but work
valid_targets = [[" Teven", "ĠPatrick", "ĠClara"], ["ĠSam"], [" Sam"]]
invalid_targets = [[], [""], ""]
for model_name in self.small_models:
unmasker = pipeline(task="fill-mask", model=model_name, tokenizer=model_name, framework="tf")
@ -111,7 +136,7 @@ class FillMaskPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
"My name is <mask>",
"The largest city in France is <mask>",
]
valid_targets = [" Patrick", " Clara"]
valid_targets = ["ĠPatrick", "ĠClara"]
for model_name in self.large_models:
unmasker = pipeline(
task="fill-mask",
@ -184,7 +209,7 @@ class FillMaskPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
"My name is <mask>",
"The largest city in France is <mask>",
]
valid_targets = [" Patrick", " Clara"]
valid_targets = ["ĠPatrick", "ĠClara"]
for model_name in self.large_models:
unmasker = pipeline(task="fill-mask", model=model_name, tokenizer=model_name, framework="tf", top_k=2)