From c6c9db3d0cc36c5fe57d508b669c305ac7894145 Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Wed, 2 Nov 2022 14:15:09 +0100 Subject: [PATCH] Fix gradient checkpoint test in encoder-decoder (#20017) Co-authored-by: ydshieh --- .../models/encoder_decoder/test_modeling_encoder_decoder.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/models/encoder_decoder/test_modeling_encoder_decoder.py b/tests/models/encoder_decoder/test_modeling_encoder_decoder.py index 1181b94789..8f565aec06 100644 --- a/tests/models/encoder_decoder/test_modeling_encoder_decoder.py +++ b/tests/models/encoder_decoder/test_modeling_encoder_decoder.py @@ -618,8 +618,10 @@ class EncoderDecoderMixin: ) model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model) - model.train() + model.to(torch_device) model.gradient_checkpointing_enable() + model.train() + model.config.decoder_start_token_id = 0 model.config.pad_token_id = 0 @@ -629,6 +631,8 @@ class EncoderDecoderMixin: "labels": inputs_dict["labels"], "decoder_input_ids": inputs_dict["decoder_input_ids"], } + model_inputs = {k: v.to(torch_device) for k, v in model_inputs.items()} + loss = model(**model_inputs).loss loss.backward()