feat: Sequential beam search (#26304)
This commit is contained in:
parent
268fc1fdfa
commit
d4fc1eb498
|
@ -200,7 +200,8 @@ class GenerationConfig(PushToHubMixin):
|
|||
Higher guidance scale encourages the model to generate samples that are more closely linked to the input
|
||||
prompt, usually at the expense of poorer quality.
|
||||
low_memory (`bool`, *optional*):
|
||||
Switch to sequential topk for contrastive search to reduce peak memory. Used with contrastive search.
|
||||
Switch to sequential beam search and sequential topk for contrastive search to reduce peak memory.
|
||||
Used with beam search and contrastive search.
|
||||
|
||||
|
||||
> Parameters that define the output variables of `generate`
|
||||
|
|
|
@ -1558,6 +1558,7 @@ class GenerationMixin:
|
|||
output_scores=generation_config.output_scores,
|
||||
return_dict_in_generate=generation_config.return_dict_in_generate,
|
||||
synced_gpus=synced_gpus,
|
||||
sequential=generation_config.low_memory,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
|
@ -1951,8 +1952,7 @@ class GenerationMixin:
|
|||
model_kwargs["past_key_values"] = tuple(new_key_values)
|
||||
|
||||
if sequential:
|
||||
all_outputs = {key: [] for key in outputs} # defined in first loop iteration
|
||||
all_last_hstates, all_hstates, all_logits = [], [], []
|
||||
all_outputs = []
|
||||
for i in range(top_k):
|
||||
# compute the candidate tokens by the language model and collect their hidden_states
|
||||
next_model_inputs = self.prepare_inputs_for_generation(top_k_ids[:, i].view(-1, 1), **model_kwargs)
|
||||
|
@ -1963,32 +1963,8 @@ class GenerationMixin:
|
|||
output_hidden_states=True,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
for key in all_outputs:
|
||||
all_outputs[key].append(outputs[key])
|
||||
|
||||
if self.config.is_encoder_decoder:
|
||||
next_hidden = outputs.decoder_hidden_states[-1]
|
||||
full_hidden_states = outputs.decoder_hidden_states
|
||||
|
||||
else:
|
||||
next_hidden = outputs.hidden_states[-1]
|
||||
full_hidden_states = outputs.hidden_states
|
||||
|
||||
all_last_hstates.append(torch.squeeze(next_hidden, 0))
|
||||
all_hstates.append(full_hidden_states)
|
||||
all_logits.append(outputs.logits[:, -1, :])
|
||||
|
||||
# stack hidden states
|
||||
next_hidden = torch.stack([all_last_hstates[i] for i in range(top_k)], dim=0)
|
||||
final_full_hstates = [0 for i in range(len(full_hidden_states))]
|
||||
for layer in range(len(full_hidden_states)):
|
||||
final_full_hstates[layer] = torch.stack(
|
||||
[torch.squeeze(all_hstates[i][layer], 0) for i in range(top_k)], dim=0
|
||||
)
|
||||
full_hidden_states = tuple(final_full_hstates)
|
||||
|
||||
# stack logits
|
||||
logits = torch.cat(all_logits, dim=0)
|
||||
all_outputs.append(outputs)
|
||||
outputs = stack_model_outputs(all_outputs)
|
||||
|
||||
else:
|
||||
# compute the candidate tokens by the language model and collect their hidden_states
|
||||
|
@ -2001,15 +1977,15 @@ class GenerationMixin:
|
|||
output_hidden_states=True,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
# name is different for encoder-decoder and decoder-only models
|
||||
if self.config.is_encoder_decoder:
|
||||
next_hidden = outputs.decoder_hidden_states[-1]
|
||||
full_hidden_states = outputs.decoder_hidden_states
|
||||
else:
|
||||
next_hidden = outputs.hidden_states[-1]
|
||||
full_hidden_states = outputs.hidden_states
|
||||
# name is different for encoder-decoder and decoder-only models
|
||||
if self.config.is_encoder_decoder:
|
||||
next_hidden = outputs.decoder_hidden_states[-1]
|
||||
full_hidden_states = outputs.decoder_hidden_states
|
||||
else:
|
||||
next_hidden = outputs.hidden_states[-1]
|
||||
full_hidden_states = outputs.hidden_states
|
||||
|
||||
logits = outputs.logits[:, -1, :]
|
||||
logits = outputs.logits[:, -1, :]
|
||||
|
||||
context_hidden = last_hidden_states.repeat_interleave(top_k, dim=0)
|
||||
|
||||
|
@ -2747,6 +2723,7 @@ class GenerationMixin:
|
|||
output_scores: Optional[bool] = None,
|
||||
return_dict_in_generate: Optional[bool] = None,
|
||||
synced_gpus: bool = False,
|
||||
sequential: Optional[bool] = None,
|
||||
**model_kwargs,
|
||||
) -> Union[GenerateBeamOutput, torch.LongTensor]:
|
||||
r"""
|
||||
|
@ -2792,6 +2769,10 @@ class GenerationMixin:
|
|||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
synced_gpus (`bool`, *optional*, defaults to `False`):
|
||||
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
|
||||
sequential (`bool`, defaults to `False`):
|
||||
By default, beam search has `batch_size * num_beams` as effective batch size (see `beam_search()` for
|
||||
more details). This flag will avoid parallelizing the beam search and will instead run beam search
|
||||
sequentially.
|
||||
model_kwargs:
|
||||
Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
|
||||
an encoder-decoder model the kwargs should include `encoder_outputs`.
|
||||
|
@ -2858,6 +2839,7 @@ class GenerationMixin:
|
|||
# init values
|
||||
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
||||
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
||||
sequential = sequential if sequential is not None else self.generation_config.low_memory
|
||||
if max_length is not None:
|
||||
warnings.warn(
|
||||
"`max_length` is deprecated in this function, use"
|
||||
|
@ -2932,12 +2914,39 @@ class GenerationMixin:
|
|||
|
||||
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
||||
|
||||
outputs = self(
|
||||
**model_inputs,
|
||||
return_dict=True,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
# if sequential is True, split the input to batches of batch_size and run sequentially
|
||||
if sequential:
|
||||
if any(
|
||||
model_name in self.__class__.__name__.lower()
|
||||
for model_name in ["fsmt", "reformer", "bloom", "ctrl", "gpt_bigcode", "transo_xl", "xlnet", "cpm"]
|
||||
):
|
||||
raise RuntimeError(
|
||||
f"Currently generation for {self.__class__.__name__} is not supported "
|
||||
f"for `low_memory beam_search`. Please open an issue on GitHub if you need this feature."
|
||||
)
|
||||
|
||||
inputs_per_sub_batches = _split_model_inputs(
|
||||
model_inputs, split_size=batch_size, full_batch_size=batch_beam_size
|
||||
)
|
||||
outputs_per_sub_batch = [
|
||||
self(
|
||||
**inputs_per_sub_batch,
|
||||
return_dict=True,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
for inputs_per_sub_batch in inputs_per_sub_batches
|
||||
]
|
||||
|
||||
outputs = stack_model_outputs(outputs_per_sub_batch)
|
||||
|
||||
else: # Unchanged original behavior
|
||||
outputs = self(
|
||||
**model_inputs,
|
||||
return_dict=True,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
|
||||
if synced_gpus and this_peer_finished:
|
||||
cur_len = cur_len + 1
|
||||
|
@ -4656,3 +4665,139 @@ def _ranking_fast(
|
|||
contrastive_score = torch.stack(torch.split(contrastive_score, beam_width)) # [B, K]
|
||||
_, selected_idx = contrastive_score.max(dim=-1) # [B]
|
||||
return selected_idx
|
||||
|
||||
|
||||
def _split(data, full_batch_size: int, split_size: int = None):
|
||||
"""
|
||||
Takes care of three cases:
|
||||
1. data is a tensor: e.g. last_hidden_state, pooler_output etc. split them on the batch_size dim
|
||||
2. data is a tuple: e.g. hidden_states, attentions etc. Keep the tuple as it is and split each tensor in it and
|
||||
return a list of tuples
|
||||
3. data is a tuple of tuples, e.g. past_key_values. Keep the tuple as it is and split each tuple in it and
|
||||
return a list of tuples of tuples
|
||||
(see documentation of ModelOutput)
|
||||
"""
|
||||
if data is None:
|
||||
return [None] * (full_batch_size // split_size)
|
||||
if isinstance(data, torch.Tensor):
|
||||
return [data[i : i + split_size] for i in range(0, full_batch_size, split_size)]
|
||||
elif isinstance(data, tuple):
|
||||
# If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example)
|
||||
if isinstance(data[0], tuple):
|
||||
return [
|
||||
tuple(tuple(tensor[i : i + split_size] for tensor in inner_tuple) for inner_tuple in data)
|
||||
for i in range(0, full_batch_size, split_size)
|
||||
]
|
||||
|
||||
else:
|
||||
return [
|
||||
tuple(sub_tensor[i : i + split_size] for sub_tensor in data)
|
||||
for i in range(0, full_batch_size, split_size)
|
||||
]
|
||||
else:
|
||||
raise ValueError(f"Unexpected attribute type: {type(data)}")
|
||||
|
||||
|
||||
def _split_model_inputs(
|
||||
model_input: Union[ModelOutput, Dict], split_size: int, full_batch_size: int
|
||||
) -> List[Union[ModelOutput, Dict]]:
|
||||
"""
|
||||
Split a ModelOutput object (or its subclasses) or Dict into a list of same-class objects based on a specified split
|
||||
size. The input object is dict when it was prepared for forward pass and ModelOutput when it was returned from
|
||||
previous forward pass.
|
||||
"""
|
||||
# Edge case: if model_input is None, return a list of Nones
|
||||
# this happens with Whisper where encoder_outputs is None
|
||||
if model_input is None:
|
||||
return [model_input] * (full_batch_size // split_size)
|
||||
# Infer the class from the object
|
||||
model_output_cls = type(model_input)
|
||||
if (full_batch_size % split_size) != 0:
|
||||
raise ValueError("`full_batch_size` must be divisible by `split_size`")
|
||||
|
||||
if split_size > full_batch_size:
|
||||
raise ValueError("`split_size` must be smaller or equal to `full_batch_size`")
|
||||
|
||||
# Helper function to split tensors or tuples of tensors
|
||||
|
||||
# Find all the dataclass fields (e.g., last_hidden_state, pooler_output etc.) and split them
|
||||
keys = (
|
||||
model_input.__dataclass_fields__.keys() if hasattr(model_input, "__dataclass_fields__") else model_input.keys()
|
||||
)
|
||||
# We only keep keys that are in the model_input
|
||||
keys = [k for k in keys if k in model_input]
|
||||
# Here we can have four types of values: tensors, tuples of tensors and booleans, and encoder_outputs which is a
|
||||
# ModelOutput object.
|
||||
# bool should not be split but replicated for each split
|
||||
bool_keys = [k for k in keys if isinstance(model_input[k], bool)]
|
||||
non_bool_keys = [k for k in keys if not isinstance(model_input[k], bool) and not k == "encoder_outputs"]
|
||||
|
||||
# we split the tensors and tuples of tensors
|
||||
data_split_list = [
|
||||
{k: _split(model_input[k], full_batch_size, split_size)[i] for k in non_bool_keys}
|
||||
for i in range(full_batch_size // split_size)
|
||||
]
|
||||
# bool values are the same and replicated for each split
|
||||
bool_data = {k: model_input[k] for k in bool_keys}
|
||||
# encoder_outputs is a ModelOutput object and should be split by its own
|
||||
if "encoder_outputs" in model_input:
|
||||
encoder_outputs_split = _split_model_inputs(model_input["encoder_outputs"], split_size, full_batch_size)
|
||||
data_split_list = [
|
||||
{**data_split, "encoder_outputs": encoder_outputs_split[i]} for i, data_split in enumerate(data_split_list)
|
||||
]
|
||||
|
||||
# Convert each dictionary in the list to an object of the inferred class
|
||||
split_model_inputs: List[Union[ModelOutput, Dict]] = [
|
||||
model_output_cls(**data_split, **bool_data) for data_split in data_split_list
|
||||
]
|
||||
|
||||
return split_model_inputs
|
||||
|
||||
|
||||
def stack_model_outputs(model_outputs: List[ModelOutput]) -> ModelOutput:
|
||||
"""
|
||||
Stack a list of ModelOutput objects (or its subclasses) along the batch_size dimension. The function infers the
|
||||
specific ModelOutput subclass from the list provided.
|
||||
"""
|
||||
if not model_outputs:
|
||||
raise ValueError("Input list is empty.")
|
||||
|
||||
# Infer the class from the first object in the list
|
||||
model_output_cls = type(model_outputs[0])
|
||||
|
||||
# Ensure all objects are of the same type
|
||||
if not all(isinstance(obj, model_output_cls) for obj in model_outputs):
|
||||
raise ValueError("All elements in the list should be of the same type.")
|
||||
|
||||
# Helper function to concat tensors or tuples of tensors
|
||||
def _concat(data):
|
||||
"""
|
||||
Reverse of `_split` function above.
|
||||
"""
|
||||
if any(data is None for data in data):
|
||||
return None
|
||||
if isinstance(data[0], torch.Tensor):
|
||||
return torch.cat(data, dim=0)
|
||||
elif isinstance(data[0], tuple):
|
||||
# If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example)
|
||||
if isinstance(data[0][0], tuple):
|
||||
return tuple(
|
||||
tuple(torch.cat([attr[i][j] for attr in data], dim=0) for j in range(len(data[0][0])))
|
||||
for i in range(len(data[0]))
|
||||
)
|
||||
else:
|
||||
return tuple(torch.cat([attr[i] for attr in data], dim=0) for i in range(len(data[0])))
|
||||
elif isinstance(data[0], (int, float)):
|
||||
# If the elements are integers or floats, return a tensor
|
||||
return torch.tensor(data)
|
||||
else:
|
||||
raise ValueError(f"Unexpected attribute type: {type(data[0])}")
|
||||
|
||||
# Use a dictionary comprehension to gather attributes from all objects and concatenate them
|
||||
concatenated_data = {
|
||||
k: _concat([getattr(model_output, k) for model_output in model_outputs])
|
||||
for k in model_output_cls.__dataclass_fields__.keys()
|
||||
}
|
||||
|
||||
# Return a new object of the inferred class with the concatenated attributes
|
||||
return model_output_cls(**concatenated_data)
|
||||
|
|
|
@ -1539,6 +1539,39 @@ class GenerationTesterMixin:
|
|||
)
|
||||
self.assertListEqual(low_output.tolist(), high_output.tolist())
|
||||
|
||||
def test_beam_search_low_memory(self):
|
||||
# Check that choosing 'low_memory' does not change the model output
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
|
||||
self.skipTest("Won't fix: old model with different cache format")
|
||||
if any(
|
||||
model_name in model_class.__name__.lower()
|
||||
for model_name in [
|
||||
"bloom",
|
||||
"ctrl",
|
||||
"gptbigcode",
|
||||
"transo_xl",
|
||||
"xlnet",
|
||||
"cpm",
|
||||
]
|
||||
):
|
||||
self.skipTest("May fix in the future: need model-specific fixes")
|
||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config(batch_size=2)
|
||||
# batch_size=1 is ok, but batch_size>1 will cause non-identical output
|
||||
|
||||
config.use_cache = True
|
||||
config.is_decoder = True
|
||||
|
||||
# test output equality of low versus high memory
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
|
||||
low_output = model.generate(input_ids, max_new_tokens=8, num_beams=5, early_stopping=True, low_memory=True)
|
||||
|
||||
high_output = model.generate(
|
||||
input_ids, max_new_tokens=8, num_beams=5, early_stopping=True, low_memory=False
|
||||
)
|
||||
self.assertListEqual(low_output.tolist(), high_output.tolist())
|
||||
|
||||
@is_flaky() # Read NOTE (1) below. If there are API issues, all attempts will fail.
|
||||
def test_assisted_decoding_matches_greedy_search(self):
|
||||
# This test ensures that the assisted generation does not introduce output changes over greedy search.
|
||||
|
@ -2766,6 +2799,19 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
|||
|
||||
self.assertTrue(torch.allclose(transition_scores_sum, outputs.sequences_scores, atol=1e-3))
|
||||
|
||||
def test_beam_search_low_memory(self):
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
||||
model = AutoModelForCausalLM.from_pretrained("gpt2")
|
||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
model_inputs = tokenizer("I", return_tensors="pt")["input_ids"]
|
||||
|
||||
low_output = model.generate(model_inputs, max_new_tokens=40, num_beams=5, early_stopping=True, low_memory=True)
|
||||
|
||||
high_output = model.generate(
|
||||
model_inputs, max_new_tokens=40, num_beams=5, early_stopping=True, low_memory=False
|
||||
)
|
||||
self.assertListEqual(low_output.tolist(), high_output.tolist())
|
||||
|
||||
@slow
|
||||
def test_beam_search_example_integration(self):
|
||||
# PT-only test: TF doesn't have a BeamSearchScorer
|
||||
|
|
Loading…
Reference in New Issue