Compare commits

...

17 Commits

Author SHA1 Message Date
Ita Zaporozhets 8b0aa67796 adding user defined tokens #30824 2024-05-30 16:31:20 +02:00
Ita Zaporozhets c651d94fb8 add user defined symbols to all tokenizers from SpmConverter 2024-05-30 16:31:07 +02:00
Ita Zaporozhets a9f52cbffa add comment 2024-05-30 16:29:46 +02:00
Ita Zaporozhets 60c890e9d8 draft commit 2024-05-30 16:05:43 +02:00
Arthur d799d6715f legacy to init the slow tokenizer when converting from slow was wrong (#30972) 2024-05-30 16:02:31 +02:00
Yih-Dar 774f7295f2 Finally fix the missing new model failure CI report (#30968)
fix

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
2024-05-30 16:02:31 +02:00
amyeroberts 0c48df2e50 🚨 out_indices always a list (#30941)
* out_indices always a list

* Update src/transformers/utils/backbone_utils.py

* Update src/transformers/utils/backbone_utils.py

* Move type casting

* nit
2024-05-30 16:02:31 +02:00
Pablo Montalvo c1d906666e Paligemma - fix slow tests, add bf16 and f16 slow tests (#30851)
* fix slow tests, add bf16 and f16 slow tests

* few fixes

* [run-slow]paligemma

* add gate decorator

* [run-slow]paligemma

* add missing gating

* [run-slow]paligemma

* [run-slow]paligemma
2024-05-30 16:02:31 +02:00
Sanchit Gandhi 67672cfc97 [whisper] only trigger forced ids warning once (#30966) 2024-05-30 16:02:31 +02:00
Jonatan Kłosko c337d55988 Avoid extra chunk in speech recognition (#29539) 2024-05-30 16:02:31 +02:00
Vaibhav Srivastav a778108a3c [doc] Add references to the fine-tuning blog and distil-whisper to Whisper. (#30938)
[doc] Add references to the fine-tuning blog and distil-whisper to Whisper doc.
2024-05-30 16:02:31 +02:00
Marc Sun bb17199cd2 Fix low cpu mem usage tests (#30808)
* Fix tests

* fix udop failing test

* remove skip

* style
2024-05-30 16:02:31 +02:00
Ita Zaporozhets be2fb4fb8f more general approach 2024-05-22 16:00:36 +02:00
Ita Zaporozhets e75583dc85 Merge remote-tracking branch 'origin/main' into 30824-spmconverter-user-defined-symbol 2024-05-22 14:02:01 +02:00
Ita Zaporozhets 529e2be112 add comment 2024-05-21 15:40:48 +02:00
Ita Zaporozhets eab71cc26b add user defined symbols to all tokenizers from SpmConverter 2024-05-21 15:22:13 +02:00
Ita Zaporozhets 996ff224a3 adding user defined tokens #30824 2024-05-21 10:05:23 +02:00
15 changed files with 224 additions and 137 deletions

View File

@ -78,6 +78,8 @@ Here is a step-by-step guide to transcribing an audio sample using a pre-trained
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with Whisper. If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we'll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource.
- [Fine-tune Whisper](https://huggingface.co/blog/fine-tune-whisper) on your own dataset for better downstream performance.
- [Distil-Whisper](https://huggingface.co/distil-whisper): Upto 6x faster, 2x smaller distilled Whisper models for English. We release the [model checkpoints](https://huggingface.co/distil-whisper), and [distillation code](https://github.com/huggingface/distil-whisper).
- A fork with a script to [convert a Whisper model in Hugging Face format to OpenAI format](https://github.com/zuazo-forks/transformers/blob/convert_hf_to_openai/src/transformers/models/whisper/convert_hf_to_openai.py). 🌎
Usage example:
```bash

View File

@ -620,14 +620,26 @@ class SpmConverter(Converter):
def converted(self) -> Tokenizer:
tokenizer = self.tokenizer(self.proto)
# Tokenizer assemble
# Add user defined symbols
user_defined_symbols = [
AddedToken(token, normalized=True, special=False) for token in self.proto.trainer_spec.user_defined_symbols
]
control_symbols = [
AddedToken(token, normalized=True, special=False) for token in self.proto.trainer_spec.control_symbols
]
tokenizer.add_tokens(user_defined_symbols + control_symbols)
# Tokenizer assemble
normalizer = self.normalizer(self.proto)
if normalizer is not None:
tokenizer.normalizer = normalizer
replacement = ""
add_prefix_space = True
if hasattr(self.original_tokenizer, "add_prefix_space"):
#TODO:ita added 1
add_prefix_space = self.proto.normalizer_spec.add_dummy_prefix
tokenizer.add_prefix_space = add_prefix_space
if hasattr(self.original_tokenizer, "add_prefix_space") and self.original_tokenizer.add_prefix_space is not None:
add_prefix_space = self.original_tokenizer.add_prefix_space
pre_tokenizer = self.pre_tokenizer(replacement, add_prefix_space)
@ -1385,6 +1397,10 @@ class LlamaConverter(SpmConverter):
AddedToken(self.original_tokenizer.convert_ids_to_tokens(2), normalized=False, special=True),
]
)
user_defined_symbols = [
AddedToken(token, normalized=True, special=False) for token in proto.trainer_spec.user_defined_symbols
]
tokenizer.add_tokens(user_defined_symbols)
else:
raise Exception(
"You're trying to run a `Unigram` model but you're file was trained with a different algorithm"

View File

@ -143,7 +143,7 @@ class LlamaTokenizer(PreTrainedTokenizer):
use_default_system_prompt=False,
spaces_between_special_tokens=False,
legacy=None,
add_prefix_space=True,
add_prefix_space=None,
**kwargs,
):
self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
@ -167,8 +167,8 @@ class LlamaTokenizer(PreTrainedTokenizer):
self.add_bos_token = add_bos_token
self.add_eos_token = add_eos_token
self.use_default_system_prompt = use_default_system_prompt
self.sp_model = self.get_spm_processor(kwargs.pop("from_slow", False))
self.add_prefix_space = add_prefix_space
self.sp_model = self.get_spm_processor(kwargs.pop("from_slow", False))
super().__init__(
bos_token=bos_token,
@ -202,6 +202,7 @@ class LlamaTokenizer(PreTrainedTokenizer):
model_pb2 = import_protobuf(f"The new behaviour of {self.__class__.__name__} (with `self.legacy = False`)")
model = model_pb2.ModelProto.FromString(sp_model)
normalizer_spec = model_pb2.NormalizerSpec()
self.add_prefix_space = normalizer_spec.add_dummy_prefix if self.add_prefix_space is None else self.add_prefix_space
normalizer_spec.add_dummy_prefix = False
model.normalizer_spec.MergeFrom(normalizer_spec)
sp_model = model.SerializeToString()

View File

@ -13,10 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import json
from shutil import copyfile
from typing import Optional, Tuple
from tokenizers import processors
from tokenizers import pre_tokenizers, normalizers, processors
from ...tokenization_utils_fast import PreTrainedTokenizerFast
from ...utils import is_sentencepiece_available, logging
@ -150,10 +151,11 @@ class LlamaTokenizerFast(PreTrainedTokenizerFast):
legacy = True
self.legacy = legacy
#TODO:ita
self.add_prefix_space = add_prefix_space
# TODO:ita
if add_prefix_space is not None:
logger.warning_once(
"You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers"
)
kwargs["from_slow"] = True
super().__init__(
@ -166,11 +168,14 @@ class LlamaTokenizerFast(PreTrainedTokenizerFast):
add_bos_token=add_bos_token,
add_eos_token=add_eos_token,
use_default_system_prompt=use_default_system_prompt,
legacy=legacy,
**kwargs,
)
self._add_bos_token = add_bos_token
self._add_eos_token = add_eos_token
self.update_post_processor()
self.update_pre_tokenizer()
self.update_normalizer()
self.use_default_system_prompt = use_default_system_prompt
self.vocab_file = vocab_file
@ -204,6 +209,38 @@ class LlamaTokenizerFast(PreTrainedTokenizerFast):
single=single, pair=pair, special_tokens=special_tokens
)
def update_normalizer(self):
"""
Updates the underlying post processor with the current `bos_token` and `eos_token`.
"""
sequence = []
if getattr(self, "legacy", True):
if getattr(self, "add_prefix_space", True):
sequence += [normalizers.Prepend(prepend="")]
sequence += [normalizers.Replace(pattern=" ", content="")]
elif not getattr(self, "legacy", True):
self._tokenizer.normalizer = normalizers.Sequence(sequence)
def update_pre_tokenizer(self):
sequence = []
if getattr(self, "add_prefix_space") == False:
prepend_scheme = "never"
elif getattr(self, "add_prefix_space") == None:
curr_normalizer = json.loads(self._tokenizer.normalizer.__getstate__().decode('utf-8'))
prepend_normalizer = [n for n in curr_normalizer['normalizers'] if n['type'] == 'Prepend']
if prepend_normalizer:
prepend_normalizer = prepend_normalizer[0]
replacement = prepend_normalizer['prepend']
self.add_prefix_space = True
else:
prepend_scheme = "never"
if getattr(self, "add_prefix_space", True):
prepend_scheme = "always"
if not getattr(self, "legacy", True):
prepend_scheme = "first"
self._tokenizer.pre_tokenizer = pre_tokenizers.Metaspace(replacement="", prepend_scheme=prepend_scheme, split=False)
@property
def add_eos_token(self):
return self._add_eos_token

View File

@ -1297,7 +1297,7 @@ class UdopStack(UdopPreTrainedModel):
# get weights from encoder position bias
self.relative_bias = self._get_relative_bias(config)
# tie weights of original position bias of encoder
def _tie_weights(self):
for bias in self.relative_bias.biases:
if isinstance(bias, RelativePositionBias1D):
self._tie_or_clone_weights(

View File

@ -1133,12 +1133,12 @@ class WhisperGenerationMixin:
forced_decoder_ids = config.forced_decoder_ids
if forced_decoder_ids is not None and task is not None:
logger.info(
logger.warning_once(
f"You have passed task={task}, but also have set `forced_decoder_ids` to {forced_decoder_ids} which creates a conflict. `forced_decoder_ids` will be ignored in favor of task={task}."
)
forced_decoder_ids = None
elif forced_decoder_ids is not None and language is not None:
logger.info(
logger.warning_once(
f"You have passed language={language}, but also have set `forced_decoder_ids` to {forced_decoder_ids} which creates a conflict. `forced_decoder_ids` will be ignored in favor of language={language}."
)
forced_decoder_ids = None

View File

@ -67,8 +67,7 @@ def chunk_iter(inputs, feature_extractor, chunk_len, stride_left, stride_right,
if dtype is not None:
processed = processed.to(dtype=dtype)
_stride_left = 0 if chunk_start_idx == 0 else stride_left
# all right strides must be full, otherwise it is the last item
is_last = chunk_end_idx > inputs_len if stride_right > 0 else chunk_end_idx >= inputs_len
is_last = chunk_end_idx >= inputs_len
_stride_right = 0 if is_last else stride_right
chunk_len = chunk.shape[0]

View File

@ -102,15 +102,16 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
from_slow = kwargs.pop("from_slow", False)
added_tokens_decoder = kwargs.pop("added_tokens_decoder", {})
if from_slow and slow_tokenizer is None and self.slow_tokenizer_class is None:
raise ValueError(
"Cannot instantiate this tokenizer from a slow version. If it's based on sentencepiece, make sure you "
"have sentencepiece installed."
)
#TODO:Ita
# if from_slow and slow_tokenizer is None and self.slow_tokenizer_class is None:
# raise ValueError(
# "Cannot instantiate this tokenizer from a slow version. If it's based on sentencepiece, make sure you "
# "have sentencepiece installed."
# )
if tokenizer_object is not None:
fast_tokenizer = copy.deepcopy(tokenizer_object)
elif fast_tokenizer_file is not None and not from_slow:
elif fast_tokenizer_file is not None: # and not from_slow:
# We have a serialization from tokenizers which let us directly build the backend
fast_tokenizer = TokenizerFast.from_file(fast_tokenizer_file)
elif slow_tokenizer is not None:

View File

@ -47,8 +47,8 @@ def verify_out_features_out_indices(
)
if out_indices is not None:
if not isinstance(out_indices, (list, tuple)):
raise ValueError(f"out_indices must be a list or tuple, got {type(out_indices)}")
if not isinstance(out_indices, list):
raise ValueError(f"out_indices must be a list, got {type(out_indices)}")
# Convert negative indices to their positive equivalent: [-1,] -> [len(stage_names) - 1,]
positive_indices = tuple(idx % len(stage_names) if idx < 0 else idx for idx in out_indices)
if any(idx for idx in positive_indices if idx not in range(len(stage_names))):
@ -58,7 +58,7 @@ def verify_out_features_out_indices(
msg += f"(equivalent to {positive_indices}))" if positive_indices != out_indices else ""
raise ValueError(msg)
if positive_indices != tuple(sorted(positive_indices)):
sorted_negative = tuple(idx for _, idx in sorted(zip(positive_indices, out_indices), key=lambda x: x[0]))
sorted_negative = [idx for _, idx in sorted(zip(positive_indices, out_indices), key=lambda x: x[0])]
raise ValueError(
f"out_indices must be in the same order as stage_names, expected {sorted_negative} got {out_indices}"
)
@ -122,6 +122,7 @@ def get_aligned_output_features_output_indices(
out_indices (`List[int]` or `Tuple[int]`): The indices of the features for the backbone to output.
stage_names (`List[str]`): The names of the stages of the backbone.
"""
out_indices = list(out_indices) if out_indices is not None else None
# First verify that the out_features and out_indices are valid
verify_out_features_out_indices(out_features=out_features, out_indices=out_indices, stage_names=stage_names)
output_features, output_indices = _align_output_features_output_indices(
@ -147,7 +148,10 @@ class BackboneMixin:
# the timm model has out_features = ['act', 'layer1', 'layer2', 'layer3', 'layer4']
self.stage_names = [stage["module"] for stage in self._backbone.feature_info.info]
self.num_features = [stage["num_chs"] for stage in self._backbone.feature_info.info]
out_indices = self._backbone.feature_info.out_indices
# In some timm versions, out_indices reflects the input type of out_indices on the `create_model` call,
# in later versions >= 1, it is always a tuple
out_indices = list(self._backbone.feature_info.out_indices)
out_features = self._backbone.feature_info.module_name()
# We verify the out indices and out features are valid

View File

@ -28,7 +28,7 @@ from transformers import (
is_vision_available,
)
from transformers.testing_utils import (
require_bitsandbytes,
require_read_token,
require_torch,
require_torch_sdpa,
slow,
@ -260,60 +260,32 @@ class PaliGemmaForConditionalGenerationModelTest(ModelTesterMixin, unittest.Test
@slow
@require_torch
@require_read_token
class PaliGemmaForConditionalGenerationIntegrationTest(unittest.TestCase):
def setUp(self):
self.processor = PaliGemmaProcessor.from_pretrained("gv-hf/PaliGemma-test-224px-hf")
self.processor = PaliGemmaProcessor.from_pretrained("google/paligemma-3b-pt-224")
def tearDown(self):
gc.collect()
torch.cuda.empty_cache()
@slow
@require_bitsandbytes
@require_read_token
def test_small_model_integration_test(self):
# Let' s make sure we test the preprocessing to replace what is used
model = PaliGemmaForConditionalGeneration.from_pretrained("gv-hf/PaliGemma-test-224px-hf")
model_id = "google/paligemma-3b-pt-224"
model = PaliGemmaForConditionalGeneration.from_pretrained(model_id)
prompt = ""
image_file = (
"https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png"
)
raw_image = Image.open(requests.get(image_file, stream=True).raw)
inputs = self.processor(text=prompt, images=raw_image, return_tensors="pt")
# fmt: off
EXPECTED_INPUT_IDS = torch.tensor([[256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000,
256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000,
256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000,
256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000,
256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000,
256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000,
256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000,
256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000,
256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000,
256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000,
256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000,
256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000,
256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000,
256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000,
256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000,
256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000,
256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000,
256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000,
256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000,
256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000,
256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000,
256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000,
256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000,
256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000,
256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000,
256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000,
256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000,
256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000, 256000,
256000, 256000, 256000, 256000, 2, 108]])
# fmt: on
EXPECTED_INPUT_IDS = torch.tensor([[257152] * 256 + [2, 108]])
self.assertTrue(torch.equal(inputs["input_ids"], EXPECTED_INPUT_IDS))
output = model.generate(**inputs, max_new_tokens=20)
EXPECTED_DECODED_TEXT = "\ncow standing on the beach" # fmt: skip
EXPECTED_DECODED_TEXT = "\ncow on the beach" # fmt: skip
self.assertEqual(
self.processor.decode(output[0], skip_special_tokens=True),
@ -321,64 +293,56 @@ class PaliGemmaForConditionalGenerationIntegrationTest(unittest.TestCase):
)
@slow
@require_bitsandbytes
def test_small_model_integration_test_paligemma(self):
@require_read_token
def test_small_model_integration_test_paligemma_VQA(self):
# Let' s make sure we test the preprocessing to replace what is used
model_id = "gv-hf/PaliGemma-test-224px-hf"
model = PaliGemmaForConditionalGeneration.from_pretrained("gv-hf/PaliGemma-test-224px-hf")
processor = PaliGemmaProcessor.from_pretrained(model_id)
model_id = "google/paligemma-3b-pt-224"
model = PaliGemmaForConditionalGeneration.from_pretrained(model_id)
prompt = "answer en Where is the cow standing?"
image_file = (
"https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png"
)
raw_image = Image.open(requests.get(image_file, stream=True).raw)
inputs = processor(text=prompt, images=raw_image, return_tensors="pt").to(torch.float16)
inputs = self.processor(text=prompt, images=raw_image, return_tensors="pt").to(torch.float16)
output = model.generate(**inputs, max_new_tokens=900, do_sample=False)
EXPECTED_DECODED_TEXT = "answer en Where is the cow standing?\nbeach" # fmt: skip
self.assertEqual(
processor.decode(output[0], skip_special_tokens=True),
self.processor.decode(output[0], skip_special_tokens=True),
EXPECTED_DECODED_TEXT,
)
@slow
@require_bitsandbytes
def test_small_model_integration_test_paligemma_batched(self):
@require_read_token
def test_small_model_integration_test_paligemma_empty_prompt(self):
# Let' s make sure we test the preprocessing to replace what is used
model_id = "gv-hf/PaliGemma-test-224px-hf"
model_id = "google/paligemma-3b-pt-224"
model = PaliGemmaForConditionalGeneration.from_pretrained(model_id)
processor = PaliGemmaProcessor.from_pretrained(model_id)
prompts = [
"answer en Where is the cow standing?",
"",
]
image1 = Image.open(
requests.get(
"https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png",
stream=True,
).raw
prompt = ""
image_file = (
"https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png"
)
image2 = image1
raw_image = Image.open(requests.get(image_file, stream=True).raw)
inputs = self.processor(text=prompt, images=raw_image, return_tensors="pt").to(torch.float16)
inputs = processor(text=prompts, images=[image1, image2], return_tensors="pt", padding=True)
output = model.generate(**inputs, max_new_tokens=900, do_sample=False)
EXPECTED_DECODED_TEXT = "\ncow on the beach" # fmt: skip
output = model.generate(**inputs, max_new_tokens=20)
EXPECTED_DECODED_TEXT = ["answer en Where is the cow standing?\nbeach", "\ncow standing on the beach"] # fmt: skip
self.assertEqual(processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT)
self.assertEqual(
self.processor.decode(output[0], skip_special_tokens=True),
EXPECTED_DECODED_TEXT,
)
@slow
@require_bitsandbytes
def test_small_model_integration_test_batch(self):
@require_read_token
def test_small_model_integration_test_paligemma_batched(self):
# Let' s make sure we test the preprocessing to replace what is used
model = PaliGemmaForConditionalGeneration.from_pretrained("gv-hf/PaliGemma-test-224px-hf")
# The first batch is longer in terms of text, the second will be padded.
model_id = "google/paligemma-3b-pt-224"
model = PaliGemmaForConditionalGeneration.from_pretrained(model_id)
prompts = [
"answer en Where is the cow standing?",
"",
@ -395,20 +359,84 @@ class PaliGemmaForConditionalGenerationIntegrationTest(unittest.TestCase):
output = model.generate(**inputs, max_new_tokens=20)
EXPECTED_DECODED_TEXT = ["answer en Where is the cow standing?\nbeach", "\ncow standing on the beach"] # fmt: skip
EXPECTED_DECODED_TEXT = ["answer en Where is the cow standing?\nbeach", "\ncow on the beach"] # fmt: skip
self.assertEqual(self.processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT)
@slow
@require_bitsandbytes
@require_torch
@require_read_token
def test_small_model_integration_test_paligemma_batched_bf16(self):
# Let' s make sure we test the preprocessing to replace what is used
model_id = "google/paligemma-3b-pt-224"
model = PaliGemmaForConditionalGeneration.from_pretrained(
model_id, revision="bfloat16", torch_dtype=torch.bfloat16
).to(torch_device)
# The first batch is longer in terms of text, the second will be padded.
prompts = [
"answer en Where is the cow standing?",
"",
]
image1 = Image.open(
requests.get(
"https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png",
stream=True,
).raw
)
image2 = image1
inputs = (
self.processor(text=prompts, images=[image1, image2], return_tensors="pt", padding=True)
.to(torch.bfloat16)
.to(torch_device)
)
output = model.generate(**inputs, max_new_tokens=20)
EXPECTED_DECODED_TEXT = ["answer en Where is the cow standing?\nbeach", "\ncow on the beach"] # fmt: skip
self.assertEqual(self.processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT)
@slow
@require_torch
@require_read_token
def test_small_model_integration_test_paligemma_batched_f16(self):
# Let' s make sure we test the preprocessing to replace what is used
model_id = "google/paligemma-3b-pt-224"
model = PaliGemmaForConditionalGeneration.from_pretrained(
model_id, revision="float16", torch_dtype=torch.float16
).to(torch_device)
# The first batch is longer in terms of text, the second will be padded.
prompts = [
"answer en Where is the cow standing?",
"",
]
image1 = Image.open(
requests.get(
"https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png",
stream=True,
).raw
)
image2 = image1
inputs = (
self.processor(text=prompts, images=[image1, image2], return_tensors="pt", padding=True)
.to(torch.float16)
.to(torch_device)
)
output = model.generate(**inputs, max_new_tokens=20)
EXPECTED_DECODED_TEXT = ["answer en Where is the cow standing?\nbeach", "\ncow on the beach"] # fmt: skip
self.assertEqual(self.processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT)
@slow
@require_read_token
def test_paligemma_index_error_bug(self):
# This is a reproducer of https://github.com/huggingface/transformers/pull/28032 and makes sure it does not happen anymore
# Please refer to that PR, or specifically https://github.com/huggingface/transformers/pull/28032#issuecomment-1860650043 for
# more details
model_id = "gv-hf/PaliGemma-test-224px-hf"
model_id = "google/paligemma-3b-pt-224"
model = PaliGemmaForConditionalGeneration.from_pretrained(model_id)
processor = PaliGemmaProcessor.from_pretrained(model_id)
# Simulate a super long prompt
prompt = "\n" * 200
image_file = (
@ -416,7 +444,7 @@ class PaliGemmaForConditionalGenerationIntegrationTest(unittest.TestCase):
)
raw_image = Image.open(requests.get(image_file, stream=True).raw)
inputs = processor(
inputs = self.processor(
text=prompt,
images=raw_image,
return_tensors="pt",

View File

@ -131,7 +131,7 @@ class TimmBackboneModelTest(ModelTesterMixin, BackboneTesterMixin, PipelineTeste
# Out indices are set to the last layer by default. For timm models, we don't know
# the number of layers in advance, so we set it to (-1,), whereas for transformers
# models, we set it to [len(stage_names) - 1] (kept for backward compatibility).
self.assertEqual(timm_model.out_indices, (-1,))
self.assertEqual(timm_model.out_indices, [-1])
self.assertEqual(transformers_model.out_indices, [len(timm_model.stage_names) - 1])
timm_model = AutoBackbone.from_pretrained(timm_checkpoint, use_timm_backbone=True, out_indices=[1, 2, 3])

View File

@ -1569,10 +1569,10 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
"input_values"
]
outs = list(chunk_iter(inputs, feature_extractor, 100, 20, 10))
self.assertEqual(len(outs), 2)
self.assertEqual([o["stride"] for o in outs], [(100, 0, 10), (30, 20, 0)])
self.assertEqual([o["input_values"].shape for o in outs], [(1, 100), (1, 30)])
self.assertEqual([o["is_last"] for o in outs], [False, True])
self.assertEqual(len(outs), 1)
self.assertEqual([o["stride"] for o in outs], [(100, 0, 0)])
self.assertEqual([o["input_values"].shape for o in outs], [(1, 100)])
self.assertEqual([o["is_last"] for o in outs], [True])
outs = list(chunk_iter(inputs, feature_extractor, 80, 20, 10))
self.assertEqual(len(outs), 2)

View File

@ -21,7 +21,6 @@ import os.path
import random
import re
import tempfile
import unittest
import warnings
from collections import defaultdict
from typing import Dict, List, Tuple
@ -444,7 +443,6 @@ class ModelTesterMixin:
@slow
@require_accelerate
@mark.accelerate_tests
@unittest.skip("Need to fix since we have a device mismatch")
def test_save_load_low_cpu_mem_usage(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
with tempfile.TemporaryDirectory() as saved_model_path:
@ -457,7 +455,6 @@ class ModelTesterMixin:
@slow
@require_accelerate
@mark.accelerate_tests
@unittest.skip("Need to fix since we have a device mismatch")
def test_save_load_low_cpu_mem_usage_checkpoints(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
with tempfile.TemporaryDirectory() as saved_model_path:
@ -471,7 +468,6 @@ class ModelTesterMixin:
@slow
@require_accelerate
@mark.accelerate_tests
@unittest.skip("Need to fix since we have a device mismatch")
def test_save_load_low_cpu_mem_usage_no_safetensors(self):
with tempfile.TemporaryDirectory() as saved_model_path:
for model_class in self.all_model_classes:
@ -482,6 +478,8 @@ class ModelTesterMixin:
self._check_save_load_low_cpu_mem_usage(model_class, saved_model_path)
def _check_save_load_low_cpu_mem_usage(self, model_class, saved_model_path):
from accelerate.utils.modeling import named_module_tensors
# Load the low usage and the normal models.
model_low_usage, loading_info = model_class.from_pretrained(
saved_model_path,
@ -496,16 +494,13 @@ class ModelTesterMixin:
# The low_cpu_mem_usage=True causes the model params to be initialized with device=meta, and then
# subsequently loaded with the correct values and onto the correct device. We check if there are any
# remaining params that were not properly loaded.
for name, param in model_low_usage.named_parameters():
for name, tensor in named_module_tensors(model_low_usage, recurse=True):
self.assertNotEqual(
param.device,
tensor.device,
torch.device("meta"),
"Parameter '" + name + "' has not been properly loaded and has device=meta.",
"Tensor '" + name + "' has not been properly loaded and has device=meta.",
)
# Tests moving the model to a device other than meta.
model_low_usage.to(torch_device)
# Check that the parameters are equal.
for p1, p2 in zip(model_low_usage.parameters(), model_non_low_usage.parameters()):
self.assertEquals(p1.data.ne(p2.data).sum(), 0)

View File

@ -70,52 +70,55 @@ class BackboneUtilsTester(unittest.TestCase):
with pytest.raises(
ValueError, match=r"out_features must be a subset of stage_names: \['a'\] got \['a', 'b'\]"
):
verify_out_features_out_indices(["a", "b"], (0, 1), ["a"])
verify_out_features_out_indices(["a", "b"], [0, 1], ["a"])
# Out features must contain no duplicates
with pytest.raises(ValueError, match=r"out_features must not contain any duplicates, got \['a', 'a'\]"):
verify_out_features_out_indices(["a", "a"], None, ["a"])
# Out indices must be a list or tuple
with pytest.raises(ValueError, match="out_indices must be a list or tuple, got <class 'int'>"):
# Out indices must be a list
with pytest.raises(ValueError, match="out_indices must be a list, got <class 'int'>"):
verify_out_features_out_indices(None, 0, ["a", "b"])
with pytest.raises(ValueError, match="out_indices must be a list, got <class 'tuple'>"):
verify_out_features_out_indices(None, (0, 1), ["a", "b"])
# Out indices must be a subset of stage names
with pytest.raises(
ValueError, match=r"out_indices must be valid indices for stage_names \['a'\], got \(0, 1\)"
ValueError, match=r"out_indices must be valid indices for stage_names \['a'\], got \[0, 1\]"
):
verify_out_features_out_indices(None, (0, 1), ["a"])
verify_out_features_out_indices(None, [0, 1], ["a"])
# Out indices must contain no duplicates
with pytest.raises(ValueError, match=r"out_indices must not contain any duplicates, got \(0, 0\)"):
verify_out_features_out_indices(None, (0, 0), ["a"])
with pytest.raises(ValueError, match=r"out_indices must not contain any duplicates, got \[0, 0\]"):
verify_out_features_out_indices(None, [0, 0], ["a"])
# Out features and out indices must be the same length
with pytest.raises(
ValueError, match="out_features and out_indices should have the same length if both are set"
):
verify_out_features_out_indices(["a", "b"], (0,), ["a", "b", "c"])
verify_out_features_out_indices(["a", "b"], [0], ["a", "b", "c"])
# Out features should match out indices
with pytest.raises(
ValueError, match="out_features and out_indices should correspond to the same stages if both are set"
):
verify_out_features_out_indices(["a", "b"], (0, 2), ["a", "b", "c"])
verify_out_features_out_indices(["a", "b"], [0, 2], ["a", "b", "c"])
# Out features and out indices should be in order
with pytest.raises(
ValueError,
match=r"out_features must be in the same order as stage_names, expected \['a', 'b'\] got \['b', 'a'\]",
):
verify_out_features_out_indices(["b", "a"], (0, 1), ["a", "b"])
verify_out_features_out_indices(["b", "a"], [0, 1], ["a", "b"])
with pytest.raises(
ValueError, match=r"out_indices must be in the same order as stage_names, expected \(-2, 1\) got \(1, -2\)"
ValueError, match=r"out_indices must be in the same order as stage_names, expected \[-2, 1\] got \[1, -2\]"
):
verify_out_features_out_indices(["a", "b"], (1, -2), ["a", "b"])
verify_out_features_out_indices(["a", "b"], [1, -2], ["a", "b"])
# Check passes with valid inputs
verify_out_features_out_indices(["a", "b", "d"], (0, 1, -1), ["a", "b", "c", "d"])
verify_out_features_out_indices(["a", "b", "d"], [0, 1, -1], ["a", "b", "c", "d"])
def test_backbone_mixin(self):
backbone = BackboneMixin()

View File

@ -1164,15 +1164,16 @@ if __name__ == "__main__":
json.dump(job_result, fp, indent=4, ensure_ascii=False)
prev_ci_artifacts = None
target_workflow = "huggingface/transformers/.github/workflows/self-scheduled.yml@refs/heads/main"
if os.environ.get("CI_WORKFLOW_REF") == target_workflow:
# Get the last previously completed CI's failure tables
artifact_names = [f"ci_results_{job_name}"]
output_dir = os.path.join(os.getcwd(), "previous_reports")
os.makedirs(output_dir, exist_ok=True)
prev_ci_artifacts = get_last_daily_ci_reports(
artifact_names=artifact_names, output_dir=output_dir, token=os.environ["ACCESS_REPO_INFO_TOKEN"]
)
if job_name == "run_models_gpu":
target_workflow = "huggingface/transformers/.github/workflows/self-scheduled-caller.yml@refs/heads/main"
if os.environ.get("CI_WORKFLOW_REF") == target_workflow:
# Get the last previously completed CI's failure tables
artifact_names = [f"ci_results_{job_name}"]
output_dir = os.path.join(os.getcwd(), "previous_reports")
os.makedirs(output_dir, exist_ok=True)
prev_ci_artifacts = get_last_daily_ci_reports(
artifact_names=artifact_names, output_dir=output_dir, token=os.environ["ACCESS_REPO_INFO_TOKEN"]
)
message = Message(
title,