[idefics] idefics-9b test use 4bit quant (#25734)

This commit is contained in:
Stas Bekman 2023-08-24 08:33:14 -07:00 committed by GitHub
parent fecf08560c
commit 7a6efe1e9f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 18 additions and 3 deletions

View File

@ -16,8 +16,15 @@
import unittest
from transformers import IdeficsConfig, is_torch_available, is_vision_available
from transformers.testing_utils import TestCasePlus, require_torch, require_vision, slow, torch_device
from transformers import BitsAndBytesConfig, IdeficsConfig, is_torch_available, is_vision_available
from transformers.testing_utils import (
TestCasePlus,
require_bitsandbytes,
require_torch,
require_vision,
slow,
torch_device,
)
from transformers.utils import cached_property
from ...test_configuration_common import ConfigTester
@ -434,6 +441,7 @@ class IdeficsModelIntegrationTest(TestCasePlus):
def default_processor(self):
return IdeficsProcessor.from_pretrained("HuggingFaceM4/idefics-9b") if is_vision_available() else None
@require_bitsandbytes
@slow
def test_inference_natural_language_visual_reasoning(self):
cat_image_path = self.tests_dir / "fixtures/tests_samples/COCO/000000039769.png"
@ -459,7 +467,14 @@ class IdeficsModelIntegrationTest(TestCasePlus):
],
]
model = IdeficsForVisionText2Text.from_pretrained("HuggingFaceM4/idefics-9b").to(torch_device)
# the CI gpu is small so using quantization to fit
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype="float16",
)
model = IdeficsForVisionText2Text.from_pretrained(
"HuggingFaceM4/idefics-9b", quantization_config=quantization_config, device_map="auto"
)
processor = self.default_processor
inputs = processor(prompts, return_tensors="pt").to(torch_device)
generated_ids = model.generate(**inputs, max_length=100)