Fix deprecation warnings for int div (#15180)

* Fix deprecation warnings for int div

Co-authored-by: mgoldey <matthew.goldey@gmail.com>

* Fix import

* ensure that tensor output is python scalar

* make backward compatible

* make code more readable

* adapt test functions

Co-authored-by: mgoldey <matthew.goldey@gmail.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Sylvain Gugger 2022-01-18 07:28:53 -05:00 committed by GitHub
parent f6d3fee855
commit 531336bbfd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 43 additions and 30 deletions

View File

@ -302,6 +302,8 @@ class DataCollatorForWav2Vec2Pretraining:
batch_size = batch["input_values"].shape[0]
mask_indices_seq_length = self.model._get_feat_extract_output_lengths(batch["input_values"].shape[-1])
# make sure masked sequence length is a Python scalar
mask_indices_seq_length = int(mask_indices_seq_length)
# make sure that no loss is computed on padded inputs
if batch.get("attention_mask") is not None:

View File

@ -23,6 +23,7 @@ from functools import partial
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
import torch
from packaging import version
from torch import Tensor, device, nn
from torch.nn import CrossEntropyLoss
@ -2362,3 +2363,13 @@ def apply_chunking_to_forward(
return torch.cat(output_chunks, dim=chunk_dim)
return forward_fn(*input_tensors)
def torch_int_div(tensor1, tensor2):
"""
A function that performs integer division across different versions of PyTorch.
"""
if version.parse(torch.__version__) < version.parse("1.8.0"):
return tensor1 // tensor2
else:
return torch.div(tensor1, tensor2, rounding_mode="floor")

View File

@ -33,7 +33,7 @@ from ...file_utils import (
replace_return_docstrings,
)
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput
from ...modeling_utils import PreTrainedModel
from ...modeling_utils import PreTrainedModel, torch_int_div
from ...utils import logging
from .configuration_hubert import HubertConfig
@ -829,7 +829,7 @@ class HubertPreTrainedModel(PreTrainedModel):
def _conv_out_length(input_length, kernel_size, stride):
# 1D convolutional layer output length formula taken
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
return (input_length - kernel_size) // stride + 1
return torch_int_div(input_length - kernel_size, stride) + 1
for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
input_lengths = _conv_out_length(input_lengths, kernel_size, stride)

View File

@ -29,7 +29,7 @@ from transformers.deepspeed import is_deepspeed_zero3_enabled
from ...activations import ACT2FN
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput
from ...modeling_utils import PreTrainedModel
from ...modeling_utils import PreTrainedModel, torch_int_div
from ...utils import logging
from .configuration_sew import SEWConfig
@ -735,7 +735,7 @@ class SEWPreTrainedModel(PreTrainedModel):
def _conv_out_length(input_length, kernel_size, stride):
# 1D convolutional layer output length formula taken
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
return (input_length - kernel_size) // stride + 1
return torch_int_div(input_length - kernel_size, stride) + 1
for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
input_lengths = _conv_out_length(input_lengths, kernel_size, stride)

View File

@ -30,7 +30,7 @@ from transformers.deepspeed import is_deepspeed_zero3_enabled
from ...activations import ACT2FN
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput
from ...modeling_utils import PreTrainedModel
from ...modeling_utils import PreTrainedModel, torch_int_div
from ...utils import logging
from .configuration_sew_d import SEWDConfig
@ -1266,7 +1266,7 @@ class SEWDPreTrainedModel(PreTrainedModel):
def _conv_out_length(input_length, kernel_size, stride):
# 1D convolutional layer output length formula taken
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
return (input_length - kernel_size) // stride + 1
return torch_int_div(input_length - kernel_size, stride) + 1
for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
input_lengths = _conv_out_length(input_lengths, kernel_size, stride)

View File

@ -35,7 +35,7 @@ from ...file_utils import (
replace_return_docstrings,
)
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput
from ...modeling_utils import PreTrainedModel
from ...modeling_utils import PreTrainedModel, torch_int_div
from ...utils import logging
from .configuration_unispeech import UniSpeechConfig
@ -969,7 +969,7 @@ class UniSpeechPreTrainedModel(PreTrainedModel):
def _conv_out_length(input_length, kernel_size, stride):
# 1D convolutional layer output length formula taken
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
return (input_length - kernel_size) // stride + 1
return torch_int_div(input_length - kernel_size, stride) + 1
for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
input_lengths = _conv_out_length(input_lengths, kernel_size, stride)

View File

@ -35,7 +35,7 @@ from ...file_utils import (
replace_return_docstrings,
)
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput, TokenClassifierOutput
from ...modeling_utils import PreTrainedModel
from ...modeling_utils import PreTrainedModel, torch_int_div
from ...utils import logging
from .configuration_unispeech_sat import UniSpeechSatConfig
@ -1003,7 +1003,7 @@ class UniSpeechSatPreTrainedModel(PreTrainedModel):
def _conv_out_length(input_length, kernel_size, stride):
# 1D convolutional layer output length formula taken
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
return (input_length - kernel_size) // stride + 1
return torch_int_div(input_length - kernel_size, stride) + 1
for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
input_lengths = _conv_out_length(input_lengths, kernel_size, stride)

View File

@ -41,7 +41,7 @@ from ...modeling_outputs import (
SequenceClassifierOutput,
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
from ...modeling_utils import PreTrainedModel, torch_int_div
from ...utils import logging
from .configuration_wav2vec2 import Wav2Vec2Config
@ -1104,7 +1104,7 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel):
def _conv_out_length(input_length, kernel_size, stride):
# 1D convolutional layer output length formula taken
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
return (input_length - kernel_size) // stride + 1
return torch_int_div(input_length - kernel_size, stride) + 1
for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
input_lengths = _conv_out_length(input_lengths, kernel_size, stride)

View File

@ -35,7 +35,7 @@ from ...file_utils import (
add_start_docstrings_to_model_forward,
)
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput, TokenClassifierOutput
from ...modeling_utils import PreTrainedModel
from ...modeling_utils import PreTrainedModel, torch_int_div
from ...utils import logging
from .configuration_wavlm import WavLMConfig
@ -1057,7 +1057,7 @@ class WavLMPreTrainedModel(PreTrainedModel):
def _conv_out_length(input_length, kernel_size, stride):
# 1D convolutional layer output length formula taken
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
return (input_length - kernel_size) // stride + 1
return torch_int_div(input_length - kernel_size, stride) + 1
for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
input_lengths = _conv_out_length(input_lengths, kernel_size, stride)

View File

@ -794,10 +794,10 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = Wav2Vec2ForPreTraining(config).to(torch_device)
features_shape = (
inputs_dict["input_values"].shape[0],
model._get_feat_extract_output_lengths(inputs_dict["input_values"].shape[1]),
)
batch_size = inputs_dict["input_values"].shape[0]
feature_seq_length = int(model._get_feat_extract_output_lengths(inputs_dict["input_values"].shape[1]))
features_shape = (batch_size, feature_seq_length)
mask_time_indices = _compute_mask_indices(
features_shape,
@ -1158,10 +1158,10 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
inputs_dict = feature_extractor(input_speech, return_tensors="pt", padding=True)
features_shape = (
inputs_dict["input_values"].shape[0],
model._get_feat_extract_output_lengths(torch.tensor(inputs_dict["input_values"].shape[1])),
)
batch_size = inputs_dict["input_values"].shape[0]
feature_seq_length = int(model._get_feat_extract_output_lengths(inputs_dict["input_values"].shape[1]))
features_shape = (batch_size, feature_seq_length)
np.random.seed(4)
mask_time_indices = _compute_mask_indices(
@ -1208,10 +1208,10 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
inputs_dict = feature_extractor(input_speech, return_tensors="pt", padding=True)
features_shape = (
inputs_dict["input_values"].shape[0],
model._get_feat_extract_output_lengths(torch.tensor(inputs_dict["input_values"].shape[1])),
)
batch_size = inputs_dict["input_values"].shape[0]
feature_seq_length = int(model._get_feat_extract_output_lengths(inputs_dict["input_values"].shape[1]))
features_shape = (batch_size, feature_seq_length)
torch.manual_seed(0)
mask_time_indices = _compute_mask_indices(
@ -1279,10 +1279,10 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
inputs_dict = feature_extractor(input_speech, return_tensors="pt", padding=True)
features_shape = (
inputs_dict["input_values"].shape[0],
model._get_feat_extract_output_lengths(inputs_dict["input_values"].shape[1]),
)
batch_size = inputs_dict["input_values"].shape[0]
feature_seq_length = int(model._get_feat_extract_output_lengths(inputs_dict["input_values"].shape[1]))
features_shape = (batch_size, feature_seq_length)
torch.manual_seed(0)
np.random.seed(0)