[SDPA] Make sure attn mask creation is always done on CPU (#28400)
* [SDPA] Make sure attn mask creation is always done on CPU * Update docker to 2.1.1 * revert test change
This commit is contained in:
parent
5c7e11e010
commit
8604dd308d
|
@ -9,9 +9,9 @@ SHELL ["sh", "-lc"]
|
|||
# The following `ARG` are mainly used to specify the versions explicitly & directly in this docker file, and not meant
|
||||
# to be used as arguments for docker build (so far).
|
||||
|
||||
ARG PYTORCH='2.1.0'
|
||||
ARG PYTORCH='2.1.1'
|
||||
# (not always a valid torch version)
|
||||
ARG INTEL_TORCH_EXT='2.1.0'
|
||||
ARG INTEL_TORCH_EXT='2.1.1'
|
||||
# Example: `cu102`, `cu113`, etc.
|
||||
ARG CUDA='cu118'
|
||||
|
||||
|
|
|
@ -11,7 +11,7 @@ ARG REF=main
|
|||
RUN git clone https://github.com/huggingface/transformers && cd transformers && git checkout $REF
|
||||
|
||||
# If set to nothing, will install the latest version
|
||||
ARG PYTORCH='2.1.0'
|
||||
ARG PYTORCH='2.1.1'
|
||||
ARG TORCH_VISION=''
|
||||
ARG TORCH_AUDIO=''
|
||||
# Example: `cu102`, `cu113`, etc.
|
||||
|
|
|
@ -234,8 +234,8 @@ class AttentionMaskConverter:
|
|||
|
||||
# Get the index of the first non-zero value for every sample in the batch.
|
||||
# In the above example, indices = [[2], [0], [1]]]
|
||||
tmp = torch.arange(attention_mask.shape[1], 0, -1, device=attention_mask.device)
|
||||
indices = torch.argmax(attention_mask * tmp, 1, keepdim=True)
|
||||
tmp = torch.arange(attention_mask.shape[1], 0, -1)
|
||||
indices = torch.argmax(attention_mask.cpu() * tmp, 1, keepdim=True)
|
||||
|
||||
# Find the batch indexes that have unattended tokens on the leftmost side (e.g. [0, 0, 1, 1, 1]), for which the first rows of the
|
||||
# expanded mask will be completely unattended.
|
||||
|
|
Loading…
Reference in New Issue