Attention mask is important in the case of batching... (#16222)

* Attention mask is important in the case of batching...

* Improve the fix.

* Making the sentence different enough that they exhibit different
predictions.
This commit is contained in:
Nicolas Patry 2022-03-18 10:02:12 +01:00 committed by GitHub
parent ec4e421b7d
commit ecb4662d17
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 18 additions and 2 deletions

View File

@ -149,7 +149,7 @@ def pad_collate_fn(tokenizer, feature_extractor):
_padding_value = t_padding_value
elif key in {"input_values", "pixel_values", "input_features"}:
_padding_value = f_padding_value
elif key in {"p_mask"}:
elif key in {"p_mask", "special_tokens_mask"}:
_padding_value = 1
elif key in {"attention_mask", "token_type_ids"}:
_padding_value = 0

View File

@ -192,7 +192,6 @@ class TokenClassificationPipeline(Pipeline):
truncation = True if self.tokenizer.model_max_length and self.tokenizer.model_max_length > 0 else False
model_inputs = self.tokenizer(
sentence,
return_attention_mask=False,
return_tensors=self.framework,
truncation=truncation,
return_special_tokens_mask=True,

View File

@ -649,6 +649,23 @@ class TokenClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest
],
)
# Batch size does not affect outputs (attention_mask are required)
sentences = ["This is a test !", "Another test this is with longer sentence"]
outputs = token_classifier(sentences)
outputs_batched = token_classifier(sentences, batch_size=2)
# Batching does not make a difference in predictions
self.assertEqual(nested_simplify(outputs_batched), nested_simplify(outputs))
self.assertEqual(
nested_simplify(outputs_batched),
[
[
{"entity": "I-MISC", "score": 0.115, "index": 1, "word": "this", "start": 0, "end": 4},
{"entity": "I-MISC", "score": 0.115, "index": 2, "word": "is", "start": 5, "end": 7},
],
[],
],
)
@require_torch
def test_pt_ignore_subwords_slow_tokenizer_raises(self):
model_name = "sshleifer/tiny-dbmdz-bert-large-cased-finetuned-conll03-english"