From 8604dd308d8930b5dc788eefcd9eefad7555a11c Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 9 Jan 2024 11:05:19 +0100 Subject: [PATCH] [SDPA] Make sure attn mask creation is always done on CPU (#28400) * [SDPA] Make sure attn mask creation is always done on CPU * Update docker to 2.1.1 * revert test change --- docker/transformers-all-latest-gpu/Dockerfile | 4 ++-- docker/transformers-pytorch-gpu/Dockerfile | 2 +- src/transformers/modeling_attn_mask_utils.py | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/docker/transformers-all-latest-gpu/Dockerfile b/docker/transformers-all-latest-gpu/Dockerfile index 0d694eaa72..baa0804430 100644 --- a/docker/transformers-all-latest-gpu/Dockerfile +++ b/docker/transformers-all-latest-gpu/Dockerfile @@ -9,9 +9,9 @@ SHELL ["sh", "-lc"] # The following `ARG` are mainly used to specify the versions explicitly & directly in this docker file, and not meant # to be used as arguments for docker build (so far). -ARG PYTORCH='2.1.0' +ARG PYTORCH='2.1.1' # (not always a valid torch version) -ARG INTEL_TORCH_EXT='2.1.0' +ARG INTEL_TORCH_EXT='2.1.1' # Example: `cu102`, `cu113`, etc. ARG CUDA='cu118' diff --git a/docker/transformers-pytorch-gpu/Dockerfile b/docker/transformers-pytorch-gpu/Dockerfile index 44f6095894..a45210e7d1 100644 --- a/docker/transformers-pytorch-gpu/Dockerfile +++ b/docker/transformers-pytorch-gpu/Dockerfile @@ -11,7 +11,7 @@ ARG REF=main RUN git clone https://github.com/huggingface/transformers && cd transformers && git checkout $REF # If set to nothing, will install the latest version -ARG PYTORCH='2.1.0' +ARG PYTORCH='2.1.1' ARG TORCH_VISION='' ARG TORCH_AUDIO='' # Example: `cu102`, `cu113`, etc. diff --git a/src/transformers/modeling_attn_mask_utils.py b/src/transformers/modeling_attn_mask_utils.py index 32cecd4f2a..f0964f9402 100755 --- a/src/transformers/modeling_attn_mask_utils.py +++ b/src/transformers/modeling_attn_mask_utils.py @@ -234,8 +234,8 @@ class AttentionMaskConverter: # Get the index of the first non-zero value for every sample in the batch. # In the above example, indices = [[2], [0], [1]]] - tmp = torch.arange(attention_mask.shape[1], 0, -1, device=attention_mask.device) - indices = torch.argmax(attention_mask * tmp, 1, keepdim=True) + tmp = torch.arange(attention_mask.shape[1], 0, -1) + indices = torch.argmax(attention_mask.cpu() * tmp, 1, keepdim=True) # Find the batch indexes that have unattended tokens on the leftmost side (e.g. [0, 0, 1, 1, 1]), for which the first rows of the # expanded mask will be completely unattended.