Generate: text generation pipeline no longer emits `max_length` warning when it is not set (#23139)
This commit is contained in:
parent
516dc6305f
commit
b369e507aa
|
@ -385,7 +385,6 @@ class FlaxGenerationMixin:
|
|||
UserWarning,
|
||||
)
|
||||
elif generation_config.max_new_tokens is not None:
|
||||
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
|
||||
if not has_default_max_length:
|
||||
logger.warning(
|
||||
f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
|
||||
|
@ -393,6 +392,7 @@ class FlaxGenerationMixin:
|
|||
"Please refer to the documentation for more information. "
|
||||
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
|
||||
)
|
||||
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
|
||||
|
||||
if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length:
|
||||
raise ValueError(
|
||||
|
|
|
@ -858,7 +858,6 @@ class TFGenerationMixin:
|
|||
UserWarning,
|
||||
)
|
||||
elif generation_config.max_new_tokens is not None:
|
||||
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
|
||||
if not has_default_max_length:
|
||||
logger.warning(
|
||||
f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
|
||||
|
@ -866,6 +865,7 @@ class TFGenerationMixin:
|
|||
"Please refer to the documentation for more information. "
|
||||
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
|
||||
)
|
||||
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
|
||||
|
||||
# If the input length is a tensor (i.e. dynamic length), skip length checks
|
||||
if not isinstance(input_ids_seq_length, tf.Tensor):
|
||||
|
|
|
@ -1348,7 +1348,6 @@ class GenerationMixin:
|
|||
UserWarning,
|
||||
)
|
||||
elif generation_config.max_new_tokens is not None:
|
||||
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
|
||||
if not has_default_max_length:
|
||||
logger.warning(
|
||||
f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
|
||||
|
@ -1356,6 +1355,7 @@ class GenerationMixin:
|
|||
"Please refer to the documentation for more information. "
|
||||
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
|
||||
)
|
||||
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
|
||||
|
||||
if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length:
|
||||
raise ValueError(
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import copy
|
||||
import enum
|
||||
import warnings
|
||||
|
||||
|
@ -105,17 +106,8 @@ class TextGenerationPipeline(Pipeline):
|
|||
prefix_inputs = self.tokenizer(
|
||||
prefix, padding=False, add_special_tokens=False, return_tensors=self.framework
|
||||
)
|
||||
prefix_length = prefix_inputs["input_ids"].shape[-1]
|
||||
generate_kwargs["prefix_length"] = prefix_inputs["input_ids"].shape[-1]
|
||||
|
||||
if "max_new_tokens" in generate_kwargs:
|
||||
pass
|
||||
elif "max_length" in generate_kwargs:
|
||||
generate_kwargs["max_length"] += prefix_length
|
||||
else:
|
||||
generate_kwargs["max_length"] = self.model.config.max_length + prefix_length
|
||||
|
||||
if "min_length" in generate_kwargs:
|
||||
generate_kwargs["min_length"] += prefix_length
|
||||
if handle_long_generation is not None:
|
||||
if handle_long_generation not in {"hole"}:
|
||||
raise ValueError(
|
||||
|
@ -247,6 +239,26 @@ class TextGenerationPipeline(Pipeline):
|
|||
else:
|
||||
in_b = input_ids.shape[0]
|
||||
prompt_text = model_inputs.pop("prompt_text")
|
||||
|
||||
# If there is a prefix, we may need to adjust the generation length. Do so without permanently modifying
|
||||
# generate_kwargs, as some of the parameterization may come from the initialization of the pipeline.
|
||||
generate_kwargs = copy.deepcopy(generate_kwargs)
|
||||
prefix_length = generate_kwargs.pop("prefix_length", 0)
|
||||
if prefix_length > 0:
|
||||
has_max_new_tokens = "max_new_tokens" in generate_kwargs or (
|
||||
"generation_config" in generate_kwargs
|
||||
and generate_kwargs["generation_config"].max_new_tokens is not None
|
||||
)
|
||||
if not has_max_new_tokens:
|
||||
generate_kwargs["max_length"] = generate_kwargs.get("max_length") or self.model.config.max_length
|
||||
generate_kwargs["max_length"] += prefix_length
|
||||
has_min_new_tokens = "min_new_tokens" in generate_kwargs or (
|
||||
"generation_config" in generate_kwargs
|
||||
and generate_kwargs["generation_config"].min_new_tokens is not None
|
||||
)
|
||||
if not has_min_new_tokens and "min_length" in generate_kwargs:
|
||||
generate_kwargs["min_length"] += prefix_length
|
||||
|
||||
# BS x SL
|
||||
generated_sequence = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, **generate_kwargs)
|
||||
out_b = generated_sequence.shape[0]
|
||||
|
|
|
@ -14,8 +14,15 @@
|
|||
|
||||
import unittest
|
||||
|
||||
from transformers import MODEL_FOR_CAUSAL_LM_MAPPING, TF_MODEL_FOR_CAUSAL_LM_MAPPING, TextGenerationPipeline, pipeline
|
||||
from transformers import (
|
||||
MODEL_FOR_CAUSAL_LM_MAPPING,
|
||||
TF_MODEL_FOR_CAUSAL_LM_MAPPING,
|
||||
TextGenerationPipeline,
|
||||
logging,
|
||||
pipeline,
|
||||
)
|
||||
from transformers.testing_utils import (
|
||||
CaptureLogger,
|
||||
is_pipeline_test,
|
||||
require_accelerate,
|
||||
require_tf,
|
||||
|
@ -323,3 +330,26 @@ class TextGenerationPipelineTests(unittest.TestCase):
|
|||
|
||||
pipe = pipeline(model="hf-internal-testing/tiny-random-bloom", device_map="auto", torch_dtype=torch.float16)
|
||||
pipe("This is a test", do_sample=True, top_p=0.5)
|
||||
|
||||
def test_pipeline_length_setting_warning(self):
|
||||
prompt = """Hello world"""
|
||||
text_generator = pipeline("text-generation", model="hf-internal-testing/tiny-random-gpt2")
|
||||
if text_generator.model.framework == "tf":
|
||||
logger = logging.get_logger("transformers.generation.tf_utils")
|
||||
else:
|
||||
logger = logging.get_logger("transformers.generation.utils")
|
||||
logger_msg = "Both `max_new_tokens`" # The beggining of the message to be checked in this test
|
||||
|
||||
# Both are set by the user -> log warning
|
||||
with CaptureLogger(logger) as cl:
|
||||
_ = text_generator(prompt, max_length=10, max_new_tokens=1)
|
||||
self.assertIn(logger_msg, cl.out)
|
||||
|
||||
# The user only sets one -> no warning
|
||||
with CaptureLogger(logger) as cl:
|
||||
_ = text_generator(prompt, max_new_tokens=1)
|
||||
self.assertNotIn(logger_msg, cl.out)
|
||||
|
||||
with CaptureLogger(logger) as cl:
|
||||
_ = text_generator(prompt, max_length=10)
|
||||
self.assertNotIn(logger_msg, cl.out)
|
||||
|
|
Loading…
Reference in New Issue