Have seq2seq just use gather (#27025)
* Have seq2seq just use gather * Change * Reset after * Make slow * Apply suggestions from code review Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Clean * Simplify and just use gather * Update tests/trainer/test_trainer_seq2seq.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * gather always for seq2seq --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
parent
250032e974
commit
067c4a310d
|
@ -3208,13 +3208,13 @@ class Trainer:
|
|||
|
||||
# Update containers on host
|
||||
if loss is not None:
|
||||
losses = self.accelerator.gather_for_metrics((loss.repeat(batch_size)))
|
||||
losses = self.gather_function((loss.repeat(batch_size)))
|
||||
losses_host = losses if losses_host is None else nested_concat(losses_host, losses, padding_index=-100)
|
||||
if labels is not None:
|
||||
labels = self.accelerator.pad_across_processes(labels, dim=1, pad_index=-100)
|
||||
if inputs_decode is not None:
|
||||
inputs_decode = self.accelerator.pad_across_processes(inputs_decode, dim=1, pad_index=-100)
|
||||
inputs_decode = self.accelerator.gather_for_metrics((inputs_decode))
|
||||
inputs_decode = self.gather_function((inputs_decode))
|
||||
inputs_host = (
|
||||
inputs_decode
|
||||
if inputs_host is None
|
||||
|
@ -3224,11 +3224,11 @@ class Trainer:
|
|||
logits = self.accelerator.pad_across_processes(logits, dim=1, pad_index=-100)
|
||||
if self.preprocess_logits_for_metrics is not None:
|
||||
logits = self.preprocess_logits_for_metrics(logits, labels)
|
||||
logits = self.accelerator.gather_for_metrics((logits))
|
||||
logits = self.gather_function((logits))
|
||||
preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100)
|
||||
|
||||
if labels is not None:
|
||||
labels = self.accelerator.gather_for_metrics((labels))
|
||||
labels = self.gather_function((labels))
|
||||
labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100)
|
||||
|
||||
self.control = self.callback_handler.on_prediction_step(args, self.state, self.control)
|
||||
|
@ -3261,6 +3261,8 @@ class Trainer:
|
|||
# Set back to None to begin a new accumulation
|
||||
losses_host, preds_host, inputs_host, labels_host = None, None, None, None
|
||||
|
||||
# After all calls to `.gather_function`, reset to `gather_for_metrics`:
|
||||
self.gather_function = self.accelerator.gather_for_metrics
|
||||
if args.past_index and hasattr(self, "_past"):
|
||||
# Clean the state at the end of the evaluation loop
|
||||
delattr(self, "_past")
|
||||
|
@ -3930,6 +3932,8 @@ class Trainer:
|
|||
deepspeed_plugin=self.args.deepspeed_plugin,
|
||||
gradient_accumulation_plugin=gradient_accumulation_plugin,
|
||||
)
|
||||
# some Trainer classes need to use `gather` instead of `gather_for_metrics`, thus we store a flag
|
||||
self.gather_function = self.accelerator.gather_for_metrics
|
||||
|
||||
# deepspeed and accelerate flags covering both trainer args and accelerate launcher
|
||||
self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
|
||||
|
|
|
@ -160,8 +160,9 @@ class Seq2SeqTrainer(Trainer):
|
|||
gen_kwargs["max_length"] = self.args.generation_max_length
|
||||
if gen_kwargs.get("num_beams") is None and self.args.generation_num_beams is not None:
|
||||
gen_kwargs["num_beams"] = self.args.generation_num_beams
|
||||
# We don't want to drop samples in general
|
||||
self.gather_function = self.accelerator.gather
|
||||
self._gen_kwargs = gen_kwargs
|
||||
|
||||
return super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
|
||||
|
||||
def predict(
|
||||
|
@ -223,6 +224,7 @@ class Seq2SeqTrainer(Trainer):
|
|||
gen_kwargs["max_length"] = self.args.generation_max_length
|
||||
if gen_kwargs.get("num_beams") is None and self.args.generation_num_beams is not None:
|
||||
gen_kwargs["num_beams"] = self.args.generation_num_beams
|
||||
self.gather_function = self.accelerator.gather
|
||||
self._gen_kwargs = gen_kwargs
|
||||
|
||||
return super().predict(test_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
|
||||
|
|
|
@ -12,8 +12,16 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from transformers import BertTokenizer, EncoderDecoderModel, Seq2SeqTrainer, Seq2SeqTrainingArguments
|
||||
from transformers import (
|
||||
AutoModelForSeq2SeqLM,
|
||||
BertTokenizer,
|
||||
DataCollatorForSeq2Seq,
|
||||
EncoderDecoderModel,
|
||||
GenerationConfig,
|
||||
Seq2SeqTrainer,
|
||||
Seq2SeqTrainingArguments,
|
||||
T5Tokenizer,
|
||||
)
|
||||
from transformers.testing_utils import TestCasePlus, require_torch, slow
|
||||
from transformers.utils import is_datasets_available
|
||||
|
||||
|
@ -124,3 +132,52 @@ class Seq2seqTrainerTester(TestCasePlus):
|
|||
|
||||
# start training
|
||||
trainer.train()
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
def test_return_sequences(self):
|
||||
# Tests that the number of generated sequences is correct when num_return_sequences > 1
|
||||
# and essentially ensuring that `accelerator.gather()` is used instead of `gather_for_metrics`
|
||||
INPUT_COLUMN = "question"
|
||||
TARGET_COLUMN = "answer"
|
||||
MAX_INPUT_LENGTH = 256
|
||||
MAX_TARGET_LENGTH = 256
|
||||
|
||||
dataset = datasets.load_dataset("gsm8k", "main", split="train[:38]")
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained("t5-small")
|
||||
tokenizer = T5Tokenizer.from_pretrained("t5-small")
|
||||
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model, return_tensors="pt", padding="longest")
|
||||
gen_config = GenerationConfig.from_pretrained(
|
||||
"t5-small", max_length=None, min_length=None, max_new_tokens=256, min_new_tokens=1, num_beams=5
|
||||
)
|
||||
|
||||
training_args = Seq2SeqTrainingArguments(".", predict_with_generate=True)
|
||||
|
||||
trainer = Seq2SeqTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
tokenizer=tokenizer,
|
||||
data_collator=data_collator,
|
||||
compute_metrics=lambda x: {"samples": x[0].shape[0]},
|
||||
)
|
||||
|
||||
def prepare_data(examples):
|
||||
# Remove pairs where at least one record is none
|
||||
inputs = examples[INPUT_COLUMN]
|
||||
targets = examples[TARGET_COLUMN]
|
||||
|
||||
model_inputs = tokenizer(inputs, max_length=MAX_INPUT_LENGTH, truncation=True)
|
||||
labels = tokenizer(text_target=targets, max_length=MAX_TARGET_LENGTH, truncation=True)
|
||||
model_inputs["labels"] = labels["input_ids"]
|
||||
|
||||
return model_inputs
|
||||
|
||||
prepared_dataset = dataset.map(prepare_data, batched=True, remove_columns=[INPUT_COLUMN, TARGET_COLUMN])
|
||||
dataset_len = len(prepared_dataset) # 38
|
||||
|
||||
for num_return_sequences in range(3, 0, -1):
|
||||
gen_config.num_return_sequences = num_return_sequences
|
||||
metrics = trainer.evaluate(eval_dataset=prepared_dataset, generation_config=gen_config)
|
||||
assert (
|
||||
metrics["eval_samples"] == dataset_len * num_return_sequences
|
||||
), f"Got {metrics['eval_samples']}, expected: {dataset_len * num_return_sequences}"
|
||||
|
|
Loading…
Reference in New Issue