Compare commits
14 Commits
main
...
v4.41-rele
Author | SHA1 | Date |
---|---|---|
ArthurZucker | ab0f050b42 | |
Matt | 57f5553d2e | |
oOraph | 73b180c2be | |
Aymeric Roucher | a6325a77b2 | |
Pablo Montalvo | 9ccdc84cb2 | |
Lucain | 12aa3167b4 | |
ArthurZucker | 75f15f39a0 | |
Pablo Montalvo | 8282db5cc9 | |
ArthurZucker | e5b788ade3 | |
Raushan Turganbay | 9d054596e7 | |
hoshi-hiyouga | e5d174f12a | |
Arthur | 04141855bd | |
Arthur | 6d2439a126 | |
ArthurZucker | 4c6c45ba13 |
|
@ -1,3 +1,4 @@
|
|||
# Optimizing inference
|
||||
|
||||
perf_infer_gpu_many: perf_infer_gpu_one
|
||||
transformers_agents: agents
|
||||
|
|
|
@ -61,7 +61,7 @@ from transformers.utils import check_min_version, send_example_telemetry
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.41.0.dev0")
|
||||
check_min_version("4.41.0")
|
||||
|
||||
Array = Any
|
||||
Dataset = datasets.arrow_dataset.Dataset
|
||||
|
|
|
@ -60,7 +60,7 @@ from transformers.utils.versions import require_version
|
|||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risk.
|
||||
check_min_version("4.41.0.dev0")
|
||||
check_min_version("4.41.0")
|
||||
|
||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/flax/speech-recognition/requirements.txt")
|
||||
|
||||
|
|
|
@ -55,7 +55,7 @@ from transformers.utils import check_min_version, send_example_telemetry
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.41.0.dev0")
|
||||
check_min_version("4.41.0")
|
||||
|
||||
Array = Any
|
||||
Dataset = datasets.arrow_dataset.Dataset
|
||||
|
|
|
@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.41.0.dev0")
|
||||
check_min_version("4.41.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")
|
||||
|
||||
|
|
|
@ -45,7 +45,7 @@ from transformers.utils.versions import require_version
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.41.0.dev0")
|
||||
check_min_version("4.41.0")
|
||||
|
||||
require_version("datasets>=1.14.0", "To fix: pip install -r examples/pytorch/audio-classification/requirements.txt")
|
||||
|
||||
|
|
|
@ -54,7 +54,7 @@ from transformers.utils.versions import require_version
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.41.0.dev0")
|
||||
check_min_version("4.41.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/contrastive-image-text/requirements.txt")
|
||||
|
||||
|
|
|
@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.41.0.dev0")
|
||||
check_min_version("4.41.0")
|
||||
|
||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt")
|
||||
|
||||
|
|
|
@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
|
|||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.41.0.dev0")
|
||||
check_min_version("4.41.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
|
|
@ -43,7 +43,7 @@ from transformers.utils.versions import require_version
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.41.0.dev0")
|
||||
check_min_version("4.41.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt")
|
||||
|
||||
|
|
|
@ -48,7 +48,7 @@ Any model supported by the AutoModelForMaskedImageModeling API can be used.
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.41.0.dev0")
|
||||
check_min_version("4.41.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt")
|
||||
|
||||
|
|
|
@ -53,7 +53,7 @@ Any model supported by the AutoModelForMaskedImageModeling API can be used.
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.41.0.dev0")
|
||||
check_min_version("4.41.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt")
|
||||
|
||||
|
|
|
@ -55,7 +55,7 @@ from transformers.utils.versions import require_version
|
|||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.41.0.dev0")
|
||||
check_min_version("4.41.0")
|
||||
|
||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
||||
|
||||
|
|
|
@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
|
|||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.41.0.dev0")
|
||||
check_min_version("4.41.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
|
|
@ -58,7 +58,7 @@ from transformers.utils.versions import require_version
|
|||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.41.0.dev0")
|
||||
check_min_version("4.41.0")
|
||||
|
||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
||||
|
||||
|
|
|
@ -60,7 +60,7 @@ from transformers.utils.versions import require_version
|
|||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.41.0.dev0")
|
||||
check_min_version("4.41.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
|
|
@ -54,7 +54,7 @@ from transformers.utils.versions import require_version
|
|||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.41.0.dev0")
|
||||
check_min_version("4.41.0")
|
||||
|
||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
||||
|
||||
|
|
|
@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
|
|||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.41.0.dev0")
|
||||
check_min_version("4.41.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
||||
|
|
|
@ -47,7 +47,7 @@ from transformers.utils.versions import require_version
|
|||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.41.0.dev0")
|
||||
check_min_version("4.41.0")
|
||||
|
||||
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
||||
|
||||
|
|
|
@ -47,7 +47,7 @@ from transformers.utils import PaddingStrategy, check_min_version, send_example_
|
|||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.41.0.dev0")
|
||||
check_min_version("4.41.0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
|
@ -56,7 +56,7 @@ from transformers.utils import PaddingStrategy, check_min_version, send_example_
|
|||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.41.0.dev0")
|
||||
check_min_version("4.41.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
# You should update this to your particular problem to have better documentation of `model_type`
|
||||
|
|
|
@ -48,7 +48,8 @@ from transformers.utils.versions import require_version
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.40.0.dev0")
|
||||
check_min_version("4.41.0")
|
||||
|
||||
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/object-detection/requirements.txt")
|
||||
|
||||
|
||||
|
|
|
@ -51,7 +51,7 @@ from transformers.utils.versions import require_version
|
|||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.40.0.dev0")
|
||||
check_min_version("4.41.0")
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = get_logger(__name__)
|
||||
|
|
|
@ -50,7 +50,7 @@ from transformers.utils.versions import require_version
|
|||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.41.0.dev0")
|
||||
check_min_version("4.41.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
|
||||
|
||||
|
|
|
@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
|
|||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.41.0.dev0")
|
||||
check_min_version("4.41.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
|
||||
|
||||
|
|
|
@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
|
|||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.41.0.dev0")
|
||||
check_min_version("4.41.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
|
||||
|
||||
|
|
|
@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
|
|||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.41.0.dev0")
|
||||
check_min_version("4.41.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
|
||||
|
||||
|
|
|
@ -46,7 +46,7 @@ from transformers.utils.versions import require_version
|
|||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.41.0.dev0")
|
||||
check_min_version("4.41.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")
|
||||
|
||||
|
|
|
@ -50,7 +50,7 @@ from transformers.utils.versions import require_version
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.41.0.dev0")
|
||||
check_min_version("4.41.0")
|
||||
|
||||
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/semantic-segmentation/requirements.txt")
|
||||
|
||||
|
|
|
@ -49,7 +49,7 @@ from transformers.utils.versions import require_version
|
|||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.41.0.dev0")
|
||||
check_min_version("4.41.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
|
|
@ -50,7 +50,7 @@ from transformers.utils.versions import require_version
|
|||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.41.0.dev0")
|
||||
check_min_version("4.41.0")
|
||||
|
||||
require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
|
||||
|
||||
|
|
|
@ -53,7 +53,7 @@ from transformers.utils.versions import require_version
|
|||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.41.0.dev0")
|
||||
check_min_version("4.41.0")
|
||||
|
||||
require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
|
||||
|
||||
|
|
|
@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
|
|||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.41.0.dev0")
|
||||
check_min_version("4.41.0")
|
||||
|
||||
require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
|
||||
|
||||
|
|
|
@ -52,7 +52,7 @@ from transformers.utils.versions import require_version
|
|||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.41.0.dev0")
|
||||
check_min_version("4.41.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
|
||||
|
||||
|
|
|
@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
|
|||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.41.0.dev0")
|
||||
check_min_version("4.41.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
|
||||
|
|
|
@ -47,7 +47,7 @@ from transformers.utils.versions import require_version
|
|||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.41.0.dev0")
|
||||
check_min_version("4.41.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
|
||||
|
||||
|
|
|
@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
|
|||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.41.0.dev0")
|
||||
check_min_version("4.41.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
|
||||
|
||||
|
|
|
@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
|
|||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.41.0.dev0")
|
||||
check_min_version("4.41.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
|
|
@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
|
|||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.41.0.dev0")
|
||||
check_min_version("4.41.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
|
||||
|
||||
|
|
|
@ -49,7 +49,7 @@ from transformers.utils.versions import require_version
|
|||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.41.0.dev0")
|
||||
check_min_version("4.41.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")
|
||||
|
||||
|
|
|
@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
|
|||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.41.0.dev0")
|
||||
check_min_version("4.41.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")
|
||||
|
|
|
@ -52,7 +52,7 @@ from transformers.utils.versions import require_version
|
|||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.41.0.dev0")
|
||||
check_min_version("4.41.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt")
|
||||
|
||||
|
|
|
@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
|
|||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.41.0.dev0")
|
||||
check_min_version("4.41.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt")
|
||||
|
|
|
@ -51,7 +51,7 @@ from transformers.utils.versions import require_version
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.41.0.dev0")
|
||||
check_min_version("4.41.0")
|
||||
|
||||
require_version(
|
||||
"datasets>=1.8.0", "To fix: pip install -r examples/tensorflow/contrastive-image-text/requirements.txt"
|
||||
|
|
|
@ -55,7 +55,7 @@ from transformers.utils.versions import require_version
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.41.0.dev0")
|
||||
check_min_version("4.41.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt")
|
||||
|
||||
|
|
|
@ -50,7 +50,7 @@ from transformers.utils import PaddingStrategy, check_min_version, send_example_
|
|||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.41.0.dev0")
|
||||
check_min_version("4.41.0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
|
@ -62,7 +62,7 @@ except (ModuleNotFoundError, ImportError):
|
|||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.41.0.dev0")
|
||||
check_min_version("4.41.0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
|
@ -53,7 +53,7 @@ from transformers.utils.versions import require_version
|
|||
|
||||
# region Checking dependencies
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.41.0.dev0")
|
||||
check_min_version("4.41.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
|
||||
|
||||
|
|
|
@ -47,7 +47,7 @@ from transformers.utils import check_min_version, send_example_telemetry
|
|||
|
||||
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.41.0.dev0")
|
||||
check_min_version("4.41.0")
|
||||
|
||||
task_to_keys = {
|
||||
"cola": ("sentence", None),
|
||||
|
|
|
@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
|
|||
|
||||
# region Dependencies and constants
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.41.0.dev0")
|
||||
check_min_version("4.41.0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
|
||||
|
||||
|
|
2
setup.py
2
setup.py
|
@ -426,7 +426,7 @@ install_requires = [
|
|||
|
||||
setup(
|
||||
name="transformers",
|
||||
version="4.41.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
version="4.41.2", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
author="The Hugging Face team (past and future) with the help of all our contributors (https://github.com/huggingface/transformers/graphs/contributors)",
|
||||
author_email="transformers@huggingface.co",
|
||||
description="State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow",
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
# to defer the actual importing for when the objects are requested. This way `import transformers` provides the names
|
||||
# in the namespace without actually importing anything (and especially none of the backends).
|
||||
|
||||
__version__ = "4.41.0.dev0"
|
||||
__version__ = "4.41.2"
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
|
|
|
@ -197,7 +197,10 @@ def get_class_in_module(class_name: str, module_path: Union[str, os.PathLike]) -
|
|||
Returns:
|
||||
`typing.Type`: The class looked for.
|
||||
"""
|
||||
name = os.path.normpath(module_path).rstrip(".py").replace(os.path.sep, ".")
|
||||
name = os.path.normpath(module_path)
|
||||
if name.endswith(".py"):
|
||||
name = name[:-3]
|
||||
name = name.replace(os.path.sep, ".")
|
||||
module_spec = importlib.util.spec_from_file_location(name, location=Path(HF_MODULES_CACHE) / module_path)
|
||||
module = sys.modules.get(name)
|
||||
if module is None:
|
||||
|
|
|
@ -1354,6 +1354,23 @@ class GenerationMixin:
|
|||
self._static_cache.reset() # reset the cache for a new generation
|
||||
return self._static_cache
|
||||
|
||||
def _get_decoder_start_token_id(
|
||||
self, decoder_start_token_id: Union[int, List[int]] = None, bos_token_id: int = None
|
||||
) -> int:
|
||||
decoder_start_token_id = (
|
||||
decoder_start_token_id
|
||||
if decoder_start_token_id is not None
|
||||
else self.generation_config.decoder_start_token_id
|
||||
)
|
||||
bos_token_id = bos_token_id if bos_token_id is not None else self.generation_config.bos_token_id
|
||||
|
||||
if decoder_start_token_id is not None:
|
||||
return decoder_start_token_id
|
||||
elif bos_token_id is not None:
|
||||
return bos_token_id
|
||||
else:
|
||||
return
|
||||
|
||||
def _prepare_special_tokens(
|
||||
self,
|
||||
generation_config: GenerationConfig,
|
||||
|
@ -1378,11 +1395,16 @@ class GenerationMixin:
|
|||
return token
|
||||
return torch.tensor(token, device=device, dtype=torch.long)
|
||||
|
||||
# for BC we also try to get `decoder_start_token_id` from model's generation config (#30892)
|
||||
if self.config.is_encoder_decoder:
|
||||
generation_config.decoder_start_token_id = self._get_decoder_start_token_id(
|
||||
generation_config.decoder_start_token_id, generation_config.bos_token_id
|
||||
)
|
||||
|
||||
bos_token_id = _tensor_or_none(generation_config.bos_token_id, device=device)
|
||||
eos_token_id = _tensor_or_none(generation_config.eos_token_id, device=device)
|
||||
pad_token_id = _tensor_or_none(generation_config.pad_token_id, device=device)
|
||||
decoder_start_token_id = _tensor_or_none(generation_config.decoder_start_token_id, device=device)
|
||||
decoder_start_token_id = decoder_start_token_id if decoder_start_token_id is not None else bos_token_id
|
||||
|
||||
# We can have more than one eos token. Always treat it as a 1D tensor (when it exists).
|
||||
if eos_token_id is not None and eos_token_id.ndim == 0:
|
||||
|
|
|
@ -31,17 +31,8 @@ from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Union
|
|||
import numpy as np
|
||||
import packaging.version
|
||||
|
||||
from .. import PreTrainedModel, TFPreTrainedModel
|
||||
from .. import __version__ as version
|
||||
from ..utils import (
|
||||
PushToHubMixin,
|
||||
flatten_dict,
|
||||
is_datasets_available,
|
||||
is_pandas_available,
|
||||
is_tf_available,
|
||||
is_torch_available,
|
||||
logging,
|
||||
)
|
||||
from ..utils import flatten_dict, is_datasets_available, is_pandas_available, is_torch_available, logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
@ -78,7 +69,6 @@ if TYPE_CHECKING and _has_neptune:
|
|||
except importlib.metadata.PackageNotFoundError:
|
||||
_has_neptune = False
|
||||
|
||||
from .. import modelcard # noqa: E402
|
||||
from ..trainer_callback import ProgressCallback, TrainerCallback # noqa: E402
|
||||
from ..trainer_utils import PREFIX_CHECKPOINT_DIR, BestRun, IntervalStrategy # noqa: E402
|
||||
from ..training_args import ParallelMode # noqa: E402
|
||||
|
@ -673,22 +663,6 @@ class TensorBoardCallback(TrainerCallback):
|
|||
self.tb_writer = None
|
||||
|
||||
|
||||
def save_model_architecture_to_file(model: Any, output_dir: str):
|
||||
with open(f"{output_dir}/model_architecture.txt", "w+") as f:
|
||||
if isinstance(model, PreTrainedModel):
|
||||
print(model, file=f)
|
||||
elif is_tf_available() and isinstance(model, TFPreTrainedModel):
|
||||
|
||||
def print_to_file(s):
|
||||
print(s, file=f)
|
||||
|
||||
model.summary(print_fn=print_to_file)
|
||||
elif is_torch_available() and (
|
||||
isinstance(model, (torch.nn.Module, PushToHubMixin)) and hasattr(model, "base_model")
|
||||
):
|
||||
print(model, file=f)
|
||||
|
||||
|
||||
class WandbCallback(TrainerCallback):
|
||||
"""
|
||||
A [`TrainerCallback`] that logs metrics, media, model checkpoints to [Weight and Biases](https://www.wandb.com/).
|
||||
|
@ -754,9 +728,6 @@ class WandbCallback(TrainerCallback):
|
|||
if hasattr(model, "config") and model.config is not None:
|
||||
model_config = model.config.to_dict()
|
||||
combined_dict = {**model_config, **combined_dict}
|
||||
if hasattr(model, "peft_config") and model.peft_config is not None:
|
||||
peft_config = model.peft_config
|
||||
combined_dict = {**{"peft_config": peft_config}, **combined_dict}
|
||||
trial_name = state.trial_name
|
||||
init_args = {}
|
||||
if trial_name is not None:
|
||||
|
@ -790,46 +761,6 @@ class WandbCallback(TrainerCallback):
|
|||
self._wandb.watch(model, log=_watch_model, log_freq=max(100, state.logging_steps))
|
||||
self._wandb.run._label(code="transformers_trainer")
|
||||
|
||||
# add number of model parameters to wandb config
|
||||
try:
|
||||
self._wandb.config["model/num_parameters"] = model.num_parameters()
|
||||
except AttributeError:
|
||||
logger.info("Could not log the number of model parameters in Weights & Biases.")
|
||||
|
||||
# log the initial model and architecture to an artifact
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
model_name = (
|
||||
f"model-{self._wandb.run.id}"
|
||||
if (args.run_name is None or args.run_name == args.output_dir)
|
||||
else f"model-{self._wandb.run.name}"
|
||||
)
|
||||
model_artifact = self._wandb.Artifact(
|
||||
name=model_name,
|
||||
type="model",
|
||||
metadata={
|
||||
"model_config": model.config.to_dict() if hasattr(model, "config") else None,
|
||||
"num_parameters": self._wandb.config.get("model/num_parameters"),
|
||||
"initial_model": True,
|
||||
},
|
||||
)
|
||||
model.save_pretrained(temp_dir)
|
||||
# add the architecture to a separate text file
|
||||
save_model_architecture_to_file(model, temp_dir)
|
||||
|
||||
for f in Path(temp_dir).glob("*"):
|
||||
if f.is_file():
|
||||
with model_artifact.new_file(f.name, mode="wb") as fa:
|
||||
fa.write(f.read_bytes())
|
||||
self._wandb.run.log_artifact(model_artifact, aliases=["base_model"])
|
||||
|
||||
badge_markdown = (
|
||||
f'[<img src="https://raw.githubusercontent.com/wandb/assets/main/wandb-github-badge'
|
||||
f'-28.svg" alt="Visualize in Weights & Biases" width="20'
|
||||
f'0" height="32"/>]({self._wandb.run.get_url()})'
|
||||
)
|
||||
|
||||
modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}"
|
||||
|
||||
def on_train_begin(self, args, state, control, model=None, **kwargs):
|
||||
if self._wandb is None:
|
||||
return
|
||||
|
@ -860,25 +791,20 @@ class WandbCallback(TrainerCallback):
|
|||
else {
|
||||
f"eval/{args.metric_for_best_model}": state.best_metric,
|
||||
"train/total_floss": state.total_flos,
|
||||
"model/num_parameters": self._wandb.config.get("model/num_parameters"),
|
||||
}
|
||||
)
|
||||
metadata["final_model"] = True
|
||||
logger.info("Logging model artifacts. ...")
|
||||
model_name = (
|
||||
f"model-{self._wandb.run.id}"
|
||||
if (args.run_name is None or args.run_name == args.output_dir)
|
||||
else f"model-{self._wandb.run.name}"
|
||||
)
|
||||
# add the model architecture to a separate text file
|
||||
save_model_architecture_to_file(model, temp_dir)
|
||||
|
||||
artifact = self._wandb.Artifact(name=model_name, type="model", metadata=metadata)
|
||||
for f in Path(temp_dir).glob("*"):
|
||||
if f.is_file():
|
||||
with artifact.new_file(f.name, mode="wb") as fa:
|
||||
fa.write(f.read_bytes())
|
||||
self._wandb.run.log_artifact(artifact, aliases=["final_model"])
|
||||
self._wandb.run.log_artifact(artifact)
|
||||
|
||||
def on_log(self, args, state, control, model=None, logs=None, **kwargs):
|
||||
single_value_scalars = [
|
||||
|
@ -908,30 +834,18 @@ class WandbCallback(TrainerCallback):
|
|||
for k, v in dict(self._wandb.summary).items()
|
||||
if isinstance(v, numbers.Number) and not k.startswith("_")
|
||||
}
|
||||
checkpoint_metadata["model/num_parameters"] = self._wandb.config.get("model/num_parameters")
|
||||
|
||||
ckpt_dir = f"checkpoint-{state.global_step}"
|
||||
artifact_path = os.path.join(args.output_dir, ckpt_dir)
|
||||
logger.info(f"Logging checkpoint artifacts in {ckpt_dir}. ...")
|
||||
checkpoint_name = (
|
||||
f"model-{self._wandb.run.id}"
|
||||
f"checkpoint-{self._wandb.run.id}"
|
||||
if (args.run_name is None or args.run_name == args.output_dir)
|
||||
else f"model-{self._wandb.run.name}"
|
||||
else f"checkpoint-{self._wandb.run.name}"
|
||||
)
|
||||
artifact = self._wandb.Artifact(name=checkpoint_name, type="model", metadata=checkpoint_metadata)
|
||||
artifact.add_dir(artifact_path)
|
||||
self._wandb.log_artifact(
|
||||
artifact, aliases=[f"epoch_{round(state.epoch, 2)}", f"checkpoint_global_step_{state.global_step}"]
|
||||
)
|
||||
|
||||
def on_predict(self, args, state, control, metrics, **kwargs):
|
||||
if self._wandb is None:
|
||||
return
|
||||
if not self._initialized:
|
||||
self.setup(args, state, **kwargs)
|
||||
if state.is_world_process_zero:
|
||||
metrics = rewrite_logs(metrics)
|
||||
self._wandb.log(metrics)
|
||||
self._wandb.log_artifact(artifact, aliases=[f"checkpoint-{state.global_step}"])
|
||||
|
||||
|
||||
class CometCallback(TrainerCallback):
|
||||
|
|
|
@ -3389,7 +3389,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
)
|
||||
if resolved_archive_file is not None:
|
||||
is_sharded = True
|
||||
|
||||
if not local_files_only and not is_offline_mode():
|
||||
if resolved_archive_file is not None:
|
||||
if filename in [WEIGHTS_NAME, WEIGHTS_INDEX_NAME]:
|
||||
# If the PyTorch file was found, check if there is a safetensors file on the repository
|
||||
|
|
|
@ -151,9 +151,6 @@ class LlamaTokenizerFast(PreTrainedTokenizerFast):
|
|||
self.legacy = legacy
|
||||
|
||||
if add_prefix_space is not None:
|
||||
logger.warning_once(
|
||||
"You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers"
|
||||
)
|
||||
kwargs["from_slow"] = True
|
||||
|
||||
super().__init__(
|
||||
|
@ -166,6 +163,7 @@ class LlamaTokenizerFast(PreTrainedTokenizerFast):
|
|||
add_bos_token=add_bos_token,
|
||||
add_eos_token=add_eos_token,
|
||||
use_default_system_prompt=use_default_system_prompt,
|
||||
legacy=legacy,
|
||||
**kwargs,
|
||||
)
|
||||
self._add_bos_token = add_bos_token
|
||||
|
|
|
@ -40,8 +40,8 @@ class LlavaNextProcessor(ProcessorMixin):
|
|||
"""
|
||||
|
||||
attributes = ["image_processor", "tokenizer"]
|
||||
image_processor_class = "LlavaNextImageProcessor"
|
||||
tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast")
|
||||
image_processor_class = "AutoImageProcessor"
|
||||
tokenizer_class = "AutoTokenizer"
|
||||
|
||||
def __init__(self, image_processor=None, tokenizer=None):
|
||||
super().__init__(image_processor, tokenizer)
|
||||
|
|
|
@ -282,9 +282,14 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel):
|
|||
self.vocab_size = model_embeds.num_embeddings
|
||||
return model_embeds
|
||||
|
||||
def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels):
|
||||
def _merge_input_ids_with_image_features(
|
||||
self, image_features, inputs_embeds, input_ids, attention_mask, labels, token_type_ids, cache_position
|
||||
):
|
||||
_, _, embed_dim = image_features.shape
|
||||
batch_size, sequence_length = input_ids.shape
|
||||
dtype, device = inputs_embeds.dtype, inputs_embeds.device
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
|
||||
scaled_image_features = image_features / (self.config.hidden_size**0.5)
|
||||
final_embedding = torch.zeros(
|
||||
batch_size, sequence_length, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
|
||||
|
@ -295,34 +300,56 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel):
|
|||
pad_mask = input_ids == self.pad_token_id
|
||||
|
||||
# expand masks to match embedding dimension
|
||||
text_mask_expanded = text_mask.unsqueeze(-1).expand(-1, -1, embed_dim)
|
||||
pad_mask_expanded = pad_mask.unsqueeze(-1).expand(-1, -1, embed_dim)
|
||||
text_mask_expanded = text_mask.unsqueeze(-1).expand(-1, -1, embed_dim).to(inputs_embeds.device)
|
||||
pad_mask_expanded = pad_mask.unsqueeze(-1).expand(-1, -1, embed_dim).to(inputs_embeds.device)
|
||||
# insert padding and text token embeddings
|
||||
final_embedding = torch.where(text_mask_expanded, inputs_embeds, final_embedding)
|
||||
final_embedding = torch.where(pad_mask_expanded, torch.zeros_like(final_embedding), final_embedding)
|
||||
# insert image embeddings - the image mask is always less or equal to the sentence in length
|
||||
final_embedding = final_embedding.masked_scatter(
|
||||
image_mask.unsqueeze(-1).expand_as(final_embedding), scaled_image_features
|
||||
image_mask.unsqueeze(-1).expand_as(final_embedding).to(device=final_embedding.device),
|
||||
scaled_image_features.to(device=final_embedding.device, dtype=final_embedding.dtype),
|
||||
)
|
||||
final_embedding = torch.where(pad_mask_expanded, torch.zeros_like(final_embedding), final_embedding)
|
||||
if attention_mask is not None:
|
||||
position_ids = (attention_mask.cumsum(-1)).masked_fill_((attention_mask == 0), 1)
|
||||
else:
|
||||
position_ids = None
|
||||
|
||||
final_attention_mask_4d = attention_mask.unsqueeze(1).unsqueeze(2) * attention_mask.unsqueeze(1).unsqueeze(-1)
|
||||
final_attention_mask_4d = final_attention_mask_4d.float().expand(
|
||||
-1, self.config.text_config.num_key_value_heads, -1, -1
|
||||
if token_type_ids is not None and labels is not None:
|
||||
# we are training thus we need to create a full mask on the image + prefix but causal on suffix
|
||||
target_length = cache_position[-1] + 1
|
||||
causal_mask = torch.full(
|
||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
|
||||
)
|
||||
if sequence_length != 1:
|
||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(inputs_embeds.shape[0], 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
|
||||
causal_mask.device
|
||||
)
|
||||
# unmask the prefill
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
token_type_ids[:, None, None, :].to(causal_mask.device) == 0, 0
|
||||
)
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
|
||||
# position_ids = torch.arange(0, sequence_length, device=input_ids.device).expand(batch_size, -1)
|
||||
# position_ids = torch.where(input_ids == self.pad_token_id, torch.ones_like(position_ids), position_ids)
|
||||
position_ids = (attention_mask.cumsum(-1)).masked_fill_((attention_mask == 0), 1)
|
||||
|
||||
if labels is not None:
|
||||
final_labels = torch.full(
|
||||
(batch_size, sequence_length), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device
|
||||
)
|
||||
final_labels = torch.where(input_ids != self.pad_token_id, labels, final_labels)
|
||||
else:
|
||||
causal_mask = attention_mask.unsqueeze(1).unsqueeze(2) * attention_mask.unsqueeze(1).unsqueeze(-1)
|
||||
causal_mask = causal_mask.to(dtype).expand(-1, self.config.text_config.num_key_value_heads, -1, -1)
|
||||
final_labels = None
|
||||
return final_embedding, final_attention_mask_4d, final_labels, position_ids
|
||||
return final_embedding, causal_mask, final_labels, position_ids
|
||||
|
||||
@add_start_docstrings_to_model_forward(PALIGEMMA_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=PaliGemmaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
||||
|
@ -333,6 +360,7 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel):
|
|||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None,
|
||||
token_type_ids: Optional[torch.LongTensor] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
|
@ -396,8 +424,10 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel):
|
|||
selected_image_feature = image_outputs.last_hidden_state
|
||||
image_features = self.multi_modal_projector(selected_image_feature)
|
||||
|
||||
if cache_position is None:
|
||||
cache_position = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device)
|
||||
inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
|
||||
image_features, inputs_embeds, input_ids, attention_mask, labels
|
||||
image_features, inputs_embeds, input_ids, attention_mask, labels, token_type_ids, cache_position
|
||||
)
|
||||
|
||||
else:
|
||||
|
@ -456,7 +486,7 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel):
|
|||
# we use the input attention mask to shift the logits and labels, because it is 2D.
|
||||
shift_attention_mask = input_attention_mask[..., 1:]
|
||||
shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous()
|
||||
shift_labels = shift_labels[shift_attention_mask.to(logits.device) != 0].contiguous()
|
||||
shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous()
|
||||
else:
|
||||
shift_logits = shift_logits.contiguous()
|
||||
shift_labels = shift_labels.contiguous()
|
||||
|
@ -486,6 +516,7 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel):
|
|||
cache_position=None,
|
||||
pixel_values=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
**kwargs,
|
||||
):
|
||||
past_length = 0
|
||||
|
@ -544,6 +575,7 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel):
|
|||
"use_cache": kwargs.get("use_cache"),
|
||||
"attention_mask": attention_mask,
|
||||
"pixel_values": pixel_values,
|
||||
"token_type_ids": token_type_ids,
|
||||
}
|
||||
)
|
||||
return model_inputs
|
||||
|
|
File diff suppressed because one or more lines are too long
|
@ -42,7 +42,7 @@ class VideoLlavaProcessor(ProcessorMixin):
|
|||
|
||||
attributes = ["image_processor", "tokenizer"]
|
||||
image_processor_class = "VideoLlavaImageProcessor"
|
||||
tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast")
|
||||
tokenizer_class = "AutoTokenizer"
|
||||
|
||||
def __init__(self, image_processor=None, tokenizer=None):
|
||||
super().__init__(image_processor, tokenizer)
|
||||
|
|
|
@ -65,6 +65,7 @@ if is_torch_available():
|
|||
GenerateBeamEncoderDecoderOutput,
|
||||
GenerateDecoderOnlyOutput,
|
||||
GenerateEncoderDecoderOutput,
|
||||
GenerationConfig,
|
||||
GreedySearchDecoderOnlyOutput,
|
||||
GreedySearchEncoderDecoderOutput,
|
||||
LogitsProcessorList,
|
||||
|
@ -2478,6 +2479,35 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
|||
|
||||
self.assertListEqual(outputs.tolist(), outputs_batched_ids.tolist())
|
||||
|
||||
def test_decoder_start_id_from_config(self):
|
||||
# Refer to: (#30899)
|
||||
articles = [
|
||||
"Justin Timberlake and Jessica Biel, welcome to parenthood.",
|
||||
"Michael Phelps is arguably the most decorated Olympian of all time.",
|
||||
]
|
||||
bart_tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
|
||||
bart_model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(
|
||||
torch_device
|
||||
)
|
||||
input_ids = bart_tokenizer(articles, return_tensors="pt", padding=True).input_ids.to(torch_device)
|
||||
decoder_start_token_id = bart_model.generation_config.decoder_start_token_id
|
||||
|
||||
# we should be able to take `decoder_start_token_id` from model's generation config if user passes a `GenerationConfig` type
|
||||
outputs = bart_model.generate(input_ids, generation_config=GenerationConfig(do_sample=False))
|
||||
|
||||
# If the generatoin config has no `decoder_start_token_id` or `bos_token_id`, we will raise an error unless user passes it in config
|
||||
bart_model.generation_config.decoder_start_token_id = None
|
||||
bart_model.generation_config.bos_token_id = None
|
||||
outputs_with_user_id = bart_model.generate(
|
||||
input_ids,
|
||||
generation_config=GenerationConfig(do_sample=False, decoder_start_token_id=decoder_start_token_id),
|
||||
)
|
||||
|
||||
self.assertListEqual(outputs.tolist(), outputs_with_user_id.tolist())
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
outputs = bart_model.generate(input_ids, generation_config=GenerationConfig(do_sample=False))
|
||||
|
||||
def test_contrastive_search_batched(self):
|
||||
# PT-only test: TF doesn't have constrained beam search
|
||||
# Tests that contrastive search works with batched inputs (i.e. has the same output as for non-batched inputs)
|
||||
|
|
|
@ -163,6 +163,8 @@ class PaliGemmaVisionText2TextModelTester:
|
|||
"pixel_values": pixel_values,
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"labels": input_ids,
|
||||
"token_type_ids": torch.zeros_like(input_ids),
|
||||
}
|
||||
return config, inputs_dict
|
||||
|
||||
|
|
|
@ -33,6 +33,7 @@ from requests.exceptions import HTTPError
|
|||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModel,
|
||||
AutoModelForImageClassification,
|
||||
AutoModelForSequenceClassification,
|
||||
OwlViTForObjectDetection,
|
||||
PretrainedConfig,
|
||||
|
@ -76,7 +77,6 @@ sys.path.append(str(Path(__file__).parent.parent / "utils"))
|
|||
|
||||
from test_module.custom_configuration import CustomConfig, NoSuperInitConfig # noqa E402
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
from safetensors.torch import save_file as safe_save_file
|
||||
|
@ -194,6 +194,97 @@ if is_torch_available():
|
|||
attention_mask = _prepare_4d_attention_mask(mask, dtype=inputs_embeds.dtype)
|
||||
return attention_mask
|
||||
|
||||
class TestOffline(unittest.TestCase):
|
||||
def test_offline(self):
|
||||
# Ugly setup with monkeypatches, amending env vars here is too late as libs have already been imported
|
||||
from huggingface_hub import constants
|
||||
|
||||
from transformers.utils import hub
|
||||
|
||||
offlfine_env = hub._is_offline_mode
|
||||
hub_cache_env = constants.HF_HUB_CACHE
|
||||
hub_cache_env1 = constants.HUGGINGFACE_HUB_CACHE
|
||||
default_cache = constants.default_cache_path
|
||||
transformers_cache = hub.TRANSFORMERS_CACHE
|
||||
|
||||
try:
|
||||
hub._is_offline_mode = True
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
LOG.info("Temporary cache dir %s", tmpdir)
|
||||
constants.HF_HUB_CACHE = tmpdir
|
||||
constants.HUGGINGFACE_HUB_CACHE = tmpdir
|
||||
constants.default_cache_path = tmpdir
|
||||
hub.TRANSFORMERS_CACHE = tmpdir
|
||||
# First offline load should fail
|
||||
try:
|
||||
AutoModelForImageClassification.from_pretrained(
|
||||
TINY_IMAGE_CLASSIF, revision="main", use_auth_token=None
|
||||
)
|
||||
except OSError:
|
||||
LOG.info("Loading model %s in offline mode failed as expected", TINY_IMAGE_CLASSIF)
|
||||
else:
|
||||
self.fail("Loading model {} in offline mode should fail".format(TINY_IMAGE_CLASSIF))
|
||||
|
||||
# Download model -> Huggingface Hub not concerned by our offline mode
|
||||
LOG.info("Downloading %s for offline tests", TINY_IMAGE_CLASSIF)
|
||||
hub_api = HfApi()
|
||||
local_dir = hub_api.snapshot_download(TINY_IMAGE_CLASSIF, cache_dir=tmpdir)
|
||||
|
||||
LOG.info("Model %s downloaded in %s", TINY_IMAGE_CLASSIF, local_dir)
|
||||
|
||||
AutoModelForImageClassification.from_pretrained(
|
||||
TINY_IMAGE_CLASSIF, revision="main", use_auth_token=None
|
||||
)
|
||||
finally:
|
||||
# Tear down: reset env as it was before calling this test
|
||||
hub._is_offline_mode = offlfine_env
|
||||
constants.HF_HUB_CACHE = hub_cache_env
|
||||
constants.HUGGINGFACE_HUB_CACHE = hub_cache_env1
|
||||
constants.default_cache_path = default_cache
|
||||
hub.TRANSFORMERS_CACHE = transformers_cache
|
||||
|
||||
def test_local_files_only(self):
|
||||
# Ugly setup with monkeypatches, amending env vars here is too late as libs have already been imported
|
||||
from huggingface_hub import constants
|
||||
|
||||
from transformers.utils import hub
|
||||
|
||||
hub_cache_env = constants.HF_HUB_CACHE
|
||||
hub_cache_env1 = constants.HUGGINGFACE_HUB_CACHE
|
||||
default_cache = constants.default_cache_path
|
||||
transformers_cache = hub.TRANSFORMERS_CACHE
|
||||
try:
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
LOG.info("Temporary cache dir %s", tmpdir)
|
||||
constants.HF_HUB_CACHE = tmpdir
|
||||
constants.HUGGINGFACE_HUB_CACHE = tmpdir
|
||||
constants.default_cache_path = tmpdir
|
||||
hub.TRANSFORMERS_CACHE = tmpdir
|
||||
try:
|
||||
AutoModelForImageClassification.from_pretrained(
|
||||
TINY_IMAGE_CLASSIF, revision="main", use_auth_token=None, local_files_only=True
|
||||
)
|
||||
except OSError:
|
||||
LOG.info("Loading model %s in offline mode failed as expected", TINY_IMAGE_CLASSIF)
|
||||
else:
|
||||
self.fail("Loading model {} in offline mode should fail".format(TINY_IMAGE_CLASSIF))
|
||||
|
||||
LOG.info("Downloading %s for offline tests", TINY_IMAGE_CLASSIF)
|
||||
hub_api = HfApi()
|
||||
local_dir = hub_api.snapshot_download(TINY_IMAGE_CLASSIF, cache_dir=tmpdir)
|
||||
|
||||
LOG.info("Model %s downloaded in %s", TINY_IMAGE_CLASSIF, local_dir)
|
||||
|
||||
AutoModelForImageClassification.from_pretrained(
|
||||
TINY_IMAGE_CLASSIF, revision="main", use_auth_token=None, local_files_only=True
|
||||
)
|
||||
finally:
|
||||
# Tear down: reset env as it was before calling this test
|
||||
constants.HF_HUB_CACHE = hub_cache_env
|
||||
constants.HUGGINGFACE_HUB_CACHE = hub_cache_env1
|
||||
constants.default_cache_path = default_cache
|
||||
hub.TRANSFORMERS_CACHE = transformers_cache
|
||||
|
||||
|
||||
if is_flax_available():
|
||||
from transformers import FlaxBertModel
|
||||
|
@ -205,6 +296,9 @@ if is_tf_available():
|
|||
TINY_T5 = "patrickvonplaten/t5-tiny-random"
|
||||
TINY_BERT_FOR_TOKEN_CLASSIFICATION = "hf-internal-testing/tiny-bert-for-token-classification"
|
||||
TINY_MISTRAL = "hf-internal-testing/tiny-random-MistralForCausalLM"
|
||||
TINY_IMAGE_CLASSIF = "hf-internal-testing/tiny-random-SiglipForImageClassification"
|
||||
|
||||
LOG = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def check_models_equal(model1, model2):
|
||||
|
|
Loading…
Reference in New Issue