From 50f4539b8201b26b18085260bf801cdeadfa6640 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 23 Apr 2021 15:36:27 +0200 Subject: [PATCH] push (#11400) --- .../convert_wav2vec2_original_pytorch_checkpoint_to_pytorch.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/wav2vec2/convert_wav2vec2_original_pytorch_checkpoint_to_pytorch.py b/src/transformers/models/wav2vec2/convert_wav2vec2_original_pytorch_checkpoint_to_pytorch.py index a00b5cf5eb..cc902ee3bc 100644 --- a/src/transformers/models/wav2vec2/convert_wav2vec2_original_pytorch_checkpoint_to_pytorch.py +++ b/src/transformers/models/wav2vec2/convert_wav2vec2_original_pytorch_checkpoint_to_pytorch.py @@ -66,7 +66,8 @@ def set_recursively(hf_pointer, key, value, full_name, weight_type): assert ( hf_shape == value.shape - ), f"Shape of hf {key + '.' + weight_type} is {hf_shape}, but should be {value.shape} for {full_name}" + ), f"Shape of hf {key + '.' + weight_type if weight_type is not None else ''} is {hf_shape}, but should be {value.shape} for {full_name}" + if weight_type == "weight": hf_pointer.weight.data = value elif weight_type == "weight_g":