Support arbitrary processor (#30875)

* Support arbitrary processor

* fix

* nit

* update

* nit

* nit

* fix and revert

* add a small test

* better check

* fixup

* bug so let's just use class for now

* oups

* .
This commit is contained in:
Arthur 2024-05-17 16:51:31 +02:00 committed by GitHub
parent 57edd84bdb
commit 0a9300f474
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 32 additions and 2 deletions

View File

@ -41,8 +41,8 @@ class LlavaProcessor(ProcessorMixin):
"""
attributes = ["image_processor", "tokenizer"]
image_processor_class = "CLIPImageProcessor"
tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast")
image_processor_class = "AutoImageProcessor"
tokenizer_class = "AutoTokenizer"
def __init__(self, image_processor=None, tokenizer=None):
super().__init__(image_processor, tokenizer)

View File

@ -0,0 +1,30 @@
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from transformers.testing_utils import require_vision
from transformers.utils import is_vision_available
if is_vision_available():
from transformers import AutoTokenizer, LlavaProcessor
@require_vision
class LlavaProcessorTest(unittest.TestCase):
def test_can_load_various_tokenizers(self):
for checkpoint in ["Intel/llava-gemma-2b", "llava-hf/llava-1.5-7b-hf"]:
processor = LlavaProcessor.from_pretrained(checkpoint)
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
self.assertEqual(processor.tokenizer.__class__, tokenizer.__class__)