Bark model Flash Attention 2 Enabling to pass on check_device_map parameter to super() (#29357)

* Fixing error #29332. The _check_and_enable_flash_attn_2() method receives a check_device_map parameter and fails.

* style fixup
This commit is contained in:
Damith Senanayake 2024-03-11 23:44:12 +11:00 committed by GitHub
parent 6d67837f06
commit 9a3f4d4daf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 2 additions and 1 deletions

View File

@ -1881,6 +1881,7 @@ class BarkModel(BarkPreTrainedModel):
torch_dtype: Optional[torch.dtype] = None,
device_map: Optional[Union[str, Dict[str, int]]] = None,
hard_check_only: bool = False,
check_device_map: bool = False,
):
"""
`_check_and_enable_flash_attn_2` originally don't expand flash attention enabling to the model
@ -1901,7 +1902,7 @@ class BarkModel(BarkPreTrainedModel):
can initialize the correct attention module
"""
config = super()._check_and_enable_flash_attn_2(
config, torch_dtype, device_map, hard_check_only=hard_check_only
config, torch_dtype, device_map, hard_check_only=hard_check_only, check_device_map=check_device_map
)
config.semantic_config._attn_implementation = config._attn_implementation