Generate: add missing `**model_kwargs` in sample tests (#18696)
This commit is contained in:
parent
e54a1b49aa
commit
e95d433d77
|
@ -327,6 +327,7 @@ class GenerationTesterMixin:
|
||||||
remove_invalid_values=True,
|
remove_invalid_values=True,
|
||||||
**logits_warper_kwargs,
|
**logits_warper_kwargs,
|
||||||
**process_kwargs,
|
**process_kwargs,
|
||||||
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
|
@ -361,6 +362,7 @@ class GenerationTesterMixin:
|
||||||
**kwargs,
|
**kwargs,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
return output_sample, output_generate
|
return output_sample, output_generate
|
||||||
|
|
||||||
def _beam_search_generate(
|
def _beam_search_generate(
|
||||||
|
|
Loading…
Reference in New Issue