Avoid `jnp` import in `utils/generic.py` (#30322)
fix Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
parent
60d5f8f9f0
commit
01ae3b87c0
|
@ -38,10 +38,6 @@ from .import_utils import (
|
|||
)
|
||||
|
||||
|
||||
if is_flax_available():
|
||||
import jax.numpy as jnp
|
||||
|
||||
|
||||
class cached_property(property):
|
||||
"""
|
||||
Descriptor that mimics @property but caches output in member variable.
|
||||
|
@ -624,6 +620,8 @@ def transpose(array, axes=None):
|
|||
|
||||
return tf.transpose(array, perm=axes)
|
||||
elif is_jax_tensor(array):
|
||||
import jax.numpy as jnp
|
||||
|
||||
return jnp.transpose(array, axes=axes)
|
||||
else:
|
||||
raise ValueError(f"Type not supported for transpose: {type(array)}.")
|
||||
|
@ -643,6 +641,8 @@ def reshape(array, newshape):
|
|||
|
||||
return tf.reshape(array, newshape)
|
||||
elif is_jax_tensor(array):
|
||||
import jax.numpy as jnp
|
||||
|
||||
return jnp.reshape(array, newshape)
|
||||
else:
|
||||
raise ValueError(f"Type not supported for reshape: {type(array)}.")
|
||||
|
@ -662,6 +662,8 @@ def squeeze(array, axis=None):
|
|||
|
||||
return tf.squeeze(array, axis=axis)
|
||||
elif is_jax_tensor(array):
|
||||
import jax.numpy as jnp
|
||||
|
||||
return jnp.squeeze(array, axis=axis)
|
||||
else:
|
||||
raise ValueError(f"Type not supported for squeeze: {type(array)}.")
|
||||
|
@ -681,6 +683,8 @@ def expand_dims(array, axis):
|
|||
|
||||
return tf.expand_dims(array, axis=axis)
|
||||
elif is_jax_tensor(array):
|
||||
import jax.numpy as jnp
|
||||
|
||||
return jnp.expand_dims(array, axis=axis)
|
||||
else:
|
||||
raise ValueError(f"Type not supported for expand_dims: {type(array)}.")
|
||||
|
|
Loading…
Reference in New Issue