Fix gradient checkpoint test in encoder-decoder (#20017)

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar 2022-11-02 14:15:09 +01:00 committed by GitHub
parent a6b7759880
commit c6c9db3d0c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 5 additions and 1 deletions

View File

@ -618,8 +618,10 @@ class EncoderDecoderMixin:
)
model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
model.train()
model.to(torch_device)
model.gradient_checkpointing_enable()
model.train()
model.config.decoder_start_token_id = 0
model.config.pad_token_id = 0
@ -629,6 +631,8 @@ class EncoderDecoderMixin:
"labels": inputs_dict["labels"],
"decoder_input_ids": inputs_dict["decoder_input_ids"],
}
model_inputs = {k: v.to(torch_device) for k, v in model_inputs.items()}
loss = model(**model_inputs).loss
loss.backward()