Generate: add missing `**model_kwargs` in sample tests (#18696)
This commit is contained in:
parent
e54a1b49aa
commit
e95d433d77
|
@ -327,6 +327,7 @@ class GenerationTesterMixin:
|
|||
remove_invalid_values=True,
|
||||
**logits_warper_kwargs,
|
||||
**process_kwargs,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
@ -361,6 +362,7 @@ class GenerationTesterMixin:
|
|||
**kwargs,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
return output_sample, output_generate
|
||||
|
||||
def _beam_search_generate(
|
||||
|
|
Loading…
Reference in New Issue