Update Image segmentation description (#23261)

* Update Image segmentation description

* prompt -> label
This commit is contained in:
Lysandre Debut 2023-05-10 15:36:15 +02:00 committed by GitHub
parent 4f05bbf165
commit b203de7c86
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 5 additions and 6 deletions

View File

@ -28,10 +28,9 @@ if is_vision_available():
class ImageSegmentationTool(PipelineTool):
description = (
"This is a tool that creates a segmentation mask identifiying elements inside an image according to a prompt. "
"It takes two arguments named `image` which should be the original image, and `prompt` which should be a text "
"describing the elements what should be identified in the segmentation mask. The tool returns the mask as a "
"black-and-white image."
"This is a tool that creates a segmentation mask of an image according to a label. It cannot create an image."
"It takes two arguments named `image` which should be the original image, and `label` which should be a text "
"describing the elements what should be identified in the segmentation mask. The tool returns the mask."
)
default_checkpoint = "CIDAS/clipseg-rd64-refined"
name = "image_segmenter"
@ -44,9 +43,9 @@ class ImageSegmentationTool(PipelineTool):
requires_backends(self, ["vision"])
super().__init__(*args, **kwargs)
def encode(self, image: "Image", prompt: str):
def encode(self, image: "Image", label: str):
self.pre_processor.image_processor.size = {"width": image.size[0], "height": image.size[1]}
return self.pre_processor(text=[prompt], images=[image], padding=True, return_tensors="pt")
return self.pre_processor(text=[label], images=[image], padding=True, return_tensors="pt")
def forward(self, inputs):
with torch.no_grad():