Jax: scipy version pin (#30402)

scipy pin for jax
This commit is contained in:
Joao Gante 2024-04-23 10:42:17 +01:00 committed by GitHub
parent 2d61823fa2
commit 31921d8d5e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 3 additions and 1 deletions

View File

@ -161,6 +161,7 @@ _deps = [
"safetensors>=0.4.1",
"sagemaker>=2.31.0",
"scikit-learn",
"scipy<1.13.0", # SciPy >= 1.13.0 is not supported with the current jax pin (`jax>=0.4.1,<=0.4.13`)
"sentencepiece>=0.1.91,!=0.1.92",
"sigopt",
"starlette",
@ -267,7 +268,7 @@ if os.name == "nt": # windows
extras["flax"] = [] # jax is not supported on windows
else:
extras["retrieval"] = deps_list("faiss-cpu", "datasets")
extras["flax"] = deps_list("jax", "jaxlib", "flax", "optax")
extras["flax"] = deps_list("jax", "jaxlib", "flax", "optax", "scipy")
extras["tokenizers"] = deps_list("tokenizers")
extras["ftfy"] = deps_list("ftfy")

View File

@ -67,6 +67,7 @@ deps = {
"safetensors": "safetensors>=0.4.1",
"sagemaker": "sagemaker>=2.31.0",
"scikit-learn": "scikit-learn",
"scipy": "scipy<1.13.0",
"sentencepiece": "sentencepiece>=0.1.91,!=0.1.92",
"sigopt": "sigopt",
"starlette": "starlette",