Generate: add missing `**model_kwargs` in sample tests (#18696)

This commit is contained in:
Joao Gante 2022-08-19 16:14:27 +01:00 committed by GitHub
parent e54a1b49aa
commit e95d433d77
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 0 deletions

View File

@ -327,6 +327,7 @@ class GenerationTesterMixin:
remove_invalid_values=True, remove_invalid_values=True,
**logits_warper_kwargs, **logits_warper_kwargs,
**process_kwargs, **process_kwargs,
**model_kwargs,
) )
torch.manual_seed(0) torch.manual_seed(0)
@ -361,6 +362,7 @@ class GenerationTesterMixin:
**kwargs, **kwargs,
**model_kwargs, **model_kwargs,
) )
return output_sample, output_generate return output_sample, output_generate
def _beam_search_generate( def _beam_search_generate(