Make public versions of private tensor utils (#19775)

* Make public versions of private utils

* I need sleep
This commit is contained in:
Sylvain Gugger 2022-10-21 09:34:01 -04:00 committed by GitHub
parent 3aaabaa214
commit 9151e649a5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 80 additions and 49 deletions

View File

@ -20,8 +20,7 @@ from typing import Dict, List, Optional, Union
import numpy as np
from .feature_extraction_utils import BatchFeature, FeatureExtractionMixin
from .utils import PaddingStrategy, TensorType, is_tf_available, is_torch_available, logging, to_numpy
from .utils.generic import _is_tensorflow, _is_torch
from .utils import PaddingStrategy, TensorType, is_tf_tensor, is_torch_tensor, logging, to_numpy
logger = logging.get_logger(__name__)
@ -160,9 +159,9 @@ class SequenceFeatureExtractor(FeatureExtractionMixin):
first_element = required_input[index][0]
if return_tensors is None:
if is_tf_available() and _is_tensorflow(first_element):
if is_tf_tensor(first_element):
return_tensors = "tf"
elif is_torch_available() and _is_torch(first_element):
elif is_torch_tensor(first_element):
return_tensors = "pt"
elif isinstance(first_element, (int, float, list, tuple, np.ndarray)):
return_tensors = "np"

View File

@ -33,14 +33,16 @@ from .utils import (
copy_func,
download_url,
is_flax_available,
is_jax_tensor,
is_numpy_array,
is_offline_mode,
is_remote_url,
is_tf_available,
is_torch_available,
is_torch_device,
logging,
torch_required,
)
from .utils.generic import _is_jax, _is_numpy, _is_torch_device
if TYPE_CHECKING:
@ -150,10 +152,10 @@ class BatchFeature(UserDict):
import jax.numpy as jnp # noqa: F811
as_tensor = jnp.array
is_tensor = _is_jax
is_tensor = is_jax_tensor
else:
as_tensor = np.asarray
is_tensor = _is_numpy
is_tensor = is_numpy_array
# Do the tensor conversion in batch
for key, value in self.items():
@ -188,7 +190,7 @@ class BatchFeature(UserDict):
# This check catches things like APEX blindly calling "to" on all inputs to a module
# Otherwise it passes the casts down and casts the LongTensor containing the token idxs
# into a HalfTensor
if isinstance(device, str) or _is_torch_device(device) or isinstance(device, int):
if isinstance(device, str) or is_torch_device(device) or isinstance(device, int):
self.data = {k: v.to(device=device) for k, v in self.data.items()}
else:
logger.warning(f"Attempting to cast a BatchFeature to type {str(device)}. This is not supported.")

View File

@ -21,14 +21,21 @@ from packaging import version
import requests
from .utils import is_flax_available, is_tf_available, is_torch_available, is_vision_available
from .utils import (
ExplicitEnum,
is_jax_tensor,
is_tf_tensor,
is_torch_available,
is_torch_tensor,
is_vision_available,
to_numpy,
)
from .utils.constants import ( # noqa: F401
IMAGENET_DEFAULT_MEAN,
IMAGENET_DEFAULT_STD,
IMAGENET_STANDARD_MEAN,
IMAGENET_STANDARD_STD,
)
from .utils.generic import ExplicitEnum, _is_jax, _is_tensorflow, _is_torch, to_numpy
if is_vision_available():
@ -55,18 +62,6 @@ class ChannelDimension(ExplicitEnum):
LAST = "channels_last"
def is_torch_tensor(obj):
return _is_torch(obj) if is_torch_available() else False
def is_tf_tensor(obj):
return _is_tensorflow(obj) if is_tf_available() else False
def is_jax_tensor(obj):
return _is_jax(obj) if is_flax_available() else False
def is_valid_image(img):
return (
isinstance(img, (PIL.Image.Image, np.ndarray))

View File

@ -33,11 +33,9 @@ from ...tokenization_utils_base import (
TextInput,
TextInputPair,
TruncationStrategy,
_is_tensorflow,
_is_torch,
to_py_obj,
)
from ...utils import add_end_docstrings, is_tf_available, is_torch_available, logging
from ...utils import add_end_docstrings, is_tf_tensor, is_torch_tensor, logging
logger = logging.get_logger(__name__)
@ -1174,9 +1172,9 @@ class LukeTokenizer(RobertaTokenizer):
first_element = required_input[index][0]
# At this state, if `first_element` is still a list/tuple, it's an empty one so there is nothing to do.
if not isinstance(first_element, (int, list, tuple)):
if is_tf_available() and _is_tensorflow(first_element):
if is_tf_tensor(first_element):
return_tensors = "tf" if return_tensors is None else return_tensors
elif is_torch_available() and _is_torch(first_element):
elif is_torch_tensor(first_element):
return_tensors = "pt" if return_tensors is None else return_tensors
elif isinstance(first_element, np.ndarray):
return_tensors = "np" if return_tensors is None else return_tensors

View File

@ -37,11 +37,9 @@ from ...tokenization_utils_base import (
TextInput,
TextInputPair,
TruncationStrategy,
_is_tensorflow,
_is_torch,
to_py_obj,
)
from ...utils import add_end_docstrings, is_tf_available, is_torch_available, logging
from ...utils import add_end_docstrings, is_tf_tensor, is_torch_tensor, logging
logger = logging.get_logger(__name__)
@ -1287,9 +1285,9 @@ class MLukeTokenizer(PreTrainedTokenizer):
first_element = required_input[index][0]
# At this state, if `first_element` is still a list/tuple, it's an empty one so there is nothing to do.
if not isinstance(first_element, (int, list, tuple)):
if is_tf_available() and _is_tensorflow(first_element):
if is_tf_tensor(first_element):
return_tensors = "tf" if return_tensors is None else return_tensors
elif is_torch_available() and _is_torch(first_element):
elif is_torch_tensor(first_element):
return_tensors = "pt" if return_tensors is None else return_tensors
elif isinstance(first_element, np.ndarray):
return_tensors = "np" if return_tensors is None else return_tensors

View File

@ -45,16 +45,20 @@ from .utils import (
download_url,
extract_commit_hash,
is_flax_available,
is_jax_tensor,
is_numpy_array,
is_offline_mode,
is_remote_url,
is_tf_available,
is_tf_tensor,
is_tokenizers_available,
is_torch_available,
is_torch_device,
is_torch_tensor,
logging,
to_py_obj,
torch_required,
)
from .utils.generic import _is_jax, _is_numpy, _is_tensorflow, _is_torch, _is_torch_device
if TYPE_CHECKING:
@ -696,15 +700,10 @@ class BatchEncoding(UserDict):
import jax.numpy as jnp # noqa: F811
as_tensor = jnp.array
is_tensor = _is_jax
is_tensor = is_jax_tensor
else:
as_tensor = np.asarray
is_tensor = _is_numpy
# (mfuntowicz: This code is unreachable)
# else:
# raise ImportError(
# f"Unable to convert output to tensors format {tensor_type}"
# )
is_tensor = is_numpy_array
# Do the tensor conversion in batch
for key, value in self.items():
@ -753,7 +752,7 @@ class BatchEncoding(UserDict):
# This check catches things like APEX blindly calling "to" on all inputs to a module
# Otherwise it passes the casts down and casts the LongTensor containing the token idxs
# into a HalfTensor
if isinstance(device, str) or _is_torch_device(device) or isinstance(device, int):
if isinstance(device, str) or is_torch_device(device) or isinstance(device, int):
self.data = {k: v.to(device=device) for k, v in self.data.items()}
else:
logger.warning(f"Attempting to cast a BatchEncoding to type {str(device)}. This is not supported.")
@ -2925,9 +2924,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
break
# At this state, if `first_element` is still a list/tuple, it's an empty one so there is nothing to do.
if not isinstance(first_element, (int, list, tuple)):
if is_tf_available() and _is_tensorflow(first_element):
if is_tf_tensor(first_element):
return_tensors = "tf" if return_tensors is None else return_tensors
elif is_torch_available() and _is_torch(first_element):
elif is_torch_tensor(first_element):
return_tensors = "pt" if return_tensors is None else return_tensors
elif isinstance(first_element, np.ndarray):
return_tensors = "np" if return_tensors is None else return_tensors

View File

@ -40,7 +40,12 @@ from .generic import (
cached_property,
find_labels,
flatten_dict,
is_jax_tensor,
is_numpy_array,
is_tensor,
is_tf_tensor,
is_torch_device,
is_torch_tensor,
to_numpy,
to_py_obj,
working_or_temp_dir,

View File

@ -83,30 +83,65 @@ def _is_numpy(x):
return isinstance(x, np.ndarray)
def is_numpy_array(x):
"""
Tests if `x` is a numpy array or not.
"""
return _is_numpy(x)
def _is_torch(x):
import torch
return isinstance(x, torch.Tensor)
def is_torch_tensor(x):
"""
Tests if `x` is a torch tensor or not. Safe to call even if torch is not installed.
"""
return False if not is_torch_available() else _is_torch(x)
def _is_torch_device(x):
import torch
return isinstance(x, torch.device)
def is_torch_device(x):
"""
Tests if `x` is a torch device or not. Safe to call even if torch is not installed.
"""
return False if not is_torch_available() else _is_torch_device(x)
def _is_tensorflow(x):
import tensorflow as tf
return isinstance(x, tf.Tensor)
def is_tf_tensor(x):
"""
Tests if `x` is a tensorflow tensor or not. Safe to call even if tensorflow is not installed.
"""
return False if not is_tf_available() else _is_tensorflow(x)
def _is_jax(x):
import jax.numpy as jnp # noqa: F811
return isinstance(x, jnp.ndarray)
def is_jax_tensor(x):
"""
Tests if `x` is a Jax tensor or not. Safe to call even if jax is not installed.
"""
return False if not is_flax_available() else _is_jax(x)
def to_py_obj(obj):
"""
Convert a TensorFlow tensor, PyTorch tensor, Numpy array or python list to a python list.
@ -115,11 +150,11 @@ def to_py_obj(obj):
return {k: to_py_obj(v) for k, v in obj.items()}
elif isinstance(obj, (list, tuple)):
return [to_py_obj(o) for o in obj]
elif is_tf_available() and _is_tensorflow(obj):
elif is_tf_tensor(obj):
return obj.numpy().tolist()
elif is_torch_available() and _is_torch(obj):
elif is_torch_tensor(obj):
return obj.detach().cpu().tolist()
elif is_flax_available() and _is_jax(obj):
elif is_jax_tensor(obj):
return np.asarray(obj).tolist()
elif isinstance(obj, (np.ndarray, np.number)): # tolist also works on 0d np arrays
return obj.tolist()
@ -135,11 +170,11 @@ def to_numpy(obj):
return {k: to_numpy(v) for k, v in obj.items()}
elif isinstance(obj, (list, tuple)):
return np.array(obj)
elif is_tf_available() and _is_tensorflow(obj):
elif is_tf_tensor(obj):
return obj.numpy()
elif is_torch_available() and _is_torch(obj):
elif is_torch_tensor(obj):
return obj.detach().cpu().numpy()
elif is_flax_available() and _is_jax(obj):
elif is_jax_tensor(obj):
return np.asarray(obj)
else:
return obj