Make public versions of private tensor utils (#19775)
* Make public versions of private utils * I need sleep
This commit is contained in:
parent
3aaabaa214
commit
9151e649a5
|
@ -20,8 +20,7 @@ from typing import Dict, List, Optional, Union
|
|||
import numpy as np
|
||||
|
||||
from .feature_extraction_utils import BatchFeature, FeatureExtractionMixin
|
||||
from .utils import PaddingStrategy, TensorType, is_tf_available, is_torch_available, logging, to_numpy
|
||||
from .utils.generic import _is_tensorflow, _is_torch
|
||||
from .utils import PaddingStrategy, TensorType, is_tf_tensor, is_torch_tensor, logging, to_numpy
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
@ -160,9 +159,9 @@ class SequenceFeatureExtractor(FeatureExtractionMixin):
|
|||
first_element = required_input[index][0]
|
||||
|
||||
if return_tensors is None:
|
||||
if is_tf_available() and _is_tensorflow(first_element):
|
||||
if is_tf_tensor(first_element):
|
||||
return_tensors = "tf"
|
||||
elif is_torch_available() and _is_torch(first_element):
|
||||
elif is_torch_tensor(first_element):
|
||||
return_tensors = "pt"
|
||||
elif isinstance(first_element, (int, float, list, tuple, np.ndarray)):
|
||||
return_tensors = "np"
|
||||
|
|
|
@ -33,14 +33,16 @@ from .utils import (
|
|||
copy_func,
|
||||
download_url,
|
||||
is_flax_available,
|
||||
is_jax_tensor,
|
||||
is_numpy_array,
|
||||
is_offline_mode,
|
||||
is_remote_url,
|
||||
is_tf_available,
|
||||
is_torch_available,
|
||||
is_torch_device,
|
||||
logging,
|
||||
torch_required,
|
||||
)
|
||||
from .utils.generic import _is_jax, _is_numpy, _is_torch_device
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -150,10 +152,10 @@ class BatchFeature(UserDict):
|
|||
import jax.numpy as jnp # noqa: F811
|
||||
|
||||
as_tensor = jnp.array
|
||||
is_tensor = _is_jax
|
||||
is_tensor = is_jax_tensor
|
||||
else:
|
||||
as_tensor = np.asarray
|
||||
is_tensor = _is_numpy
|
||||
is_tensor = is_numpy_array
|
||||
|
||||
# Do the tensor conversion in batch
|
||||
for key, value in self.items():
|
||||
|
@ -188,7 +190,7 @@ class BatchFeature(UserDict):
|
|||
# This check catches things like APEX blindly calling "to" on all inputs to a module
|
||||
# Otherwise it passes the casts down and casts the LongTensor containing the token idxs
|
||||
# into a HalfTensor
|
||||
if isinstance(device, str) or _is_torch_device(device) or isinstance(device, int):
|
||||
if isinstance(device, str) or is_torch_device(device) or isinstance(device, int):
|
||||
self.data = {k: v.to(device=device) for k, v in self.data.items()}
|
||||
else:
|
||||
logger.warning(f"Attempting to cast a BatchFeature to type {str(device)}. This is not supported.")
|
||||
|
|
|
@ -21,14 +21,21 @@ from packaging import version
|
|||
|
||||
import requests
|
||||
|
||||
from .utils import is_flax_available, is_tf_available, is_torch_available, is_vision_available
|
||||
from .utils import (
|
||||
ExplicitEnum,
|
||||
is_jax_tensor,
|
||||
is_tf_tensor,
|
||||
is_torch_available,
|
||||
is_torch_tensor,
|
||||
is_vision_available,
|
||||
to_numpy,
|
||||
)
|
||||
from .utils.constants import ( # noqa: F401
|
||||
IMAGENET_DEFAULT_MEAN,
|
||||
IMAGENET_DEFAULT_STD,
|
||||
IMAGENET_STANDARD_MEAN,
|
||||
IMAGENET_STANDARD_STD,
|
||||
)
|
||||
from .utils.generic import ExplicitEnum, _is_jax, _is_tensorflow, _is_torch, to_numpy
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
|
@ -55,18 +62,6 @@ class ChannelDimension(ExplicitEnum):
|
|||
LAST = "channels_last"
|
||||
|
||||
|
||||
def is_torch_tensor(obj):
|
||||
return _is_torch(obj) if is_torch_available() else False
|
||||
|
||||
|
||||
def is_tf_tensor(obj):
|
||||
return _is_tensorflow(obj) if is_tf_available() else False
|
||||
|
||||
|
||||
def is_jax_tensor(obj):
|
||||
return _is_jax(obj) if is_flax_available() else False
|
||||
|
||||
|
||||
def is_valid_image(img):
|
||||
return (
|
||||
isinstance(img, (PIL.Image.Image, np.ndarray))
|
||||
|
|
|
@ -33,11 +33,9 @@ from ...tokenization_utils_base import (
|
|||
TextInput,
|
||||
TextInputPair,
|
||||
TruncationStrategy,
|
||||
_is_tensorflow,
|
||||
_is_torch,
|
||||
to_py_obj,
|
||||
)
|
||||
from ...utils import add_end_docstrings, is_tf_available, is_torch_available, logging
|
||||
from ...utils import add_end_docstrings, is_tf_tensor, is_torch_tensor, logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
@ -1174,9 +1172,9 @@ class LukeTokenizer(RobertaTokenizer):
|
|||
first_element = required_input[index][0]
|
||||
# At this state, if `first_element` is still a list/tuple, it's an empty one so there is nothing to do.
|
||||
if not isinstance(first_element, (int, list, tuple)):
|
||||
if is_tf_available() and _is_tensorflow(first_element):
|
||||
if is_tf_tensor(first_element):
|
||||
return_tensors = "tf" if return_tensors is None else return_tensors
|
||||
elif is_torch_available() and _is_torch(first_element):
|
||||
elif is_torch_tensor(first_element):
|
||||
return_tensors = "pt" if return_tensors is None else return_tensors
|
||||
elif isinstance(first_element, np.ndarray):
|
||||
return_tensors = "np" if return_tensors is None else return_tensors
|
||||
|
|
|
@ -37,11 +37,9 @@ from ...tokenization_utils_base import (
|
|||
TextInput,
|
||||
TextInputPair,
|
||||
TruncationStrategy,
|
||||
_is_tensorflow,
|
||||
_is_torch,
|
||||
to_py_obj,
|
||||
)
|
||||
from ...utils import add_end_docstrings, is_tf_available, is_torch_available, logging
|
||||
from ...utils import add_end_docstrings, is_tf_tensor, is_torch_tensor, logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
@ -1287,9 +1285,9 @@ class MLukeTokenizer(PreTrainedTokenizer):
|
|||
first_element = required_input[index][0]
|
||||
# At this state, if `first_element` is still a list/tuple, it's an empty one so there is nothing to do.
|
||||
if not isinstance(first_element, (int, list, tuple)):
|
||||
if is_tf_available() and _is_tensorflow(first_element):
|
||||
if is_tf_tensor(first_element):
|
||||
return_tensors = "tf" if return_tensors is None else return_tensors
|
||||
elif is_torch_available() and _is_torch(first_element):
|
||||
elif is_torch_tensor(first_element):
|
||||
return_tensors = "pt" if return_tensors is None else return_tensors
|
||||
elif isinstance(first_element, np.ndarray):
|
||||
return_tensors = "np" if return_tensors is None else return_tensors
|
||||
|
|
|
@ -45,16 +45,20 @@ from .utils import (
|
|||
download_url,
|
||||
extract_commit_hash,
|
||||
is_flax_available,
|
||||
is_jax_tensor,
|
||||
is_numpy_array,
|
||||
is_offline_mode,
|
||||
is_remote_url,
|
||||
is_tf_available,
|
||||
is_tf_tensor,
|
||||
is_tokenizers_available,
|
||||
is_torch_available,
|
||||
is_torch_device,
|
||||
is_torch_tensor,
|
||||
logging,
|
||||
to_py_obj,
|
||||
torch_required,
|
||||
)
|
||||
from .utils.generic import _is_jax, _is_numpy, _is_tensorflow, _is_torch, _is_torch_device
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -696,15 +700,10 @@ class BatchEncoding(UserDict):
|
|||
import jax.numpy as jnp # noqa: F811
|
||||
|
||||
as_tensor = jnp.array
|
||||
is_tensor = _is_jax
|
||||
is_tensor = is_jax_tensor
|
||||
else:
|
||||
as_tensor = np.asarray
|
||||
is_tensor = _is_numpy
|
||||
# (mfuntowicz: This code is unreachable)
|
||||
# else:
|
||||
# raise ImportError(
|
||||
# f"Unable to convert output to tensors format {tensor_type}"
|
||||
# )
|
||||
is_tensor = is_numpy_array
|
||||
|
||||
# Do the tensor conversion in batch
|
||||
for key, value in self.items():
|
||||
|
@ -753,7 +752,7 @@ class BatchEncoding(UserDict):
|
|||
# This check catches things like APEX blindly calling "to" on all inputs to a module
|
||||
# Otherwise it passes the casts down and casts the LongTensor containing the token idxs
|
||||
# into a HalfTensor
|
||||
if isinstance(device, str) or _is_torch_device(device) or isinstance(device, int):
|
||||
if isinstance(device, str) or is_torch_device(device) or isinstance(device, int):
|
||||
self.data = {k: v.to(device=device) for k, v in self.data.items()}
|
||||
else:
|
||||
logger.warning(f"Attempting to cast a BatchEncoding to type {str(device)}. This is not supported.")
|
||||
|
@ -2925,9 +2924,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
|||
break
|
||||
# At this state, if `first_element` is still a list/tuple, it's an empty one so there is nothing to do.
|
||||
if not isinstance(first_element, (int, list, tuple)):
|
||||
if is_tf_available() and _is_tensorflow(first_element):
|
||||
if is_tf_tensor(first_element):
|
||||
return_tensors = "tf" if return_tensors is None else return_tensors
|
||||
elif is_torch_available() and _is_torch(first_element):
|
||||
elif is_torch_tensor(first_element):
|
||||
return_tensors = "pt" if return_tensors is None else return_tensors
|
||||
elif isinstance(first_element, np.ndarray):
|
||||
return_tensors = "np" if return_tensors is None else return_tensors
|
||||
|
|
|
@ -40,7 +40,12 @@ from .generic import (
|
|||
cached_property,
|
||||
find_labels,
|
||||
flatten_dict,
|
||||
is_jax_tensor,
|
||||
is_numpy_array,
|
||||
is_tensor,
|
||||
is_tf_tensor,
|
||||
is_torch_device,
|
||||
is_torch_tensor,
|
||||
to_numpy,
|
||||
to_py_obj,
|
||||
working_or_temp_dir,
|
||||
|
|
|
@ -83,30 +83,65 @@ def _is_numpy(x):
|
|||
return isinstance(x, np.ndarray)
|
||||
|
||||
|
||||
def is_numpy_array(x):
|
||||
"""
|
||||
Tests if `x` is a numpy array or not.
|
||||
"""
|
||||
return _is_numpy(x)
|
||||
|
||||
|
||||
def _is_torch(x):
|
||||
import torch
|
||||
|
||||
return isinstance(x, torch.Tensor)
|
||||
|
||||
|
||||
def is_torch_tensor(x):
|
||||
"""
|
||||
Tests if `x` is a torch tensor or not. Safe to call even if torch is not installed.
|
||||
"""
|
||||
return False if not is_torch_available() else _is_torch(x)
|
||||
|
||||
|
||||
def _is_torch_device(x):
|
||||
import torch
|
||||
|
||||
return isinstance(x, torch.device)
|
||||
|
||||
|
||||
def is_torch_device(x):
|
||||
"""
|
||||
Tests if `x` is a torch device or not. Safe to call even if torch is not installed.
|
||||
"""
|
||||
return False if not is_torch_available() else _is_torch_device(x)
|
||||
|
||||
|
||||
def _is_tensorflow(x):
|
||||
import tensorflow as tf
|
||||
|
||||
return isinstance(x, tf.Tensor)
|
||||
|
||||
|
||||
def is_tf_tensor(x):
|
||||
"""
|
||||
Tests if `x` is a tensorflow tensor or not. Safe to call even if tensorflow is not installed.
|
||||
"""
|
||||
return False if not is_tf_available() else _is_tensorflow(x)
|
||||
|
||||
|
||||
def _is_jax(x):
|
||||
import jax.numpy as jnp # noqa: F811
|
||||
|
||||
return isinstance(x, jnp.ndarray)
|
||||
|
||||
|
||||
def is_jax_tensor(x):
|
||||
"""
|
||||
Tests if `x` is a Jax tensor or not. Safe to call even if jax is not installed.
|
||||
"""
|
||||
return False if not is_flax_available() else _is_jax(x)
|
||||
|
||||
|
||||
def to_py_obj(obj):
|
||||
"""
|
||||
Convert a TensorFlow tensor, PyTorch tensor, Numpy array or python list to a python list.
|
||||
|
@ -115,11 +150,11 @@ def to_py_obj(obj):
|
|||
return {k: to_py_obj(v) for k, v in obj.items()}
|
||||
elif isinstance(obj, (list, tuple)):
|
||||
return [to_py_obj(o) for o in obj]
|
||||
elif is_tf_available() and _is_tensorflow(obj):
|
||||
elif is_tf_tensor(obj):
|
||||
return obj.numpy().tolist()
|
||||
elif is_torch_available() and _is_torch(obj):
|
||||
elif is_torch_tensor(obj):
|
||||
return obj.detach().cpu().tolist()
|
||||
elif is_flax_available() and _is_jax(obj):
|
||||
elif is_jax_tensor(obj):
|
||||
return np.asarray(obj).tolist()
|
||||
elif isinstance(obj, (np.ndarray, np.number)): # tolist also works on 0d np arrays
|
||||
return obj.tolist()
|
||||
|
@ -135,11 +170,11 @@ def to_numpy(obj):
|
|||
return {k: to_numpy(v) for k, v in obj.items()}
|
||||
elif isinstance(obj, (list, tuple)):
|
||||
return np.array(obj)
|
||||
elif is_tf_available() and _is_tensorflow(obj):
|
||||
elif is_tf_tensor(obj):
|
||||
return obj.numpy()
|
||||
elif is_torch_available() and _is_torch(obj):
|
||||
elif is_torch_tensor(obj):
|
||||
return obj.detach().cpu().numpy()
|
||||
elif is_flax_available() and _is_jax(obj):
|
||||
elif is_jax_tensor(obj):
|
||||
return np.asarray(obj)
|
||||
else:
|
||||
return obj
|
||||
|
|
Loading…
Reference in New Issue