Bark model Flash Attention 2 Enabling to pass on check_device_map parameter to super() (#29357)
* Fixing error #29332. The _check_and_enable_flash_attn_2() method receives a check_device_map parameter and fails. * style fixup
This commit is contained in:
parent
6d67837f06
commit
9a3f4d4daf
|
@ -1881,6 +1881,7 @@ class BarkModel(BarkPreTrainedModel):
|
|||
torch_dtype: Optional[torch.dtype] = None,
|
||||
device_map: Optional[Union[str, Dict[str, int]]] = None,
|
||||
hard_check_only: bool = False,
|
||||
check_device_map: bool = False,
|
||||
):
|
||||
"""
|
||||
`_check_and_enable_flash_attn_2` originally don't expand flash attention enabling to the model
|
||||
|
@ -1901,7 +1902,7 @@ class BarkModel(BarkPreTrainedModel):
|
|||
can initialize the correct attention module
|
||||
"""
|
||||
config = super()._check_and_enable_flash_attn_2(
|
||||
config, torch_dtype, device_map, hard_check_only=hard_check_only
|
||||
config, torch_dtype, device_map, hard_check_only=hard_check_only, check_device_map=check_device_map
|
||||
)
|
||||
|
||||
config.semantic_config._attn_implementation = config._attn_implementation
|
||||
|
|
Loading…
Reference in New Issue