Fix gradient checkpoint test in encoder-decoder (#20017)
Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
parent
a6b7759880
commit
c6c9db3d0c
|
@ -618,8 +618,10 @@ class EncoderDecoderMixin:
|
|||
)
|
||||
|
||||
model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||
model.train()
|
||||
model.to(torch_device)
|
||||
model.gradient_checkpointing_enable()
|
||||
model.train()
|
||||
|
||||
model.config.decoder_start_token_id = 0
|
||||
model.config.pad_token_id = 0
|
||||
|
||||
|
@ -629,6 +631,8 @@ class EncoderDecoderMixin:
|
|||
"labels": inputs_dict["labels"],
|
||||
"decoder_input_ids": inputs_dict["decoder_input_ids"],
|
||||
}
|
||||
model_inputs = {k: v.to(torch_device) for k, v in model_inputs.items()}
|
||||
|
||||
loss = model(**model_inputs).loss
|
||||
loss.backward()
|
||||
|
||||
|
|
Loading…
Reference in New Issue