Replace `as_target` context managers by direct calls (#18325)
* Preliminary work on tokenizers * Quality + fix tests * Treat processors * Fix pad * Remove all uses of in tests, docs and examples * Replace all as_target_tokenizer * Fix tests * Fix quality * Update examples/flax/image-captioning/run_image_captioning_flax.py Co-authored-by: amyeroberts <amy@huggingface.co> * Style Co-authored-by: amyeroberts <amy@huggingface.co>
This commit is contained in:
parent
a64bcb564d
commit
986526a0e4
|
@ -55,9 +55,7 @@ tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M", src_lang="en
|
|||
src_text = "Life is like a box of chocolates."
|
||||
tgt_text = "La vie est comme une boîte de chocolat."
|
||||
|
||||
model_inputs = tokenizer(src_text, return_tensors="pt")
|
||||
with tokenizer.as_target_tokenizer():
|
||||
labels = tokenizer(tgt_text, return_tensors="pt").input_ids
|
||||
model_inputs = tokenizer(src_text, text_target=tgt_text, return_tensors="pt")
|
||||
|
||||
loss = model(**model_inputs, labels=labels) # forward pass
|
||||
```
|
||||
|
|
|
@ -155,7 +155,7 @@ Example of translating english to many romance languages, using old-style 2 char
|
|||
## MarianTokenizer
|
||||
|
||||
[[autodoc]] MarianTokenizer
|
||||
- as_target_tokenizer
|
||||
- build_inputs_with_special_tokens
|
||||
|
||||
## MarianModel
|
||||
|
||||
|
|
|
@ -34,8 +34,8 @@ model is multilingual it expects the sequences in a different format. A special
|
|||
source and target text. The source text format is `X [eos, src_lang_code]` where `X` is the source text. The
|
||||
target text format is `[tgt_lang_code] X [eos]`. `bos` is never used.
|
||||
|
||||
The regular [`~MBartTokenizer.__call__`] will encode source text format, and it should be wrapped
|
||||
inside the context manager [`~MBartTokenizer.as_target_tokenizer`] to encode target text format.
|
||||
The regular [`~MBartTokenizer.__call__`] will encode source text format passed as first argument or with the `text`
|
||||
keyword, and target text format passed with the `text_label` keyword argument.
|
||||
|
||||
- Supervised training
|
||||
|
||||
|
@ -46,13 +46,11 @@ inside the context manager [`~MBartTokenizer.as_target_tokenizer`] to encode tar
|
|||
>>> example_english_phrase = "UN Chief Says There Is No Military Solution in Syria"
|
||||
>>> expected_translation_romanian = "Şeful ONU declară că nu există o soluţie militară în Siria"
|
||||
|
||||
>>> inputs = tokenizer(example_english_phrase, return_tensors="pt")
|
||||
>>> with tokenizer.as_target_tokenizer():
|
||||
... labels = tokenizer(expected_translation_romanian, return_tensors="pt")
|
||||
>>> inputs = tokenizer(example_english_phrase, text_target=expected_translation_romanian, return_tensors="pt")
|
||||
|
||||
>>> model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-en-ro")
|
||||
>>> # forward pass
|
||||
>>> model(**inputs, labels=batch["labels"])
|
||||
>>> model(**inputs)
|
||||
```
|
||||
|
||||
- Generation
|
||||
|
@ -108,11 +106,9 @@ tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50", src_
|
|||
src_text = " UN Chief Says There Is No Military Solution in Syria"
|
||||
tgt_text = "Şeful ONU declară că nu există o soluţie militară în Siria"
|
||||
|
||||
model_inputs = tokenizer(src_text, return_tensors="pt")
|
||||
with tokenizer.as_target_tokenizer():
|
||||
labels = tokenizer(tgt_text, return_tensors="pt").input_ids
|
||||
model_inputs = tokenizer(src_text, text_target=tgt_text, return_tensors="pt")
|
||||
|
||||
model(**model_inputs, labels=labels) # forward pass
|
||||
model(**model_inputs) # forward pass
|
||||
```
|
||||
|
||||
- Generation
|
||||
|
@ -154,7 +150,6 @@ tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
|
|||
## MBartTokenizer
|
||||
|
||||
[[autodoc]] MBartTokenizer
|
||||
- as_target_tokenizer
|
||||
- build_inputs_with_special_tokens
|
||||
|
||||
## MBartTokenizerFast
|
||||
|
|
|
@ -48,7 +48,6 @@ This model was contributed by [cwkeam](https://huggingface.co/cwkeam). The origi
|
|||
- save_pretrained
|
||||
- batch_decode
|
||||
- decode
|
||||
- as_target_processor
|
||||
|
||||
|
||||
## MCTCTModel
|
||||
|
|
|
@ -91,7 +91,6 @@ UN-Chef sagt, es gibt keine militärische Lösung in Syrien
|
|||
## NllbTokenizer
|
||||
|
||||
[[autodoc]] NllbTokenizer
|
||||
- as_target_tokenizer
|
||||
- build_inputs_with_special_tokens
|
||||
|
||||
## NllbTokenizerFast
|
||||
|
|
|
@ -45,8 +45,9 @@ target text format is `[tgt_lang_code] X [eos]`. `bos` is never used.
|
|||
|
||||
However, for fine-tuning, in some cases no language token is provided in cases where a single language is used. Please refer to [the paper](https://arxiv.org/abs/2103.06333) to learn more about this.
|
||||
|
||||
In cases where the language code is needed, The regular [`~PLBartTokenizer.__call__`] will encode source text format, and it should be wrapped
|
||||
inside the context manager [`~PLBartTokenizer.as_target_tokenizer`] to encode target text format.
|
||||
In cases where the language code is needed, the regular [`~PLBartTokenizer.__call__`] will encode source text format
|
||||
when you pass texts as the first argument or with the keyword argument `text`, and will encode target text format if
|
||||
it's passed with the `text_target` keyword argument.
|
||||
|
||||
- Supervised training
|
||||
|
||||
|
@ -56,11 +57,7 @@ inside the context manager [`~PLBartTokenizer.as_target_tokenizer`] to encode ta
|
|||
>>> tokenizer = PLBartTokenizer.from_pretrained("uclanlp/plbart-base", src_lang="en_XX", tgt_lang="python")
|
||||
>>> example_python_phrase = "def maximum(a,b,c):NEW_LINE_INDENTreturn max([a,b,c])"
|
||||
>>> expected_translation_english = "Returns the maximum value of a b c."
|
||||
>>> inputs = tokenizer(example_python_phrase, return_tensors="pt")
|
||||
>>> with tokenizer.as_target_tokenizer():
|
||||
... labels = tokenizer(expected_translation_english, return_tensors="pt")
|
||||
>>> inputs["labels"] = labels["input_ids"]
|
||||
>>> # forward pass
|
||||
>>> inputs = tokenizer(example_python_phrase, text_target=expected_translation_english, return_tensors="pt")
|
||||
>>> model(**inputs)
|
||||
```
|
||||
|
||||
|
@ -88,7 +85,6 @@ inside the context manager [`~PLBartTokenizer.as_target_tokenizer`] to encode ta
|
|||
## PLBartTokenizer
|
||||
|
||||
[[autodoc]] PLBartTokenizer
|
||||
- as_target_tokenizer
|
||||
- build_inputs_with_special_tokens
|
||||
|
||||
## PLBartModel
|
||||
|
|
|
@ -107,7 +107,7 @@ speech inputs) and `labels` (which are the `input_ids` of the encoded target seq
|
|||
>>> labels = tokenizer(ds[0]["text"], return_tensors="pt").input_ids
|
||||
|
||||
>>> # the forward function automatically creates the correct decoder_input_ids
|
||||
>>> loss = model(input_values, labels=labels).loss
|
||||
>>> loss = model(**input_features).loss
|
||||
>>> loss.backward()
|
||||
```
|
||||
|
||||
|
|
|
@ -120,7 +120,6 @@ See the [model hub](https://huggingface.co/models?filter=speech_to_text) to look
|
|||
- save_pretrained
|
||||
- batch_decode
|
||||
- decode
|
||||
- as_target_processor
|
||||
|
||||
## Speech2TextModel
|
||||
|
||||
|
|
|
@ -114,7 +114,6 @@ See [model hub](https://huggingface.co/models?filter=speech2text2) to look for S
|
|||
- save_pretrained
|
||||
- batch_decode
|
||||
- decode
|
||||
- as_target_processor
|
||||
|
||||
## Speech2Text2ForCausalLM
|
||||
|
||||
|
|
|
@ -94,7 +94,6 @@ See the [model hub](https://huggingface.co/models?filter=trocr) to look for TrOC
|
|||
- save_pretrained
|
||||
- batch_decode
|
||||
- decode
|
||||
- as_target_processor
|
||||
|
||||
## TrOCRForCausalLM
|
||||
|
||||
|
|
|
@ -62,7 +62,6 @@ This model was contributed by [patrickvonplaten](https://huggingface.co/patrickv
|
|||
- save_pretrained
|
||||
- batch_decode
|
||||
- decode
|
||||
- as_target_processor
|
||||
|
||||
## Wav2Vec2ProcessorWithLM
|
||||
|
||||
|
@ -73,7 +72,6 @@ This model was contributed by [patrickvonplaten](https://huggingface.co/patrickv
|
|||
- save_pretrained
|
||||
- batch_decode
|
||||
- decode
|
||||
- as_target_processor
|
||||
|
||||
## Wav2Vec2 specific outputs
|
||||
|
||||
|
|
|
@ -486,10 +486,8 @@ A processor combines a feature extractor and tokenizer. Load a processor with [`
|
|||
>>> def prepare_dataset(example):
|
||||
... audio = example["audio"]
|
||||
|
||||
... example["input_values"] = processor(audio["array"], sampling_rate=16000)
|
||||
... example.update(processor(audio=audio["array"], text=example["text"], sampling_rate=16000))
|
||||
|
||||
... with processor.as_target_processor():
|
||||
... example["labels"] = processor(example["text"]).input_ids
|
||||
... return example
|
||||
```
|
||||
|
||||
|
|
|
@ -109,11 +109,10 @@ The preprocessing function needs to:
|
|||
>>> def prepare_dataset(batch):
|
||||
... audio = batch["audio"]
|
||||
|
||||
... batch["input_values"] = processor(audio["array"], sampling_rate=audio["sampling_rate"]).input_values[0]
|
||||
... batch = processor(audio=audio["array"], sampling_rate=audio["sampling_rate"]).input_values[0]
|
||||
... batch["input_length"] = len(batch["input_values"])
|
||||
|
||||
... with processor.as_target_processor():
|
||||
... batch["labels"] = processor(batch["transcription"]).input_ids
|
||||
... batch["labels"] = processor(text=batch["transcription"]).input_ids
|
||||
... return batch
|
||||
```
|
||||
|
||||
|
@ -146,17 +145,9 @@ Unlike other data collators, this specific data collator needs to apply a differ
|
|||
... input_features = [{"input_values": feature["input_values"]} for feature in features]
|
||||
... label_features = [{"input_ids": feature["labels"]} for feature in features]
|
||||
|
||||
... batch = self.processor.pad(
|
||||
... input_features,
|
||||
... padding=self.padding,
|
||||
... return_tensors="pt",
|
||||
... )
|
||||
... with self.processor.as_target_processor():
|
||||
... labels_batch = self.processor.pad(
|
||||
... label_features,
|
||||
... padding=self.padding,
|
||||
... return_tensors="pt",
|
||||
... )
|
||||
... batch = self.processor.pad(input_features, padding=self.padding, return_tensors="pt")
|
||||
|
||||
... labels_batch = self.processor.pad(labels=label_features, padding=self.padding, return_tensors="pt")
|
||||
|
||||
... # replace padding with -100 to ignore loss correctly
|
||||
... labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
|
||||
|
|
|
@ -67,7 +67,7 @@ Load the T5 tokenizer to process `text` and `summary`:
|
|||
The preprocessing function needs to:
|
||||
|
||||
1. Prefix the input with a prompt so T5 knows this is a summarization task. Some models capable of multiple NLP tasks require prompting for specific tasks.
|
||||
2. Use a context manager with the `as_target_tokenizer()` function to parallelize tokenization of inputs and labels.
|
||||
2. Use the keyword `text_target` argument when tokenizing labels.
|
||||
3. Truncate sequences to be no longer than the maximum length set by the `max_length` parameter.
|
||||
|
||||
```py
|
||||
|
@ -78,8 +78,7 @@ The preprocessing function needs to:
|
|||
... inputs = [prefix + doc for doc in examples["text"]]
|
||||
... model_inputs = tokenizer(inputs, max_length=1024, truncation=True)
|
||||
|
||||
... with tokenizer.as_target_tokenizer():
|
||||
... labels = tokenizer(examples["summary"], max_length=128, truncation=True)
|
||||
... labels = tokenizer(text_target=examples["summary"], max_length=128, truncation=True)
|
||||
|
||||
... model_inputs["labels"] = labels["input_ids"]
|
||||
... return model_inputs
|
||||
|
|
|
@ -78,12 +78,7 @@ The preprocessing function needs to:
|
|||
>>> def preprocess_function(examples):
|
||||
... inputs = [prefix + example[source_lang] for example in examples["translation"]]
|
||||
... targets = [example[target_lang] for example in examples["translation"]]
|
||||
... model_inputs = tokenizer(inputs, max_length=128, truncation=True)
|
||||
|
||||
... with tokenizer.as_target_tokenizer():
|
||||
... labels = tokenizer(targets, max_length=128, truncation=True)
|
||||
|
||||
... model_inputs["labels"] = labels["input_ids"]
|
||||
... model_inputs = tokenizer(inputs, text_target=targets, max_length=128, truncation=True)
|
||||
... return model_inputs
|
||||
```
|
||||
|
||||
|
|
|
@ -471,10 +471,8 @@ Un processor combina un extractor de características y un tokenizador. Cargue u
|
|||
>>> def prepare_dataset(example):
|
||||
... audio = example["audio"]
|
||||
|
||||
... example["input_values"] = processor(audio["array"], sampling_rate=16000)
|
||||
... example.update(processor(audio=audio["array"], text=example["text"], sampling_rate=16000))
|
||||
|
||||
... with processor.as_target_processor():
|
||||
... example["labels"] = processor(example["text"]).input_ids
|
||||
... return example
|
||||
```
|
||||
|
||||
|
|
|
@ -471,10 +471,8 @@ Un processor combina un estrattore di caratteristiche e un tokenizer. Carica un
|
|||
>>> def prepare_dataset(example):
|
||||
... audio = example["audio"]
|
||||
|
||||
... example["input_values"] = processor(audio["array"], sampling_rate=16000)
|
||||
... example.update(processor(audio=audio["array"], text=example["text"], sampling_rate=16000))
|
||||
|
||||
... with processor.as_target_processor():
|
||||
... example["labels"] = processor(example["text"]).input_ids
|
||||
... return example
|
||||
```
|
||||
|
||||
|
|
|
@ -552,10 +552,13 @@ def main():
|
|||
targets = captions
|
||||
|
||||
model_inputs = {}
|
||||
# Setup the tokenizer for targets
|
||||
with tokenizer.as_target_tokenizer():
|
||||
|
||||
labels = tokenizer(
|
||||
targets, max_length=max_target_length, padding="max_length", truncation=True, return_tensors="np"
|
||||
text_target=targets,
|
||||
max_length=max_target_length,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
return_tensors="np",
|
||||
)
|
||||
model_inputs["labels"] = labels["input_ids"]
|
||||
decoder_input_ids = shift_tokens_right_fn(
|
||||
|
|
|
@ -590,9 +590,12 @@ def main():
|
|||
)
|
||||
|
||||
# Setup the tokenizer for targets
|
||||
with tokenizer.as_target_tokenizer():
|
||||
labels = tokenizer(
|
||||
targets, max_length=max_target_length, padding="max_length", truncation=True, return_tensors="np"
|
||||
text_target=targets,
|
||||
max_length=max_target_length,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
return_tensors="np",
|
||||
)
|
||||
|
||||
model_inputs["labels"] = labels["input_ids"]
|
||||
|
|
|
@ -453,9 +453,8 @@ def main():
|
|||
inputs, targets = preprocess_squad_batch(examples, question_column, context_column, answer_column)
|
||||
|
||||
model_inputs = tokenizer(inputs, max_length=max_seq_length, padding=padding, truncation=True)
|
||||
# Setup the tokenizer for targets
|
||||
with tokenizer.as_target_tokenizer():
|
||||
labels = tokenizer(targets, max_length=max_answer_length, padding=padding, truncation=True)
|
||||
# Tokenize targets with text_target=...
|
||||
labels = tokenizer(text_target=targets, max_length=max_answer_length, padding=padding, truncation=True)
|
||||
|
||||
# If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
|
||||
# padding in the loss.
|
||||
|
@ -479,9 +478,8 @@ def main():
|
|||
return_overflowing_tokens=True,
|
||||
return_offsets_mapping=True,
|
||||
)
|
||||
# Setup the tokenizer for targets
|
||||
with tokenizer.as_target_tokenizer():
|
||||
labels = tokenizer(targets, max_length=max_answer_length, padding=padding, truncation=True)
|
||||
# Tokenize targets with the `text_target` keyword argument
|
||||
labels = tokenizer(text_target=targets, max_length=max_answer_length, padding=padding, truncation=True)
|
||||
|
||||
# Since one example might give us several features if it has a long context, we need a map from a feature to
|
||||
# its corresponding example. This key gives us just that.
|
||||
|
|
|
@ -305,9 +305,8 @@ class DataCollatorCTCWithPadding:
|
|||
return_tensors="pt",
|
||||
)
|
||||
|
||||
with self.processor.as_target_processor():
|
||||
labels_batch = self.processor.pad(
|
||||
label_features,
|
||||
labels=label_features,
|
||||
padding=self.padding,
|
||||
pad_to_multiple_of=self.pad_to_multiple_of_labels,
|
||||
return_tensors="pt",
|
||||
|
|
|
@ -522,9 +522,8 @@ def main():
|
|||
inputs = [prefix + inp for inp in inputs]
|
||||
model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True)
|
||||
|
||||
# Setup the tokenizer for targets
|
||||
with tokenizer.as_target_tokenizer():
|
||||
labels = tokenizer(targets, max_length=max_target_length, padding=padding, truncation=True)
|
||||
# Tokenize targets with the `text_target` keyword argument
|
||||
labels = tokenizer(text_target=targets, max_length=max_target_length, padding=padding, truncation=True)
|
||||
|
||||
# If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
|
||||
# padding in the loss.
|
||||
|
|
|
@ -470,9 +470,8 @@ def main():
|
|||
inputs = [prefix + inp for inp in inputs]
|
||||
model_inputs = tokenizer(inputs, max_length=args.max_source_length, padding=padding, truncation=True)
|
||||
|
||||
# Setup the tokenizer for targets
|
||||
with tokenizer.as_target_tokenizer():
|
||||
labels = tokenizer(targets, max_length=max_target_length, padding=padding, truncation=True)
|
||||
# Tokenize targets with the `text_target` keyword argument
|
||||
labels = tokenizer(text_target=targets, max_length=max_target_length, padding=padding, truncation=True)
|
||||
|
||||
# If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
|
||||
# padding in the loss.
|
||||
|
|
|
@ -443,9 +443,8 @@ def main():
|
|||
inputs = [prefix + inp for inp in inputs]
|
||||
model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True)
|
||||
|
||||
# Setup the tokenizer for targets
|
||||
with tokenizer.as_target_tokenizer():
|
||||
labels = tokenizer(targets, max_length=max_target_length, padding=padding, truncation=True)
|
||||
# Tokenize targets with the `text_target` keyword argument
|
||||
labels = tokenizer(text_target=targets, max_length=max_target_length, padding=padding, truncation=True)
|
||||
|
||||
# If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
|
||||
# padding in the loss.
|
||||
|
|
|
@ -452,9 +452,8 @@ def main():
|
|||
inputs = [prefix + inp for inp in inputs]
|
||||
model_inputs = tokenizer(inputs, max_length=args.max_source_length, padding=padding, truncation=True)
|
||||
|
||||
# Setup the tokenizer for targets
|
||||
with tokenizer.as_target_tokenizer():
|
||||
labels = tokenizer(targets, max_length=max_target_length, padding=padding, truncation=True)
|
||||
# Tokenize targets with the `text_target` keyword argument
|
||||
labels = tokenizer(text_target=targets, max_length=max_target_length, padding=padding, truncation=True)
|
||||
|
||||
# If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
|
||||
# padding in the loss.
|
||||
|
|
|
@ -304,9 +304,8 @@ class DataCollatorCTCWithPadding:
|
|||
return_tensors="pt",
|
||||
)
|
||||
|
||||
with self.processor.as_target_processor():
|
||||
labels_batch = self.processor.pad(
|
||||
label_features,
|
||||
labels=label_features,
|
||||
padding=self.padding,
|
||||
pad_to_multiple_of=self.pad_to_multiple_of_labels,
|
||||
return_tensors="pt",
|
||||
|
|
|
@ -301,9 +301,8 @@ class DataCollatorCTCWithPadding:
|
|||
return_tensors="pt",
|
||||
)
|
||||
|
||||
with self.processor.as_target_processor():
|
||||
labels_batch = self.processor.pad(
|
||||
label_features,
|
||||
labels=label_features,
|
||||
padding=self.padding,
|
||||
pad_to_multiple_of=self.pad_to_multiple_of_labels,
|
||||
return_tensors="pt",
|
||||
|
|
|
@ -437,7 +437,6 @@ def main():
|
|||
table=tables, query=questions, max_length=data_args.max_source_length, padding=padding, truncation=True
|
||||
)
|
||||
|
||||
with tokenizer.as_target_tokenizer():
|
||||
labels = tokenizer(
|
||||
answer=[", ".join(answer) for answer in answers],
|
||||
max_length=max_target_length,
|
||||
|
|
|
@ -413,7 +413,6 @@ def main():
|
|||
table=tables, query=questions, max_length=data_args.max_source_length, padding=padding, truncation=True
|
||||
)
|
||||
|
||||
with tokenizer.as_target_tokenizer():
|
||||
labels = tokenizer(
|
||||
answer=[", ".join(answer) for answer in answers],
|
||||
max_length=max_target_length,
|
||||
|
|
|
@ -266,9 +266,8 @@ class DataCollatorCTCWithPadding:
|
|||
pad_to_multiple_of=self.pad_to_multiple_of,
|
||||
return_tensors="pt",
|
||||
)
|
||||
with self.processor.as_target_processor():
|
||||
labels_batch = self.processor.pad(
|
||||
label_features,
|
||||
labels=label_features,
|
||||
padding=self.padding,
|
||||
max_length=self.max_length_labels,
|
||||
pad_to_multiple_of=self.pad_to_multiple_of_labels,
|
||||
|
@ -419,9 +418,10 @@ def main():
|
|||
len(set(batch["sampling_rate"])) == 1
|
||||
), f"Make sure all inputs have the same sampling rate of {processor.feature_extractor.sampling_rate}."
|
||||
|
||||
batch["input_values"] = processor(batch["speech"], sampling_rate=batch["sampling_rate"][0]).input_values
|
||||
with processor.as_target_processor():
|
||||
batch["labels"] = processor(batch[data_args.target_text_column]).input_ids
|
||||
processed_batch = processor(
|
||||
audio=batch["speech"], text=batch[data_args.target_text_column], sampling_rate=batch["sampling_rate"][0]
|
||||
)
|
||||
batch.update(processed_batch)
|
||||
return batch
|
||||
|
||||
train_dataset = train_dataset.map(
|
||||
|
|
|
@ -185,9 +185,8 @@ class DataCollatorCTCWithPadding:
|
|||
pad_to_multiple_of=self.pad_to_multiple_of,
|
||||
return_tensors="pt",
|
||||
)
|
||||
with self.processor.as_target_processor():
|
||||
labels_batch = self.processor.pad(
|
||||
label_features,
|
||||
labels=label_features,
|
||||
padding=self.padding,
|
||||
max_length=self.max_length_labels,
|
||||
pad_to_multiple_of=self.pad_to_multiple_of_labels,
|
||||
|
@ -414,10 +413,11 @@ def main():
|
|||
assert (
|
||||
len(set(batch["sampling_rate"])) == 1
|
||||
), f"Make sure all inputs have the same sampling rate of {processor.feature_extractor.sampling_rate}."
|
||||
batch["input_values"] = processor(batch["speech"], sampling_rate=batch["sampling_rate"][0]).input_values
|
||||
# Setup the processor for targets
|
||||
with processor.as_target_processor():
|
||||
batch["labels"] = processor(batch["target_text"]).input_ids
|
||||
|
||||
processed_batch = processor(
|
||||
audio=batch["speech"], text=batch["target_text"], sampling_rate=batch["sampling_rate"][0]
|
||||
)
|
||||
batch.update(processed_batch)
|
||||
return batch
|
||||
|
||||
train_dataset = train_dataset.map(
|
||||
|
|
|
@ -349,9 +349,8 @@ class SpeechDataCollatorWithPadding:
|
|||
|
||||
if self.pad_labels:
|
||||
label_features = [{"input_ids": feature["labels"]} for feature in features]
|
||||
with self.processor.as_target_processor():
|
||||
labels_batch = self.processor.pad(
|
||||
label_features,
|
||||
labels=label_features,
|
||||
padding=self.padding,
|
||||
pad_to_multiple_of=self.pad_to_multiple_of_labels,
|
||||
return_tensors="pt",
|
||||
|
|
|
@ -504,9 +504,8 @@ def main():
|
|||
inputs = [prefix + inp for inp in inputs]
|
||||
model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True)
|
||||
|
||||
# Setup the tokenizer for targets
|
||||
with tokenizer.as_target_tokenizer():
|
||||
labels = tokenizer(targets, max_length=max_target_length, padding=padding, truncation=True)
|
||||
# Tokenize targets with the `text_target` keyword argument
|
||||
labels = tokenizer(text_target=targets, max_length=max_target_length, padding=padding, truncation=True)
|
||||
|
||||
# If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
|
||||
# padding in the loss.
|
||||
|
|
|
@ -458,9 +458,8 @@ def main():
|
|||
inputs = [prefix + inp for inp in inputs]
|
||||
model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True)
|
||||
|
||||
# Setup the tokenizer for targets
|
||||
with tokenizer.as_target_tokenizer():
|
||||
labels = tokenizer(targets, max_length=max_target_length, padding=padding, truncation=True)
|
||||
# Tokenize targets with the `text_target` keyword argument
|
||||
labels = tokenizer(text_target=targets, max_length=max_target_length, padding=padding, truncation=True)
|
||||
|
||||
# If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
|
||||
# padding in the loss.
|
||||
|
|
|
@ -1612,9 +1612,8 @@ class TFHubertForCTC(TFHubertPreTrainedModel):
|
|||
>>> # compute loss
|
||||
>>> target_transcription = "A MAN SAID TO THE UNIVERSE SIR I EXIST"
|
||||
|
||||
>>> # wrap processor as target processor to encode labels
|
||||
>>> with processor.as_target_processor():
|
||||
... labels = processor(transcription, return_tensors="tf").input_values
|
||||
>>> # Pass the transcription as text to encode labels
|
||||
>>> labels = processor(text=transcription, return_tensors="tf").input_values
|
||||
|
||||
>>> loss = model(input_values, labels=labels).loss
|
||||
```"""
|
||||
|
|
|
@ -14,7 +14,6 @@
|
|||
"""Tokenization classes for M2M100."""
|
||||
import json
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
@ -116,10 +115,8 @@ class M2M100Tokenizer(PreTrainedTokenizer):
|
|||
>>> tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M", src_lang="en", tgt_lang="ro")
|
||||
>>> src_text = " UN Chief Says There Is No Military Solution in Syria"
|
||||
>>> tgt_text = "Şeful ONU declară că nu există o soluţie militară în Siria"
|
||||
>>> model_inputs = tokenizer(src_text, return_tensors="pt")
|
||||
>>> with tokenizer.as_target_tokenizer():
|
||||
... labels = tokenizer(tgt_text, return_tensors="pt").input_ids
|
||||
>>> model(**model_inputs, labels=labels) # should work
|
||||
>>> model_inputs = tokenizer(src_text, text_target=tgt_text, return_tensors="pt")
|
||||
>>> model(**model_inputs) # should work
|
||||
```"""
|
||||
|
||||
vocab_files_names = VOCAB_FILES_NAMES
|
||||
|
@ -346,16 +343,12 @@ class M2M100Tokenizer(PreTrainedTokenizer):
|
|||
inputs["forced_bos_token_id"] = tgt_lang_id
|
||||
return inputs
|
||||
|
||||
@contextmanager
|
||||
def as_target_tokenizer(self):
|
||||
"""
|
||||
Temporarily sets the tokenizer for encoding the targets. Useful for tokenizer associated to
|
||||
sequence-to-sequence models that need a slightly different processing for the labels.
|
||||
"""
|
||||
self.set_tgt_lang_special_tokens(self.tgt_lang)
|
||||
yield
|
||||
def _switch_to_input_mode(self):
|
||||
self.set_src_lang_special_tokens(self.src_lang)
|
||||
|
||||
def _switch_to_target_mode(self):
|
||||
self.set_tgt_lang_special_tokens(self.tgt_lang)
|
||||
|
||||
def set_src_lang_special_tokens(self, src_lang: str) -> None:
|
||||
"""Reset the special tokens to the source lang setting. No prefix and suffix=[eos, src_lang_code]."""
|
||||
lang_token = self.get_lang_token(src_lang)
|
||||
|
|
|
@ -15,7 +15,6 @@ import json
|
|||
import os
|
||||
import re
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
@ -112,10 +111,7 @@ class MarianTokenizer(PreTrainedTokenizer):
|
|||
>>> tokenizer = MarianTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-de")
|
||||
>>> src_texts = ["I am a small frog.", "Tom asked his teacher for advice."]
|
||||
>>> tgt_texts = ["Ich bin ein kleiner Frosch.", "Tom bat seinen Lehrer um Rat."] # optional
|
||||
>>> inputs = tokenizer(src_texts, return_tensors="pt", padding=True)
|
||||
>>> with tokenizer.as_target_tokenizer():
|
||||
... labels = tokenizer(tgt_texts, return_tensors="pt", padding=True)
|
||||
>>> inputs["labels"] = labels["input_ids"]
|
||||
>>> inputs = tokenizer(src_texts, text_target=tgt_texts, return_tensors="pt", padding=True)
|
||||
# keys [input_ids, attention_mask, labels].
|
||||
|
||||
>>> outputs = model(**inputs) # should work
|
||||
|
@ -281,18 +277,14 @@ class MarianTokenizer(PreTrainedTokenizer):
|
|||
# We don't expect to process pairs, but leave the pair logic for API consistency
|
||||
return token_ids_0 + token_ids_1 + [self.eos_token_id]
|
||||
|
||||
@contextmanager
|
||||
def as_target_tokenizer(self):
|
||||
"""
|
||||
Temporarily sets the tokenizer for encoding the targets. Useful for tokenizer associated to
|
||||
sequence-to-sequence models that need a slightly different processing for the labels.
|
||||
"""
|
||||
def _switch_to_input_mode(self):
|
||||
self.current_spm = self.spm_source
|
||||
self.current_encoder = self.encoder
|
||||
|
||||
def _switch_to_target_mode(self):
|
||||
self.current_spm = self.spm_target
|
||||
if self.separate_vocabs:
|
||||
self.current_encoder = self.target_encoder
|
||||
yield
|
||||
self.current_spm = self.spm_source
|
||||
self.current_encoder = self.encoder
|
||||
|
||||
@property
|
||||
def vocab_size(self) -> int:
|
||||
|
|
|
@ -14,7 +14,6 @@
|
|||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from shutil import copyfile
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
|
@ -69,10 +68,7 @@ class MBartTokenizer(PreTrainedTokenizer):
|
|||
>>> tokenizer = MBartTokenizer.from_pretrained("facebook/mbart-large-en-ro", src_lang="en_XX", tgt_lang="ro_RO")
|
||||
>>> example_english_phrase = " UN Chief Says There Is No Military Solution in Syria"
|
||||
>>> expected_translation_romanian = "Şeful ONU declară că nu există o soluţie militară în Siria"
|
||||
>>> inputs = tokenizer(example_english_phrase, return_tensors="pt")
|
||||
>>> with tokenizer.as_target_tokenizer():
|
||||
... labels = tokenizer(expected_translation_romanian, return_tensors="pt")
|
||||
>>> inputs["labels"] = labels["input_ids"]
|
||||
>>> inputs = tokenizer(example_english_phrase, text_target=expected_translation_romanian, return_tensors="pt")
|
||||
```"""
|
||||
|
||||
vocab_files_names = VOCAB_FILES_NAMES
|
||||
|
@ -340,15 +336,11 @@ class MBartTokenizer(PreTrainedTokenizer):
|
|||
self.tgt_lang = tgt_lang
|
||||
return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs)
|
||||
|
||||
@contextmanager
|
||||
def as_target_tokenizer(self):
|
||||
"""
|
||||
Temporarily sets the tokenizer for encoding the targets. Useful for tokenizer associated to
|
||||
sequence-to-sequence models that need a slightly different processing for the labels.
|
||||
"""
|
||||
self.set_tgt_lang_special_tokens(self.tgt_lang)
|
||||
yield
|
||||
self.set_src_lang_special_tokens(self.src_lang)
|
||||
def _switch_to_input_mode(self):
|
||||
return self.set_src_lang_special_tokens(self.src_lang)
|
||||
|
||||
def _switch_to_target_mode(self):
|
||||
return self.set_tgt_lang_special_tokens(self.tgt_lang)
|
||||
|
||||
def set_src_lang_special_tokens(self, src_lang) -> None:
|
||||
"""Reset the special tokens to the source lang setting. No prefix and suffix=[eos, src_lang_code]."""
|
||||
|
|
|
@ -14,7 +14,6 @@
|
|||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from shutil import copyfile
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
|
@ -82,10 +81,7 @@ class MBartTokenizerFast(PreTrainedTokenizerFast):
|
|||
... )
|
||||
>>> example_english_phrase = " UN Chief Says There Is No Military Solution in Syria"
|
||||
>>> expected_translation_romanian = "Şeful ONU declară că nu există o soluţie militară în Siria"
|
||||
>>> inputs = tokenizer(example_english_phrase, return_tensors="pt")
|
||||
>>> with tokenizer.as_target_tokenizer():
|
||||
... labels = tokenizer(expected_translation_romanian, return_tensors="pt")
|
||||
>>> inputs["labels"] = labels["input_ids"]
|
||||
>>> inputs = tokenizer(example_english_phrase, text_target=expected_translation_romanian, return_tensors="pt")
|
||||
```"""
|
||||
|
||||
vocab_files_names = VOCAB_FILES_NAMES
|
||||
|
@ -240,15 +236,11 @@ class MBartTokenizerFast(PreTrainedTokenizerFast):
|
|||
self.tgt_lang = tgt_lang
|
||||
return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs)
|
||||
|
||||
@contextmanager
|
||||
def as_target_tokenizer(self):
|
||||
"""
|
||||
Temporarily sets the tokenizer for encoding the targets. Useful for tokenizer associated to
|
||||
sequence-to-sequence models that need a slightly different processing for the labels.
|
||||
"""
|
||||
self.set_tgt_lang_special_tokens(self.tgt_lang)
|
||||
yield
|
||||
self.set_src_lang_special_tokens(self.src_lang)
|
||||
def _switch_to_input_mode(self):
|
||||
return self.set_src_lang_special_tokens(self.src_lang)
|
||||
|
||||
def _switch_to_target_mode(self):
|
||||
return self.set_tgt_lang_special_tokens(self.tgt_lang)
|
||||
|
||||
def set_src_lang_special_tokens(self, src_lang) -> None:
|
||||
"""Reset the special tokens to the source lang setting. No prefix and suffix=[eos, src_lang_code]."""
|
||||
|
|
|
@ -14,7 +14,6 @@
|
|||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from shutil import copyfile
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
|
@ -102,10 +101,8 @@ class MBart50Tokenizer(PreTrainedTokenizer):
|
|||
>>> tokenizer = MBart50Tokenizer.from_pretrained("facebook/mbart-large-50", src_lang="en_XX", tgt_lang="ro_RO")
|
||||
>>> src_text = " UN Chief Says There Is No Military Solution in Syria"
|
||||
>>> tgt_text = "Şeful ONU declară că nu există o soluţie militară în Siria"
|
||||
>>> model_inputs = tokenizer(src_text, return_tensors="pt")
|
||||
>>> with tokenizer.as_target_tokenizer():
|
||||
... labels = tokenizer(tgt_text, return_tensors="pt").input_ids
|
||||
>>> # model(**model_inputs, labels=labels) should work
|
||||
>>> model_inputs = tokenizer(src_text, text_target=tgt_text, return_tensors="pt")
|
||||
>>> # model(**model_inputs) should work
|
||||
```"""
|
||||
|
||||
vocab_files_names = VOCAB_FILES_NAMES
|
||||
|
@ -337,15 +334,11 @@ class MBart50Tokenizer(PreTrainedTokenizer):
|
|||
self.tgt_lang = tgt_lang
|
||||
return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs)
|
||||
|
||||
@contextmanager
|
||||
def as_target_tokenizer(self):
|
||||
"""
|
||||
Temporarily sets the tokenizer for encoding the targets. Useful for tokenizer associated to
|
||||
sequence-to-sequence models that need a slightly different processing for the labels.
|
||||
"""
|
||||
self.set_tgt_lang_special_tokens(self.tgt_lang)
|
||||
yield
|
||||
self.set_src_lang_special_tokens(self.src_lang)
|
||||
def _switch_to_input_mode(self):
|
||||
return self.set_src_lang_special_tokens(self.src_lang)
|
||||
|
||||
def _switch_to_target_mode(self):
|
||||
return self.set_tgt_lang_special_tokens(self.tgt_lang)
|
||||
|
||||
def set_src_lang_special_tokens(self, src_lang: str) -> None:
|
||||
"""Reset the special tokens to the source lang setting. prefix=[src_lang_code] and suffix=[eos]."""
|
||||
|
|
|
@ -14,7 +14,6 @@
|
|||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from shutil import copyfile
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
|
@ -98,10 +97,8 @@ class MBart50TokenizerFast(PreTrainedTokenizerFast):
|
|||
>>> tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50", src_lang="en_XX", tgt_lang="ro_RO")
|
||||
>>> src_text = " UN Chief Says There Is No Military Solution in Syria"
|
||||
>>> tgt_text = "Şeful ONU declară că nu există o soluţie militară în Siria"
|
||||
>>> model_inputs = tokenizer(src_text, return_tensors="pt")
|
||||
>>> with tokenizer.as_target_tokenizer():
|
||||
... labels = tokenizer(tgt_text, return_tensors="pt").input_ids
|
||||
>>> # model(**model_inputs, labels=labels) should work
|
||||
>>> model_inputs = tokenizer(src_text, text_target=tgt_text, return_tensors="pt")
|
||||
>>> # model(**model_inputs) should work
|
||||
```"""
|
||||
|
||||
vocab_files_names = VOCAB_FILES_NAMES
|
||||
|
@ -211,15 +208,11 @@ class MBart50TokenizerFast(PreTrainedTokenizerFast):
|
|||
self.tgt_lang = tgt_lang
|
||||
return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs)
|
||||
|
||||
@contextmanager
|
||||
def as_target_tokenizer(self):
|
||||
"""
|
||||
Temporarily sets the tokenizer for encoding the targets. Useful for tokenizer associated to
|
||||
sequence-to-sequence models that need a slightly different processing for the labels.
|
||||
"""
|
||||
self.set_tgt_lang_special_tokens(self.tgt_lang)
|
||||
yield
|
||||
self.set_src_lang_special_tokens(self.src_lang)
|
||||
def _switch_to_input_mode(self):
|
||||
return self.set_src_lang_special_tokens(self.src_lang)
|
||||
|
||||
def _switch_to_target_mode(self):
|
||||
return self.set_tgt_lang_special_tokens(self.tgt_lang)
|
||||
|
||||
def set_src_lang_special_tokens(self, src_lang: str) -> None:
|
||||
"""Reset the special tokens to the source lang setting. prefix=[src_lang_code] and suffix=[eos]."""
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
"""
|
||||
Speech processor class for M-CTC-T
|
||||
"""
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
|
||||
from ...processing_utils import ProcessorMixin
|
||||
|
@ -39,6 +40,7 @@ class MCTCTProcessor(ProcessorMixin):
|
|||
def __init__(self, feature_extractor, tokenizer):
|
||||
super().__init__(feature_extractor, tokenizer)
|
||||
self.current_processor = self.feature_extractor
|
||||
self._in_target_context_manager = False
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
"""
|
||||
|
@ -47,8 +49,36 @@ class MCTCTProcessor(ProcessorMixin):
|
|||
[`~MCTCTProcessor.as_target_processor`] this method forwards all its arguments to AutoTokenizer's
|
||||
[`~AutoTokenizer.__call__`]. Please refer to the doctsring of the above two methods for more information.
|
||||
"""
|
||||
# For backward compatibility
|
||||
if self._in_target_context_manager:
|
||||
return self.current_processor(*args, **kwargs)
|
||||
|
||||
if "raw_speech" in kwargs:
|
||||
warnings.warn("Using `raw_speech` as a keyword argument is deprecated. Use `audio` instead.")
|
||||
audio = kwargs.pop("raw_speech")
|
||||
else:
|
||||
audio = kwargs.pop("audio", None)
|
||||
text = kwargs.pop("text", None)
|
||||
if len(args) > 0:
|
||||
audio = args[0]
|
||||
args = args[1:]
|
||||
|
||||
if audio is None and text is None:
|
||||
raise ValueError("You need to specify either an `audio` or `text` input to process.")
|
||||
|
||||
if audio is not None:
|
||||
inputs = self.feature_extractor(audio, *args, **kwargs)
|
||||
if text is not None:
|
||||
encodings = self.tokenizer(text, **kwargs)
|
||||
|
||||
if text is None:
|
||||
return inputs
|
||||
elif audio is None:
|
||||
return encodings
|
||||
else:
|
||||
inputs["labels"] = encodings["input_ids"]
|
||||
return inputs
|
||||
|
||||
def batch_decode(self, *args, **kwargs):
|
||||
"""
|
||||
This method forwards all its arguments to AutoTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please refer
|
||||
|
@ -63,8 +93,29 @@ class MCTCTProcessor(ProcessorMixin):
|
|||
[`~MCTCTProcessor.as_target_processor`] this method forwards all its arguments to PreTrainedTokenizer's
|
||||
[`~PreTrainedTokenizer.pad`]. Please refer to the docstring of the above two methods for more information.
|
||||
"""
|
||||
# For backward compatibility
|
||||
if self._in_target_context_manager:
|
||||
return self.current_processor.pad(*args, **kwargs)
|
||||
|
||||
input_features = kwargs.pop("input_features", None)
|
||||
labels = kwargs.pop("labels", None)
|
||||
if len(args) > 0:
|
||||
input_features = args[0]
|
||||
args = args[1:]
|
||||
|
||||
if input_features is not None:
|
||||
input_features = self.feature_extractor.pad(input_features, *args, **kwargs)
|
||||
if labels is not None:
|
||||
labels = self.tokenizer.pad(labels, **kwargs)
|
||||
|
||||
if labels is None:
|
||||
return input_features
|
||||
elif input_features is None:
|
||||
return labels
|
||||
else:
|
||||
input_features["labels"] = labels["input_ids"]
|
||||
return input_features
|
||||
|
||||
def decode(self, *args, **kwargs):
|
||||
"""
|
||||
This method forwards all its arguments to AutoTokenizer's [`~PreTrainedTokenizer.decode`]. Please refer to the
|
||||
|
@ -77,6 +128,13 @@ class MCTCTProcessor(ProcessorMixin):
|
|||
"""
|
||||
Temporarily sets the tokenizer for processing the input. Useful for encoding the labels when fine-tuning MCTCT.
|
||||
"""
|
||||
warnings.warn(
|
||||
"`as_target_processor` is deprecated and will be removed in v5 of Transformers. You can process your "
|
||||
"labels by using the argument `text` of the regular `__call__` method (either in the same call as "
|
||||
"your audio inputs, or in a separate call."
|
||||
)
|
||||
self._in_target_context_manager = True
|
||||
self.current_processor = self.tokenizer
|
||||
yield
|
||||
self.current_processor = self.feature_extractor
|
||||
self._in_target_context_manager = False
|
||||
|
|
|
@ -57,8 +57,7 @@ class FlaxMT5Model(FlaxT5Model):
|
|||
>>> summary = "Weiter Verhandlung in Syrien."
|
||||
>>> inputs = tokenizer(article, return_tensors="np")
|
||||
|
||||
>>> with tokenizer.as_target_tokenizer():
|
||||
... decoder_input_ids = tokenizer(summary, return_tensors="np").input_ids
|
||||
>>> decoder_input_ids = tokenizer(text_target=summary, return_tensors="np").input_ids
|
||||
|
||||
>>> outputs = model(input_ids=inputs["input_ids"], decoder_input_ids=decoder_input_ids)
|
||||
>>> hidden_states = outputs.last_hidden_state
|
||||
|
@ -84,8 +83,7 @@ class FlaxMT5EncoderModel(FlaxT5EncoderModel):
|
|||
>>> summary = "Weiter Verhandlung in Syrien."
|
||||
>>> inputs = tokenizer(article, return_tensors="np")
|
||||
|
||||
>>> with tokenizer.as_target_tokenizer():
|
||||
... decoder_input_ids = tokenizer(summary, return_tensors="np").input_ids
|
||||
>>> decoder_input_ids = tokenizer(text_target=summary, return_tensors="np").input_ids
|
||||
|
||||
>>> outputs = model(input_ids=inputs["input_ids"])
|
||||
>>> hidden_states = outputs.last_hidden_state
|
||||
|
@ -111,8 +109,7 @@ class FlaxMT5ForConditionalGeneration(FlaxT5ForConditionalGeneration):
|
|||
>>> summary = "Weiter Verhandlung in Syrien."
|
||||
>>> inputs = tokenizer(article, return_tensors="np")
|
||||
|
||||
>>> with tokenizer.as_target_tokenizer():
|
||||
... decoder_input_ids = tokenizer(summary, return_tensors="np").input_ids
|
||||
>>> decoder_input_ids = tokenizer(text_target=summary, return_tensors="np").input_ids
|
||||
|
||||
>>> outputs = model(**inputs, decoder_input_ids=decoder_input_ids)
|
||||
>>> logits = outputs.logits
|
||||
|
|
|
@ -40,8 +40,7 @@ class MT5Model(T5Model):
|
|||
>>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien."
|
||||
>>> summary = "Weiter Verhandlung in Syrien."
|
||||
>>> inputs = tokenizer(article, return_tensors="pt")
|
||||
>>> with tokenizer.as_target_tokenizer():
|
||||
... labels = tokenizer(summary, return_tensors="pt")
|
||||
>>> labels = tokenizer(text_target=summary, return_tensors="pt")
|
||||
|
||||
>>> outputs = model(input_ids=inputs["input_ids"], decoder_input_ids=labels["input_ids"])
|
||||
>>> hidden_states = outputs.last_hidden_state
|
||||
|
@ -73,11 +72,9 @@ class MT5ForConditionalGeneration(T5ForConditionalGeneration):
|
|||
>>> tokenizer = T5Tokenizer.from_pretrained("google/mt5-small")
|
||||
>>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien."
|
||||
>>> summary = "Weiter Verhandlung in Syrien."
|
||||
>>> inputs = tokenizer(article, return_tensors="pt")
|
||||
>>> with tokenizer.as_target_tokenizer():
|
||||
... labels = tokenizer(summary, return_tensors="pt")
|
||||
>>> inputs = tokenizer(article, text_target=summary, return_tensors="pt")
|
||||
|
||||
>>> outputs = model(**inputs, labels=labels["input_ids"])
|
||||
>>> outputs = model(**inputs)
|
||||
>>> loss = outputs.loss
|
||||
```"""
|
||||
|
||||
|
|
|
@ -40,8 +40,7 @@ class TFMT5Model(TFT5Model):
|
|||
>>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien."
|
||||
>>> summary = "Weiter Verhandlung in Syrien."
|
||||
>>> inputs = tokenizer(article, return_tensors="tf")
|
||||
>>> with tokenizer.as_target_tokenizer():
|
||||
... labels = tokenizer(summary, return_tensors="tf")
|
||||
>>> labels = tokenizer(text_target=summary, return_tensors="tf")
|
||||
|
||||
>>> outputs = model(input_ids=inputs["input_ids"], decoder_input_ids=labels["input_ids"])
|
||||
>>> hidden_states = outputs.last_hidden_state
|
||||
|
@ -64,11 +63,9 @@ class TFMT5ForConditionalGeneration(TFT5ForConditionalGeneration):
|
|||
>>> tokenizer = T5Tokenizer.from_pretrained("google/mt5-small")
|
||||
>>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien."
|
||||
>>> summary = "Weiter Verhandlung in Syrien."
|
||||
>>> inputs = tokenizer(article, return_tensors="tf")
|
||||
>>> with tokenizer.as_target_tokenizer():
|
||||
... labels = tokenizer(summary, return_tensors="tf")
|
||||
>>> inputs = tokenizer(article, text_target=summary, return_tensors="tf")
|
||||
|
||||
>>> outputs = model(**inputs, labels=labels["input_ids"])
|
||||
>>> outputs = model(**inputs)
|
||||
>>> loss = outputs.loss
|
||||
```"""
|
||||
|
||||
|
|
|
@ -14,7 +14,6 @@
|
|||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from shutil import copyfile
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
|
@ -67,10 +66,7 @@ class NllbTokenizer(PreTrainedTokenizer):
|
|||
... )
|
||||
>>> example_english_phrase = " UN Chief Says There Is No Military Solution in Syria"
|
||||
>>> expected_translation_french = "Le chef de l'ONU affirme qu'il n'y a pas de solution militaire en Syrie."
|
||||
>>> inputs = tokenizer(example_english_phrase, return_tensors="pt")
|
||||
>>> with tokenizer.as_target_tokenizer():
|
||||
... labels = tokenizer(expected_translation_french, return_tensors="pt")
|
||||
>>> inputs["labels"] = labels["input_ids"]
|
||||
>>> inputs = tokenizer(example_english_phrase, text_target=expected_translation_french, return_tensors="pt")
|
||||
```
|
||||
|
||||
Args:
|
||||
|
@ -386,15 +382,11 @@ class NllbTokenizer(PreTrainedTokenizer):
|
|||
self.tgt_lang = tgt_lang
|
||||
return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs)
|
||||
|
||||
@contextmanager
|
||||
def as_target_tokenizer(self):
|
||||
"""
|
||||
Temporarily sets the tokenizer for encoding the targets. Useful for tokenizer associated to
|
||||
sequence-to-sequence models that need a slightly different processing for the labels.
|
||||
"""
|
||||
self.set_tgt_lang_special_tokens(self.tgt_lang)
|
||||
yield
|
||||
self.set_src_lang_special_tokens(self.src_lang)
|
||||
def _switch_to_input_mode(self):
|
||||
return self.set_src_lang_special_tokens(self.src_lang)
|
||||
|
||||
def _switch_to_target_mode(self):
|
||||
return self.set_tgt_lang_special_tokens(self.tgt_lang)
|
||||
|
||||
def set_src_lang_special_tokens(self, src_lang) -> None:
|
||||
"""Reset the special tokens to the source lang setting. No prefix and suffix=[eos, src_lang_code]."""
|
||||
|
|
|
@ -14,7 +14,6 @@
|
|||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from shutil import copyfile
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
|
@ -80,10 +79,7 @@ class NllbTokenizerFast(PreTrainedTokenizerFast):
|
|||
... )
|
||||
>>> example_english_phrase = " UN Chief Says There Is No Military Solution in Syria"
|
||||
>>> expected_translation_french = "Le chef de l'ONU affirme qu'il n'y a pas de solution militaire en Syrie."
|
||||
>>> inputs = tokenizer(example_english_phrase, return_tensors="pt")
|
||||
>>> with tokenizer.as_target_tokenizer():
|
||||
... labels = tokenizer(expected_translation_french, return_tensors="pt")
|
||||
>>> inputs["labels"] = labels["input_ids"]
|
||||
>>> inputs = tokenizer(example_english_phrase, text_target=expected_translation_french, return_tensors="pt")
|
||||
```
|
||||
|
||||
Args:
|
||||
|
@ -284,15 +280,11 @@ class NllbTokenizerFast(PreTrainedTokenizerFast):
|
|||
self.tgt_lang = tgt_lang
|
||||
return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs)
|
||||
|
||||
@contextmanager
|
||||
def as_target_tokenizer(self):
|
||||
"""
|
||||
Temporarily sets the tokenizer for encoding the targets. Useful for tokenizer associated to
|
||||
sequence-to-sequence models that need a slightly different processing for the labels.
|
||||
"""
|
||||
self.set_tgt_lang_special_tokens(self.tgt_lang)
|
||||
yield
|
||||
self.set_src_lang_special_tokens(self.src_lang)
|
||||
def _switch_to_input_mode(self):
|
||||
return self.set_src_lang_special_tokens(self.src_lang)
|
||||
|
||||
def _switch_to_target_mode(self):
|
||||
return self.set_tgt_lang_special_tokens(self.tgt_lang)
|
||||
|
||||
def set_src_lang_special_tokens(self, src_lang) -> None:
|
||||
"""Reset the special tokens to the source lang setting. No prefix and suffix=[eos, src_lang_code]."""
|
||||
|
|
|
@ -14,7 +14,6 @@
|
|||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from shutil import copyfile
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
|
@ -153,10 +152,7 @@ class PLBartTokenizer(PreTrainedTokenizer):
|
|||
>>> tokenizer = PLBartTokenizer.from_pretrained("uclanlp/plbart-python-en_XX", src_lang="python", tgt_lang="en_XX")
|
||||
>>> example_python_phrase = "def maximum(a,b,c):NEW_LINE_INDENTreturn max([a,b,c])"
|
||||
>>> expected_translation_english = "Returns the maximum value of a b c."
|
||||
>>> inputs = tokenizer(example_python_phrase, return_tensors="pt")
|
||||
>>> with tokenizer.as_target_tokenizer():
|
||||
... labels = tokenizer(expected_translation_english, return_tensors="pt")
|
||||
>>> inputs["labels"] = labels["input_ids"]
|
||||
>>> inputs = tokenizer(example_python_phrase, text_target=expected_translation_english, return_tensors="pt")
|
||||
```"""
|
||||
|
||||
vocab_files_names = VOCAB_FILES_NAMES
|
||||
|
@ -441,15 +437,11 @@ class PLBartTokenizer(PreTrainedTokenizer):
|
|||
self.tgt_lang = tgt_lang
|
||||
return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs)
|
||||
|
||||
@contextmanager
|
||||
def as_target_tokenizer(self):
|
||||
"""
|
||||
Temporarily sets the tokenizer for encoding the targets. Useful for tokenizer associated to
|
||||
sequence-to-sequence models that need a slightly different processing for the labels.
|
||||
"""
|
||||
self.set_tgt_lang_special_tokens(self.tgt_lang)
|
||||
yield
|
||||
self.set_src_lang_special_tokens(self.src_lang)
|
||||
def _switch_to_input_mode(self):
|
||||
return self.set_src_lang_special_tokens(self.src_lang)
|
||||
|
||||
def _switch_to_target_mode(self):
|
||||
return self.set_tgt_lang_special_tokens(self.tgt_lang)
|
||||
|
||||
def set_src_lang_special_tokens(self, src_lang) -> None:
|
||||
"""Reset the special tokens to the source lang setting. No prefix and suffix=[eos, src_lang_code]."""
|
||||
|
|
|
@ -818,8 +818,7 @@ class RagSequenceForGeneration(RagPreTrainedModel):
|
|||
>>> model = RagSequenceForGeneration.from_pretrained("facebook/rag-token-nq", retriever=retriever)
|
||||
|
||||
>>> inputs = tokenizer("How many people live in Paris?", return_tensors="pt")
|
||||
>>> with tokenizer.as_target_tokenizer():
|
||||
... targets = tokenizer("In Paris, there are 10 million people.", return_tensors="pt")
|
||||
>>> targets = tokenizer(text_target="In Paris, there are 10 million people.", return_tensors="pt")
|
||||
>>> input_ids = inputs["input_ids"]
|
||||
>>> labels = targets["input_ids"]
|
||||
>>> outputs = model(input_ids=input_ids, labels=labels)
|
||||
|
@ -1287,8 +1286,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
|||
>>> model = RagTokenForGeneration.from_pretrained("facebook/rag-token-nq", retriever=retriever)
|
||||
|
||||
>>> inputs = tokenizer("How many people live in Paris?", return_tensors="pt")
|
||||
>>> with tokenizer.as_target_tokenizer():
|
||||
... targets = tokenizer("In Paris, there are 10 million people.", return_tensors="pt")
|
||||
>>> targets = tokenizer(text_target="In Paris, there are 10 million people.", return_tensors="pt")
|
||||
>>> input_ids = inputs["input_ids"]
|
||||
>>> labels = targets["input_ids"]
|
||||
>>> outputs = model(input_ids=input_ids, labels=labels)
|
||||
|
|
|
@ -15,7 +15,6 @@
|
|||
"""Tokenization classes for RAG."""
|
||||
import os
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from typing import List, Optional
|
||||
|
||||
from ...tokenization_utils_base import BatchEncoding
|
||||
|
@ -68,16 +67,12 @@ class RagTokenizer:
|
|||
def decode(self, *args, **kwargs):
|
||||
return self.generator.decode(*args, **kwargs)
|
||||
|
||||
@contextmanager
|
||||
def as_target_tokenizer(self):
|
||||
"""
|
||||
Temporarily sets the tokenizer for encoding the targets. Useful for tokenizer associated to
|
||||
sequence-to-sequence models that need a slightly different processing for the labels.
|
||||
"""
|
||||
self.current_tokenizer = self.generator
|
||||
yield
|
||||
def _switch_to_input_mode(self):
|
||||
self.current_tokenizer = self.question_encoder
|
||||
|
||||
def _switch_to_target_mode(self):
|
||||
self.current_tokenizer = self.generator
|
||||
|
||||
def prepare_seq2seq_batch(
|
||||
self,
|
||||
src_texts: List[str],
|
||||
|
@ -110,11 +105,10 @@ class RagTokenizer:
|
|||
if tgt_texts is None:
|
||||
return model_inputs
|
||||
# Process tgt_texts
|
||||
with self.as_target_tokenizer():
|
||||
if max_target_length is None:
|
||||
max_target_length = self.current_tokenizer.model_max_length
|
||||
labels = self(
|
||||
tgt_texts,
|
||||
text_target=tgt_texts,
|
||||
add_special_tokens=True,
|
||||
return_tensors=return_tensors,
|
||||
padding=padding,
|
||||
|
|
|
@ -482,8 +482,7 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
|
|||
'Mr. Quilter ist der Apostel der Mittelschicht und wir freuen uns, sein Evangelium willkommen heißen zu können.'
|
||||
|
||||
>>> # Training: Train model on English transcription
|
||||
>>> with processor.as_target_processor():
|
||||
... labels = processor(ds[0]["text"], return_tensors="pt").input_ids
|
||||
>>> labels = processor(text=ds[0]["text"], return_tensors="pt").input_ids
|
||||
|
||||
>>> loss = model(input_values, labels=labels).loss
|
||||
>>> loss.backward()
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
"""
|
||||
Speech processor class for Speech2Text
|
||||
"""
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
|
||||
from ...processing_utils import ProcessorMixin
|
||||
|
@ -41,6 +42,7 @@ class Speech2TextProcessor(ProcessorMixin):
|
|||
def __init__(self, feature_extractor, tokenizer):
|
||||
super().__init__(feature_extractor, tokenizer)
|
||||
self.current_processor = self.feature_extractor
|
||||
self._in_target_context_manager = False
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
"""
|
||||
|
@ -50,8 +52,36 @@ class Speech2TextProcessor(ProcessorMixin):
|
|||
[`~Speech2TextTokenizer.__call__`]. Please refer to the doctsring of the above two methods for more
|
||||
information.
|
||||
"""
|
||||
# For backward compatibility
|
||||
if self._in_target_context_manager:
|
||||
return self.current_processor(*args, **kwargs)
|
||||
|
||||
if "raw_speech" in kwargs:
|
||||
warnings.warn("Using `raw_speech` as a keyword argument is deprecated. Use `audio` instead.")
|
||||
audio = kwargs.pop("raw_speech")
|
||||
else:
|
||||
audio = kwargs.pop("audio", None)
|
||||
text = kwargs.pop("text", None)
|
||||
if len(args) > 0:
|
||||
audio = args[0]
|
||||
args = args[1:]
|
||||
|
||||
if audio is None and text is None:
|
||||
raise ValueError("You need to specify either an `audio` or `text` input to process.")
|
||||
|
||||
if audio is not None:
|
||||
inputs = self.feature_extractor(audio, *args, **kwargs)
|
||||
if text is not None:
|
||||
encodings = self.tokenizer(text, **kwargs)
|
||||
|
||||
if text is None:
|
||||
return inputs
|
||||
elif audio is None:
|
||||
return encodings
|
||||
else:
|
||||
inputs["labels"] = encodings["input_ids"]
|
||||
return inputs
|
||||
|
||||
def batch_decode(self, *args, **kwargs):
|
||||
"""
|
||||
This method forwards all its arguments to Speech2TextTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please
|
||||
|
@ -72,6 +102,13 @@ class Speech2TextProcessor(ProcessorMixin):
|
|||
Temporarily sets the tokenizer for processing the input. Useful for encoding the labels when fine-tuning
|
||||
Speech2Text.
|
||||
"""
|
||||
warnings.warn(
|
||||
"`as_target_processor` is deprecated and will be removed in v5 of Transformers. You can process your "
|
||||
"labels by using the argument `text` of the regular `__call__` method (either in the same call as "
|
||||
"your audio inputs, or in a separate call."
|
||||
)
|
||||
self._in_target_context_manager = True
|
||||
self.current_processor = self.tokenizer
|
||||
yield
|
||||
self.current_processor = self.feature_extractor
|
||||
self._in_target_context_manager = False
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
"""
|
||||
Speech processor class for Speech2Text2
|
||||
"""
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
|
||||
from ...processing_utils import ProcessorMixin
|
||||
|
@ -40,6 +41,7 @@ class Speech2Text2Processor(ProcessorMixin):
|
|||
def __init__(self, feature_extractor, tokenizer):
|
||||
super().__init__(feature_extractor, tokenizer)
|
||||
self.current_processor = self.feature_extractor
|
||||
self._in_target_context_manager = False
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
"""
|
||||
|
@ -49,8 +51,36 @@ class Speech2Text2Processor(ProcessorMixin):
|
|||
Speech2Text2Tokenizer's [`~Speech2Text2Tokenizer.__call__`]. Please refer to the doctsring of the above two
|
||||
methods for more information.
|
||||
"""
|
||||
# For backward compatibility
|
||||
if self._in_target_context_manager:
|
||||
return self.current_processor(*args, **kwargs)
|
||||
|
||||
if "raw_speech" in kwargs:
|
||||
warnings.warn("Using `raw_speech` as a keyword argument is deprecated. Use `audio` instead.")
|
||||
audio = kwargs.pop("raw_speech")
|
||||
else:
|
||||
audio = kwargs.pop("audio", None)
|
||||
text = kwargs.pop("text", None)
|
||||
if len(args) > 0:
|
||||
audio = args[0]
|
||||
args = args[1:]
|
||||
|
||||
if audio is None and text is None:
|
||||
raise ValueError("You need to specify either an `audio` or `text` input to process.")
|
||||
|
||||
if audio is not None:
|
||||
inputs = self.feature_extractor(audio, *args, **kwargs)
|
||||
if text is not None:
|
||||
encodings = self.tokenizer(text, **kwargs)
|
||||
|
||||
if text is None:
|
||||
return inputs
|
||||
elif audio is None:
|
||||
return encodings
|
||||
else:
|
||||
inputs["labels"] = encodings["input_ids"]
|
||||
return inputs
|
||||
|
||||
def batch_decode(self, *args, **kwargs):
|
||||
"""
|
||||
This method forwards all its arguments to Speech2Text2Tokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please
|
||||
|
@ -71,6 +101,13 @@ class Speech2Text2Processor(ProcessorMixin):
|
|||
Temporarily sets the tokenizer for processing the input. Useful for encoding the labels when fine-tuning
|
||||
Speech2Text2.
|
||||
"""
|
||||
warnings.warn(
|
||||
"`as_target_processor` is deprecated and will be removed in v5 of Transformers. You can process your "
|
||||
"labels by using the argument `text` of the regular `__call__` method (either in the same call as "
|
||||
"your audio inputs, or in a separate call."
|
||||
)
|
||||
self._in_target_context_manager = True
|
||||
self.current_processor = self.tokenizer
|
||||
yield
|
||||
self.current_processor = self.feature_extractor
|
||||
self._in_target_context_manager = False
|
||||
|
|
|
@ -17,7 +17,6 @@
|
|||
import json
|
||||
import os
|
||||
import random
|
||||
from contextlib import contextmanager
|
||||
from functools import lru_cache
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
|
@ -63,12 +62,6 @@ class TapexTruncationStrategy(ExplicitEnum):
|
|||
DROP_ROWS_TO_FIT = "drop_rows_to_fit"
|
||||
|
||||
|
||||
class TokenizerStrategy(ExplicitEnum):
|
||||
|
||||
TOKENIZE_SOURCE = "tokenize_source"
|
||||
TOKENIZE_TARGET = "tokenize_target"
|
||||
|
||||
|
||||
TAPEX_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING = r"""
|
||||
add_special_tokens (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to encode the sequences with the special tokens relative to their model.
|
||||
|
@ -341,9 +334,6 @@ class TapexTokenizer(PreTrainedTokenizer):
|
|||
self.max_cell_length = max_cell_length
|
||||
self.table_linearize = IndexedRowTableLinearize()
|
||||
|
||||
# property to decide using which call function
|
||||
self.current_tokenizer = TokenizerStrategy.TOKENIZE_SOURCE
|
||||
|
||||
def build_inputs_with_special_tokens(
|
||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
||||
) -> List[int]:
|
||||
|
@ -555,9 +545,7 @@ class TapexTokenizer(PreTrainedTokenizer):
|
|||
Optionally, the corresponding answer to the questions as supervision.
|
||||
"""
|
||||
|
||||
if self.current_tokenizer == TokenizerStrategy.TOKENIZE_SOURCE:
|
||||
if table is None:
|
||||
raise ValueError("Please ensure that the table is not empty if you use TAPEX to encode source.")
|
||||
if table is not None:
|
||||
return self.source_call_func(
|
||||
table=table,
|
||||
query=query,
|
||||
|
@ -578,9 +566,7 @@ class TapexTokenizer(PreTrainedTokenizer):
|
|||
verbose=verbose,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
if answer is None:
|
||||
raise ValueError("Please ensure that the answer is not empty if you use TAPEX to encode target.")
|
||||
elif answer is not None:
|
||||
return self.target_call_func(
|
||||
answer=answer,
|
||||
add_special_tokens=add_special_tokens,
|
||||
|
@ -599,6 +585,8 @@ class TapexTokenizer(PreTrainedTokenizer):
|
|||
verbose=verbose,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
raise ValueError("You need to provide either a `table` or an `answer`.")
|
||||
|
||||
def source_call_func(
|
||||
self,
|
||||
|
@ -1330,17 +1318,6 @@ class TapexTokenizer(PreTrainedTokenizer):
|
|||
verbose=verbose,
|
||||
)
|
||||
|
||||
@contextmanager
|
||||
def as_target_tokenizer(self):
|
||||
"""
|
||||
Temporarily sets the tokenizer for encoding the targets. Useful for tokenizer associated to
|
||||
sequence-to-sequence models that need a slightly different processing for the labels.
|
||||
"""
|
||||
self.current_tokenizer = TokenizerStrategy.TOKENIZE_TARGET
|
||||
yield
|
||||
# restore the call function
|
||||
self.current_tokenizer = TokenizerStrategy.TOKENIZE_SOURCE
|
||||
|
||||
def prepare_table_query(
|
||||
self,
|
||||
table,
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
"""
|
||||
Processor class for TrOCR.
|
||||
"""
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
|
||||
from ...processing_utils import ProcessorMixin
|
||||
|
@ -40,6 +41,7 @@ class TrOCRProcessor(ProcessorMixin):
|
|||
def __init__(self, feature_extractor, tokenizer):
|
||||
super().__init__(feature_extractor, tokenizer)
|
||||
self.current_processor = self.feature_extractor
|
||||
self._in_target_context_manager = False
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
"""
|
||||
|
@ -48,8 +50,36 @@ class TrOCRProcessor(ProcessorMixin):
|
|||
[`~TrOCRProcessor.as_target_processor`] this method forwards all its arguments to TrOCRTokenizer's
|
||||
[`~TrOCRTokenizer.__call__`]. Please refer to the doctsring of the above two methods for more information.
|
||||
"""
|
||||
# For backward compatibility
|
||||
if self._in_target_context_manager:
|
||||
return self.current_processor(*args, **kwargs)
|
||||
|
||||
if "raw_speech" in kwargs:
|
||||
warnings.warn("Using `raw_speech` as a keyword argument is deprecated. Use `audio` instead.")
|
||||
audio = kwargs.pop("raw_speech")
|
||||
else:
|
||||
audio = kwargs.pop("audio", None)
|
||||
text = kwargs.pop("text", None)
|
||||
if len(args) > 0:
|
||||
audio = args[0]
|
||||
args = args[1:]
|
||||
|
||||
if audio is None and text is None:
|
||||
raise ValueError("You need to specify either an `audio` or `text` input to process.")
|
||||
|
||||
if audio is not None:
|
||||
inputs = self.feature_extractor(audio, *args, **kwargs)
|
||||
if text is not None:
|
||||
encodings = self.tokenizer(text, **kwargs)
|
||||
|
||||
if text is None:
|
||||
return inputs
|
||||
elif audio is None:
|
||||
return encodings
|
||||
else:
|
||||
inputs["labels"] = encodings["input_ids"]
|
||||
return inputs
|
||||
|
||||
def batch_decode(self, *args, **kwargs):
|
||||
"""
|
||||
This method forwards all its arguments to TrOCRTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please refer
|
||||
|
@ -69,6 +99,13 @@ class TrOCRProcessor(ProcessorMixin):
|
|||
"""
|
||||
Temporarily sets the tokenizer for processing the input. Useful for encoding the labels when fine-tuning TrOCR.
|
||||
"""
|
||||
warnings.warn(
|
||||
"`as_target_processor` is deprecated and will be removed in v5 of Transformers. You can process your "
|
||||
"labels by using the argument `text` of the regular `__call__` method (either in the same call as "
|
||||
"your audio inputs, or in a separate call."
|
||||
)
|
||||
self._in_target_context_manager = True
|
||||
self.current_processor = self.tokenizer
|
||||
yield
|
||||
self.current_processor = self.feature_extractor
|
||||
self._in_target_context_manager = False
|
||||
|
|
|
@ -1650,9 +1650,8 @@ class TFWav2Vec2ForCTC(TFWav2Vec2PreTrainedModel):
|
|||
>>> # compute loss
|
||||
>>> target_transcription = "A MAN SAID TO THE UNIVERSE SIR I EXIST"
|
||||
|
||||
>>> # wrap processor as target processor to encode labels
|
||||
>>> with processor.as_target_processor():
|
||||
... labels = processor(transcription, return_tensors="tf").input_ids
|
||||
>>> # Pass transcription as `text` to encode labels
|
||||
>>> labels = processor(text=transcription, return_tensors="tf").input_ids
|
||||
|
||||
>>> loss = model(input_values, labels=labels).loss
|
||||
```"""
|
||||
|
|
|
@ -43,6 +43,7 @@ class Wav2Vec2Processor(ProcessorMixin):
|
|||
def __init__(self, feature_extractor, tokenizer):
|
||||
super().__init__(feature_extractor, tokenizer)
|
||||
self.current_processor = self.feature_extractor
|
||||
self._in_target_context_manager = False
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
||||
|
@ -70,8 +71,36 @@ class Wav2Vec2Processor(ProcessorMixin):
|
|||
[`~Wav2Vec2Processor.as_target_processor`] this method forwards all its arguments to PreTrainedTokenizer's
|
||||
[`~PreTrainedTokenizer.__call__`]. Please refer to the docstring of the above two methods for more information.
|
||||
"""
|
||||
# For backward compatibility
|
||||
if self._in_target_context_manager:
|
||||
return self.current_processor(*args, **kwargs)
|
||||
|
||||
if "raw_speech" in kwargs:
|
||||
warnings.warn("Using `raw_speech` as a keyword argument is deprecated. Use `audio` instead.")
|
||||
audio = kwargs.pop("raw_speech")
|
||||
else:
|
||||
audio = kwargs.pop("audio", None)
|
||||
text = kwargs.pop("text", None)
|
||||
if len(args) > 0:
|
||||
audio = args[0]
|
||||
args = args[1:]
|
||||
|
||||
if audio is None and text is None:
|
||||
raise ValueError("You need to specify either an `audio` or `text` input to process.")
|
||||
|
||||
if audio is not None:
|
||||
inputs = self.feature_extractor(audio, *args, **kwargs)
|
||||
if text is not None:
|
||||
encodings = self.tokenizer(text, **kwargs)
|
||||
|
||||
if text is None:
|
||||
return inputs
|
||||
elif audio is None:
|
||||
return encodings
|
||||
else:
|
||||
inputs["labels"] = encodings["input_ids"]
|
||||
return inputs
|
||||
|
||||
def pad(self, *args, **kwargs):
|
||||
"""
|
||||
When used in normal mode, this method forwards all its arguments to Wav2Vec2FeatureExtractor's
|
||||
|
@ -79,8 +108,29 @@ class Wav2Vec2Processor(ProcessorMixin):
|
|||
[`~Wav2Vec2Processor.as_target_processor`] this method forwards all its arguments to PreTrainedTokenizer's
|
||||
[`~PreTrainedTokenizer.pad`]. Please refer to the docstring of the above two methods for more information.
|
||||
"""
|
||||
# For backward compatibility
|
||||
if self._in_target_context_manager:
|
||||
return self.current_processor.pad(*args, **kwargs)
|
||||
|
||||
input_features = kwargs.pop("input_features", None)
|
||||
labels = kwargs.pop("labels", None)
|
||||
if len(args) > 0:
|
||||
input_features = args[0]
|
||||
args = args[1:]
|
||||
|
||||
if input_features is not None:
|
||||
input_features = self.feature_extractor.pad(input_features, *args, **kwargs)
|
||||
if labels is not None:
|
||||
labels = self.tokenizer.pad(labels, **kwargs)
|
||||
|
||||
if labels is None:
|
||||
return input_features
|
||||
elif input_features is None:
|
||||
return labels
|
||||
else:
|
||||
input_features["labels"] = labels["input_ids"]
|
||||
return input_features
|
||||
|
||||
def batch_decode(self, *args, **kwargs):
|
||||
"""
|
||||
This method forwards all its arguments to PreTrainedTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please
|
||||
|
@ -101,6 +151,13 @@ class Wav2Vec2Processor(ProcessorMixin):
|
|||
Temporarily sets the tokenizer for processing the input. Useful for encoding the labels when fine-tuning
|
||||
Wav2Vec2.
|
||||
"""
|
||||
warnings.warn(
|
||||
"`as_target_processor` is deprecated and will be removed in v5 of Transformers. You can process your "
|
||||
"labels by using the argument `text` of the regular `__call__` method (either in the same call as "
|
||||
"your audio inputs, or in a separate call."
|
||||
)
|
||||
self._in_target_context_manager = True
|
||||
self.current_processor = self.tokenizer
|
||||
yield
|
||||
self.current_processor = self.feature_extractor
|
||||
self._in_target_context_manager = False
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
Speech processor class for Wav2Vec2
|
||||
"""
|
||||
import os
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from multiprocessing import get_context
|
||||
|
@ -99,6 +100,7 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin):
|
|||
|
||||
self.decoder = decoder
|
||||
self.current_processor = self.feature_extractor
|
||||
self._in_target_context_manager = False
|
||||
|
||||
def save_pretrained(self, save_directory):
|
||||
super().save_pretrained(save_directory)
|
||||
|
@ -214,8 +216,36 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin):
|
|||
Wav2Vec2CTCTokenizer's [`~Wav2Vec2CTCTokenizer.__call__`]. Please refer to the docstring of the above two
|
||||
methods for more information.
|
||||
"""
|
||||
# For backward compatibility
|
||||
if self._in_target_context_manager:
|
||||
return self.current_processor(*args, **kwargs)
|
||||
|
||||
if "raw_speech" in kwargs:
|
||||
warnings.warn("Using `raw_speech` as a keyword argument is deprecated. Use `audio` instead.")
|
||||
audio = kwargs.pop("raw_speech")
|
||||
else:
|
||||
audio = kwargs.pop("audio", None)
|
||||
text = kwargs.pop("text", None)
|
||||
if len(args) > 0:
|
||||
audio = args[0]
|
||||
args = args[1:]
|
||||
|
||||
if audio is None and text is None:
|
||||
raise ValueError("You need to specify either an `audio` or `text` input to process.")
|
||||
|
||||
if audio is not None:
|
||||
inputs = self.feature_extractor(audio, *args, **kwargs)
|
||||
if text is not None:
|
||||
encodings = self.tokenizer(text, **kwargs)
|
||||
|
||||
if text is None:
|
||||
return inputs
|
||||
elif audio is None:
|
||||
return encodings
|
||||
else:
|
||||
inputs["labels"] = encodings["input_ids"]
|
||||
return inputs
|
||||
|
||||
def pad(self, *args, **kwargs):
|
||||
"""
|
||||
When used in normal mode, this method forwards all its arguments to Wav2Vec2FeatureExtractor's
|
||||
|
@ -224,8 +254,29 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin):
|
|||
Wav2Vec2CTCTokenizer's [`~Wav2Vec2CTCTokenizer.pad`]. Please refer to the docstring of the above two methods
|
||||
for more information.
|
||||
"""
|
||||
# For backward compatibility
|
||||
if self._in_target_context_manager:
|
||||
return self.current_processor.pad(*args, **kwargs)
|
||||
|
||||
input_features = kwargs.pop("input_features", None)
|
||||
labels = kwargs.pop("labels", None)
|
||||
if len(args) > 0:
|
||||
input_features = args[0]
|
||||
args = args[1:]
|
||||
|
||||
if input_features is not None:
|
||||
input_features = self.feature_extractor.pad(input_features, *args, **kwargs)
|
||||
if labels is not None:
|
||||
labels = self.tokenizer.pad(labels, **kwargs)
|
||||
|
||||
if labels is None:
|
||||
return input_features
|
||||
elif input_features is None:
|
||||
return labels
|
||||
else:
|
||||
input_features["labels"] = labels["input_ids"]
|
||||
return input_features
|
||||
|
||||
def batch_decode(
|
||||
self,
|
||||
logits: np.ndarray,
|
||||
|
@ -486,9 +537,16 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin):
|
|||
@contextmanager
|
||||
def as_target_processor(self):
|
||||
"""
|
||||
Temporarily sets the tokenizer for processing the input. Useful for encoding the labels when fine-tuning
|
||||
Temporarily sets the processor for processing the target. Useful for encoding the labels when fine-tuning
|
||||
Wav2Vec2.
|
||||
"""
|
||||
warnings.warn(
|
||||
"`as_target_processor` is deprecated and will be removed in v5 of Transformers. You can process your "
|
||||
"labels by using the argument `text` of the regular `__call__` method (either in the same call as "
|
||||
"your audio inputs, or in a separate call."
|
||||
)
|
||||
self._in_target_context_manager = True
|
||||
self.current_processor = self.tokenizer
|
||||
yield
|
||||
self.current_processor = self.feature_extractor
|
||||
self._in_target_context_manager = False
|
||||
|
|
|
@ -1501,7 +1501,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
|||
self.deprecation_warnings = (
|
||||
{}
|
||||
) # Use to store when we have already noticed a deprecation warning (avoid overlogging).
|
||||
|
||||
self._in_target_context_manager = False
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@property
|
||||
|
@ -2431,8 +2431,12 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
|||
@add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
|
||||
def __call__(
|
||||
self,
|
||||
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]],
|
||||
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
|
||||
text_pair: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None,
|
||||
text_target: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
|
||||
text_pair_target: Optional[
|
||||
Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]
|
||||
] = None,
|
||||
add_special_tokens: bool = True,
|
||||
padding: Union[bool, str, PaddingStrategy] = False,
|
||||
truncation: Union[bool, str, TruncationStrategy] = False,
|
||||
|
@ -2455,15 +2459,85 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
|||
sequences.
|
||||
|
||||
Args:
|
||||
text (`str`, `List[str]`, `List[List[str]]`):
|
||||
text (`str`, `List[str]`, `List[List[str]]`, *optional*):
|
||||
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
|
||||
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
|
||||
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
||||
text_pair (`str`, `List[str]`, `List[List[str]]`):
|
||||
text_pair (`str`, `List[str]`, `List[List[str]]`, *optional*):
|
||||
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
|
||||
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
|
||||
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
||||
text_target (`str`, `List[str]`, `List[List[str]]`, *optional*):
|
||||
The sequence or batch of sequences to be encoded as target texts. Each sequence can be a string or a
|
||||
list of strings (pretokenized string). If the sequences are provided as list of strings (pretokenized),
|
||||
you must set `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
||||
text_pair_target (`str`, `List[str]`, `List[List[str]]`, *optional*):
|
||||
The sequence or batch of sequences to be encoded as target texts. Each sequence can be a string or a
|
||||
list of strings (pretokenized string). If the sequences are provided as list of strings (pretokenized),
|
||||
you must set `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
||||
"""
|
||||
# To avoid duplicating
|
||||
all_kwargs = dict(
|
||||
add_special_tokens=add_special_tokens,
|
||||
padding=padding,
|
||||
truncation=truncation,
|
||||
max_length=max_length,
|
||||
stride=stride,
|
||||
is_split_into_words=is_split_into_words,
|
||||
pad_to_multiple_of=pad_to_multiple_of,
|
||||
return_tensors=return_tensors,
|
||||
return_token_type_ids=return_token_type_ids,
|
||||
return_attention_mask=return_attention_mask,
|
||||
return_overflowing_tokens=return_overflowing_tokens,
|
||||
return_special_tokens_mask=return_special_tokens_mask,
|
||||
return_offsets_mapping=return_offsets_mapping,
|
||||
return_length=return_length,
|
||||
verbose=verbose,
|
||||
)
|
||||
all_kwargs.update(kwargs)
|
||||
if text is None and text_target is None:
|
||||
raise ValueError("You need to specify either `text` or `text_target`.")
|
||||
if text is not None:
|
||||
# The context manager will send the inputs as normal texts and not text_target, but we shouldn't change the
|
||||
# input mode in this case.
|
||||
if not self._in_target_context_manager:
|
||||
self._switch_to_input_mode()
|
||||
encodings = self._call_one(text=text, text_pair=text_pair, **all_kwargs)
|
||||
if text_target is not None:
|
||||
self._switch_to_target_mode()
|
||||
target_encodings = self._call_one(text=text_target, text_pair=text_pair_target, **all_kwargs)
|
||||
# Leave back tokenizer in input mode
|
||||
self._switch_to_input_mode()
|
||||
|
||||
if text_target is None:
|
||||
return encodings
|
||||
elif text is None:
|
||||
return target_encodings
|
||||
else:
|
||||
encodings["labels"] = target_encodings["input_ids"]
|
||||
return encodings
|
||||
|
||||
def _call_one(
|
||||
self,
|
||||
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]],
|
||||
text_pair: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None,
|
||||
add_special_tokens: bool = True,
|
||||
padding: Union[bool, str, PaddingStrategy] = False,
|
||||
truncation: Union[bool, str, TruncationStrategy] = False,
|
||||
max_length: Optional[int] = None,
|
||||
stride: int = 0,
|
||||
is_split_into_words: bool = False,
|
||||
pad_to_multiple_of: Optional[int] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
return_token_type_ids: Optional[bool] = None,
|
||||
return_attention_mask: Optional[bool] = None,
|
||||
return_overflowing_tokens: bool = False,
|
||||
return_special_tokens_mask: bool = False,
|
||||
return_offsets_mapping: bool = False,
|
||||
return_length: bool = False,
|
||||
verbose: bool = True,
|
||||
**kwargs
|
||||
) -> BatchEncoding:
|
||||
# Input type checking for clearer error
|
||||
def _is_valid_text_input(t):
|
||||
if isinstance(t, str):
|
||||
|
@ -3456,13 +3530,34 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
|||
)
|
||||
self.deprecation_warnings["sequence-length-is-longer-than-the-specified-maximum"] = True
|
||||
|
||||
def _switch_to_input_mode(self):
|
||||
"""
|
||||
Private method to put the tokenizer in input mode (when it has different modes for input/outputs)
|
||||
"""
|
||||
pass
|
||||
|
||||
def _switch_to_target_mode(self):
|
||||
"""
|
||||
Private method to put the tokenizer in target mode (when it has different modes for input/outputs)
|
||||
"""
|
||||
pass
|
||||
|
||||
@contextmanager
|
||||
def as_target_tokenizer(self):
|
||||
"""
|
||||
Temporarily sets the tokenizer for encoding the targets. Useful for tokenizer associated to
|
||||
sequence-to-sequence models that need a slightly different processing for the labels.
|
||||
"""
|
||||
warnings.warn(
|
||||
"`as_target_tokenizer` is deprecated and will be removed in v5 of Transformers. You can tokenize your "
|
||||
"labels by using the argument `text_target` of the regular `__call__` method (either in the same call as "
|
||||
"your input texts if you use the same keyword arguments, or in a separate call."
|
||||
)
|
||||
self._switch_to_target_mode()
|
||||
self._in_target_context_manager = True
|
||||
yield
|
||||
self._in_target_context_manager = False
|
||||
self._switch_to_input_mode()
|
||||
|
||||
@classmethod
|
||||
def register_for_auto_class(cls, auto_class="AutoTokenizer"):
|
||||
|
@ -3563,14 +3658,17 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
|||
# docstyle-ignore
|
||||
formatted_warning = """
|
||||
`prepare_seq2seq_batch` is deprecated and will be removed in version 5 of HuggingFace Transformers. Use the regular
|
||||
`__call__` method to prepare your inputs and the tokenizer under the `as_target_tokenizer` context manager to prepare
|
||||
your targets.
|
||||
`__call__` method to prepare your inputs and targets.
|
||||
|
||||
Here is a short example:
|
||||
|
||||
model_inputs = tokenizer(src_texts, text_target=tgt_texts, ...)
|
||||
|
||||
If you either need to use different keyword arguments for the source and target texts, you should do two calls like
|
||||
this:
|
||||
|
||||
model_inputs = tokenizer(src_texts, ...)
|
||||
with tokenizer.as_target_tokenizer():
|
||||
labels = tokenizer(tgt_texts, ...)
|
||||
labels = tokenizer(text_target=tgt_texts, ...)
|
||||
model_inputs["labels"] = labels["input_ids"]
|
||||
|
||||
See the documentation of your specific tokenizer for more details on the specific arguments to the tokenizer of choice.
|
||||
|
|
|
@ -428,8 +428,7 @@ PT_SPEECH_CTC_SAMPLE = r"""
|
|||
```
|
||||
|
||||
```python
|
||||
>>> with processor.as_target_processor():
|
||||
... inputs["labels"] = processor(dataset[0]["text"], return_tensors="pt").input_ids
|
||||
>>> inputs["labels"] = processor(text=dataset[0]["text"], return_tensors="pt").input_ids
|
||||
|
||||
>>> # compute loss
|
||||
>>> loss = model(**inputs).loss
|
||||
|
@ -849,8 +848,7 @@ TF_SPEECH_CTC_SAMPLE = r"""
|
|||
```
|
||||
|
||||
```python
|
||||
>>> with processor.as_target_processor():
|
||||
... inputs["labels"] = processor(dataset[0]["text"], return_tensors="tf").input_ids
|
||||
>>> inputs["labels"] = processor(text=dataset[0]["text"], return_tensors="tf").input_ids
|
||||
|
||||
>>> # compute loss
|
||||
>>> loss = model(**inputs).loss
|
||||
|
|
|
@ -112,14 +112,13 @@ class TestTokenizationBart(TokenizerTesterMixin, unittest.TestCase):
|
|||
self.assertNotIn("decoder_attention_mask", batch)
|
||||
|
||||
@require_torch
|
||||
def test_as_target_tokenizer_target_length(self):
|
||||
def test_tokenizer_as_target_length(self):
|
||||
tgt_text = [
|
||||
"Summary of the text.",
|
||||
"Another summary.",
|
||||
]
|
||||
for tokenizer in [self.default_tokenizer, self.default_tokenizer_fast]:
|
||||
with tokenizer.as_target_tokenizer():
|
||||
targets = tokenizer(tgt_text, max_length=32, padding="max_length", return_tensors="pt")
|
||||
targets = tokenizer(text_target=tgt_text, max_length=32, padding="max_length", return_tensors="pt")
|
||||
self.assertEqual(32, targets["input_ids"].shape[1])
|
||||
|
||||
@require_torch
|
||||
|
@ -140,8 +139,7 @@ class TestTokenizationBart(TokenizerTesterMixin, unittest.TestCase):
|
|||
]
|
||||
for tokenizer in [self.default_tokenizer, self.default_tokenizer_fast]:
|
||||
inputs = tokenizer(src_text, return_tensors="pt")
|
||||
with tokenizer.as_target_tokenizer():
|
||||
targets = tokenizer(tgt_text, return_tensors="pt")
|
||||
targets = tokenizer(text_target=tgt_text, return_tensors="pt")
|
||||
input_ids = inputs["input_ids"]
|
||||
labels = targets["input_ids"]
|
||||
self.assertTrue((input_ids[:, 0] == tokenizer.bos_token_id).all().item())
|
||||
|
|
|
@ -152,9 +152,8 @@ class ByT5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||
"Summary of the text.",
|
||||
"Another summary.",
|
||||
]
|
||||
with tokenizer.as_target_tokenizer():
|
||||
targets = tokenizer(
|
||||
tgt_text, max_length=32, padding="max_length", truncation=True, return_tensors=FRAMEWORK
|
||||
text_target=tgt_text, max_length=32, padding="max_length", truncation=True, return_tensors=FRAMEWORK
|
||||
)
|
||||
self.assertEqual(32, targets["input_ids"].shape[1])
|
||||
|
||||
|
@ -167,12 +166,10 @@ class ByT5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||
expected_tgt_tokens = [86, 120, 112, 112, 100, 117, 124, 35, 114, 105, 35, 119, 107, 104, 35, 119, 104, 123, 119, 49, 35, 1]
|
||||
# fmt: on
|
||||
|
||||
batch = tokenizer(src_text)
|
||||
with tokenizer.as_target_tokenizer():
|
||||
targets = tokenizer(tgt_text)
|
||||
batch = tokenizer(src_text, text_target=tgt_text)
|
||||
|
||||
self.assertEqual(expected_src_tokens, batch["input_ids"][0])
|
||||
self.assertEqual(expected_tgt_tokens, targets["input_ids"][0])
|
||||
self.assertEqual(expected_tgt_tokens, batch["labels"][0])
|
||||
|
||||
# cannot use default save_and_load_tokenzier test method because tokenzier has no vocab
|
||||
def test_save_and_load_tokenizer(self):
|
||||
|
|
|
@ -80,8 +80,9 @@ class CanineTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||
"What's the weater?",
|
||||
"It's about 25 degrees.",
|
||||
]
|
||||
with tokenizer.as_target_tokenizer():
|
||||
targets = tokenizer(tgt_text, max_length=32, padding="max_length", truncation=True, return_tensors="pt")
|
||||
targets = tokenizer(
|
||||
text_target=tgt_text, max_length=32, padding="max_length", truncation=True, return_tensors="pt"
|
||||
)
|
||||
self.assertEqual(32, targets["input_ids"].shape[1])
|
||||
|
||||
# cannot use default save_and_load_tokenzier test method because tokenzier has no vocab
|
||||
|
|
|
@ -13,7 +13,6 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from transformers import (
|
||||
DPRContextEncoderTokenizer,
|
||||
DPRContextEncoderTokenizerFast,
|
||||
|
|
|
@ -187,9 +187,7 @@ class M2M100TokenizerIntegrationTest(unittest.TestCase):
|
|||
self.tokenizer.src_lang = "en"
|
||||
self.tokenizer.tgt_lang = "fr"
|
||||
|
||||
batch = self.tokenizer(self.src_text, padding=True, return_tensors="pt")
|
||||
with self.tokenizer.as_target_tokenizer():
|
||||
batch["labels"] = self.tokenizer(self.tgt_text, padding=True, return_tensors="pt").input_ids
|
||||
batch = self.tokenizer(self.src_text, text_target=self.tgt_text, padding=True, return_tensors="pt")
|
||||
|
||||
batch["decoder_input_ids"] = shift_tokens_right(
|
||||
batch["labels"], self.tokenizer.pad_token_id, self.tokenizer.eos_token_id
|
||||
|
@ -217,17 +215,19 @@ class M2M100TokenizerIntegrationTest(unittest.TestCase):
|
|||
self.assertListEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id])
|
||||
|
||||
@require_torch
|
||||
def test_as_target_tokenizer(self):
|
||||
def test_tokenizer_target_mode(self):
|
||||
self.tokenizer.tgt_lang = "mr"
|
||||
with self.tokenizer.as_target_tokenizer():
|
||||
self.tokenizer._switch_to_target_mode()
|
||||
self.assertListEqual(self.tokenizer.prefix_tokens, [self.tokenizer.get_lang_id("mr")])
|
||||
self.assertListEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id])
|
||||
self.tokenizer._switch_to_input_mode()
|
||||
self.assertListEqual(self.tokenizer.prefix_tokens, [self.tokenizer.get_lang_id(self.tokenizer.src_lang)])
|
||||
|
||||
self.tokenizer.tgt_lang = "zh"
|
||||
with self.tokenizer.as_target_tokenizer():
|
||||
self.tokenizer._switch_to_target_mode()
|
||||
self.assertListEqual(self.tokenizer.prefix_tokens, [self.tokenizer.get_lang_id("zh")])
|
||||
self.assertListEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id])
|
||||
self.tokenizer._switch_to_input_mode()
|
||||
self.assertListEqual(self.tokenizer.prefix_tokens, [self.tokenizer.get_lang_id(self.tokenizer.src_lang)])
|
||||
|
||||
@require_torch
|
||||
|
|
|
@ -438,10 +438,7 @@ class TestMarian_EN_DE_More(MarianIntegrationTest):
|
|||
src, tgt = ["I am a small frog"], ["Ich bin ein kleiner Frosch."]
|
||||
expected_ids = [38, 121, 14, 697, 38848, 0]
|
||||
|
||||
model_inputs = self.tokenizer(src, return_tensors="pt").to(torch_device)
|
||||
with self.tokenizer.as_target_tokenizer():
|
||||
targets = self.tokenizer(tgt, return_tensors="pt")
|
||||
model_inputs["labels"] = targets["input_ids"].to(torch_device)
|
||||
model_inputs = self.tokenizer(src, text_target=tgt, return_tensors="pt").to(torch_device)
|
||||
|
||||
self.assertListEqual(expected_ids, model_inputs.input_ids[0].tolist())
|
||||
|
||||
|
|
|
@ -145,8 +145,7 @@ class MarianTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||
src_ids = tokenizer(source_text).input_ids
|
||||
self.assertListEqual(src_ids, expected_src_ids)
|
||||
|
||||
with tokenizer.as_target_tokenizer():
|
||||
target_ids = tokenizer(target_text).input_ids
|
||||
target_ids = tokenizer(text_target=target_text).input_ids
|
||||
self.assertListEqual(target_ids, expected_target_ids)
|
||||
|
||||
decoded = tokenizer.decode(target_ids, skip_special_tokens=True)
|
||||
|
|
|
@ -265,33 +265,27 @@ class MBartEnroIntegrationTest(unittest.TestCase):
|
|||
|
||||
@require_torch
|
||||
def test_batch_fairseq_parity(self):
|
||||
batch = self.tokenizer(self.src_text, padding=True)
|
||||
with self.tokenizer.as_target_tokenizer():
|
||||
targets = self.tokenizer(self.tgt_text, padding=True, return_tensors="pt")
|
||||
labels = targets["input_ids"]
|
||||
batch["decoder_input_ids"] = shift_tokens_right(labels, self.tokenizer.pad_token_id).tolist()
|
||||
batch = self.tokenizer(self.src_text, text_target=self.tgt_text, padding=True, return_tensors="pt")
|
||||
batch["decoder_input_ids"] = shift_tokens_right(batch["labels"], self.tokenizer.pad_token_id)
|
||||
|
||||
# fairseq batch: https://gist.github.com/sshleifer/cba08bc2109361a74ac3760a7e30e4f4
|
||||
assert batch.input_ids[1][-2:] == [2, EN_CODE]
|
||||
assert batch.decoder_input_ids[1][0] == RO_CODE
|
||||
assert batch.input_ids[1][-2:].tolist() == [2, EN_CODE]
|
||||
assert batch.decoder_input_ids[1][0].tolist() == RO_CODE
|
||||
assert batch.decoder_input_ids[1][-1] == 2
|
||||
assert labels[1][-2:].tolist() == [2, RO_CODE]
|
||||
assert batch.labels[1][-2:].tolist() == [2, RO_CODE]
|
||||
|
||||
@require_torch
|
||||
def test_enro_tokenizer_prepare_batch(self):
|
||||
batch = self.tokenizer(
|
||||
self.src_text, padding=True, truncation=True, max_length=len(self.expected_src_tokens), return_tensors="pt"
|
||||
)
|
||||
with self.tokenizer.as_target_tokenizer():
|
||||
targets = self.tokenizer(
|
||||
self.tgt_text,
|
||||
self.src_text,
|
||||
text_target=self.tgt_text,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=len(self.expected_src_tokens),
|
||||
return_tensors="pt",
|
||||
)
|
||||
labels = targets["input_ids"]
|
||||
batch["decoder_input_ids"] = shift_tokens_right(labels, self.tokenizer.pad_token_id)
|
||||
|
||||
batch["decoder_input_ids"] = shift_tokens_right(batch["labels"], self.tokenizer.pad_token_id)
|
||||
|
||||
self.assertIsInstance(batch, BatchEncoding)
|
||||
|
||||
|
@ -306,8 +300,9 @@ class MBartEnroIntegrationTest(unittest.TestCase):
|
|||
|
||||
def test_seq2seq_max_length(self):
|
||||
batch = self.tokenizer(self.src_text, padding=True, truncation=True, max_length=3, return_tensors="pt")
|
||||
with self.tokenizer.as_target_tokenizer():
|
||||
targets = self.tokenizer(self.tgt_text, padding=True, truncation=True, max_length=10, return_tensors="pt")
|
||||
targets = self.tokenizer(
|
||||
text_target=self.tgt_text, padding=True, truncation=True, max_length=10, return_tensors="pt"
|
||||
)
|
||||
labels = targets["input_ids"]
|
||||
batch["decoder_input_ids"] = shift_tokens_right(labels, self.tokenizer.pad_token_id)
|
||||
|
||||
|
|
|
@ -256,35 +256,27 @@ class MBart50OneToManyIntegrationTest(unittest.TestCase):
|
|||
|
||||
@require_torch
|
||||
def test_batch_fairseq_parity(self):
|
||||
batch = self.tokenizer(self.src_text, padding=True)
|
||||
with self.tokenizer.as_target_tokenizer():
|
||||
targets = self.tokenizer(self.tgt_text, padding=True, return_tensors="pt")
|
||||
labels = targets["input_ids"]
|
||||
batch["decoder_input_ids"] = shift_tokens_right(labels, self.tokenizer.pad_token_id).tolist()
|
||||
labels = labels.tolist()
|
||||
batch = self.tokenizer(self.src_text, text_target=self.tgt_text, padding=True, return_tensors="pt")
|
||||
batch["decoder_input_ids"] = shift_tokens_right(batch["labels"], self.tokenizer.pad_token_id)
|
||||
|
||||
# fairseq batch: https://gist.github.com/sshleifer/cba08bc2109361a74ac3760a7e30e4f4
|
||||
assert batch.input_ids[1][0] == EN_CODE
|
||||
assert batch.input_ids[1][-1] == 2
|
||||
assert labels[1][0] == RO_CODE
|
||||
assert labels[1][-1] == 2
|
||||
assert batch.decoder_input_ids[1][:2] == [2, RO_CODE]
|
||||
assert batch.labels[1][0] == RO_CODE
|
||||
assert batch.labels[1][-1] == 2
|
||||
assert batch.decoder_input_ids[1][:2].tolist() == [2, RO_CODE]
|
||||
|
||||
@require_torch
|
||||
def test_tokenizer_prepare_batch(self):
|
||||
batch = self.tokenizer(
|
||||
self.src_text, padding=True, truncation=True, max_length=len(self.expected_src_tokens), return_tensors="pt"
|
||||
)
|
||||
with self.tokenizer.as_target_tokenizer():
|
||||
targets = self.tokenizer(
|
||||
self.tgt_text,
|
||||
self.src_text,
|
||||
text_target=self.tgt_text,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=len(self.expected_src_tokens),
|
||||
return_tensors="pt",
|
||||
)
|
||||
labels = targets["input_ids"]
|
||||
batch["decoder_input_ids"] = shift_tokens_right(labels, self.tokenizer.pad_token_id)
|
||||
batch["decoder_input_ids"] = shift_tokens_right(batch["labels"], self.tokenizer.pad_token_id)
|
||||
|
||||
self.assertIsInstance(batch, BatchEncoding)
|
||||
|
||||
|
@ -299,8 +291,9 @@ class MBart50OneToManyIntegrationTest(unittest.TestCase):
|
|||
|
||||
def test_seq2seq_max_target_length(self):
|
||||
batch = self.tokenizer(self.src_text, padding=True, truncation=True, max_length=3, return_tensors="pt")
|
||||
with self.tokenizer.as_target_tokenizer():
|
||||
targets = self.tokenizer(self.tgt_text, padding=True, truncation=True, max_length=10, return_tensors="pt")
|
||||
targets = self.tokenizer(
|
||||
text_target=self.tgt_text, padding=True, truncation=True, max_length=10, return_tensors="pt"
|
||||
)
|
||||
labels = targets["input_ids"]
|
||||
batch["decoder_input_ids"] = shift_tokens_right(labels, self.tokenizer.pad_token_id)
|
||||
|
||||
|
|
|
@ -125,8 +125,7 @@ class MCTCTProcessorTest(unittest.TestCase):
|
|||
|
||||
input_str = "This is a test string"
|
||||
|
||||
with processor.as_target_processor():
|
||||
encoded_processor = processor(input_str)
|
||||
encoded_processor = processor(text=input_str)
|
||||
|
||||
encoded_tok = tokenizer(input_str)
|
||||
|
||||
|
|
|
@ -112,14 +112,13 @@ class TestTokenizationMvp(TokenizerTesterMixin, unittest.TestCase):
|
|||
self.assertNotIn("decoder_attention_mask", batch)
|
||||
|
||||
@require_torch
|
||||
def test_as_target_tokenizer_target_length(self):
|
||||
def test_tokenizer_as_target_length(self):
|
||||
tgt_text = [
|
||||
"Summary of the text.",
|
||||
"Another summary.",
|
||||
]
|
||||
for tokenizer in [self.default_tokenizer, self.default_tokenizer_fast]:
|
||||
with tokenizer.as_target_tokenizer():
|
||||
targets = tokenizer(tgt_text, max_length=32, padding="max_length", return_tensors="pt")
|
||||
targets = tokenizer(text_target=tgt_text, max_length=32, padding="max_length", return_tensors="pt")
|
||||
self.assertEqual(32, targets["input_ids"].shape[1])
|
||||
|
||||
@require_torch
|
||||
|
@ -139,11 +138,9 @@ class TestTokenizationMvp(TokenizerTesterMixin, unittest.TestCase):
|
|||
"Summary of the text.",
|
||||
]
|
||||
for tokenizer in [self.default_tokenizer, self.default_tokenizer_fast]:
|
||||
inputs = tokenizer(src_text, return_tensors="pt")
|
||||
with tokenizer.as_target_tokenizer():
|
||||
targets = tokenizer(tgt_text, return_tensors="pt")
|
||||
inputs = tokenizer(src_text, text_target=tgt_text, return_tensors="pt")
|
||||
input_ids = inputs["input_ids"]
|
||||
labels = targets["input_ids"]
|
||||
labels = inputs["labels"]
|
||||
self.assertTrue((input_ids[:, 0] == tokenizer.bos_token_id).all().item())
|
||||
self.assertTrue((labels[:, 0] == tokenizer.bos_token_id).all().item())
|
||||
self.assertTrue((input_ids[:, -1] == tokenizer.eos_token_id).all().item())
|
||||
|
|
|
@ -373,19 +373,15 @@ class NllbDistilledIntegrationTest(unittest.TestCase):
|
|||
@require_torch
|
||||
def test_enro_tokenizer_prepare_batch(self):
|
||||
batch = self.tokenizer(
|
||||
self.src_text, padding=True, truncation=True, max_length=len(self.expected_src_tokens), return_tensors="pt"
|
||||
)
|
||||
with self.tokenizer.as_target_tokenizer():
|
||||
targets = self.tokenizer(
|
||||
self.tgt_text,
|
||||
self.src_text,
|
||||
text_target=self.tgt_text,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=len(self.expected_src_tokens),
|
||||
return_tensors="pt",
|
||||
)
|
||||
labels = targets["input_ids"]
|
||||
batch["decoder_input_ids"] = shift_tokens_right(
|
||||
labels, self.tokenizer.pad_token_id, self.tokenizer.lang_code_to_id["ron_Latn"]
|
||||
batch["labels"], self.tokenizer.pad_token_id, self.tokenizer.lang_code_to_id["ron_Latn"]
|
||||
)
|
||||
|
||||
self.assertIsInstance(batch, BatchEncoding)
|
||||
|
@ -401,8 +397,9 @@ class NllbDistilledIntegrationTest(unittest.TestCase):
|
|||
|
||||
def test_seq2seq_max_length(self):
|
||||
batch = self.tokenizer(self.src_text, padding=True, truncation=True, max_length=3, return_tensors="pt")
|
||||
with self.tokenizer.as_target_tokenizer():
|
||||
targets = self.tokenizer(self.tgt_text, padding=True, truncation=True, max_length=10, return_tensors="pt")
|
||||
targets = self.tokenizer(
|
||||
text_target=self.tgt_text, padding=True, truncation=True, max_length=10, return_tensors="pt"
|
||||
)
|
||||
labels = targets["input_ids"]
|
||||
batch["decoder_input_ids"] = shift_tokens_right(
|
||||
labels,
|
||||
|
|
|
@ -109,9 +109,8 @@ class PegasusTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||
src_texts = ["This is going to be way too long." * 150, "short example"]
|
||||
tgt_texts = ["not super long but more than 5 tokens", "tiny"]
|
||||
batch = self._large_tokenizer(src_texts, padding=True, truncation=True, return_tensors="pt")
|
||||
with self._large_tokenizer.as_target_tokenizer():
|
||||
targets = self._large_tokenizer(
|
||||
tgt_texts, max_length=5, padding=True, truncation=True, return_tensors="pt"
|
||||
text_target=tgt_texts, max_length=5, padding=True, truncation=True, return_tensors="pt"
|
||||
)
|
||||
|
||||
assert batch.input_ids.shape == (2, 1024)
|
||||
|
@ -174,9 +173,8 @@ class BigBirdPegasusTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||
src_texts = ["This is going to be way too long." * 1000, "short example"]
|
||||
tgt_texts = ["not super long but more than 5 tokens", "tiny"]
|
||||
batch = self._large_tokenizer(src_texts, padding=True, truncation=True, return_tensors="pt")
|
||||
with self._large_tokenizer.as_target_tokenizer():
|
||||
targets = self._large_tokenizer(
|
||||
tgt_texts, max_length=5, padding=True, truncation=True, return_tensors="pt"
|
||||
text_target=tgt_texts, max_length=5, padding=True, truncation=True, return_tensors="pt"
|
||||
)
|
||||
|
||||
assert batch.input_ids.shape == (2, 4096)
|
||||
|
|
|
@ -146,9 +146,8 @@ class PerceiverTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||
"Summary of the text.",
|
||||
"Another summary.",
|
||||
]
|
||||
with tokenizer.as_target_tokenizer():
|
||||
targets = tokenizer(
|
||||
tgt_text, max_length=32, padding="max_length", truncation=True, return_tensors=FRAMEWORK
|
||||
text_target=tgt_text, max_length=32, padding="max_length", truncation=True, return_tensors=FRAMEWORK
|
||||
)
|
||||
self.assertEqual(32, targets["input_ids"].shape[1])
|
||||
|
||||
|
|
|
@ -299,33 +299,26 @@ class PLBartPythonEnIntegrationTest(unittest.TestCase):
|
|||
|
||||
@require_torch
|
||||
def test_batch_fairseq_parity(self):
|
||||
batch = self.tokenizer(self.src_text, padding=True)
|
||||
with self.tokenizer.as_target_tokenizer():
|
||||
targets = self.tokenizer(self.tgt_text, padding=True, return_tensors="pt")
|
||||
labels = targets["input_ids"]
|
||||
batch["decoder_input_ids"] = shift_tokens_right(labels, self.tokenizer.pad_token_id).tolist()
|
||||
batch = self.tokenizer(self.src_text, text_target=self.tgt_text, padding=True, return_tensors="pt")
|
||||
batch["decoder_input_ids"] = shift_tokens_right(batch["labels"], self.tokenizer.pad_token_id)
|
||||
|
||||
# fairseq batch: https://gist.github.com/sshleifer/cba08bc2109361a74ac3760a7e30e4f4
|
||||
self.assertEqual(batch.input_ids[1][-2:], [2, PYTHON_CODE])
|
||||
self.assertEqual(batch.input_ids[1][-2:].tolist(), [2, PYTHON_CODE])
|
||||
self.assertEqual(batch.decoder_input_ids[1][0], EN_CODE)
|
||||
self.assertEqual(batch.decoder_input_ids[1][-1], 2)
|
||||
self.assertEqual(labels[1][-2:].tolist(), [2, EN_CODE])
|
||||
self.assertEqual(batch.labels[1][-2:].tolist(), [2, EN_CODE])
|
||||
|
||||
@require_torch
|
||||
def test_python_en_tokenizer_prepare_batch(self):
|
||||
batch = self.tokenizer(
|
||||
self.src_text, padding=True, truncation=True, max_length=len(self.expected_src_tokens), return_tensors="pt"
|
||||
)
|
||||
with self.tokenizer.as_target_tokenizer():
|
||||
targets = self.tokenizer(
|
||||
self.tgt_text,
|
||||
self.src_text,
|
||||
text_target=self.tgt_text,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=len(self.expected_src_tokens),
|
||||
return_tensors="pt",
|
||||
)
|
||||
labels = targets["input_ids"]
|
||||
batch["decoder_input_ids"] = shift_tokens_right(labels, self.tokenizer.pad_token_id)
|
||||
batch["decoder_input_ids"] = shift_tokens_right(batch["labels"], self.tokenizer.pad_token_id)
|
||||
|
||||
self.assertIsInstance(batch, BatchEncoding)
|
||||
|
||||
|
@ -340,8 +333,9 @@ class PLBartPythonEnIntegrationTest(unittest.TestCase):
|
|||
|
||||
def test_seq2seq_max_length(self):
|
||||
batch = self.tokenizer(self.src_text, padding=True, truncation=True, max_length=3, return_tensors="pt")
|
||||
with self.tokenizer.as_target_tokenizer():
|
||||
targets = self.tokenizer(self.tgt_text, padding=True, truncation=True, max_length=10, return_tensors="pt")
|
||||
targets = self.tokenizer(
|
||||
text_target=self.tgt_text, padding=True, truncation=True, max_length=10, return_tensors="pt"
|
||||
)
|
||||
labels = targets["input_ids"]
|
||||
batch["decoder_input_ids"] = shift_tokens_right(labels, self.tokenizer.pad_token_id)
|
||||
|
||||
|
|
|
@ -125,8 +125,7 @@ class Speech2TextProcessorTest(unittest.TestCase):
|
|||
|
||||
input_str = "This is a test string"
|
||||
|
||||
with processor.as_target_processor():
|
||||
encoded_processor = processor(input_str)
|
||||
encoded_processor = processor(text=input_str)
|
||||
|
||||
encoded_tok = tokenizer(input_str)
|
||||
|
||||
|
|
|
@ -210,9 +210,8 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||
"Summary of the text.",
|
||||
"Another summary.",
|
||||
]
|
||||
with tokenizer.as_target_tokenizer():
|
||||
targets = tokenizer(
|
||||
tgt_text, max_length=32, padding="max_length", truncation=True, return_tensors=FRAMEWORK
|
||||
text_target=tgt_text, max_length=32, padding="max_length", truncation=True, return_tensors=FRAMEWORK
|
||||
)
|
||||
self.assertEqual(32, targets["input_ids"].shape[1])
|
||||
|
||||
|
@ -235,12 +234,10 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||
expected_src_tokens = [71, 307, 8986, 21, 4505, 1635, 1707, 5, 1]
|
||||
expected_tgt_tokens = [20698, 13, 8, 1499, 5, 1]
|
||||
|
||||
batch = tokenizer(src_text)
|
||||
with tokenizer.as_target_tokenizer():
|
||||
targets = tokenizer(tgt_text)
|
||||
batch = tokenizer(src_text, text_target=tgt_text)
|
||||
|
||||
self.assertEqual(expected_src_tokens, batch["input_ids"][0])
|
||||
self.assertEqual(expected_tgt_tokens, targets["input_ids"][0])
|
||||
self.assertEqual(expected_tgt_tokens, batch["labels"][0])
|
||||
|
||||
def test_token_type_ids(self):
|
||||
src_text_1 = ["A first paragraph for summarization."]
|
||||
|
|
|
@ -859,7 +859,6 @@ class TapexTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||
tokenizer = TapexTokenizer.from_pretrained("microsoft/tapex-base")
|
||||
answer_text = "tapex is a good model!"
|
||||
expected_src_tokens = [0, 90, 5776, 1178, 16, 10, 205, 1421, 328, 2]
|
||||
with tokenizer.as_target_tokenizer():
|
||||
answer_encoding = tokenizer(answer=answer_text)
|
||||
self.assertListEqual(answer_encoding.input_ids, expected_src_tokens)
|
||||
|
||||
|
@ -870,8 +869,6 @@ class TapexTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||
answer_text = "Beijing, London, Paris"
|
||||
answer_text_lower = "beijing, london, paris"
|
||||
|
||||
with cased_tokenizer.as_target_tokenizer():
|
||||
with uncased_tokenizer.as_target_tokenizer():
|
||||
self.assertNotEqual(
|
||||
cased_tokenizer(answer=answer_text).input_ids, uncased_tokenizer(answer=answer_text).input_ids
|
||||
)
|
||||
|
|
|
@ -118,8 +118,7 @@ class Wav2Vec2ProcessorTest(unittest.TestCase):
|
|||
|
||||
input_str = "This is a test string"
|
||||
|
||||
with processor.as_target_processor():
|
||||
encoded_processor = processor(input_str)
|
||||
encoded_processor = processor(text=input_str)
|
||||
|
||||
encoded_tok = tokenizer(input_str)
|
||||
|
||||
|
|
|
@ -164,8 +164,7 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
|
|||
|
||||
input_str = "This is a test string"
|
||||
|
||||
with processor.as_target_processor():
|
||||
encoded_processor = processor(input_str)
|
||||
encoded_processor = processor(text=input_str)
|
||||
|
||||
encoded_tok = tokenizer(input_str)
|
||||
|
||||
|
|
Loading…
Reference in New Issue