Avoid `jnp` import in `utils/generic.py` (#30322)

fix

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar 2024-04-18 19:46:46 +02:00 committed by GitHub
parent 60d5f8f9f0
commit 01ae3b87c0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 8 additions and 4 deletions

View File

@ -38,10 +38,6 @@ from .import_utils import (
)
if is_flax_available():
import jax.numpy as jnp
class cached_property(property):
"""
Descriptor that mimics @property but caches output in member variable.
@ -624,6 +620,8 @@ def transpose(array, axes=None):
return tf.transpose(array, perm=axes)
elif is_jax_tensor(array):
import jax.numpy as jnp
return jnp.transpose(array, axes=axes)
else:
raise ValueError(f"Type not supported for transpose: {type(array)}.")
@ -643,6 +641,8 @@ def reshape(array, newshape):
return tf.reshape(array, newshape)
elif is_jax_tensor(array):
import jax.numpy as jnp
return jnp.reshape(array, newshape)
else:
raise ValueError(f"Type not supported for reshape: {type(array)}.")
@ -662,6 +662,8 @@ def squeeze(array, axis=None):
return tf.squeeze(array, axis=axis)
elif is_jax_tensor(array):
import jax.numpy as jnp
return jnp.squeeze(array, axis=axis)
else:
raise ValueError(f"Type not supported for squeeze: {type(array)}.")
@ -681,6 +683,8 @@ def expand_dims(array, axis):
return tf.expand_dims(array, axis=axis)
elif is_jax_tensor(array):
import jax.numpy as jnp
return jnp.expand_dims(array, axis=axis)
else:
raise ValueError(f"Type not supported for expand_dims: {type(array)}.")