Update Image segmentation description (#23261)
* Update Image segmentation description * prompt -> label
This commit is contained in:
parent
4f05bbf165
commit
b203de7c86
|
@ -28,10 +28,9 @@ if is_vision_available():
|
|||
|
||||
class ImageSegmentationTool(PipelineTool):
|
||||
description = (
|
||||
"This is a tool that creates a segmentation mask identifiying elements inside an image according to a prompt. "
|
||||
"It takes two arguments named `image` which should be the original image, and `prompt` which should be a text "
|
||||
"describing the elements what should be identified in the segmentation mask. The tool returns the mask as a "
|
||||
"black-and-white image."
|
||||
"This is a tool that creates a segmentation mask of an image according to a label. It cannot create an image."
|
||||
"It takes two arguments named `image` which should be the original image, and `label` which should be a text "
|
||||
"describing the elements what should be identified in the segmentation mask. The tool returns the mask."
|
||||
)
|
||||
default_checkpoint = "CIDAS/clipseg-rd64-refined"
|
||||
name = "image_segmenter"
|
||||
|
@ -44,9 +43,9 @@ class ImageSegmentationTool(PipelineTool):
|
|||
requires_backends(self, ["vision"])
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def encode(self, image: "Image", prompt: str):
|
||||
def encode(self, image: "Image", label: str):
|
||||
self.pre_processor.image_processor.size = {"width": image.size[0], "height": image.size[1]}
|
||||
return self.pre_processor(text=[prompt], images=[image], padding=True, return_tensors="pt")
|
||||
return self.pre_processor(text=[label], images=[image], padding=True, return_tensors="pt")
|
||||
|
||||
def forward(self, inputs):
|
||||
with torch.no_grad():
|
||||
|
|
Loading…
Reference in New Issue