Test composition (#23214)
* Remove nestedness in tool config * Really do it * Use remote tools descriptions * Work * Clean up eval * Changes * Tools * Tools * tool * Fix everything * Use last result/assign for evaluation * Prompt * Remove hardcoded selection * Evaluation for chat agents * correct some spelling * Small fixes * Change summarization model (#23172) * Fix link displayed * Update description of the tool * Fixes in chat prompt * Custom tools, custom prompt * Tool clean up * save_pretrained and push_to_hub for tool * Fix init * Tests * Fix tests * Tool save/from_hub/push_to_hub and tool->load_tool * Clean push_to_hub and add app file * Custom inference API for endpoints too * Clean up * old remote tool and new remote tool * Make a requirements * return_code adds tool creation * Avoid redundancy between global variables * Remote tools can be loaded * Tests * Text summarization tests * Quality * Properly mark tests * Test the python interpreter * And the CI shall be green. * fix loading of additional tools * Work on RemoteTool and fix tests * General clean up * Guard imports * Fix tools * docs: Fix broken link in 'How to add a model...' (#23216) fix link * Get default endpoint from the Hub * Add guide * Simplify tool config * Docs * Some fixes * Docs * Docs * Docs * Fix code returned by agent * Try this * Match args with signature in remote tool * Should fix python interpreter for Python 3.8 * Fix push_to_hub for tools * Other fixes to push_to_hub * Add API doc page * Docs * Docs * Custom tools * Pin tensorflow-probability (#23220) * Pin tensorflow-probability * [all-test] * [all-test] Fix syntax for bash * PoC for some chaining API * Text to speech * J'ai pris des libertés * Rename * Basic python interpreter * Add agents * Quality * Add translation tool * temp * GenQA + LID + S2T * Quality + word missing in translation * Add open assistance, support f-strings in evaluate * captioning + s2t fixes * Style * Refactor descriptions and remove chain * Support errors and rename OpenAssistantAgent * Add setup * Deal with typos + example of inference API * Some rename + README * Fixes * Update prompt * Unwanted change * Make sure everyone has a default * One prompt to rule them all. * SD * Description * Clean up remote tools * More remote tools * Add option to return code and update doc * Image segmentation * ControlNet * Gradio demo * Diffusers protection * Lib protection * ControlNet description * Cleanup * Style * Remove accelerate and try to be reproducible * No randomness * Male Basic optional in token * Clean description * Better prompts * Fix args eval in interpreter * Add tool wrapper * Tool on the Hub * Style post-rebase * Big refactor of descriptions, batch generation and evaluation for agents * Make problems easier - interface to debug * More problems, add python primitives * Back to one prompt * Remove dict for translation * Be consistent * Add prompts * New version of the agent * Evaluate new agents * New endpoints agents * Make all tools a dict variable * Typo * Add problems * Add to big prompt * Harmonize * Add tools * New evaluation * Add more tools * Build prompt with tools descriptions * Tools on the Hub * Let's chat! * Cleanup * Temporary bs4 safeguard * Cache agents and clean up * Blank init * Fix evaluation for agents * New format for tools on the Hub * Add method to reset state * Remove nestedness in tool config * Really do it * Use remote tools descriptions * Work * Clean up eval * Changes * Tools * Tools * tool * Fix everything * Use last result/assign for evaluation * Prompt * Remove hardcoded selection * Evaluation for chat agents * correct some spelling * Small fixes * Change summarization model (#23172) * Fix link displayed * Update description of the tool * Fixes in chat prompt * Custom tools, custom prompt * Tool clean up * save_pretrained and push_to_hub for tool * Fix init * Tests * Fix tests * Tool save/from_hub/push_to_hub and tool->load_tool * Clean push_to_hub and add app file * Custom inference API for endpoints too * Clean up * old remote tool and new remote tool * Make a requirements * return_code adds tool creation * Avoid redundancy between global variables * Remote tools can be loaded * Tests * Text summarization tests * Quality * Properly mark tests * Test the python interpreter * And the CI shall be green. * Work on RemoteTool and fix tests * fix loading of additional tools * General clean up * Guard imports * Fix tools * Get default endpoint from the Hub * Simplify tool config * Add guide * Docs * Some fixes * Docs * Docs * Fix code returned by agent * Try this * Docs * Match args with signature in remote tool * Should fix python interpreter for Python 3.8 * Fix push_to_hub for tools * Other fixes to push_to_hub * Add API doc page * Fixes * Doc fixes * Docs * Fix audio * Custom tools * Audio fix * Improve custom tools docstring * Docstrings * Trigger CI * Mode docstrings * More docstrings * Improve custom tools * Fix for remote tools * Style * Fix repo consistency * Quality * Tip * Cleanup on doc * Cleanup toc * Add disclaimer for starcoder vs openai * Remove disclaimer * Small fixed in the prompts * 4.29 * Update src/transformers/tools/agents.py Co-authored-by: Lysandre Debut <lysandre.debut@reseau.eseo.fr> * Complete documentation * Small fixes * Agent evaluation * Note about gradio-tools & LC * Clean up agents and prompt * Apply suggestions from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Apply suggestions from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Note about gradio-tools & LC * Add copyrights and address review comments * Quality * Add all language codes * Add remote tool tests * Move custom prompts to other docs * Apply suggestions from code review Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * TTS tests * Quality --------- Co-authored-by: Lysandre <hi@lyand.re> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Philipp Schmid <32632186+philschmid@users.noreply.github.com> Co-authored-by: Connor Henderson <connor.henderson@talkiatry.com> Co-authored-by: Lysandre <lysandre.debut@reseau.eseo.fr> Co-authored-by: Lysandre <lysandre@huggingface.co> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
parent
366a8ca09e
commit
3335724376
|
@ -43,6 +43,7 @@ def pytest_configure(config):
|
|||
)
|
||||
config.addinivalue_line("markers", "is_staging_test: mark test to run only in the staging environment")
|
||||
config.addinivalue_line("markers", "accelerate_tests: mark test that require accelerate")
|
||||
config.addinivalue_line("markers", "tool_tests: mark the tool tests that are run on their specific schedule")
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
|
|
|
@ -21,6 +21,8 @@
|
|||
title: Set up distributed training with 🤗 Accelerate
|
||||
- local: model_sharing
|
||||
title: Share your model
|
||||
- local: transformers_agents
|
||||
title: Agents
|
||||
title: Tutorials
|
||||
- sections:
|
||||
- sections:
|
||||
|
@ -99,6 +101,8 @@
|
|||
title: Notebooks with examples
|
||||
- local: community
|
||||
title: Community resources
|
||||
- local: custom_tools
|
||||
title: Custom Tools
|
||||
- local: troubleshooting
|
||||
title: Troubleshoot
|
||||
title: Developer guides
|
||||
|
@ -179,6 +183,8 @@
|
|||
title: Conceptual guides
|
||||
- sections:
|
||||
- sections:
|
||||
- local: main_classes/agent
|
||||
title: Agents and Tools
|
||||
- local: model_doc/auto
|
||||
title: Auto Classes
|
||||
- local: main_classes/callback
|
||||
|
|
|
@ -0,0 +1,503 @@
|
|||
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Custom Tools and Prompts
|
||||
|
||||
<Tip>
|
||||
|
||||
If you are not aware of what tools and agents are in the context of transformers, we recommend you read the
|
||||
[Transformers Agents](transformers_agents) page first.
|
||||
|
||||
</Tip>
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Transformers Agent is an experimental API which is subject to change at any time. Results returned by the agents
|
||||
can vary as the APIs or underlying models are prone to change.
|
||||
|
||||
</Tip>
|
||||
|
||||
Creating and using custom tools and prompts is paramount to empowering the agent and having it perform new tasks.
|
||||
In this guide we'll take a look at:
|
||||
|
||||
- How to customize the prompt
|
||||
- How to use custom tools
|
||||
- How to create custom tools
|
||||
|
||||
## Customizing the prompt
|
||||
|
||||
As explained in [Transformers Agents](transformers_agents) agents can run in [`~Agent.run`] and [`~Agent.chat`] mode.
|
||||
Both the run and chat mode underlie the same logic. The language model powering the agent is conditioned on a long prompt
|
||||
and simply asked to complete the prompt by generating next tokens until the stop token is reached.
|
||||
The only difference between the `run` and `chat` mode is that during the `chat` mode the prompt is extended with
|
||||
previous user inputs and model generations, which seemingly gives the agent a memory and allows it to refer to
|
||||
past interactions.
|
||||
|
||||
Let's take a closer look into how the prompt is structured to understand how it can be best customized.
|
||||
The prompt is structured broadly into four parts.
|
||||
|
||||
- 1. Introduction: how the agent should behave, explanation of the concept of tools.
|
||||
- 2. Description of all the tools. This is defined by a `<<all_tools>>` token that is dynamically replaced at runtime with the tools defined/chosen by the user.
|
||||
- 3. A set of examples of tasks and their solution
|
||||
- 4. Current example, and request for solution.
|
||||
|
||||
To better understand each part, let's look at a shortened version of how such a prompt can look like in practice.
|
||||
|
||||
```
|
||||
I will ask you to perform a task, your job is to come up with a series of simple commands in Python that will perform the task.
|
||||
[...]
|
||||
You can print intermediate results if it makes sense to do so.
|
||||
|
||||
Tools:
|
||||
- document_qa: This is a tool that answers a question about an document (pdf). It takes an input named `document` which should be the document containing the information, as well as a `question` that is the question about the document. It returns a text that contains the answer to the question.
|
||||
- image_captioner: This is a tool that generates a description of an image. It takes an input named `image` which should be the image to caption, and returns a text that contains the description in English.
|
||||
[...]
|
||||
|
||||
Task: "Answer the question in the variable `question` about the image stored in the variable `image`. The question is in French."
|
||||
|
||||
I will use the following tools: `translator` to translate the question into English and then `image_qa` to answer the question on the input image.
|
||||
|
||||
Answer:
|
||||
```py
|
||||
translated_question = translator(question=question, src_lang="French", tgt_lang="English")
|
||||
print(f"The translated question is {translated_question}.")
|
||||
answer = image_qa(image=image, question=translated_question)
|
||||
print(f"The answer is {answer}")
|
||||
```
|
||||
|
||||
Task: "Identify the oldest person in the `document` and create an image showcasing the result as a banner."
|
||||
|
||||
I will use the following tools: `document_qa` to find the oldest person in the document, then `image_generator` to generate an image according to the answer.
|
||||
|
||||
Answer:
|
||||
```py
|
||||
answer = document_qa(document, question="What is the oldest person?")
|
||||
print(f"The answer is {answer}.")
|
||||
image = image_generator("A banner showing " + answer)
|
||||
```
|
||||
|
||||
[...]
|
||||
|
||||
Task: "Draw me a picture of rivers and lakes"
|
||||
|
||||
I will use the following
|
||||
```
|
||||
|
||||
The first part explains precisely how the model shall behave and what it should do. This part
|
||||
most likely does not need to be customized.
|
||||
|
||||
TODO(PVP) - explain better how the .description and .name influence the prompt
|
||||
|
||||
### Customizing the tool descriptions
|
||||
|
||||
The performance of the agent is directly linked to the prompt itself. We structure the prompt so that it works well
|
||||
with what we intend for the agent to do; but for maximum customization we also offer the ability to specify a different prompt when instantiating the agent.
|
||||
|
||||
### Customizing the single-execution prompt
|
||||
|
||||
In order to specify a custom single-execution prompt, one would so the following:
|
||||
|
||||
```py
|
||||
template = """ [...] """
|
||||
|
||||
agent = HfAgent(your_endpoint, run_prompt_template=template)
|
||||
```
|
||||
|
||||
<Tip>
|
||||
|
||||
Please make sure to have the `<<all_tools>>` string defined somewhere in the `template` so that the agent can be aware
|
||||
of the tools it has available to it.
|
||||
|
||||
</Tip>
|
||||
|
||||
#### Chat-execution prompt
|
||||
|
||||
In order to specify a custom single-execution prompt, one would so the following:
|
||||
|
||||
```
|
||||
template = """ [...] """
|
||||
|
||||
agent = HfAgent(
|
||||
url_endpoint=your_endpoint,
|
||||
token=your_hf_token,
|
||||
chat_prompt_template=template
|
||||
)
|
||||
```
|
||||
|
||||
<Tip>
|
||||
|
||||
Please make sure to have the `<<all_tools>>` string defined somewhere in the `template` so that the agent can be
|
||||
aware of the tools it has available to it.
|
||||
|
||||
</Tip>
|
||||
|
||||
## Using custom tools
|
||||
|
||||
In this section, we'll be leveraging two existing custom tools that are specific to image generation:
|
||||
|
||||
- We replace [huggingface-tools/image-transformation](https://huggingface.co/spaces/huggingface-tools/image-transformation),
|
||||
with [diffusers/controlnet-canny-tool](https://huggingface.co/spaces/diffusers/controlnet-canny-tool)
|
||||
to allow for more image modifications.
|
||||
- We add a new tool for image upscaling to the default toolbox:
|
||||
[diffusers/latent-upscaler-tool](https://huggingface.co/spaces/diffusers/latent-upscaler-tool) replace the existing image-transformation tool.
|
||||
|
||||
We'll start by loading the custom tools with the convenient [`load_tool`] function:
|
||||
|
||||
```py
|
||||
from transformers import load_tool
|
||||
|
||||
controlnet_transformer = load_tool("diffusers/controlnet-canny-tool")
|
||||
upscaler = load_tool("diffusers/latent-upscaler-tool")
|
||||
```
|
||||
|
||||
Upon adding custom tools to an agent, the tools' descriptions and names are automatically
|
||||
included in the agents' prompts. Thus, it is imperative that custom tools have
|
||||
a well-written description and name in order for the agent to understand how to use them.
|
||||
Let's take a look at the description and name of `controlnet_transformer`:
|
||||
|
||||
```py
|
||||
print(f"Description: '{controlnet_transformer.description}'")
|
||||
print(f"Name: '{controlnet_transformer.name}'")
|
||||
```
|
||||
|
||||
gives
|
||||
```
|
||||
Description: 'This is a tool that transforms an image with ControlNet according to a prompt.
|
||||
It takes two inputs: `image`, which should be the image to transform, and `prompt`, which should be the prompt to use to change it. It returns the modified image.'
|
||||
Name: 'image_transformer'
|
||||
```
|
||||
|
||||
The name and description is accurate and fits the style of the [curated set of tools](./transformers_agents#a-curated-set-of-tools).
|
||||
Next, let's instantiate an agent with `controlnet_transformer` and `upscaler`:
|
||||
|
||||
```py
|
||||
tools = [controlnet_transformer, upscaler]
|
||||
agent = HfAgent("https://api-inference.huggingface.co/models/bigcode/starcoder", additional_tools=tools)
|
||||
```
|
||||
|
||||
This command should give you the following info:
|
||||
|
||||
```
|
||||
image_transformer has been replaced by <transformers_modules.diffusers.controlnet-canny-tool.bd76182c7777eba9612fc03c0
|
||||
8718a60c0aa6312.image_transformation.ControlNetTransformationTool object at 0x7f1d3bfa3a00> as provided in `additional_tools`
|
||||
```
|
||||
|
||||
The set of curated tools already has a `image_transformer` tool which is hereby replaced with our custom tool.
|
||||
|
||||
<Tip>
|
||||
|
||||
Overwriting existing tools can be beneficial if we want to use a custom tool exactly for the same task as an existing tool
|
||||
because the agent is well-versed in using the specific task. Beware that the custom tool should follow the exact same API
|
||||
as the overwritten tool in this case.
|
||||
|
||||
</Tip>
|
||||
|
||||
The upscaler tool was given the name `image_upscaler` which is not yet present in the default toolbox and is therefore is simply added to the list of tools.
|
||||
You can always have a look at the toolbox that is currently available to the agent via the `agent.toolbox` attribute:
|
||||
|
||||
```py
|
||||
print("\n".join([f"- {a}" for a in agent.toolbox.keys()]))
|
||||
```
|
||||
|
||||
```
|
||||
- document_qa
|
||||
- image_captioner
|
||||
- image_qa
|
||||
- image_segmenter
|
||||
- transcriber
|
||||
- summarizer
|
||||
- text_classifier
|
||||
- text_qa
|
||||
- text_reader
|
||||
- translator
|
||||
- image_transformer
|
||||
- text_downloader
|
||||
- image_generator
|
||||
- video_generator
|
||||
- image_upscaler
|
||||
```
|
||||
|
||||
Note how `image_upscaler` is now part of the agents' toolbox.
|
||||
|
||||
Let's now try out the new tools! We will re-use the image we generated in (Transformers Agents Quickstart)[./transformers_agents#single-execution-run].
|
||||
|
||||
```py
|
||||
from diffusers.utils import load_image
|
||||
|
||||
image = load_image(
|
||||
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rivers_and_lakes.png"
|
||||
)
|
||||
```
|
||||
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rivers_and_lakes.png" width=200>
|
||||
|
||||
Let's transform the image into a beautiful winter landscape:
|
||||
|
||||
```py
|
||||
image = agent.run("Transform the image: 'A frozen lake and snowy forest'", image=image)
|
||||
```
|
||||
|
||||
```
|
||||
==Explanation from the agent==
|
||||
I will use the following tool: `image_transformer` to transform the image.
|
||||
|
||||
|
||||
==Code generated by the agent==
|
||||
image = image_transformer(image, prompt="A frozen lake and snowy forest")
|
||||
```
|
||||
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rivers_and_lakes_winter.png" width=200>
|
||||
|
||||
The new image processing tool is based on ControlNet which is can make very strong modifications to the image.
|
||||
By default the image processing tool returns an image of size 512x512 pixels. Let's see if we can upscale it.
|
||||
|
||||
```py
|
||||
image = agent.run("Upscale the image", image)
|
||||
```
|
||||
|
||||
```
|
||||
==Explanation from the agent==
|
||||
I will use the following tool: `image_upscaler` to upscale the image.
|
||||
|
||||
|
||||
==Code generated by the agent==
|
||||
upscaled_image = image_upscaler(image)
|
||||
```
|
||||
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rivers_and_lakes_winter_upscale.png" width=400>
|
||||
|
||||
The agent automatically mapped our prompt "Upscale the image" to the just added upscaler tool purely based on the description and name of the upscaler tool
|
||||
and was able to correctly run it.
|
||||
|
||||
Next, let's have a look into how you can create a new custom tool.
|
||||
|
||||
### Adding new tools
|
||||
|
||||
In this section we show how to create a new tool that can be added to the agent.
|
||||
|
||||
#### Creating a new tool
|
||||
|
||||
We'll first start by creating a tool. We'll add the not-so-useful yet fun task of fetching the model on the Hugging Face
|
||||
Hub with the most downloads for a given task.
|
||||
|
||||
We can do that with the following code:
|
||||
|
||||
```python
|
||||
from huggingface_hub import list_models
|
||||
|
||||
task = "text-classification"
|
||||
|
||||
model = next(iter(list_models(filter=task, sort="downloads", direction=-1)))
|
||||
print(model.id)
|
||||
```
|
||||
|
||||
For the task `text-classification`, this returns `'facebook/bart-large-mnli'`, for `translation` it returns `'t5-base`.
|
||||
|
||||
How do we convert this to a tool that the agent can leverage? All tools depend on the superclass `Tool` that holds the
|
||||
main attributes necessary. We'll create a class that inherits from it:
|
||||
|
||||
```python
|
||||
from transformers import Tool
|
||||
|
||||
|
||||
class HFModelDownloadsTool(Tool):
|
||||
pass
|
||||
```
|
||||
|
||||
This class has a few needs:
|
||||
- An attribute `name`, which corresponds to the name of the tool itself. To be in tune with other tools which have a
|
||||
performative name, we'll name it `model_download_counter`.
|
||||
- An attribute `description`, which will be used to populate the prompt of the agent.
|
||||
- `inputs` and `outputs` attributes. Defining this will help the python interpreter make educated choices about types,
|
||||
and will allow for a gradio-demo to be spawned when we push our tool to the Hub. They're both a list of expected
|
||||
values, which can be `text`, `image`, or `audio`.
|
||||
- A `__call__` method which contains the inference code. This is the code we've played with above!
|
||||
|
||||
Here's what our class looks like now:
|
||||
|
||||
```python
|
||||
from transformers import Tool
|
||||
from huggingface_hub import list_models
|
||||
|
||||
|
||||
class HFModelDownloadsTool(Tool):
|
||||
name = "model_download_counter"
|
||||
description = (
|
||||
"This is a tool that returns the most downloaded model of a given task on the Hugging Face Hub. "
|
||||
"It takes the name of the category (such as text-classification, depth-estimation, etc), and "
|
||||
"returns the name of the checkpoint."
|
||||
)
|
||||
|
||||
inputs = ["text"]
|
||||
outputs = ["text"]
|
||||
|
||||
def __call__(self, task: str):
|
||||
model = next(iter(list_models(filter=task, sort="downloads", direction=-1)))
|
||||
return model.id
|
||||
```
|
||||
|
||||
We now have our tool handy. Save it in a file and import it from your main script. Let's name this file
|
||||
`model_downloads.py`, so the resulting import code looks like this:
|
||||
|
||||
```python
|
||||
from model_downloads import HFModelDownloadsTool
|
||||
|
||||
tool = HFModelDownloadsTool()
|
||||
```
|
||||
|
||||
In order to let others benefit from it and for simpler initialization, we recommend pushing it to the Hub under your
|
||||
namespace. To do so, just call `push_to_hub` on the `tool` variable:
|
||||
|
||||
```python
|
||||
tool.push_to_hub("lysandre/hf-model-downloads")
|
||||
```
|
||||
|
||||
You now have your code on the Hub! Let's take a look at the final step, which is to have the agent use it.
|
||||
|
||||
#### Having the agent use the tool
|
||||
|
||||
We now have our tool that lives on the Hub which can be instantiated as such:
|
||||
|
||||
```python
|
||||
from transformers import load_tool
|
||||
|
||||
tool = load_tool("lysandre/hf-model-downloads")
|
||||
```
|
||||
|
||||
In order to use it in the agent, simply pass it in the `additional_tools` parameter of the agent initialization method:
|
||||
|
||||
```python
|
||||
from transformers import HfAgent
|
||||
|
||||
agent = HfAgent("https://api-inference.huggingface.co/models/bigcode/starcoder", additional_tools=[tool])
|
||||
|
||||
agent.run(
|
||||
"Can you read out loud the name of the model that has the most downloads in the 'text-to-video' task on the Hugging Face Hub?"
|
||||
)
|
||||
```
|
||||
which outputs the following:
|
||||
```
|
||||
==Code generated by the agent==
|
||||
model = model_download_counter(task="text-to-video")
|
||||
print(f"The model with the most downloads is {model}.")
|
||||
audio_model = text_reader(model)
|
||||
|
||||
|
||||
==Result==
|
||||
The model with the most downloads is damo-vilab/text-to-video-ms-1.7b.
|
||||
```
|
||||
|
||||
and generates the following audio.
|
||||
|
||||
| **Audio** |
|
||||
|------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| <audio controls><source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/damo.wav" type="audio/wav"/> |
|
||||
|
||||
|
||||
<Tip>
|
||||
|
||||
Depending on the LLM, some are quite brittle and require very exact prompts in order to work well. Having a well-defined
|
||||
description of the tool is paramount to having it be leveraged by the agent.
|
||||
|
||||
</Tip>
|
||||
|
||||
### Replacing existing tools
|
||||
|
||||
Replacing existing tools can be done simply by assigning a new item to the agent's toolbox. Here's how one would do so:
|
||||
|
||||
```python
|
||||
from transformers import HfAgent, load_tool
|
||||
|
||||
agent = HfAgent("https://api-inference.huggingface.co/models/bigcode/starcoder")
|
||||
agent.toolbox["image-transformation"] = load_tool("diffusers/controlnet-canny-tool")
|
||||
```
|
||||
|
||||
<Tip>
|
||||
|
||||
Beware when replacing tools with others! This will also adjust the agent's prompt. This can be good if you have a better
|
||||
prompt suited for the task, but it can also result in your tool being selected way more than others or for other
|
||||
tools to be selected instead of the one you have defined.
|
||||
|
||||
</Tip>
|
||||
|
||||
## Leveraging gradio-tools
|
||||
|
||||
[gradio-tools](https://github.com/freddyaboulton/gradio-tools) is a powerful library that allows using Hugging
|
||||
Face Spaces as tools. It supports many existing Spaces as well as custom Spaces to be designed with it.
|
||||
|
||||
We offer support for `gradio_tools` by using the `Tool.from_gradio` method. For example, we want to take
|
||||
advantage of the `StableDiffusionPromptGeneratorTool` tool offered in the `gradio-tools` toolkit so as to
|
||||
improve our prompts and generate better images.
|
||||
|
||||
We first import the tool from `gradio_tools` and instantiate it:
|
||||
|
||||
```python
|
||||
from gradio_tools import StableDiffusionPromptGeneratorTool
|
||||
|
||||
gradio_tool = StableDiffusionPromptGeneratorTool()
|
||||
```
|
||||
|
||||
We pass that instance to the `Tool.from_gradio` method:
|
||||
|
||||
```python
|
||||
from transformers import Tool
|
||||
|
||||
tool = Tool.from_gradio(gradio_tools)
|
||||
```
|
||||
|
||||
Now we can manage it exactly as we would a usual custom tool. We leverage it to improve our prompt
|
||||
` a rabbit wearing a space suit`:
|
||||
|
||||
```python
|
||||
from transformers import HfAgent
|
||||
|
||||
agent = HfAgent("https://api-inference.huggingface.co/models/bigcode/starcoder", additional_tools=[tool])
|
||||
|
||||
agent.run("Generate an image of the `prompt` after improving it.", prompt="A rabbit wearing a space suit")
|
||||
```
|
||||
|
||||
The model adequately leverages the tool:
|
||||
```
|
||||
==Explanation from the agent==
|
||||
I will use the following tools: `StableDiffusionPromptGenerator` to improve the prompt, then `image_generator` to generate an image according to the improved prompt.
|
||||
|
||||
|
||||
==Code generated by the agent==
|
||||
improved_prompt = StableDiffusionPromptGenerator(prompt)
|
||||
print(f"The improved prompt is {improved_prompt}.")
|
||||
image = image_generator(improved_prompt)
|
||||
```
|
||||
|
||||
Before finally generating the image:
|
||||
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png">
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
gradio-tools requires *textual* inputs and outputs, even when working with different modalities. This implementation
|
||||
works with image and audio objects. The two are currently incompatible, but will rapidly become compatible as we
|
||||
work to improve the support.
|
||||
|
||||
</Tip>
|
||||
|
||||
## Future compatibility with Langchain
|
||||
|
||||
We love Langchain and think it has a very compelling suite of tools. In order to handle these tools,
|
||||
Langchain requires *textual* inputs and outputs, even when working with different modalities.
|
||||
This is often the serialized version (i.e., saved to disk) of the objects.
|
||||
|
||||
This difference means that multi-modality isn't handled between transformers-agents and langchain.
|
||||
We aim for this limitation to be resolved in future versions, and welcome any help from avid langchain
|
||||
users to help us achieve this compatibility.
|
||||
|
||||
We would love to have better support. If you would like to help, please
|
||||
[open an issue](https://github.com/huggingface/transformers/issues/new) and share what you have in mind.
|
|
@ -0,0 +1,64 @@
|
|||
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Agents & Tools
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Transformers Agent is an experimental API which is subject to change at any time. Results returned by the agents
|
||||
can vary as the APIs or underlying models are prone to change.
|
||||
|
||||
</Tip>
|
||||
|
||||
To learn more about agents and tools make sure to read the [introductory guide](../agents_and_tools). This page
|
||||
contains the API docs for the underlying classes.
|
||||
|
||||
## Agents
|
||||
|
||||
We provide two types of agents: [`HfAgent`] uses inference endpoints for opensource models and [`OpenAiAgent`] uses OpenAI closed models.
|
||||
|
||||
### HfAgent
|
||||
|
||||
[[autodoc]] HfAgent
|
||||
|
||||
### OpenAiAgent
|
||||
|
||||
[[autodoc]] OpenAiAgent
|
||||
|
||||
### Agent
|
||||
|
||||
[[autodoc]] Agent
|
||||
- chat
|
||||
- run
|
||||
- prepare_for_new_chat
|
||||
|
||||
## Tools
|
||||
|
||||
### load_tool
|
||||
|
||||
[[autodoc]] load_tool
|
||||
|
||||
### Tool
|
||||
|
||||
[[autodoc]] Tool
|
||||
|
||||
### PipelineTool
|
||||
|
||||
[[autodoc]] PipelineTool
|
||||
|
||||
### RemoteTool
|
||||
|
||||
[[autodoc]] RemoteTool
|
||||
|
||||
### launch_gradio_demo
|
||||
|
||||
[[autodoc]] launch_gradio_demo
|
|
@ -0,0 +1,329 @@
|
|||
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Transformers Agent
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Transformers Agent is an experimental API which is subject to change at any time. Results returned by the agents
|
||||
can vary as the APIs or underlying models are prone to change.
|
||||
|
||||
</Tip>
|
||||
|
||||
Transformers version v4.29.0, building on the concept of *tools* and *agents*.
|
||||
|
||||
In short, it provides a natural language API on top of transformers: we define a set of curated tools, and design an
|
||||
agent to interpret natural language and to use these tools. It is extensible by design; we curated some relevant tools,
|
||||
but we'll show you how the system can be extended easily to use any tool developed by the community.
|
||||
|
||||
Let's start with a few examples of what can be achieved with this new API. It is particularly powerful when it comes
|
||||
to multimodal tasks, so let's take it for a spin to generate images and read text out loud.
|
||||
|
||||
```py
|
||||
agent.run("Caption the following image", image=image)
|
||||
```
|
||||
|
||||
| **Input** | **Output** |
|
||||
|-----------------------------------------------------------------------------------------------------------------------------|-----------------------------------|
|
||||
| <img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/beaver.png" width=200> | A beaver is swimming in the water |
|
||||
|
||||
---
|
||||
|
||||
```py
|
||||
agent.run("Read the following text out loud", text=text)
|
||||
```
|
||||
| **Input** | **Output** |
|
||||
|-------------------------------------------------------------------------------------------------------------------------|----------------------------------------------|
|
||||
| A beaver is swimming in the water | <audio controls><source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tts_example.wav" type="audio/wav"> your browser does not support the audio element. </audio>
|
||||
|
||||
---
|
||||
|
||||
```py
|
||||
agent.run(
|
||||
"In the following `document`, where will the TRRF Scientific Advisory Council Meeting take place?",
|
||||
document=document,
|
||||
)
|
||||
```
|
||||
| **Input** | **Output** |
|
||||
|-----------------------------------------------------------------------------------------------------------------------------|----------------|
|
||||
| <img src="https://datasets-server.huggingface.co/assets/hf-internal-testing/example-documents/--/hf-internal-testing--example-documents/test/0/image/image.jpg" width=200> | ballroom foyer |
|
||||
|
||||
## Quickstart
|
||||
|
||||
Before being able to use `agent.run`, you will need to instantiate an agent, which is a large language model (LLM).
|
||||
We recommend using the [bigcode/starcoder](https://huggingface.co/bigcode/starcoder) checkpoint as it works very well
|
||||
for the task at hand and is open-source, but please find other examples below.
|
||||
|
||||
Start by logging-in to have access to the Inference API:
|
||||
|
||||
```py
|
||||
from huggingface_hub import login
|
||||
|
||||
login("<YOUR_TOKEN>")
|
||||
```
|
||||
|
||||
Then, instantiate the agent
|
||||
|
||||
```py
|
||||
from transformers import HfAgent
|
||||
|
||||
agent = HfAgent("https://api-inference.huggingface.co/models/bigcode/starcoder")
|
||||
```
|
||||
|
||||
This is using the inference API that Hugging Face provides for free at the moment, if you have your own inference
|
||||
endpoint for this model (or another one) you can replace the url above by your url endpoint.
|
||||
|
||||
<Tip>
|
||||
|
||||
We're showcasing StarCoder as the default in the documentation as the model is free to use and performs admirably well
|
||||
on simple tasks. However, the checkpoint doesn't hold up when handling more complex prompts. If you're facing such an
|
||||
issue, we recommend trying out the OpenAI model which, while sadly not open-source, performs better at this given time.
|
||||
|
||||
</Tip>
|
||||
|
||||
You're now good to go! Let's dive into the two APIs that you now have at your disposal.
|
||||
|
||||
### Single execution (run)
|
||||
|
||||
The single execution method is when using the [`~Agent.run`] method of the agent:
|
||||
|
||||
```py
|
||||
agent.run("Draw me a picture of rivers and lakes")
|
||||
```
|
||||
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rivers_and_lakes.png" width=200>
|
||||
|
||||
It automatically select the tool (or tools) appropriate for the task you want to perform and run them appropriately. It
|
||||
can perform one or several tasks in the same instruction (though the more complex your instruction, the more likely
|
||||
the agent is to fail).
|
||||
|
||||
```py
|
||||
agent.chat("Draw me a picture of the sea then transform the picture to add an island.")
|
||||
```
|
||||
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/sea_and_island.png" width=200>
|
||||
|
||||
<br/>
|
||||
|
||||
|
||||
Every [`~Agent.run`] operation is independent, so you can run it several times in a row with different tasks.
|
||||
|
||||
Note that your `agent` is just a large-language model, so small variations in your prompt might yield completely
|
||||
different results. It's important to explain as clearly as possible the task you want to perform.
|
||||
|
||||
If you'd like to keep a state across executions or to pass non-text objects to the agent, you can do so by specifying
|
||||
variables that you would like the agent to use. For example you could generate the first image of rivers and lakes,
|
||||
and ask the model to update that picture to add an island by doing the following:
|
||||
|
||||
```python
|
||||
picture = agent.run("Draw me a picture of rivers and lakes")
|
||||
updated_picture = agent.chat("Take that `picture` and add an island to it", picture=picture)
|
||||
```
|
||||
|
||||
<Tip>
|
||||
|
||||
This can be helpful when the model is unable to understand your request and mixes tools. An example would be:
|
||||
|
||||
```python
|
||||
agent.run("Draw me the picture of a capybara swimming in the sea")
|
||||
```
|
||||
|
||||
Here, the model could interpret it two ways:
|
||||
- Have the `text-to-image` generate a capybara swimming in the sea
|
||||
- Or, have the `text-to-image` generate capybara, then use the `image-transformation` tool to have it swim in the sea
|
||||
|
||||
In case you would like to force the first scenario, you could do so by passing it the prompt as an argument:
|
||||
|
||||
```python
|
||||
agent.run("Draw me a picture of the `prompt`", prompt="a capybara swimming in the sea")
|
||||
```
|
||||
|
||||
</Tip>
|
||||
|
||||
|
||||
### Chat-based execution (chat)
|
||||
|
||||
The agent also has a chat-based approach, using the [`~Agent.chat`] method:
|
||||
|
||||
```py
|
||||
agent.chat("Draw me a picture of rivers and lakes")
|
||||
```
|
||||
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rivers_and_lakes.png" width=200>
|
||||
|
||||
```py
|
||||
agent.chat("Transform the picture so that there is a rock in there")
|
||||
```
|
||||
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rivers_and_lakes_and_beaver.png" width=200>
|
||||
|
||||
<br/>
|
||||
|
||||
This is an interesting approach when you want to keep the state across instructions. It's better for experimentation,
|
||||
but will tend to be much better at single instructions rather than complex instructions (which the [`~Agent.run`]
|
||||
method is better at handling).
|
||||
|
||||
This method can also take arguments if you would like to pass non-text types or specific prompts.
|
||||
|
||||
### ⚠️ Remote execution
|
||||
|
||||
For demonstration purposes and so that this can be used with all setups, we have created remote executors for several
|
||||
of the default tools the agent has access to. These are created using
|
||||
[inference endpoints](https://huggingface.co/inference-endpoints). To see how to setup remote executors tools yourself,
|
||||
we recommend reading the custom tool guide [TODO LINK].
|
||||
|
||||
In order to run with remote tools, specifying `remote=True` to either [`~Agent.run`] or [`~Agent.chat`] is sufficient.
|
||||
|
||||
For example, the following command could be run on any device efficiently, without needing significant RAM or GPU:
|
||||
|
||||
```python
|
||||
agent.run("Draw me a picture of rivers and lakes", remote=True)
|
||||
```
|
||||
|
||||
The same can be said for [`~Agent.chat`]:
|
||||
|
||||
```py
|
||||
agent.chat("Draw me a picture of rivers and lakes", remote=True)
|
||||
```
|
||||
|
||||
### What's happening here? What are tools, and what are agents?
|
||||
|
||||
#### Agents
|
||||
|
||||
The "agent" here is a large language model, and we're prompting it so that it has access to a specific set of tools.
|
||||
|
||||
LLMs are pretty good at generating small samples of code, so this API takes advantage of that by prompting the
|
||||
LLM to give a small sample of code performing a task with a set of tools. This prompt is then completed by the
|
||||
task you give your agent and the description of the tools you give it. This way it gets access to the doc of the
|
||||
tools you are using, especially their expected inputs and outputs and can generate the relevant code.
|
||||
|
||||
#### Tools
|
||||
|
||||
Tools are very simple: they're a single function, with a name, and a description. We then use these tools description
|
||||
to prompt the agent. Through the prompt, we show the agent how it would leverage tools in order to perform what was
|
||||
requests in the query.
|
||||
|
||||
This is using brand-new tools and not pipelines, because the agent writes better code with very atomic tools.
|
||||
Pipelines are more refactored and often combine several tasks in one. Tools are really meant to be focused on
|
||||
one very simple task only.
|
||||
|
||||
#### Code-execution?!
|
||||
|
||||
This code is then executed with our small Python interpreter on the set of inputs passed along with your tools.
|
||||
We hear you screaming "Arbitrary code execution!" in the back, but let us explain why that is not the case.
|
||||
|
||||
The only functions that can be called are the tools you provided and the print function, so you're already
|
||||
limited in what can be executed. You should be safe if it's limited to Hugging Face tools.
|
||||
|
||||
Then, we don't allow any attribute lookup or imports (which shouldn't be needed anyway for passing along
|
||||
inputs/outputs to a small set of functions) so all the most obvious attacks (and you'd need to prompt the LLM
|
||||
to output them anyway) shouldn't be an issue. If you want to be on the super safe side, you can execute the
|
||||
run() method with the additional argument return_code=True, in which case the agent will just return the code
|
||||
to execute and you can decide whether to do it or not.
|
||||
|
||||
The execution will stop at any line trying to perform an illegal operation or if there is a regular Python error
|
||||
with the code generated by the agent.
|
||||
|
||||
### A curated set of tools
|
||||
|
||||
We identify a set of tools that can empower such agents. Here is an updated list of the tools we have integrated
|
||||
in `transformers`:
|
||||
|
||||
- **Document question answering**: given a document (such as a PDF) in image format, answer a question on this document ([Donut](../model_doc/donut))
|
||||
- **Text question answering**: given a long text and a question, answer the question in the text ([Flan-T5](../model_doc/flan-t5))
|
||||
- **Unconditional image captioning**: Caption the image! ([BLIP](../model_doc/blip))
|
||||
- **Image question answering**: given an image, answer a question on this image ([VILT](../model_doc/vilt))
|
||||
- **Image segmentation**: given an image and a prompt, output the segmentation mask of that prompt ([CLIPSeg](../model_doc/clipseg))
|
||||
- **Speech to text**: given an audio recording of a person talking, transcribe the speech into text ([Whisper](../model_doc/whisper))
|
||||
- **Text to speech**: convert text to speech ([SpeechT5](../model_doc/speecht5))
|
||||
- **Zero-shot text classification**: given a text and a list of labels, identify to which label the text corresponds the most ([BART](../model_doc/bart))
|
||||
- **Text summarization**: summarize a long text in one or a few sentences ([BART](../model_doc/bart))
|
||||
- **Translation**: translate the text into a given language ([NLLB](../model_doc/nllb))
|
||||
|
||||
These tools have an integration in transformers, and can be used manually as well, for example:
|
||||
|
||||
```py
|
||||
from transformers import load_tool
|
||||
|
||||
tool = load_tool("text-to-speech")
|
||||
audio = tool("This is a text to speech tool")
|
||||
```
|
||||
|
||||
### Custom tools
|
||||
|
||||
While we identify a curated set of tools, we strongly believe that the main value provided by this implementation is
|
||||
the ability to quickly create and share custom tools.
|
||||
|
||||
By pushing the code of a tool to a Hugging Face Space or a model repository, you're then able to leverage the tool
|
||||
directly with the agent. We've added a few
|
||||
**transformers-agnostic** tools to the `huggingface-tools` organization:
|
||||
|
||||
- **Text downloader**: to download a text from a web URL
|
||||
- **Text to image**: generate an image according to a prompt, leveraging stable diffusion
|
||||
- **Image transformation**: modify an image given an initial image and a prompt, leveraging instruct pix2pix stable diffusion
|
||||
|
||||
The text-to-image tool we have been using since the beginning is actually a remote tool that lives in
|
||||
[*huggingface-tools/text-to-image*](https://huggingface.co/spaces/huggingface-tools/text-to-image)! We will
|
||||
continue releasing such tools on this and other organization, to further supercharge this implementation.
|
||||
|
||||
The agents have by default access to tools that reside on `huggingface-tools`.
|
||||
We explain how to you can write and share your own tools as well as leverage any custom tool that resides on the Hub in [following guide](custom_tools).
|
||||
[following guide](custom_tools).
|
||||
|
||||
### Leveraging different agents
|
||||
|
||||
We showcase here how to use the [bigcode/starcoder](https://huggingface.co/bigcode/starcoder) model as an LLM, but
|
||||
it isn't the only model available. We also support the OpenAssistant model and OpenAI's davinci models (3.5 and 4).
|
||||
|
||||
We're planning on supporting local language models in an ulterior version.
|
||||
|
||||
The tools defined in this implementation are agnostic to the agent used; we are showcasing the agents that work with
|
||||
our prompts below, but the tools can also be used with Langchain, Minichain, or any other Agent-based library.
|
||||
|
||||
#### Example code for the OpenAssistant model
|
||||
|
||||
```py
|
||||
from transformers import HfAgent
|
||||
|
||||
agent = HfAgent(url_endpoint="https://OpenAssistant/oasst-sft-1-pythia-12b", token="<HF_TOKEN>")
|
||||
```
|
||||
|
||||
#### Example code for OpenAI models
|
||||
|
||||
```py
|
||||
from transformers import OpenAiAgent
|
||||
|
||||
agent = OpenAiAgent(model="text-davinci-003", api_key="<API_KEY>")
|
||||
```
|
||||
|
||||
### Code generation
|
||||
|
||||
So far we have shown how to use the agents to perform actions for you. However, the agent is really only generating code
|
||||
that we then execute using a very restricted Python interpreter. In case you would like to use the code generated in
|
||||
a different setting, the agent can be prompted to return the code, along with tool definition and accurate imports.
|
||||
|
||||
For example, the following instruction
|
||||
```python
|
||||
agent.run("Draw me a picture of rivers and lakes", return_code=True)
|
||||
```
|
||||
|
||||
returns the following code
|
||||
|
||||
```python
|
||||
from transformers import load_tool
|
||||
|
||||
image_generator = load_tool("huggingface-tools/text-to-image")
|
||||
|
||||
image = image_generator(prompt="rivers and lakes")
|
||||
```
|
||||
|
||||
that you can then modify and execute yourself.
|
|
@ -610,6 +610,16 @@ _import_structure = {
|
|||
"SpecialTokensMixin",
|
||||
"TokenSpan",
|
||||
],
|
||||
"tools": [
|
||||
"Agent",
|
||||
"HfAgent",
|
||||
"OpenAiAgent",
|
||||
"PipelineTool",
|
||||
"RemoteTool",
|
||||
"Tool",
|
||||
"launch_gradio_demo",
|
||||
"load_tool",
|
||||
],
|
||||
"trainer_callback": [
|
||||
"DefaultFlowCallback",
|
||||
"EarlyStoppingCallback",
|
||||
|
@ -4340,6 +4350,9 @@ if TYPE_CHECKING:
|
|||
TokenSpan,
|
||||
)
|
||||
|
||||
# Tools
|
||||
from .tools import Agent, HfAgent, OpenAiAgent, PipelineTool, RemoteTool, Tool, launch_gradio_demo, load_tool
|
||||
|
||||
# Trainer
|
||||
from .trainer_callback import (
|
||||
DefaultFlowCallback,
|
||||
|
|
|
@ -115,9 +115,9 @@ def get_relative_import_files(module_file):
|
|||
return all_relative_imports
|
||||
|
||||
|
||||
def check_imports(filename):
|
||||
def get_imports(filename):
|
||||
"""
|
||||
Check if the current Python environment contains all the libraries that are imported in a file.
|
||||
Extracts all the libraries that are imported in a file.
|
||||
"""
|
||||
with open(filename, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
|
@ -131,9 +131,14 @@ def check_imports(filename):
|
|||
imports += re.findall(r"^\s*from\s+(\S+)\s+import", content, flags=re.MULTILINE)
|
||||
# Only keep the top-level module
|
||||
imports = [imp.split(".")[0] for imp in imports if not imp.startswith(".")]
|
||||
return list(set(imports))
|
||||
|
||||
# Unique-ify and test we got them all
|
||||
imports = list(set(imports))
|
||||
|
||||
def check_imports(filename):
|
||||
"""
|
||||
Check if the current Python environment contains all the libraries that are imported in a file.
|
||||
"""
|
||||
imports = get_imports(filename)
|
||||
missing_packages = []
|
||||
for imp in imports:
|
||||
try:
|
||||
|
@ -169,6 +174,7 @@ def get_cached_module_file(
|
|||
use_auth_token: Optional[Union[bool, str]] = None,
|
||||
revision: Optional[str] = None,
|
||||
local_files_only: bool = False,
|
||||
repo_type: Optional[str] = None,
|
||||
_commit_hash: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
|
@ -207,6 +213,8 @@ def get_cached_module_file(
|
|||
identifier allowed by git.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
If `True`, will only try to load the tokenizer configuration from local files.
|
||||
repo_type (`str`, *optional*):
|
||||
Specify the repo type (useful when downloading from a space for instance).
|
||||
|
||||
<Tip>
|
||||
|
||||
|
@ -229,7 +237,7 @@ def get_cached_module_file(
|
|||
else:
|
||||
submodule = pretrained_model_name_or_path.replace("/", os.path.sep)
|
||||
cached_module = try_to_load_from_cache(
|
||||
pretrained_model_name_or_path, module_file, cache_dir=cache_dir, revision=_commit_hash
|
||||
pretrained_model_name_or_path, module_file, cache_dir=cache_dir, revision=_commit_hash, repo_type=repo_type
|
||||
)
|
||||
|
||||
new_files = []
|
||||
|
@ -245,6 +253,7 @@ def get_cached_module_file(
|
|||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
revision=revision,
|
||||
repo_type=repo_type,
|
||||
_commit_hash=_commit_hash,
|
||||
)
|
||||
if not is_local and cached_module != resolved_module_file:
|
||||
|
@ -309,8 +318,10 @@ def get_cached_module_file(
|
|||
|
||||
if len(new_files) > 0:
|
||||
new_files = "\n".join([f"- {f}" for f in new_files])
|
||||
repo_type_str = "" if repo_type is None else f"{repo_type}/"
|
||||
url = f"https://huggingface.co/{repo_type_str}{pretrained_model_name_or_path}"
|
||||
logger.warning(
|
||||
f"A new version of the following files was downloaded from {pretrained_model_name_or_path}:\n{new_files}"
|
||||
f"A new version of the following files was downloaded from {url}:\n{new_files}"
|
||||
"\n. Make sure to double-check they do not contain any added malicious code. To avoid downloading new "
|
||||
"versions of the code file, you can pin a revision."
|
||||
)
|
||||
|
@ -328,6 +339,7 @@ def get_class_from_dynamic_module(
|
|||
use_auth_token: Optional[Union[bool, str]] = None,
|
||||
revision: Optional[str] = None,
|
||||
local_files_only: bool = False,
|
||||
repo_type: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
|
@ -377,6 +389,8 @@ def get_class_from_dynamic_module(
|
|||
identifier allowed by git.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
If `True`, will only try to load the tokenizer configuration from local files.
|
||||
repo_type (`str`, *optional*):
|
||||
Specify the repo type (useful when downloading from a space for instance).
|
||||
|
||||
<Tip>
|
||||
|
||||
|
@ -418,6 +432,7 @@ def get_class_from_dynamic_module(
|
|||
use_auth_token=use_auth_token,
|
||||
revision=revision,
|
||||
local_files_only=local_files_only,
|
||||
repo_type=repo_type,
|
||||
)
|
||||
return get_class_in_module(class_name, final_module.replace(".py", ""))
|
||||
|
||||
|
@ -439,6 +454,7 @@ def custom_object_save(obj, folder, config=None):
|
|||
"this code in a separate module so we can include it in the saved folder and make it easier to share via "
|
||||
"the Hub."
|
||||
)
|
||||
return
|
||||
|
||||
def _set_auto_map_in_config(_config):
|
||||
module_name = obj.__class__.__module__
|
||||
|
@ -478,12 +494,17 @@ def custom_object_save(obj, folder, config=None):
|
|||
elif config is not None:
|
||||
_set_auto_map_in_config(config)
|
||||
|
||||
result = []
|
||||
# Copy module file to the output folder.
|
||||
object_file = sys.modules[obj.__module__].__file__
|
||||
dest_file = Path(folder) / (Path(object_file).name)
|
||||
shutil.copy(object_file, dest_file)
|
||||
result.append(dest_file)
|
||||
|
||||
# Gather all relative imports recursively and make sure they are copied as well.
|
||||
for needed_file in get_relative_import_files(object_file):
|
||||
dest_file = Path(folder) / (Path(needed_file).name)
|
||||
shutil.copy(needed_file, dest_file)
|
||||
result.append(dest_file)
|
||||
|
||||
return result
|
||||
|
|
|
@ -64,6 +64,10 @@ class ChannelDimension(ExplicitEnum):
|
|||
LAST = "channels_last"
|
||||
|
||||
|
||||
def is_pil_image(img):
|
||||
return is_vision_available() and isinstance(img, PIL.Image.Image)
|
||||
|
||||
|
||||
def is_valid_image(img):
|
||||
return (
|
||||
(is_vision_available() and isinstance(img, PIL.Image.Image))
|
||||
|
|
|
@ -148,6 +148,7 @@ _run_custom_tokenizers = parse_flag_from_env("RUN_CUSTOM_TOKENIZERS", default=Fa
|
|||
_run_staging = parse_flag_from_env("HUGGINGFACE_CO_STAGING", default=False)
|
||||
_tf_gpu_memory_limit = parse_int_from_env("TF_GPU_MEMORY_LIMIT", default=None)
|
||||
_run_pipeline_tests = parse_flag_from_env("RUN_PIPELINE_TESTS", default=True)
|
||||
_run_tool_tests = parse_flag_from_env("RUN_TOOL_TESTS", default=False)
|
||||
|
||||
|
||||
def is_pt_tf_cross_test(test_case):
|
||||
|
@ -221,6 +222,21 @@ def is_pipeline_test(test_case):
|
|||
return pytest.mark.is_pipeline_test()(test_case)
|
||||
|
||||
|
||||
def is_tool_test(test_case):
|
||||
"""
|
||||
Decorator marking a test as a tool test. If RUN_TOOL_TESTS is set to a falsy value, those tests will be skipped.
|
||||
"""
|
||||
if not _run_tool_tests:
|
||||
return unittest.skip("test is a tool test")(test_case)
|
||||
else:
|
||||
try:
|
||||
import pytest # We don't need a hard dependency on pytest in the main library
|
||||
except ImportError:
|
||||
return test_case
|
||||
else:
|
||||
return pytest.mark.is_tool_test()(test_case)
|
||||
|
||||
|
||||
def slow(test_case):
|
||||
"""
|
||||
Decorator marking a test as slow.
|
||||
|
|
|
@ -0,0 +1,73 @@
|
|||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ..utils import (
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
is_torch_available,
|
||||
)
|
||||
|
||||
|
||||
_import_structure = {
|
||||
"agents": ["Agent", "HfAgent", "OpenAiAgent"],
|
||||
"base": ["PipelineTool", "RemoteTool", "Tool", "launch_gradio_demo", "load_tool"],
|
||||
}
|
||||
|
||||
try:
|
||||
if not is_torch_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
_import_structure["document_question_answering"] = ["DocumentQuestionAnsweringTool"]
|
||||
_import_structure["image_captioning"] = ["ImageCaptioningTool"]
|
||||
_import_structure["image_question_answering"] = ["ImageQuestionAnsweringTool"]
|
||||
_import_structure["image_segmentation"] = ["ImageSegmentationTool"]
|
||||
_import_structure["language_identifier"] = ["LanguageIdentificationTool"]
|
||||
_import_structure["speech_to_text"] = ["SpeechToTextTool"]
|
||||
_import_structure["text_classification"] = ["TextClassificationTool"]
|
||||
_import_structure["text_question_answering"] = ["TextQuestionAnsweringTool"]
|
||||
_import_structure["text_summarization"] = ["TextSummarizationTool"]
|
||||
_import_structure["text_to_speech"] = ["TextToSpeechTool"]
|
||||
_import_structure["translation"] = ["TranslationTool"]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .agents import Agent, HfAgent, OpenAiAgent
|
||||
from .base import PipelineTool, RemoteTool, Tool, launch_gradio_demo, load_tool
|
||||
|
||||
try:
|
||||
if not is_torch_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
from .document_question_answering import DocumentQuestionAnsweringTool
|
||||
from .image_captioning import ImageCaptioningTool
|
||||
from .image_question_answering import ImageQuestionAnsweringTool
|
||||
from .image_segmentation import ImageSegmentationTool
|
||||
from .language_identifier import LanguageIdentificationTool
|
||||
from .speech_to_text import SpeechToTextTool
|
||||
from .text_classification import TextClassificationTool
|
||||
from .text_question_answering import TextQuestionAnsweringTool
|
||||
from .text_summarization import TextSummarizationTool
|
||||
from .text_to_speech import TextToSpeechTool
|
||||
from .translation import TranslationTool
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
|
|
@ -0,0 +1,489 @@
|
|||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import importlib.util
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
|
||||
import requests
|
||||
from huggingface_hub import HfFolder, hf_hub_download, list_spaces
|
||||
|
||||
from ..utils import is_openai_available, logging
|
||||
from .base import TASK_MAPPING, TOOL_CONFIG_FILE, Tool, load_tool, supports_remote
|
||||
from .prompts import CHAT_MESSAGE_PROMPT, CHAT_PROMPT_TEMPLATE, RUN_PROMPT_TEMPLATE
|
||||
from .python_interpreter import evaluate
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
if is_openai_available():
|
||||
import openai
|
||||
|
||||
_tools_are_initialized = False
|
||||
|
||||
|
||||
BASE_PYTHON_TOOLS = {
|
||||
"print": print,
|
||||
"float": float,
|
||||
"int": int,
|
||||
"bool": bool,
|
||||
"str": str,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class PreTool:
|
||||
task: str
|
||||
description: str
|
||||
repo_id: str
|
||||
|
||||
|
||||
HUGGINGFACE_DEFAULT_TOOLS = {}
|
||||
|
||||
|
||||
HUGGINGFACE_DEFAULT_TOOLS_FROM_HUB = [
|
||||
"image-transformation",
|
||||
"text-download",
|
||||
"text-to-image",
|
||||
"text-to-video",
|
||||
]
|
||||
|
||||
|
||||
def get_remote_tools(organization="huggingface-tools"):
|
||||
spaces = list_spaces(author=organization)
|
||||
tools = {}
|
||||
for space_info in spaces:
|
||||
repo_id = space_info.id
|
||||
resolved_config_file = hf_hub_download(repo_id, TOOL_CONFIG_FILE, repo_type="space")
|
||||
with open(resolved_config_file, encoding="utf-8") as reader:
|
||||
config = json.load(reader)
|
||||
|
||||
task = repo_id.split("/")[-1]
|
||||
tools[config["name"]] = PreTool(task=task, description=config["description"], repo_id=repo_id)
|
||||
|
||||
return tools
|
||||
|
||||
|
||||
def _setup_default_tools():
|
||||
global HUGGINGFACE_DEFAULT_TOOLS
|
||||
global _tools_are_initialized
|
||||
|
||||
if _tools_are_initialized:
|
||||
return
|
||||
|
||||
main_module = importlib.import_module("transformers")
|
||||
tools_module = main_module.tools
|
||||
|
||||
remote_tools = get_remote_tools()
|
||||
for task_name in TASK_MAPPING:
|
||||
tool_class_name = TASK_MAPPING.get(task_name)
|
||||
tool_class = getattr(tools_module, tool_class_name)
|
||||
description = tool_class.description
|
||||
HUGGINGFACE_DEFAULT_TOOLS[tool_class.name] = PreTool(task=task_name, description=description, repo_id=None)
|
||||
|
||||
for task_name in HUGGINGFACE_DEFAULT_TOOLS_FROM_HUB:
|
||||
found = False
|
||||
for tool_name, tool in remote_tools.items():
|
||||
if tool.task == task_name:
|
||||
HUGGINGFACE_DEFAULT_TOOLS[tool_name] = tool
|
||||
found = True
|
||||
break
|
||||
|
||||
if not found:
|
||||
raise ValueError(f"{task_name} is not implemented on the Hub.")
|
||||
|
||||
_tools_are_initialized = True
|
||||
|
||||
|
||||
def resolve_tools(code, toolbox, remote=False, cached_tools=None):
|
||||
if cached_tools is None:
|
||||
resolved_tools = BASE_PYTHON_TOOLS.copy()
|
||||
else:
|
||||
resolved_tools = cached_tools
|
||||
for name, tool in toolbox.items():
|
||||
if name not in code or name in resolved_tools:
|
||||
continue
|
||||
|
||||
if isinstance(tool, Tool):
|
||||
resolved_tools[name] = tool
|
||||
else:
|
||||
task_or_repo_id = tool.task if tool.repo_id is None else tool.repo_id
|
||||
_remote = remote and supports_remote(task_or_repo_id)
|
||||
resolved_tools[name] = load_tool(task_or_repo_id, remote=_remote)
|
||||
|
||||
return resolved_tools
|
||||
|
||||
|
||||
def get_tool_creation_code(code, toolbox, remote=False):
|
||||
code_lines = ["from transformers import load_tool", ""]
|
||||
for name, tool in toolbox.items():
|
||||
if name not in code or isinstance(tool, Tool):
|
||||
continue
|
||||
|
||||
task_or_repo_id = tool.task if tool.repo_id is None else tool.repo_id
|
||||
line = f'{name} = load_tool("{task_or_repo_id}"'
|
||||
if remote:
|
||||
line += ", remote=True"
|
||||
line += ")"
|
||||
code_lines.append(line)
|
||||
|
||||
return "\n".join(code_lines) + "\n"
|
||||
|
||||
|
||||
def clean_code_for_chat(result):
|
||||
lines = result.split("\n")
|
||||
idx = 0
|
||||
while idx < len(lines) and not lines[idx].lstrip().startswith("```"):
|
||||
idx += 1
|
||||
explanation = "\n".join(lines[:idx]).strip()
|
||||
if idx == len(lines):
|
||||
return explanation, None
|
||||
|
||||
idx += 1
|
||||
start_idx = idx
|
||||
while not lines[idx].lstrip().startswith("```"):
|
||||
idx += 1
|
||||
code = "\n".join(lines[start_idx:idx]).strip()
|
||||
|
||||
return explanation, code
|
||||
|
||||
|
||||
def clean_code_for_run(result):
|
||||
result = f"I will use the following {result}"
|
||||
explanation, code = result.split("Answer:")
|
||||
explanation = explanation.strip()
|
||||
code = code.strip()
|
||||
|
||||
code_lines = code.split("\n")
|
||||
if code_lines[0] in ["```", "```py", "```python"]:
|
||||
code_lines = code_lines[1:]
|
||||
if code_lines[-1] == "```":
|
||||
code_lines = code_lines[:-1]
|
||||
code = "\n".join(code_lines)
|
||||
|
||||
return explanation, code
|
||||
|
||||
|
||||
class Agent:
|
||||
"""
|
||||
Base class for all agents which contains the main API methods.
|
||||
|
||||
Args:
|
||||
chat_prompt_template (`str`, *optional*):
|
||||
Pass along your own prompt if you want to override the default template for the `chat` method.
|
||||
run_prompt_template (`str`, *optional*):
|
||||
Pass along your own prompt if you want to override the default template for the `run` method.
|
||||
additional_tools ([`Tool`], list of tools or dictionary with tool values, *optional*):
|
||||
Any additional tools to include on top of the default ones. If you pass along a tool with the same name as
|
||||
one of the default tools, that default tool will be overridden.
|
||||
"""
|
||||
|
||||
def __init__(self, chat_prompt_template=None, run_prompt_template=None, additional_tools=None):
|
||||
_setup_default_tools()
|
||||
|
||||
self.chat_prompt_template = CHAT_MESSAGE_PROMPT if chat_prompt_template is None else chat_prompt_template
|
||||
self.run_prompt_template = RUN_PROMPT_TEMPLATE if run_prompt_template is None else run_prompt_template
|
||||
self.toolbox = HUGGINGFACE_DEFAULT_TOOLS.copy()
|
||||
if additional_tools is not None:
|
||||
if isinstance(additional_tools, (list, tuple)):
|
||||
additional_tools = {t.name: t for t in additional_tools}
|
||||
elif not isinstance(additional_tools, dict):
|
||||
additional_tools = {additional_tools.name: additional_tools}
|
||||
|
||||
replacements = {name: tool for name, tool in additional_tools.items() if name in HUGGINGFACE_DEFAULT_TOOLS}
|
||||
self.toolbox.update(additional_tools)
|
||||
if len(replacements) > 1:
|
||||
names = "\n".join([f"- {n}: {t}" for n, t in replacements.items()])
|
||||
logger.warn(
|
||||
f"The following tools have been replaced by the ones provided in `additional_tools`:\n{names}."
|
||||
)
|
||||
elif len(replacements) == 1:
|
||||
name = list(replacements.keys())[0]
|
||||
logger.warn(f"{name} has been replaced by {replacements[name]} as provided in `additional_tools`.")
|
||||
|
||||
self.prepare_for_new_chat()
|
||||
|
||||
def format_prompt(self, task, chat_mode=False):
|
||||
description = "\n".join([f"- {name}: {tool.description}" for name, tool in self.toolbox.items()])
|
||||
if chat_mode:
|
||||
if self.chat_history is None:
|
||||
prompt = CHAT_PROMPT_TEMPLATE.replace("<<all_tools>>", description)
|
||||
else:
|
||||
prompt = self.chat_history
|
||||
prompt += CHAT_MESSAGE_PROMPT.replace("<<task>>", task)
|
||||
else:
|
||||
prompt = self.run_prompt_template.replace("<<all_tools>>", description)
|
||||
prompt = prompt.replace("<<prompt>>", task)
|
||||
return prompt
|
||||
|
||||
def chat(self, task, *, return_code=False, remote=False, **kwargs):
|
||||
"""
|
||||
Sends a new request to the agent in a chat. Will use the previous ones in its history.
|
||||
|
||||
Args:
|
||||
task (`str`): The task to perform
|
||||
return_code (`bool`, *optional*, defaults to `False`):
|
||||
Whether to just return code and not evaluate it.
|
||||
remote (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to use remote tools (inference endpoints) instead of local ones.
|
||||
kwargs:
|
||||
Any keyword argument to send to the agent when evaluating the code.
|
||||
|
||||
Example:
|
||||
|
||||
```py
|
||||
from transformers import HfAgent
|
||||
|
||||
agent = HfAgent("https://api-inference.huggingface.co/models/bigcode/starcoder")
|
||||
agent.chat("Draw me a picture of rivers and lakes")
|
||||
|
||||
agent.chat("Transform the picture so that there is a rock in there")
|
||||
```
|
||||
"""
|
||||
prompt = self.format_prompt(task, chat_mode=True)
|
||||
result = self.generate_one(prompt, stop=["Human:", "====="])
|
||||
self.chat_history = prompt + result + "\n"
|
||||
explanation, code = clean_code_for_chat(result)
|
||||
|
||||
print(f"==Explanation from the agent==\n{explanation}")
|
||||
|
||||
if code is not None:
|
||||
print(f"\n\n==Code generated by the agent==\n{code}")
|
||||
if not return_code:
|
||||
print("\n\n==Result==")
|
||||
self.cached_tools = resolve_tools(code, self.toolbox, remote=remote, cached_tools=self.cached_tools)
|
||||
self.chat_state.update(kwargs)
|
||||
return evaluate(code, self.cached_tools, self.chat_state, chat_mode=True)
|
||||
else:
|
||||
tool_code = get_tool_creation_code(code, self.toolbox, remote=remote)
|
||||
return f"{tool_code}\n{code}"
|
||||
|
||||
def prepare_for_new_chat(self):
|
||||
"""
|
||||
Clears the history of prior calls to [`~Agent.chat`].
|
||||
"""
|
||||
self.chat_history = None
|
||||
self.chat_state = {}
|
||||
self.cached_tools = None
|
||||
|
||||
def run(self, task, *, return_code=False, remote=False, **kwargs):
|
||||
"""
|
||||
Sends a request to the agent.
|
||||
|
||||
Args:
|
||||
task (`str`): The task to perform
|
||||
return_code (`bool`, *optional*, defaults to `False`):
|
||||
Whether to just return code and not evaluate it.
|
||||
remote (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to use remote tools (inference endpoints) instead of local ones.
|
||||
kwargs:
|
||||
Any keyword argument to send to the agent when evaluating the code.
|
||||
|
||||
Example:
|
||||
|
||||
```py
|
||||
from transformers import HfAgent
|
||||
|
||||
agent = HfAgent("https://api-inference.huggingface.co/models/bigcode/starcoder")
|
||||
agent.run("Draw me a picture of rivers and lakes")
|
||||
```
|
||||
"""
|
||||
prompt = self.format_prompt(task)
|
||||
result = self.generate_one(prompt, stop=["Task:"])
|
||||
explanation, code = clean_code_for_run(result)
|
||||
|
||||
print(f"==Explanation from the agent==\n{explanation}")
|
||||
|
||||
print(f"\n\n==Code generated by the agent==\n{code}")
|
||||
if not return_code:
|
||||
print("\n\n==Result==")
|
||||
self.cached_tools = resolve_tools(code, self.toolbox, remote=remote, cached_tools=self.cached_tools)
|
||||
return evaluate(code, self.cached_tools, state=kwargs.copy())
|
||||
else:
|
||||
tool_code = get_tool_creation_code(code, self.toolbox, remote=remote)
|
||||
return f"{tool_code}\n{code}"
|
||||
|
||||
def generate_one(self, prompt, stop):
|
||||
# This is the method to implement in your custom agent.
|
||||
raise NotImplementedError
|
||||
|
||||
def generate_many(self, prompts, stop):
|
||||
# Override if you have a way to do batch generation faster than one by one
|
||||
return [self.generate_one(prompt, stop) for prompt in prompts]
|
||||
|
||||
|
||||
class OpenAiAgent(Agent):
|
||||
"""
|
||||
Agent that uses the openai API to generate code.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
The openAI models are used in generation mode, so even for the `chat()` API, it's better to use models like
|
||||
`"text-davinci-003"` over the chat-GPT variant. Proper support for chat-GPT models will come in a next version.
|
||||
|
||||
</Tip>
|
||||
|
||||
Args:
|
||||
model (`str`, *optional*, defaults to `"text-davinci-003"`):
|
||||
The name of the OpenAI model to use.
|
||||
api_key (`str`, *optional*):
|
||||
The API key to use. If unset, will look for the environment variable `"OPENAI_API_KEY"`.
|
||||
chat_prompt_template (`str`, *optional*):
|
||||
Pass along your own prompt if you want to override the default template for the `chat` method.
|
||||
run_prompt_template (`str`, *optional*):
|
||||
Pass along your own prompt if you want to override the default template for the `run` method.
|
||||
additional_tools ([`Tool`], list of tools or dictionary with tool values, *optional*):
|
||||
Any additional tools to include on top of the default ones. If you pass along a tool with the same name as
|
||||
one of the default tools, that default tool will be overridden.
|
||||
|
||||
Example:
|
||||
|
||||
```py
|
||||
from transformers import OpenAiAgent
|
||||
|
||||
agent = OpenAiAgent(model="text-davinci-003", api_key=xxx)
|
||||
agent.run("Is the following `text` (in Spanish) positive or negative?", text="¡Este es un API muy agradable!")
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model="text-davinci-003",
|
||||
api_key=None,
|
||||
chat_prompt_template=None,
|
||||
run_prompt_template=None,
|
||||
additional_tools=None,
|
||||
):
|
||||
if not is_openai_available():
|
||||
raise ImportError("Using `OpenAiAgent` requires `openai`: `pip install openai`.")
|
||||
|
||||
if api_key is None:
|
||||
api_key = os.environ.get("OPENAI_API_KEY", None)
|
||||
if api_key is None:
|
||||
raise ValueError(
|
||||
"You need an openai key to use `OpenAIAgent`. You can get one here: Get one here "
|
||||
"https://openai.com/api/`. If you have one, set it in your env with `os.environ['OPENAI_API_KEY'] = "
|
||||
"xxx."
|
||||
)
|
||||
else:
|
||||
openai.api_key = api_key
|
||||
self.model = model
|
||||
super().__init__(
|
||||
chat_prompt_template=chat_prompt_template,
|
||||
run_prompt_template=run_prompt_template,
|
||||
additional_tools=additional_tools,
|
||||
)
|
||||
|
||||
def generate_many(self, prompts, stop):
|
||||
if "gpt" in self.model:
|
||||
return [self._chat_generate(prompt, stop) for prompt in prompts]
|
||||
else:
|
||||
return self._completion_generate(prompts, stop)
|
||||
|
||||
def generate_one(self, prompt, stop):
|
||||
if "gpt" in self.model:
|
||||
return self._chat_generate(prompt, stop)
|
||||
else:
|
||||
return self._completion_generate([prompt], stop)[0]
|
||||
|
||||
def _chat_generate(self, prompt, stop):
|
||||
result = openai.ChatCompletion.create(
|
||||
model=self.model,
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
temperature=0,
|
||||
stop=stop,
|
||||
)
|
||||
return result["choices"][0]["message"]["content"]
|
||||
|
||||
def _completion_generate(self, prompts, stop):
|
||||
result = openai.Completion.create(
|
||||
model=self.model,
|
||||
prompt=prompts,
|
||||
temperature=0,
|
||||
stop=stop,
|
||||
max_tokens=200,
|
||||
)
|
||||
return [answer["text"] for answer in result["choices"]]
|
||||
|
||||
|
||||
class HfAgent(Agent):
|
||||
"""
|
||||
Agent that uses and inference endpoint to generate code.
|
||||
|
||||
Args:
|
||||
url_endpoint (`str`):
|
||||
The name of the url endpoint to use.
|
||||
token (`str`, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If unset, will use the token generated when
|
||||
running `huggingface-cli login` (stored in `~/.huggingface`).
|
||||
chat_prompt_template (`str`, *optional*):
|
||||
Pass along your own prompt if you want to override the default template for the `chat` method.
|
||||
run_prompt_template (`str`, *optional*):
|
||||
Pass along your own prompt if you want to override the default template for the `run` method.
|
||||
additional_tools ([`Tool`], list of tools or dictionary with tool values, *optional*):
|
||||
Any additional tools to include on top of the default ones. If you pass along a tool with the same name as
|
||||
one of the default tools, that default tool will be overridden.
|
||||
|
||||
Example:
|
||||
|
||||
```py
|
||||
from transformers import HfAgent
|
||||
|
||||
agent = HfAgent("https://api-inference.huggingface.co/models/bigcode/starcoder")
|
||||
agent.run("Is the following `text` (in Spanish) positive or negative?", text="¡Este es un API muy agradable!")
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, url_endpoint, token=None, chat_prompt_template=None, run_prompt_template=None, additional_tools=None
|
||||
):
|
||||
self.url_endpoint = url_endpoint
|
||||
if token is None:
|
||||
self.token = f"Bearer {HfFolder().get_token()}"
|
||||
elif token.startswith("Bearer") or token.startswith("Basic"):
|
||||
self.token = token
|
||||
else:
|
||||
self.token = f"Bearer {token}"
|
||||
super().__init__(
|
||||
chat_prompt_template=chat_prompt_template,
|
||||
run_prompt_template=run_prompt_template,
|
||||
additional_tools=additional_tools,
|
||||
)
|
||||
|
||||
def generate_one(self, prompt, stop):
|
||||
headers = {"Authorization": self.token}
|
||||
inputs = {
|
||||
"inputs": prompt,
|
||||
"parameters": {"max_new_tokens": 200, "return_full_text": False, "stop": stop},
|
||||
}
|
||||
|
||||
response = requests.post(self.url_endpoint, json=inputs, headers=headers)
|
||||
if response.status_code == 429:
|
||||
print("Getting rate-limited, waiting a tiny bit before trying again.")
|
||||
time.sleep(1)
|
||||
return self._generate_one(prompt)
|
||||
elif response.status_code != 200:
|
||||
raise ValueError(f"Error {response.status_code}: {response.json()}")
|
||||
|
||||
result = response.json()[0]["generated_text"]
|
||||
# Inference API returns the stop sequence
|
||||
for stop_seq in stop:
|
||||
if result.endswith(stop_seq):
|
||||
result = result[: -len(stop_seq)]
|
||||
return result
|
|
@ -0,0 +1,722 @@
|
|||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import base64
|
||||
import importlib
|
||||
import inspect
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from huggingface_hub import CommitOperationAdd, HfFolder, create_commit, create_repo, hf_hub_download
|
||||
from huggingface_hub.utils import RepositoryNotFoundError, get_session
|
||||
|
||||
from ..dynamic_module_utils import custom_object_save, get_class_from_dynamic_module, get_imports
|
||||
from ..image_utils import is_pil_image
|
||||
from ..models.auto import AutoProcessor
|
||||
from ..utils import (
|
||||
CONFIG_NAME,
|
||||
cached_file,
|
||||
is_accelerate_available,
|
||||
is_torch_available,
|
||||
is_vision_available,
|
||||
logging,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate.utils import send_to_device
|
||||
|
||||
|
||||
TOOL_CONFIG_FILE = "tool_config.json"
|
||||
|
||||
|
||||
def get_repo_type(repo_id, repo_type=None, **hub_kwargs):
|
||||
if repo_type is not None:
|
||||
return repo_type
|
||||
try:
|
||||
hf_hub_download(repo_id, TOOL_CONFIG_FILE, repo_type="space", **hub_kwargs)
|
||||
return "space"
|
||||
except RepositoryNotFoundError:
|
||||
try:
|
||||
hf_hub_download(repo_id, TOOL_CONFIG_FILE, repo_type="model", **hub_kwargs)
|
||||
return "model"
|
||||
except RepositoryNotFoundError:
|
||||
raise EnvironmentError(f"`{repo_id}` does not seem to be a valid repo identifier on the Hub.")
|
||||
except Exception:
|
||||
return "model"
|
||||
except Exception:
|
||||
return "space"
|
||||
|
||||
|
||||
# docstyle-ignore
|
||||
APP_FILE_TEMPLATE = """from transformers import launch_gradio_demo
|
||||
from {module_name} import {class_name}
|
||||
|
||||
launch_gradio_demo({class_name})
|
||||
"""
|
||||
|
||||
|
||||
class Tool:
|
||||
"""
|
||||
A base class for the functions used by the agent. Subclass this and implement the `__call__` method as well as the
|
||||
following class attributes:
|
||||
|
||||
- **description** (`str`) -- A short description of what your tool does, the inputs it expects and the output(s) it
|
||||
will return. For instance 'This is a tool that downloads a file from a `url`. It takes the `url` as input, and
|
||||
returns the text contained in the file'.
|
||||
- **name** (`str`) -- A performative name that will be used for your tool in the prompt to the agent. For instance
|
||||
`"text-classifier"` or `"image_generator"`.
|
||||
- **inputs** (`List[str]`) -- The list of modalities expected for the inputs (in the same order as in the call).
|
||||
Modalitiies should be `"text"`, `"image"` or `"audio"`. This is only used by `launch_gradio_demo` or to make a
|
||||
nice space from your tool.
|
||||
- **outputs** (`List[str]`) -- The list of modalities returned but the tool (in the same order as the return of the
|
||||
call method). Modalitiies should be `"text"`, `"image"` or `"audio"`. This is only used by `launch_gradio_demo`
|
||||
or to make a nice space from your tool.
|
||||
|
||||
You can also override the method [`~Tool.setup`] if your tool as an expensive operation to perform before being
|
||||
usable (such as loading a model). [`~Tool.setup`] will be called the first time you use your tool, but not at
|
||||
instantiation.
|
||||
"""
|
||||
|
||||
description: str = "This is a tool that ..."
|
||||
name: str = ""
|
||||
|
||||
inputs: List[str]
|
||||
outputs: List[str]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.is_initialized = False
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return NotImplemented("Write this method in your subclass of `Tool`.")
|
||||
|
||||
def setup(self):
|
||||
"""
|
||||
Overwrite this method here for any operation that is expensive and needs to be executed before you start using
|
||||
your tool. Such as loading a big model.
|
||||
"""
|
||||
self.is_initialized = True
|
||||
|
||||
def save(self, output_dir):
|
||||
"""
|
||||
Saves the relevant code files for your tool so it can be pushed to the Hub. This will copy the code of your
|
||||
tool in `output_dir` as well as autogenerate:
|
||||
|
||||
- a config file named `tool_config.json`
|
||||
- an `app.py` file so that your tool can be converted to a space
|
||||
- a `requirements.txt` containing the names of the module used by your tool (as detected when inspecting its
|
||||
code)
|
||||
|
||||
You should only use this method to save tools that are defined in a separate module (not `__main__`).
|
||||
|
||||
Args:
|
||||
output_dir (`str`): The folder in which you want to save your tool.
|
||||
"""
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
# Save module file
|
||||
if self.__module__ == "__main__":
|
||||
raise ValueError(
|
||||
f"We can't save the code defining {self} in {output_dir} as it's been defined in __main__. You "
|
||||
"have to put this code in a separate module so we can include it in the saved folder."
|
||||
)
|
||||
module_files = custom_object_save(self, output_dir)
|
||||
|
||||
module_name = self.__class__.__module__
|
||||
last_module = module_name.split(".")[-1]
|
||||
full_name = f"{last_module}.{self.__class__.__name__}"
|
||||
|
||||
# Save config file
|
||||
config_file = os.path.join(output_dir, "tool_config.json")
|
||||
if os.path.isfile(config_file):
|
||||
with open(config_file, "r", encoding="utf-8") as f:
|
||||
tool_config = json.load(f)
|
||||
else:
|
||||
tool_config = {}
|
||||
|
||||
tool_config = {"tool_class": full_name, "description": self.description, "name": self.name}
|
||||
with open(config_file, "w", encoding="utf-8") as f:
|
||||
f.write(json.dumps(tool_config, indent=2, sort_keys=True) + "\n")
|
||||
|
||||
# Save app file
|
||||
app_file = os.path.join(output_dir, "app.py")
|
||||
with open(app_file, "w", encoding="utf-8") as f:
|
||||
f.write(APP_FILE_TEMPLATE.format(module_name=last_module, class_name=self.__class__.__name__))
|
||||
|
||||
# Save requirements file
|
||||
requirements_file = os.path.join(output_dir, "requirements.txt")
|
||||
imports = []
|
||||
for module in module_files:
|
||||
imports.extend(get_imports(module))
|
||||
imports = list(set(imports))
|
||||
with open(requirements_file, "w", encoding="utf-8") as f:
|
||||
f.write("\n".join(imports) + "\n")
|
||||
|
||||
@classmethod
|
||||
def from_hub(cls, repo_id, model_repo_id=None, token=None, remote=False, **kwargs):
|
||||
"""
|
||||
Loads a tool defined on the Hub.
|
||||
|
||||
Args:
|
||||
repo_id (`str`):
|
||||
The name of the repo on the Hub where your tool is defined.
|
||||
model_repo_id (`str`, *optional*):
|
||||
If your tool uses a model and you want to use a different model than the default, you can pass a second
|
||||
repo ID or an endpoint url to this argument.
|
||||
token (`str`, *optional*):
|
||||
The token to identify you on hf.co. If unset, will use the token generated when running
|
||||
`huggingface-cli login` (stored in `~/.huggingface`).
|
||||
remote (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use your tool by downloading the model or (if it is available) with an inference endpoint.
|
||||
kwargs:
|
||||
Additional keyword arguments that will be split in two: all arguments relevant to the Hub (such as
|
||||
`cache_dir`, `revision`, `subfolder`) will be used when downloading the files for your tool, and the
|
||||
others will be passed along to its init.
|
||||
"""
|
||||
if remote and model_repo_id is None:
|
||||
endpoints = get_default_endpoints()
|
||||
if repo_id not in endpoints:
|
||||
raise ValueError(
|
||||
f"Could not infer a default endpoint for {repo_id}, you need to pass one using the "
|
||||
"`model_repo_id` argument."
|
||||
)
|
||||
model_repo_id = endpoints[repo_id]
|
||||
hub_kwargs_names = [
|
||||
"cache_dir",
|
||||
"force_download",
|
||||
"resume_download",
|
||||
"proxies",
|
||||
"revision",
|
||||
"repo_type",
|
||||
"subfolder",
|
||||
"local_files_only",
|
||||
]
|
||||
hub_kwargs = {k: v for k, v in kwargs.items() if k in hub_kwargs_names}
|
||||
|
||||
# Try to get the tool config first.
|
||||
hub_kwargs["repo_type"] = get_repo_type(repo_id, **hub_kwargs)
|
||||
resolved_config_file = cached_file(
|
||||
repo_id,
|
||||
TOOL_CONFIG_FILE,
|
||||
use_auth_token=token,
|
||||
**hub_kwargs,
|
||||
_raise_exceptions_for_missing_entries=False,
|
||||
_raise_exceptions_for_connection_errors=False,
|
||||
)
|
||||
is_tool_config = resolved_config_file is not None
|
||||
if resolved_config_file is None:
|
||||
resolved_config_file = cached_file(
|
||||
repo_id,
|
||||
CONFIG_NAME,
|
||||
use_auth_token=token,
|
||||
**hub_kwargs,
|
||||
_raise_exceptions_for_missing_entries=False,
|
||||
_raise_exceptions_for_connection_errors=False,
|
||||
)
|
||||
if resolved_config_file is None:
|
||||
raise EnvironmentError(
|
||||
f"{repo_id} does not appear to provide a valid configuration in `tool_config.json` or `config.json`."
|
||||
)
|
||||
|
||||
with open(resolved_config_file, encoding="utf-8") as reader:
|
||||
config = json.load(reader)
|
||||
|
||||
if not is_tool_config:
|
||||
if "custom_tool" not in config:
|
||||
raise EnvironmentError(
|
||||
f"{repo_id} does not provide a mapping to custom tools in its configuration `config.json`."
|
||||
)
|
||||
custom_tool = config["custom_tool"]
|
||||
else:
|
||||
custom_tool = config
|
||||
|
||||
tool_class = custom_tool["tool_class"]
|
||||
tool_class = get_class_from_dynamic_module(tool_class, repo_id, use_auth_token=token, **hub_kwargs)
|
||||
|
||||
if remote:
|
||||
return RemoteTool(model_repo_id, token=token, tool_class=tool_class)
|
||||
return tool_class(model_repo_id, token=token, **kwargs)
|
||||
|
||||
def push_to_hub(
|
||||
self,
|
||||
repo_id: str,
|
||||
commit_message: str = "Upload tool",
|
||||
private: Optional[bool] = None,
|
||||
token: Optional[Union[bool, str]] = None,
|
||||
create_pr: bool = False,
|
||||
) -> str:
|
||||
"""
|
||||
Upload the tool to the Hub.
|
||||
|
||||
Parameters:
|
||||
repo_id (`str`):
|
||||
The name of the repository you want to push your tool to. It should contain your organization name when
|
||||
pushing to a given organization.
|
||||
commit_message (`str`, *optional*, defaults to `"Upload tool"`):
|
||||
Message to commit while pushing.
|
||||
private (`bool`, *optional*):
|
||||
Whether or not the repository created should be private.
|
||||
token (`bool` or `str`, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If unset, will use the token generated
|
||||
when running `huggingface-cli login` (stored in `~/.huggingface`).
|
||||
create_pr (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to create a PR with the uploaded files or directly commit.
|
||||
"""
|
||||
repo_url = create_repo(
|
||||
repo_id=repo_id, token=token, private=private, exist_ok=True, repo_type="space", space_sdk="gradio"
|
||||
)
|
||||
repo_id = repo_url.repo_id
|
||||
|
||||
with tempfile.TemporaryDirectory() as work_dir:
|
||||
# Save all files.
|
||||
self.save(work_dir)
|
||||
os.listdir(work_dir)
|
||||
operations = [
|
||||
CommitOperationAdd(path_or_fileobj=os.path.join(work_dir, f), path_in_repo=f)
|
||||
for f in os.listdir(work_dir)
|
||||
]
|
||||
logger.info(f"Uploading the following files to {repo_id}: {','.join(os.listdir(work_dir))}")
|
||||
return create_commit(
|
||||
repo_id=repo_id,
|
||||
operations=operations,
|
||||
commit_message=commit_message,
|
||||
token=token,
|
||||
create_pr=create_pr,
|
||||
repo_type="space",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_gradio(gradio_tool):
|
||||
"""
|
||||
Creates a [`Tool`] from a gradio tool.
|
||||
"""
|
||||
|
||||
class GradioToolWrapper(Tool):
|
||||
def __init__(self, _gradio_tool):
|
||||
super().__init__()
|
||||
self.name = _gradio_tool.name
|
||||
self.description = _gradio_tool.description
|
||||
|
||||
GradioToolWrapper.__call__ = gradio_tool.run
|
||||
return GradioToolWrapper(gradio_tool)
|
||||
|
||||
|
||||
class RemoteTool(Tool):
|
||||
"""
|
||||
A [`Tool`] that will make requests to an inference endpoint.
|
||||
|
||||
Args:
|
||||
endpoint_url (`str`):
|
||||
The url of the endpoint to use.
|
||||
token (`str`, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If unset, will use the token generated when
|
||||
running `huggingface-cli login` (stored in `~/.huggingface`).
|
||||
tool_class (`type`, *optional*):
|
||||
The corresponding `tool_class` if this is a remote version of an existing tool. Will help determine when
|
||||
the output should be converted to another type (like images).
|
||||
"""
|
||||
|
||||
def __init__(self, endpoint_url=None, token=None, tool_class=None):
|
||||
self.endpoint_url = endpoint_url
|
||||
self.client = EndpointClient(endpoint_url, token=token)
|
||||
self.tool_class = tool_class
|
||||
|
||||
def prepare_inputs(self, *args, **kwargs):
|
||||
"""
|
||||
Prepare the inputs received for the HTTP client sending data to the endpoint. Positional arguments will be
|
||||
matched with the signature of the `tool_class` if it was provided at instantation. Images will be encoded into
|
||||
bytes.
|
||||
|
||||
You can override this method in your custom class of [`RemoteTool`].
|
||||
"""
|
||||
inputs = kwargs.copy()
|
||||
if len(args) > 0:
|
||||
if self.tool_class is not None:
|
||||
# Match args with the signature
|
||||
if issubclass(self.tool_class, PipelineTool):
|
||||
call_method = self.tool_class.encode
|
||||
else:
|
||||
call_method = self.tool_class.__call__
|
||||
signature = inspect.signature(call_method).parameters
|
||||
parameters = [
|
||||
k
|
||||
for k, p in signature.items()
|
||||
if p.kind not in [inspect._ParameterKind.VAR_POSITIONAL, inspect._ParameterKind.VAR_KEYWORD]
|
||||
]
|
||||
if parameters[0] == "self":
|
||||
parameters = parameters[1:]
|
||||
if len(args) > len(parameters):
|
||||
raise ValueError(
|
||||
f"{self.tool_class} only accepts {len(parameters)} arguments but {len(args)} were given."
|
||||
)
|
||||
for arg, name in zip(args, parameters):
|
||||
inputs[name] = arg
|
||||
elif len(args) > 1:
|
||||
raise ValueError("A `RemoteTool` can only accept one positional input.")
|
||||
elif len(args) == 1:
|
||||
if is_pil_image(args[0]):
|
||||
return {"inputs": self.client.encode_image(args[0])}
|
||||
return {"inputs": args[0]}
|
||||
|
||||
for key, value in inputs.items():
|
||||
if is_pil_image(value):
|
||||
inputs[key] = self.client.encode_image(value)
|
||||
|
||||
return {"inputs": inputs}
|
||||
|
||||
def extract_outputs(self, outputs):
|
||||
"""
|
||||
You can override this method in your custom class of [`RemoteTool`] to apply some custom post-processing of the
|
||||
outputs of the endpoint.
|
||||
"""
|
||||
return outputs
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
output_image = self.tool_class is not None and self.tool_class.outputs == ["image"]
|
||||
inputs = self.prepare_inputs(*args, **kwargs)
|
||||
if isinstance(inputs, dict):
|
||||
outputs = self.client(**inputs, output_image=output_image)
|
||||
else:
|
||||
outputs = self.client(inputs, output_image=output_image)
|
||||
if isinstance(outputs, list) and len(outputs) == 1 and isinstance(outputs[0], list):
|
||||
outputs = outputs[0]
|
||||
return self.extract_outputs(outputs)
|
||||
|
||||
|
||||
class PipelineTool(Tool):
|
||||
"""
|
||||
A [`Tool`] tailored towards Transformer models. On top of the class attributes of the base class [`Tool`], you will
|
||||
need to specify:
|
||||
|
||||
- **model_class** (`type`) -- The class to use to load the model in this tool.
|
||||
- **default_checkpoint** (`str`) -- The default checkpoint that should be used when the user doesn't specify one.
|
||||
- **pre_processor_class** (`type`, *optional*, defaults to [`AutoProcessor`]) -- The class to use to load the
|
||||
pre-processor
|
||||
- **post_processor_class** (`type`, *optional*, defaults to [`AutoProcessor`]) -- The class to use to load the
|
||||
post-processor (when different from the pre-processor).
|
||||
|
||||
Args:
|
||||
model (`str` or [`PreTrainedModel`], *optional*):
|
||||
The name of the checkpoint to use for the model, or the instantiated model. If unset, will default to the
|
||||
value of the class attribute `default_checkpoint`.
|
||||
pre_processor (`str` or `Any`, *optional*):
|
||||
The name of the checkpoint to use for the pre-processor, or the instantiated pre-processor (can be a
|
||||
tokenizer, an image processor, a feature extractor or a processor). Will default to the value of `model` if
|
||||
unset.
|
||||
post_processor (`str` or `Any`, *optional*):
|
||||
The name of the checkpoint to use for the post-processor, or the instantiated pre-processor (can be a
|
||||
tokenizer, an image processor, a feature extractor or a processor). Will default to the `pre_processor` if
|
||||
unset.
|
||||
device (`int`, `str` or `torch.device`, *optional*):
|
||||
The device on which to execute the model. Will default to any accelerator available (GPU, MPS etc...), the
|
||||
CPU otherwise.
|
||||
device_map (`str` or `dict`, *optional*):
|
||||
If passed along, will be used to instantiate the model.
|
||||
model_kwargs (`dict`, *optional*):
|
||||
Any keyword argument to send to the model instantiation.
|
||||
token (`str`, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If unset, will use the token generated when
|
||||
running `huggingface-cli login` (stored in `~/.huggingface`).
|
||||
hub_kwargs:
|
||||
Any additional keyword argument to send to the methods that will load the data from the Hub.
|
||||
"""
|
||||
|
||||
pre_processor_class = AutoProcessor
|
||||
model_class = None
|
||||
post_processor_class = AutoProcessor
|
||||
default_checkpoint = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model=None,
|
||||
pre_processor=None,
|
||||
post_processor=None,
|
||||
device=None,
|
||||
device_map=None,
|
||||
model_kwargs=None,
|
||||
token=None,
|
||||
**hub_kwargs,
|
||||
):
|
||||
if not is_torch_available():
|
||||
raise ImportError("Please install torch in order to use this tool.")
|
||||
|
||||
if not is_accelerate_available():
|
||||
raise ImportError("Please install accelerate in order to use this tool.")
|
||||
|
||||
if model is None:
|
||||
if self.default_checkpoint is None:
|
||||
raise ValueError("This tool does not implement a default checkpoint, you need to pass one.")
|
||||
model = self.default_checkpoint
|
||||
if pre_processor is None:
|
||||
pre_processor = model
|
||||
|
||||
self.model = model
|
||||
self.pre_processor = pre_processor
|
||||
self.post_processor = post_processor
|
||||
self.device = device
|
||||
self.device_map = device_map
|
||||
self.model_kwargs = {} if model_kwargs is None else model_kwargs
|
||||
if device_map is not None:
|
||||
self.model_kwargs["device_map"] = device_map
|
||||
self.hub_kwargs = hub_kwargs
|
||||
self.hub_kwargs["use_auth_token"] = token
|
||||
|
||||
self.is_initialized = False
|
||||
|
||||
def setup(self):
|
||||
"""
|
||||
Instantiates the `pre_processor`, `model` and `post_processor` if necessary.
|
||||
"""
|
||||
if isinstance(self.pre_processor, str):
|
||||
self.pre_processor = self.pre_processor_class.from_pretrained(self.pre_processor, **self.hub_kwargs)
|
||||
|
||||
if isinstance(self.model, str):
|
||||
self.model = self.model_class.from_pretrained(self.model, **self.model_kwargs, **self.hub_kwargs)
|
||||
|
||||
if self.post_processor is None:
|
||||
self.post_processor = self.pre_processor
|
||||
elif isinstance(self.post_processor, str):
|
||||
self.post_processor = self.post_processor_class.from_pretrained(self.post_processor, **self.hub_kwargs)
|
||||
|
||||
if self.device is None:
|
||||
if self.device_map is not None:
|
||||
self.device = list(self.model.hf_device_map.values())[0]
|
||||
else:
|
||||
self.device = get_default_device()
|
||||
|
||||
if self.device_map is None:
|
||||
self.model.to(self.device)
|
||||
|
||||
def encode(self, raw_inputs):
|
||||
"""
|
||||
Uses the `pre_processor` to prepare the inputs for the `model`.
|
||||
"""
|
||||
return self.pre_processor(raw_inputs)
|
||||
|
||||
def forward(self, inputs):
|
||||
"""
|
||||
Sends the inputs through the `model`.
|
||||
"""
|
||||
with torch.no_grad():
|
||||
return self.model(**inputs)
|
||||
|
||||
def decode(self, outputs):
|
||||
"""
|
||||
Uses the `post_processor` to decode the model output.
|
||||
"""
|
||||
return self.post_processor(outputs)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
if not self.is_initialized:
|
||||
self.setup()
|
||||
|
||||
encoded_inputs = self.encode(*args, **kwargs)
|
||||
encoded_inputs = send_to_device(encoded_inputs, self.device)
|
||||
outputs = self.forward(encoded_inputs)
|
||||
outputs = send_to_device(outputs, "cpu")
|
||||
return self.decode(outputs)
|
||||
|
||||
|
||||
def launch_gradio_demo(tool_class: Tool):
|
||||
"""
|
||||
Launches a gradio demo for a tool. The corresponding tool class needs to properly implement the class attributes
|
||||
`inputs` and `outputs`.
|
||||
|
||||
Args:
|
||||
tool_class (`type`): The class of the tool for which to launch the demo.
|
||||
"""
|
||||
try:
|
||||
import gradio as gr
|
||||
except ImportError:
|
||||
raise ImportError("Gradio should be installed in order to launch a gradio demo.")
|
||||
|
||||
tool = tool_class()
|
||||
|
||||
def fn(*args, **kwargs):
|
||||
return tool(*args, **kwargs)
|
||||
|
||||
gr.Interface(
|
||||
fn=fn,
|
||||
inputs=tool_class.inputs,
|
||||
outputs=tool_class.outputs,
|
||||
title=tool_class.__name__,
|
||||
article=tool.description,
|
||||
).launch()
|
||||
|
||||
|
||||
# TODO: Migrate to Accelerate for this once `PartialState.default_device` makes its way into a release.
|
||||
def get_default_device():
|
||||
if not is_torch_available():
|
||||
raise ImportError("Please install torch in order to use this tool.")
|
||||
|
||||
if torch.backends.mps.is_available() and torch.backends.mps.is_built():
|
||||
return torch.device("mps")
|
||||
elif torch.cuda.is_available():
|
||||
return torch.device("cuda")
|
||||
else:
|
||||
return torch.device("cpu")
|
||||
|
||||
|
||||
TASK_MAPPING = {
|
||||
"document-question-answering": "DocumentQuestionAnsweringTool",
|
||||
"image-captioning": "ImageCaptioningTool",
|
||||
"image-question-answering": "ImageQuestionAnsweringTool",
|
||||
"image-segmentation": "ImageSegmentationTool",
|
||||
"speech-to-text": "SpeechToTextTool",
|
||||
"summarization": "TextSummarizationTool",
|
||||
"text-classification": "TextClassificationTool",
|
||||
"text-question-answering": "TextQuestionAnsweringTool",
|
||||
"text-to-speech": "TextToSpeechTool",
|
||||
"translation": "TranslationTool",
|
||||
}
|
||||
|
||||
|
||||
def get_default_endpoints():
|
||||
endpoints_file = cached_file("huggingface-tools/default-endpoints", "default_endpoints.json", repo_type="dataset")
|
||||
with open(endpoints_file, "r", encoding="utf-8") as f:
|
||||
endpoints = json.load(f)
|
||||
return endpoints
|
||||
|
||||
|
||||
def supports_remote(task_or_repo_id):
|
||||
endpoints = get_default_endpoints()
|
||||
return task_or_repo_id in endpoints
|
||||
|
||||
|
||||
def load_tool(task_or_repo_id, model_repo_id=None, remote=False, token=None, **kwargs):
|
||||
"""
|
||||
Main function to quickly load a tool, be it on the Hub or in the Transformers library.
|
||||
|
||||
Args:
|
||||
task_or_repo_id (`str`):
|
||||
The task for which to load the tool or a repo ID of a tool on the Hub. Tasks implemented in Transformers
|
||||
are:
|
||||
|
||||
- `"document-question-answering"`
|
||||
- `"image-captioning"`
|
||||
- `"image-question-answering"`
|
||||
- `"image-segmentation"`
|
||||
- `"speech-to-text"`
|
||||
- `"summarization"`
|
||||
- `"text-classification"`
|
||||
- `"text-question-answering"`
|
||||
- `"text-to-speech"`
|
||||
- `"translation"`
|
||||
|
||||
model_repo_id (`str`, *optional*):
|
||||
Use this argument to use a different model than the default one for the tool you selected.
|
||||
remote (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use your tool by downloading the model or (if it is available) with an inference endpoint.
|
||||
token (`str`, *optional*):
|
||||
The token to identify you on hf.co. If unset, will use the token generated when running `huggingface-cli
|
||||
login` (stored in `~/.huggingface`).
|
||||
kwargs:
|
||||
Additional keyword arguments that will be split in two: all arguments relevant to the Hub (such as
|
||||
`cache_dir`, `revision`, `subfolder`) will be used when downloading the files for your tool, and the others
|
||||
will be passed along to its init.
|
||||
"""
|
||||
if task_or_repo_id in TASK_MAPPING:
|
||||
tool_class_name = TASK_MAPPING[task_or_repo_id]
|
||||
main_module = importlib.import_module("transformers")
|
||||
tools_module = main_module.tools
|
||||
tool_class = getattr(tools_module, tool_class_name)
|
||||
|
||||
if remote:
|
||||
if model_repo_id is None:
|
||||
endpoints = get_default_endpoints()
|
||||
if task_or_repo_id not in endpoints:
|
||||
raise ValueError(
|
||||
f"Could not infer a default endpoint for {task_or_repo_id}, you need to pass one using the "
|
||||
"`model_repo_id` argument."
|
||||
)
|
||||
model_repo_id = endpoints[task_or_repo_id]
|
||||
return RemoteTool(model_repo_id, token=token, tool_class=tool_class)
|
||||
else:
|
||||
return tool_class(model_repo_id, token=token, **kwargs)
|
||||
else:
|
||||
return Tool.from_hub(task_or_repo_id, model_repo_id=model_repo_id, token=token, remote=remote, **kwargs)
|
||||
|
||||
|
||||
def add_description(description):
|
||||
"""
|
||||
A decorator that adds a description to a function.
|
||||
"""
|
||||
|
||||
def inner(func):
|
||||
func.description = description
|
||||
func.name = func.__name__
|
||||
return func
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
## Will move to the Hub
|
||||
class EndpointClient:
|
||||
def __init__(self, endpoint_url: str, token: Optional[str] = None):
|
||||
if token is None:
|
||||
token = HfFolder().get_token()
|
||||
self.headers = {"authorization": f"Bearer {token}", "Content-Type": "application/json"}
|
||||
self.endpoint_url = endpoint_url
|
||||
|
||||
@staticmethod
|
||||
def encode_image(image):
|
||||
_bytes = io.BytesIO()
|
||||
image.save(_bytes, format="PNG")
|
||||
b64 = base64.b64encode(_bytes.getvalue())
|
||||
return b64.decode("utf-8")
|
||||
|
||||
@staticmethod
|
||||
def decode_image(raw_image):
|
||||
if not is_vision_available():
|
||||
raise ImportError(
|
||||
"This tool returned an image but Pillow is not installed. Please install it (`pip install Pillow`)."
|
||||
)
|
||||
|
||||
from PIL import Image
|
||||
|
||||
b64 = base64.b64decode(raw_image)
|
||||
_bytes = io.BytesIO(b64)
|
||||
return Image.open(_bytes)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: Optional[Union[str, Dict, List[str], List[List[str]]]] = None,
|
||||
params: Optional[Dict] = None,
|
||||
data: Optional[bytes] = None,
|
||||
output_image: bool = False,
|
||||
) -> Any:
|
||||
# Build payload
|
||||
payload = {}
|
||||
if inputs:
|
||||
payload["inputs"] = inputs
|
||||
if params:
|
||||
payload["parameters"] = params
|
||||
|
||||
# Make API call
|
||||
response = get_session().post(self.endpoint_url, headers=self.headers, json=payload, data=data)
|
||||
|
||||
# By default, parse the response for the user.
|
||||
if output_image:
|
||||
return self.decode_image(response.content)
|
||||
else:
|
||||
return response.json()
|
|
@ -0,0 +1,80 @@
|
|||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import re
|
||||
|
||||
from ..models.auto import AutoProcessor
|
||||
from ..models.vision_encoder_decoder import VisionEncoderDecoderModel
|
||||
from ..utils import is_vision_available
|
||||
from .base import PipelineTool
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class DocumentQuestionAnsweringTool(PipelineTool):
|
||||
default_checkpoint = "naver-clova-ix/donut-base-finetuned-docvqa"
|
||||
description = (
|
||||
"This is a tool that answers a question about an document (pdf). It takes an input named `document` which "
|
||||
"should be the document containing the information, as well as a `question` that is the question about the "
|
||||
"document. It returns a text that contains the answer to the question."
|
||||
)
|
||||
name = "document_qa"
|
||||
pre_processor_class = AutoProcessor
|
||||
model_class = VisionEncoderDecoderModel
|
||||
|
||||
inputs = ["image", "text"]
|
||||
outputs = ["text"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
if not is_vision_available():
|
||||
raise ValueError("Pillow must be installed to use the DocumentQuestionAnsweringTool.")
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def encode(self, image: "Image", question: str):
|
||||
task_prompt = "<s_docvqa><s_question>{user_input}</s_question><s_answer>"
|
||||
prompt = task_prompt.replace("{user_input}", question)
|
||||
decoder_input_ids = self.pre_processor.tokenizer(
|
||||
prompt, add_special_tokens=False, return_tensors="pt"
|
||||
).input_ids
|
||||
pixel_values = self.pre_processor(image, return_tensors="pt").pixel_values
|
||||
|
||||
return {"decoder_input_ids": decoder_input_ids, "pixel_values": pixel_values}
|
||||
|
||||
def forward(self, inputs):
|
||||
return self.model.generate(
|
||||
inputs["pixel_values"].to(self.device),
|
||||
decoder_input_ids=inputs["decoder_input_ids"].to(self.device),
|
||||
max_length=self.model.decoder.config.max_position_embeddings,
|
||||
early_stopping=True,
|
||||
pad_token_id=self.pre_processor.tokenizer.pad_token_id,
|
||||
eos_token_id=self.pre_processor.tokenizer.eos_token_id,
|
||||
use_cache=True,
|
||||
num_beams=1,
|
||||
bad_words_ids=[[self.pre_processor.tokenizer.unk_token_id]],
|
||||
return_dict_in_generate=True,
|
||||
).sequences
|
||||
|
||||
def decode(self, outputs):
|
||||
sequence = self.pre_processor.batch_decode(outputs)[0]
|
||||
sequence = sequence.replace(self.pre_processor.tokenizer.eos_token, "")
|
||||
sequence = sequence.replace(self.pre_processor.tokenizer.pad_token, "")
|
||||
sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
|
||||
sequence = self.pre_processor.token2json(sequence)
|
||||
|
||||
return sequence["answer"]
|
|
@ -0,0 +1,692 @@
|
|||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from .agents import BASE_PYTHON_TOOLS, clean_code_for_chat, clean_code_for_run
|
||||
from .python_interpreter import InterpretorError, evaluate
|
||||
|
||||
|
||||
### Fake tools for test
|
||||
def classifier(text, labels):
|
||||
return f"This is the classification of {text} along {labels}."
|
||||
|
||||
|
||||
def translator(text, src_lang, tgt_lang):
|
||||
return f"This is the translation of {text} from {src_lang} to {tgt_lang}."
|
||||
|
||||
|
||||
def speaker(text):
|
||||
return f"This is actually a sound reading {text}."
|
||||
|
||||
|
||||
def transcriber(audio):
|
||||
if "sound" not in audio:
|
||||
raise ValueError(f"`audio` ({audio}) is not a sound.")
|
||||
return f"This is the transcribed text from {audio}."
|
||||
|
||||
|
||||
def image_generator(prompt):
|
||||
return f"This is actually an image representing {prompt}."
|
||||
|
||||
|
||||
def image_captioner(image):
|
||||
if "image" not in image:
|
||||
raise ValueError(f"`image` ({image}) is not an image.")
|
||||
return f"This is a description of {image}."
|
||||
|
||||
|
||||
def image_transformer(image, prompt):
|
||||
if "image" not in image:
|
||||
raise ValueError(f"`image` ({image}) is not an image.")
|
||||
return f"This is a transformation of {image} according to {prompt}."
|
||||
|
||||
|
||||
def question_answerer(text, question):
|
||||
return f"This is the answer to {question} from {text}."
|
||||
|
||||
|
||||
def image_qa(image, question):
|
||||
if "image" not in image:
|
||||
raise ValueError(f"`image` ({image}) is not an image.")
|
||||
return f"This is the answer to {question} from {image}."
|
||||
|
||||
|
||||
def text_downloader(url):
|
||||
return f"This is the content of {url}."
|
||||
|
||||
|
||||
def summarizer(text):
|
||||
return f"This is a summary of {text}."
|
||||
|
||||
|
||||
def video_generator(prompt, seconds=2):
|
||||
return f"A video of {prompt}"
|
||||
|
||||
|
||||
def document_qa(image, question):
|
||||
return f"This is the answer to {question} from the document {image}."
|
||||
|
||||
|
||||
def image_segmenter(image, prompt):
|
||||
return f"This is the mask of {prompt} in {image}"
|
||||
|
||||
|
||||
TEST_TOOLS = {
|
||||
"text_classifier": classifier,
|
||||
"translator": translator,
|
||||
"text_reader": speaker,
|
||||
"summarizer": summarizer,
|
||||
"transcriber": transcriber,
|
||||
"image_generator": image_generator,
|
||||
"image_captioner": image_captioner,
|
||||
"image_transformer": image_transformer,
|
||||
"text_qa": question_answerer,
|
||||
"text_downloader": text_downloader,
|
||||
"image_qa": image_qa,
|
||||
"video_generator": video_generator,
|
||||
"document_qa": document_qa,
|
||||
"image_segmenter": image_segmenter,
|
||||
}
|
||||
|
||||
|
||||
class Problem:
|
||||
"""
|
||||
A class regrouping all the information to solve a problem on which we will evaluate agents.
|
||||
|
||||
Args:
|
||||
task (`str` ou `list[str]`):
|
||||
One or several descriptions of the task to perform. If a list, it should contain variations on the
|
||||
phrasing, but for the same task.
|
||||
inputs (`list[str]` or `dict[str, str]`):
|
||||
The inputs that will be fed to the tools. For this testing environment, only strings are accepted as
|
||||
values. Pass along a dictionary when you want to specify the values of each inputs, or just the list of
|
||||
inputs expected (the value used will be `<<input_name>>` in this case).
|
||||
answer (`str` or `list[str`]):
|
||||
The theoretical answer (or list of possible valid answers) to the problem, as code.
|
||||
"""
|
||||
|
||||
def __init__(self, task, inputs, answer):
|
||||
self.task = task
|
||||
self.inputs = inputs
|
||||
self.answer = answer
|
||||
|
||||
|
||||
### The list of problems the agent will be evaluated on.
|
||||
EVALUATION_TASKS = [
|
||||
Problem(
|
||||
task=[
|
||||
"Is the following `text` (in Spanish) positive or negative?",
|
||||
"Is the text in the variable `text` (in Spanish) positive or negative?",
|
||||
"Translate the following `text` from Spanish to English then tell me if its positive or negative.",
|
||||
],
|
||||
inputs=["text"],
|
||||
answer="""text_classifier(translator(text, src_lang="Spanish", tgt_lang="English"), labels=["positive", "negative"])""",
|
||||
),
|
||||
Problem(
|
||||
task=[
|
||||
"Tell me out loud what the `image` contains.",
|
||||
"Describe the following `image` out loud.",
|
||||
"Find what is in the picture stored in `image` then read it out loud.",
|
||||
],
|
||||
inputs=["image"],
|
||||
answer=[
|
||||
"text_reader(image_captioner(image))",
|
||||
"text_reader(image_qa(image, question='What is in the image?'))",
|
||||
],
|
||||
),
|
||||
Problem(
|
||||
task=[
|
||||
"Generate an image from the text given in `text_input`. Then transform it according to the text in `prompt`.",
|
||||
"Use the following `text_input` to generate an image, then transform it by using the text in `prompt`.",
|
||||
],
|
||||
inputs=["text_input", "prompt"],
|
||||
answer="image_transformer(image_generator(text_input), prompt)",
|
||||
),
|
||||
Problem(
|
||||
task=[
|
||||
"Download the content of `url`, summarize it then generate an image from its content.",
|
||||
"Use a summary of the web page at `url` to generate an image.",
|
||||
"Summarize the content of the web page at `url`, and use the result to generate an image.",
|
||||
],
|
||||
inputs=["url"],
|
||||
answer="image_generator(summarizer(text_downloader(url)))",
|
||||
),
|
||||
Problem(
|
||||
task=[
|
||||
"Transform the following `image` using the prompt in `text`. The prompt is in Spanish.",
|
||||
"Use the text prompt in `text` (in Spanish) to transform the following `image`.",
|
||||
"Translate the `text` from Spanish to English then use it to transform the picture in `image`.",
|
||||
],
|
||||
inputs=["text", "image"],
|
||||
answer="image_transformer(image, translator(text, src_lang='Spanish', tgt_lang='English'))",
|
||||
),
|
||||
Problem(
|
||||
task=[
|
||||
"Download the content of `url`, summarize it then read it out loud to me.",
|
||||
"Read me a summary of the web page at `url`.",
|
||||
],
|
||||
inputs=["url"],
|
||||
answer="text_reader(summarizer(text_downloader(url)))",
|
||||
),
|
||||
Problem(
|
||||
task=[
|
||||
"Generate an image from the text given in `text_input`.",
|
||||
],
|
||||
inputs=["text_input"],
|
||||
answer="image_generator(text_input)",
|
||||
),
|
||||
Problem(
|
||||
task=[
|
||||
"Replace the beaver in the `image` by the `prompt`.",
|
||||
"Transform the `image` so that it contains the `prompt`.",
|
||||
"Use `prompt` to transform this `image`.",
|
||||
],
|
||||
inputs=["image", "prompt"],
|
||||
answer="image_transformer(image, prompt)",
|
||||
),
|
||||
Problem(
|
||||
task=[
|
||||
"Provide me the summary of the `text`, then read it to me before transcribing it and translating it in French.",
|
||||
"Summarize `text`, read it out loud then transcribe the audio and translate it in French.",
|
||||
"Read me a summary of the the `text` out loud. Transcribe this and translate it in French.",
|
||||
],
|
||||
inputs=["text"],
|
||||
answer="translator(transcriber(text_reader(summarizer(text))), src_lang='English', tgt_lang='French')",
|
||||
),
|
||||
Problem(
|
||||
task=["Generate a video of the `prompt`", "Animate a `prompt`", "Make me a short video using `prompt`."],
|
||||
inputs={"prompt": "A lobster swimming"},
|
||||
answer="video_generator('A lobster swimming')",
|
||||
),
|
||||
Problem(
|
||||
task=[
|
||||
"Download the following file `url`, summarize it in a few words and generate a video from it."
|
||||
"Fetch the file at this `url`, summarize it, and create an animation out of it."
|
||||
],
|
||||
inputs=["url"],
|
||||
answer="video_generator(summarizer(text_downloader(url)))",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
EVALUATION_CHATS = [
|
||||
[
|
||||
Problem(
|
||||
task=[
|
||||
"Translate the following `text` from Spanish to English.",
|
||||
"Translate the following `text` from Spanish to English.",
|
||||
],
|
||||
inputs=["text"],
|
||||
answer="translated_text=translator(text, src_lang='Spanish', tgt_lang='English')",
|
||||
),
|
||||
Problem(
|
||||
task=[
|
||||
"Is it positive or negative?",
|
||||
"Tell me if its positive or negative.",
|
||||
],
|
||||
inputs=[],
|
||||
answer="text_classifier(translated_text, labels=['positive', 'negative'])",
|
||||
),
|
||||
],
|
||||
[
|
||||
Problem(
|
||||
task=[
|
||||
"What does this `image` contain?",
|
||||
"Describe the following `image`.",
|
||||
"Find what is in the picture stored in `image`",
|
||||
],
|
||||
inputs=["image"],
|
||||
answer=[
|
||||
"description=image_captioner(image)",
|
||||
"description=image_qa(image, question='What is in the image?')",
|
||||
],
|
||||
),
|
||||
Problem(
|
||||
task=["Now, read the description out loud.", "Great! Can you read it out loud?", "Read it out loud."],
|
||||
inputs=[],
|
||||
answer=["audio=text_reader(description)", "audio=text_reader(description)"],
|
||||
),
|
||||
],
|
||||
[
|
||||
Problem(
|
||||
task=[
|
||||
"Generate an image from the text given in `text_input`.",
|
||||
"Use the following `text_input` to generate an image",
|
||||
],
|
||||
inputs=["text_input"],
|
||||
answer="image = image_generator(text_input)",
|
||||
),
|
||||
Problem(
|
||||
task=[
|
||||
"Transform it according to the text in `prompt`.",
|
||||
"Transform it by using the text in `prompt`.",
|
||||
],
|
||||
inputs=["prompt"],
|
||||
answer="image_transformer(image, prompt)",
|
||||
),
|
||||
],
|
||||
[
|
||||
Problem(
|
||||
task=[
|
||||
"Download the content of `url` and summarize it.",
|
||||
"Summarize the content of the web page at `url`.",
|
||||
],
|
||||
inputs=["url"],
|
||||
answer="summary = summarizer(text_downloader(url))",
|
||||
),
|
||||
Problem(
|
||||
task=[
|
||||
"Generate an image from its content.",
|
||||
"Use the previous result to generate an image.",
|
||||
],
|
||||
inputs=[],
|
||||
answer="image_generator(summary)",
|
||||
),
|
||||
],
|
||||
[
|
||||
Problem(
|
||||
task=[
|
||||
"Translate this Spanish `text` in English.",
|
||||
"Translate the `text` from Spanish to English.",
|
||||
],
|
||||
inputs=["text"],
|
||||
answer="translated_text = translator(text, src_lang='Spanish', tgt_lang='English')",
|
||||
),
|
||||
Problem(
|
||||
task=[
|
||||
"Transform the following `image` using the translated `text`.",
|
||||
"Use the previous result to transform the following `image`.",
|
||||
],
|
||||
inputs=["image"],
|
||||
answer="image_transformer(image, translated_text)",
|
||||
),
|
||||
],
|
||||
[
|
||||
Problem(
|
||||
task=["Download the content of `url`.", "Get me the text on the weg page `url`."],
|
||||
inputs=["url"],
|
||||
answer="text = text_downloader(url)",
|
||||
),
|
||||
Problem(
|
||||
task=["Summarize this text.", "Summarize this text."],
|
||||
inputs=[],
|
||||
answer="summary = summarizer(text)",
|
||||
),
|
||||
Problem(
|
||||
task=["Read it out loud to me.", "Read me the previous result."],
|
||||
inputs=[],
|
||||
answer="text_reader(summary)",
|
||||
),
|
||||
],
|
||||
[
|
||||
Problem(
|
||||
task=[
|
||||
"Generate an image from the text given in `text_input`.",
|
||||
],
|
||||
inputs=["text_input"],
|
||||
answer="image_generator(text_input)",
|
||||
),
|
||||
],
|
||||
[
|
||||
Problem(
|
||||
task=[
|
||||
"Replace the beaver in the `image` by the `prompt`.",
|
||||
"Transform the `image` so that it contains the `prompt`.",
|
||||
"Use `prompt` to transform this `image`.",
|
||||
],
|
||||
inputs=["image", "prompt"],
|
||||
answer="image_transformer(image, prompt)",
|
||||
),
|
||||
],
|
||||
[
|
||||
Problem(
|
||||
task=["Provide me the summary of the `text`.", "Summarize `text`."],
|
||||
inputs=["text"],
|
||||
answer="summary = summarizer(text)",
|
||||
),
|
||||
Problem(
|
||||
task=["Read this summary to me.", "Read it out loud."],
|
||||
inputs=[],
|
||||
answer="audio = text_reader(summarizer(text))",
|
||||
),
|
||||
Problem(
|
||||
task=["Transcribing the previous result back in text.", "Transcribe the audio."],
|
||||
inputs=[],
|
||||
answer="text = transcriber(audio)",
|
||||
),
|
||||
Problem(
|
||||
task=["Translating the last result in French.", "Translate this in French."],
|
||||
inputs=[],
|
||||
answer="translator(text, src_lang='English', tgt_lang='French')",
|
||||
),
|
||||
],
|
||||
[
|
||||
Problem(
|
||||
task=["Generate a video of the `prompt`", "Animate a `prompt`", "Make me a short video using `prompt`."],
|
||||
inputs={"prompt": "A lobster swimming"},
|
||||
answer="video_generator('A lobster swimming')",
|
||||
),
|
||||
],
|
||||
[
|
||||
Problem(
|
||||
task=[
|
||||
"Download the content of `url` and summarize it.",
|
||||
"Summarize the content of the web page at `url`.",
|
||||
],
|
||||
inputs=["url"],
|
||||
answer="summary = summarizer(text_downloader(url))",
|
||||
),
|
||||
Problem(
|
||||
task=["generate a video from it.", "Create an animation from the last result."],
|
||||
inputs=[],
|
||||
answer="video_generator(summary)",
|
||||
),
|
||||
],
|
||||
]
|
||||
|
||||
|
||||
def get_theoretical_tools(agent_answer, theoretical_answer, code_answer):
|
||||
if not isinstance(theoretical_answer, list):
|
||||
return {name for name in TEST_TOOLS if name in code_answer}
|
||||
|
||||
if isinstance(agent_answer, dict):
|
||||
for one_answer, one_code in zip(theoretical_answer, code_answer):
|
||||
if one_answer in agent_answer.values():
|
||||
return {name for name in TEST_TOOLS if name in one_code}
|
||||
|
||||
for one_answer, one_code in zip(theoretical_answer, code_answer):
|
||||
if agent_answer == one_answer:
|
||||
return {name for name in TEST_TOOLS if name in one_code}
|
||||
|
||||
return {name for name in TEST_TOOLS if name in code_answer[0]}
|
||||
|
||||
|
||||
def evaluate_code(code, inputs=None, state=None, verbose=False, return_interpretor_error=False):
|
||||
tools = BASE_PYTHON_TOOLS.copy()
|
||||
for name, tool in TEST_TOOLS.items():
|
||||
if name not in code:
|
||||
continue
|
||||
tools[name] = tool
|
||||
|
||||
if isinstance(inputs, dict):
|
||||
inputs = inputs.copy()
|
||||
elif inputs is not None:
|
||||
inputs = {inp: f"<<{inp}>>" for inp in inputs}
|
||||
|
||||
if state is not None:
|
||||
state.update(inputs)
|
||||
else:
|
||||
state = inputs
|
||||
|
||||
try:
|
||||
return evaluate(code, tools, state)
|
||||
except InterpretorError as e:
|
||||
return str(e)
|
||||
except Exception as e:
|
||||
if verbose:
|
||||
print(e)
|
||||
return None
|
||||
|
||||
|
||||
def score_code(agent_answer, theoretical_answer, verbose: bool = False):
|
||||
if verbose:
|
||||
print(agent_answer, theoretical_answer)
|
||||
theoretical_answer = theoretical_answer if isinstance(theoretical_answer, list) else [theoretical_answer]
|
||||
|
||||
if agent_answer in theoretical_answer:
|
||||
if verbose:
|
||||
print("Perfect!")
|
||||
return 1
|
||||
elif isinstance(agent_answer, dict) and any(v in theoretical_answer for v in agent_answer.values()):
|
||||
if verbose:
|
||||
print("Almsot perfect, result in state!")
|
||||
return 0.75
|
||||
else:
|
||||
if verbose:
|
||||
print("Result is not the right one but code executed.")
|
||||
return 0.3
|
||||
|
||||
|
||||
def evaluate_one_result(explanation, code, agent_answer, theoretical_answer, answer, verbose=False):
|
||||
tools_in_explanation = {name for name in TEST_TOOLS if f"`{name}`" in explanation}
|
||||
theoretical_tools = get_theoretical_tools(agent_answer, theoretical_answer, answer)
|
||||
if tools_in_explanation == theoretical_tools:
|
||||
tool_selection_score = 1.0
|
||||
tool_selection_errors = None
|
||||
else:
|
||||
missing_tools = len(theoretical_tools - tools_in_explanation)
|
||||
unexpected_tools = len(tools_in_explanation - theoretical_tools)
|
||||
tool_selection_score = max(0, 1.0 - 0.25 * missing_tools - 0.25 * unexpected_tools)
|
||||
|
||||
tool_selection_errors = {
|
||||
"selected_tools": tools_in_explanation,
|
||||
"theoretical_tools": theoretical_tools,
|
||||
}
|
||||
|
||||
tools_in_code = {name for name in TEST_TOOLS if name in code}
|
||||
if tools_in_code == theoretical_tools:
|
||||
tool_used_score = 1.0
|
||||
tool_used_errors = None
|
||||
else:
|
||||
missing_tools = len(theoretical_tools - tools_in_code)
|
||||
unexpected_tools = len(tools_in_code - theoretical_tools)
|
||||
tool_used_score = max(0, 1.0 - 0.25 * missing_tools - 0.25 * unexpected_tools)
|
||||
|
||||
tool_used_errors = {
|
||||
"selected_tools": tools_in_explanation,
|
||||
"theoretical_tools": theoretical_tools,
|
||||
}
|
||||
|
||||
score = score_code(agent_answer, theoretical_answer, verbose=verbose)
|
||||
if score < 1.0:
|
||||
code_errors = {
|
||||
"code_produced": code,
|
||||
"evaluation": agent_answer,
|
||||
"theoretical_answer": theoretical_answer,
|
||||
}
|
||||
else:
|
||||
code_errors = None
|
||||
|
||||
return (tool_selection_score, tool_used_score, score), (tool_selection_errors, tool_used_errors, code_errors)
|
||||
|
||||
|
||||
def evaluate_agent(agent, batch_size=8, verbose=False, return_errors=False):
|
||||
"""
|
||||
Evaluates a new agent on all `EVALUATION_TASKS`.
|
||||
|
||||
Example:
|
||||
|
||||
```py
|
||||
agent = NewOpenAiAgent(model="text-davinci-003", api_key=your_api_key)
|
||||
bads = new_evaluate_agent(agent)
|
||||
for bad in bads:
|
||||
print(bad)
|
||||
```
|
||||
"""
|
||||
# Sanity check
|
||||
agent_tools = set(agent.toolbox.keys())
|
||||
if agent_tools != set(TEST_TOOLS):
|
||||
missing_tools = set(TEST_TOOLS) - agent_tools
|
||||
unexpected_tools = set(agent_tools) - TEST_TOOLS
|
||||
raise ValueError(
|
||||
f"Fix the test tools in the evaluate_agent module. Tools mising: {missing_tools}. Extra tools: {unexpected_tools}."
|
||||
)
|
||||
|
||||
eval_tasks = []
|
||||
eval_idx = []
|
||||
for idx, pb in enumerate(EVALUATION_TASKS):
|
||||
if isinstance(pb.task, list):
|
||||
eval_tasks.extend(pb.task)
|
||||
eval_idx.extend([idx] * len(pb.task))
|
||||
else:
|
||||
eval_tasks.append(pb.task)
|
||||
eval_idx.append(idx)
|
||||
|
||||
tool_selection_score = 0
|
||||
tool_used_score = 0
|
||||
code_score = 0
|
||||
|
||||
if return_errors:
|
||||
tool_selection_errors = {}
|
||||
tool_used_errors = {}
|
||||
code_errors = {}
|
||||
|
||||
for start_idx in range(0, len(eval_tasks), batch_size):
|
||||
end_idx = min(start_idx + batch_size, len(eval_tasks))
|
||||
batch_tasks = eval_tasks[start_idx:end_idx]
|
||||
|
||||
prompts = [agent.format_prompt(task) for task in batch_tasks]
|
||||
results = agent.generate_many(prompts, stop=["Task:"])
|
||||
|
||||
for idx, result in enumerate(results):
|
||||
problem = EVALUATION_TASKS[eval_idx[start_idx + idx]]
|
||||
if verbose:
|
||||
print(f"====Task {start_idx + idx}====\n{batch_tasks[idx]}\n")
|
||||
explanation, code = clean_code_for_run(result)
|
||||
|
||||
# Evaluate agent answer and code answer
|
||||
agent_answer = evaluate_code(code, problem.inputs, verbose=verbose)
|
||||
if isinstance(problem.answer, list):
|
||||
theoretical_answer = [evaluate_code(answer, problem.inputs) for answer in problem.answer]
|
||||
else:
|
||||
theoretical_answer = evaluate_code(problem.answer, problem.inputs)
|
||||
|
||||
scores, errors = evaluate_one_result(
|
||||
explanation, code, agent_answer, theoretical_answer, problem.answer, verbose=verbose
|
||||
)
|
||||
|
||||
tool_selection_score += scores[0]
|
||||
tool_used_score += scores[1]
|
||||
code_score += scores[2]
|
||||
|
||||
if return_errors:
|
||||
if errors[0] is not None:
|
||||
tool_selection_errors[batch_tasks[idx]] = errors[0]
|
||||
if errors[1] is not None:
|
||||
tool_used_errors[batch_tasks[idx]] = errors[1]
|
||||
if errors[2] is not None:
|
||||
code_errors[batch_tasks[idx]] = errors[2]
|
||||
|
||||
scores = {
|
||||
"tool selection score": 100 * (tool_selection_score / len(eval_tasks)),
|
||||
"tool used score": 100 * (tool_used_score / len(eval_tasks)),
|
||||
"code score": 100 * (code_score / len(eval_tasks)),
|
||||
}
|
||||
|
||||
if return_errors:
|
||||
return scores, tool_selection_errors, tool_used_errors, code_errors
|
||||
else:
|
||||
return scores
|
||||
|
||||
|
||||
def evaluate_chat_agent(agent, verbose=False, return_errors=False):
|
||||
"""
|
||||
Evaluates a new agent on all `EVALUATION_CHATS`.
|
||||
|
||||
Example:
|
||||
|
||||
```py
|
||||
agent = NewOpenAiAgent(model="text-davinci-003", api_key=your_api_key)
|
||||
bads = new_evaluate_agent(agent)
|
||||
for bad in bads:
|
||||
print(bad)
|
||||
```
|
||||
"""
|
||||
# Sanity check
|
||||
agent_tools = set(agent.toolbox.keys())
|
||||
if agent_tools != set(TEST_TOOLS):
|
||||
missing_tools = set(TEST_TOOLS) - agent_tools
|
||||
unexpected_tools = agent_tools - set(TEST_TOOLS)
|
||||
raise ValueError(
|
||||
f"Fix the test tools in the evaluate_agent module. Tools mising: {missing_tools}. Extra tools: {unexpected_tools}."
|
||||
)
|
||||
|
||||
tool_selection_score = 0
|
||||
tool_used_score = 0
|
||||
code_score = 0
|
||||
total_steps = 0
|
||||
|
||||
if return_errors:
|
||||
tool_selection_errors = {}
|
||||
tool_used_errors = {}
|
||||
code_errors = {}
|
||||
|
||||
for chat_problem in EVALUATION_CHATS:
|
||||
if isinstance(chat_problem[0].task, str):
|
||||
resolved_problems = [chat_problem]
|
||||
else:
|
||||
resolved_problems = [
|
||||
[Problem(task=pb.task[i], inputs=pb.inputs, answer=pb.answer) for pb in chat_problem]
|
||||
for i in range(len(chat_problem[0].task))
|
||||
]
|
||||
for problem in resolved_problems:
|
||||
agent.prepare_for_new_chat()
|
||||
agent_state = {}
|
||||
theoretical_state = (
|
||||
[{} for _ in range(len(problem[0].answer))] if isinstance(problem[0].answer, list) else {}
|
||||
)
|
||||
|
||||
for step, step_problem in enumerate(problem):
|
||||
if verbose:
|
||||
print(step_problem.task)
|
||||
total_steps += 1
|
||||
prompt = agent.format_prompt(step_problem.task, chat_mode=True)
|
||||
result = agent.generate_one(prompt, stop=["Human:", "====="])
|
||||
agent.chat_history = prompt + result + "\n"
|
||||
|
||||
explanation, code = clean_code_for_chat(result)
|
||||
|
||||
if verbose:
|
||||
print(f"==Explanation from the agent==\n{explanation}")
|
||||
print(f"\n==Code generated by the agent==\n{code}")
|
||||
|
||||
# Evaluate agent answer and code answer
|
||||
agent_answer = evaluate_code(code, step_problem.inputs, state=agent_state, verbose=verbose)
|
||||
|
||||
answer = step_problem.answer
|
||||
if isinstance(answer, list):
|
||||
theoretical_answer = [
|
||||
evaluate_code(a, step_problem.inputs, state=state)
|
||||
for a, state in zip(answer, theoretical_state)
|
||||
]
|
||||
else:
|
||||
theoretical_answer = evaluate_code(answer, step_problem.inputs, state=theoretical_state)
|
||||
|
||||
scores, errors = evaluate_one_result(
|
||||
explanation, code, agent_answer, theoretical_answer, answer, verbose=verbose
|
||||
)
|
||||
|
||||
tool_selection_score += scores[0]
|
||||
tool_used_score += scores[1]
|
||||
code_score += scores[2]
|
||||
|
||||
if return_errors:
|
||||
if errors[0] is not None:
|
||||
tool_selection_errors[step_problem.task] = errors[0]
|
||||
if errors[1] is not None:
|
||||
tool_used_errors[step_problem.task] = errors[1]
|
||||
if errors[2] is not None:
|
||||
code_errors[step_problem.task] = errors[2]
|
||||
|
||||
scores = {
|
||||
"tool selection score": 100 * (tool_selection_score / total_steps),
|
||||
"tool used score": 100 * (tool_used_score / total_steps),
|
||||
"code score": 100 * (code_score / total_steps),
|
||||
}
|
||||
|
||||
if return_errors:
|
||||
return scores, tool_selection_errors, tool_used_errors, code_errors
|
||||
else:
|
||||
return scores
|
|
@ -0,0 +1,51 @@
|
|||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ..models.auto import AutoModelForVision2Seq
|
||||
from ..utils import requires_backends
|
||||
from .base import PipelineTool
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class ImageCaptioningTool(PipelineTool):
|
||||
default_checkpoint = "Salesforce/blip-image-captioning-base"
|
||||
description = (
|
||||
"This is a tool that generates a description of an image. It takes an input named `image` which should be the "
|
||||
"image to caption, and returns a text that contains the description in English."
|
||||
)
|
||||
name = "image_captioner"
|
||||
model_class = AutoModelForVision2Seq
|
||||
|
||||
inputs = ["image"]
|
||||
outputs = ["text"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["vision"])
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def encode(self, image: "Image"):
|
||||
return self.pre_processor(images=image, return_tensors="pt")
|
||||
|
||||
def forward(self, inputs):
|
||||
return self.model.generate(**inputs)
|
||||
|
||||
def decode(self, outputs):
|
||||
return self.pre_processor.batch_decode(outputs, skip_special_tokens=True)[0].strip()
|
|
@ -0,0 +1,57 @@
|
|||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
from ..models.auto import AutoModelForVisualQuestionAnswering, AutoProcessor
|
||||
from ..utils import requires_backends
|
||||
from .base import PipelineTool
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class ImageQuestionAnsweringTool(PipelineTool):
|
||||
default_checkpoint = "dandelin/vilt-b32-finetuned-vqa"
|
||||
description = (
|
||||
"This is a tool that answers a question about an image. It takes an input named `image` which should be the "
|
||||
"image containing the information, as well as a `question` which should be the question in English. It "
|
||||
"returns a text that is the answer to the question."
|
||||
)
|
||||
name = "image_qa"
|
||||
pre_processor_class = AutoProcessor
|
||||
model_class = AutoModelForVisualQuestionAnswering
|
||||
|
||||
inputs = ["image", "text"]
|
||||
outputs = ["text"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["vision"])
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def encode(self, image: "Image", question: str):
|
||||
return self.pre_processor(image, question, return_tensors="pt")
|
||||
|
||||
def forward(self, inputs):
|
||||
with torch.no_grad():
|
||||
return self.model(**inputs).logits
|
||||
|
||||
def decode(self, outputs):
|
||||
idx = outputs.argmax(-1).item()
|
||||
return self.model.config.id2label[idx]
|
|
@ -0,0 +1,60 @@
|
|||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ..models.clipseg import CLIPSegForImageSegmentation
|
||||
from ..utils import is_vision_available, requires_backends
|
||||
from .base import PipelineTool
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class ImageSegmentationTool(PipelineTool):
|
||||
description = (
|
||||
"This is a tool that creates a segmentation mask identifiying elements inside an image according to a prompt. "
|
||||
"It takes two arguments named `image` which should be the original image, and `prompt` which should be a text "
|
||||
"describing the elements what should be identified in the segmentation mask. The tool returns the mask as a "
|
||||
"black-and-white image."
|
||||
)
|
||||
default_checkpoint = "CIDAS/clipseg-rd64-refined"
|
||||
name = "image_segmenter"
|
||||
model_class = CLIPSegForImageSegmentation
|
||||
|
||||
inputs = ["image", "text"]
|
||||
outputs = ["image"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["vision"])
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def encode(self, image: "Image", prompt: str):
|
||||
self.pre_processor.image_processor.size = {"width": image.size[0], "height": image.size[1]}
|
||||
return self.pre_processor(text=[prompt], images=[image], padding=True, return_tensors="pt")
|
||||
|
||||
def forward(self, inputs):
|
||||
with torch.no_grad():
|
||||
logits = self.model(**inputs).logits
|
||||
return logits
|
||||
|
||||
def decode(self, outputs):
|
||||
array = outputs.cpu().detach().numpy()
|
||||
array[array <= 0] = 0
|
||||
array[array > 0] = 1
|
||||
return Image.fromarray((array * 255).astype(np.uint8))
|
|
@ -0,0 +1,186 @@
|
|||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# docstyle-ignore
|
||||
RUN_PROMPT_TEMPLATE = """I will ask you to perform a task, your job is to come up with a series of simple commands in Python that will perform the task.
|
||||
To help you, I will give you access to a set of tools that you can use. Each tool is a Python function and has a description explaining the task it performs, the inputs it expects and the outputs it returns.
|
||||
You should first explain which tool you will use to perform the task and for what reason, then write the code in Python.
|
||||
Each instruction in Python should be a simple assignment. You can print intermediate results if it makes sense to do so.
|
||||
|
||||
Tools:
|
||||
<<all_tools>>
|
||||
|
||||
|
||||
Task: "Answer the question in the variable `question` about the image stored in the variable `image`. The question is in French."
|
||||
|
||||
I will use the following tools: `translator` to translate the question into English and then `image_qa` to answer the question on the input image.
|
||||
|
||||
Answer:
|
||||
```py
|
||||
translated_question = translator(question=question, src_lang="French", tgt_lang="English")
|
||||
print(f"The translated question is {translated_question}.")
|
||||
answer = image_qa(image=image, question=translated_question)
|
||||
print(f"The answer is {answer}")
|
||||
```
|
||||
|
||||
Task: "Identify the oldest person in the `document` and create an image showcasing the result."
|
||||
|
||||
I will use the following tools: `document_qa` to find the oldest person in the document, then `image_generator` to generate an image according to the answer.
|
||||
|
||||
Answer:
|
||||
```py
|
||||
answer = document_qa(document, question="What is the oldest person?")
|
||||
print(f"The answer is {answer}.")
|
||||
image = image_generator(answer)
|
||||
```
|
||||
|
||||
Task: "Generate an image using the text given in the variable `caption`."
|
||||
|
||||
I will use the following tool: `image_generator` to generate an image.
|
||||
|
||||
Answer:
|
||||
```py
|
||||
image = image_generator(prompt=caption)
|
||||
```
|
||||
|
||||
Task: "Summarize the text given in the variable `text` and read it out loud."
|
||||
|
||||
I will use the following tools: `summarizer` to create a summary of the input text, then `text_reader` to read it out loud.
|
||||
|
||||
Answer:
|
||||
```py
|
||||
summarized_text = summarizer(text)
|
||||
print(f"Summary: {summarized_text}")
|
||||
audio_summary = text_reader(summarized_text)
|
||||
```
|
||||
|
||||
Task: "Answer the question in the variable `question` about the text in the variable `text`. Use the answer to generate an image."
|
||||
|
||||
I will use the following tools: `text_qa` to create the answer, then `image_generator` to generate an image according to the answer.
|
||||
|
||||
Answer:
|
||||
```py
|
||||
answer = text_qa(text=text, question=question)
|
||||
print(f"The answer is {answer}.")
|
||||
image = image_generator(answer)
|
||||
```
|
||||
|
||||
Task: "Caption the following `image`."
|
||||
|
||||
I will use the following tool: `image_captioner` to generate a caption for the image.
|
||||
|
||||
Answer:
|
||||
```py
|
||||
caption = image_captioner(image)
|
||||
```
|
||||
|
||||
Task: "<<prompt>>"
|
||||
|
||||
I will use the following"""
|
||||
|
||||
|
||||
# docstyle-ignore
|
||||
CHAT_PROMPT_TEMPLATE = """Below are a series of dialogues between various people and an AI assistant specialized in coding. The AI assistant tries to be helpful, polite, honest, and humble-but-knowledgeable.
|
||||
|
||||
The job of the AI assistant is to come up with a series of simple commands in Python that will perform the task the human wants to perform.
|
||||
To help with that, the AI assistant has access to a set of tools. Each tool is a Python function and has a description explaining the task it performs, the inputs it expects and the outputs it returns.
|
||||
The AI assistant should first explain the tools it will use to perform the task and for what reason, then write the code in Python.
|
||||
Each instruction in Python should be a simple assignment. The AI assistant can print intermediate results if it makes sense to do so.
|
||||
|
||||
Tools:
|
||||
<<all_tools>>
|
||||
|
||||
=====
|
||||
|
||||
Human: Answer the question in the variable `question` about the image stored in the variable `image`.
|
||||
|
||||
Assistant: I will use the tool `image_qa` to answer the question on the input image.
|
||||
|
||||
```py
|
||||
answer = image_qa(text=question, image=image)
|
||||
print(f"The answer is {answer}")
|
||||
```
|
||||
|
||||
Human: I tried this code, it worked but didn't give me a good result. The question is in French
|
||||
|
||||
Assistant: In this case, the question needs to be translated first. I will use the tool `translator` to do this.
|
||||
|
||||
```py
|
||||
translated_question = translator(question=question, src_lang="French", tgt_lang="English")
|
||||
print(f"The translated question is {translated_question}.")
|
||||
answer = image_qa(text=translated_question, image=image)
|
||||
print(f"The answer is {answer}")
|
||||
```
|
||||
|
||||
=====
|
||||
|
||||
Human: Identify the oldest person in the `document`.
|
||||
|
||||
Assistant: I will use the tool `document_qa` to find the oldest person in the document.
|
||||
|
||||
```py
|
||||
answer = document_qa(document, question="What is the oldest person?")
|
||||
print(f"The answer is {answer}.")
|
||||
```
|
||||
|
||||
Human: Can you generate an image with the result?
|
||||
|
||||
Assistant: I will use the tool `image_generator` to do that.
|
||||
|
||||
```py
|
||||
image = image_generator(answer)
|
||||
```
|
||||
|
||||
=====
|
||||
|
||||
Human: Summarize the text given in the variable `text` and read it out loud.
|
||||
|
||||
Assistant: I will use the tool `summarizer` to create a summary of the input text, then the tool `text_reader` to read it out loud.
|
||||
|
||||
```py
|
||||
summarized_text = summarizer(text)
|
||||
print(f"Summary: {summarized_text}")
|
||||
audio_summary = text_reader(text=summary)
|
||||
```
|
||||
|
||||
Human: I got the following error: "The variable `summary` is not defined."
|
||||
|
||||
Assistant: My bad! Let's try this code instead.
|
||||
|
||||
```py
|
||||
summarized_text = summarizer(text)
|
||||
print(f"Summary: {summarized_text}")
|
||||
audio_summary = text_reader(text=summarized_text)
|
||||
```
|
||||
|
||||
Human: It worked! Can you translate the summary in German?
|
||||
|
||||
Assistant: I will use the tool `translator` to translate the text in German.
|
||||
|
||||
```py
|
||||
translated_summary = translator(summarized_text, src_lang="English", tgt_lang="German)
|
||||
```
|
||||
|
||||
====
|
||||
"""
|
||||
|
||||
|
||||
# docstyle-ignore
|
||||
CHAT_MESSAGE_PROMPT = """
|
||||
Human: <<task>>
|
||||
|
||||
Assistant: """
|
|
@ -0,0 +1,238 @@
|
|||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import ast
|
||||
import difflib
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Callable, Dict
|
||||
|
||||
|
||||
class InterpretorError(ValueError):
|
||||
"""
|
||||
An error raised when the interpretor cannot evaluate a Python expression, due to syntax error or unsupported
|
||||
operations.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def evaluate(code: str, tools: Dict[str, Callable], state=None, chat_mode=False):
|
||||
"""
|
||||
Evaluate a python expression using the content of the variables stored in a state and only evaluating a given set
|
||||
of functions.
|
||||
|
||||
This function will recurse through the nodes of the tree provided.
|
||||
|
||||
Args:
|
||||
code (`str`):
|
||||
The code to evaluate.
|
||||
tools (`Dict[str, Callable]`):
|
||||
The functions that may be called during the evaluation. Any call to another function will fail with an
|
||||
`InterpretorError`.
|
||||
state (`Dict[str, Any]`):
|
||||
A dictionary mapping variable names to values. The `state` should contain the initial inputs but will be
|
||||
updated by this function to contain all variables as they are evaluated.
|
||||
chat_mode (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not the function is called from `Agent.chat`.
|
||||
"""
|
||||
try:
|
||||
expression = ast.parse(code)
|
||||
except SyntaxError as e:
|
||||
print("The code generated by the agent is not valid.\n", e)
|
||||
return
|
||||
if state is None:
|
||||
state = {}
|
||||
result = None
|
||||
for idx, node in enumerate(expression.body):
|
||||
try:
|
||||
line_result = evaluate_ast(node, state, tools)
|
||||
except InterpretorError as e:
|
||||
msg = f"Evaluation of the code stopped at line {idx} before the end because of the following error"
|
||||
if chat_mode:
|
||||
msg += (
|
||||
f". Copy paste the following error message and send it back to the agent:\nI get an error: '{e}'"
|
||||
)
|
||||
else:
|
||||
msg += f":\n{e}"
|
||||
print(msg)
|
||||
break
|
||||
if line_result is not None:
|
||||
result = line_result
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def evaluate_ast(expression: ast.AST, state: Dict[str, Any], tools: Dict[str, Callable]):
|
||||
"""
|
||||
Evaluate an absract syntax tree using the content of the variables stored in a state and only evaluating a given
|
||||
set of functions.
|
||||
|
||||
This function will recurse trough the nodes of the tree provided.
|
||||
|
||||
Args:
|
||||
expression (`ast.AST`):
|
||||
The code to evaluate, as an abastract syntax tree.
|
||||
state (`Dict[str, Any]`):
|
||||
A dictionary mapping variable names to values. The `state` is updated if need be when the evaluation
|
||||
encounters assignements.
|
||||
tools (`Dict[str, Callable]`):
|
||||
The functions that may be called during the evaluation. Any call to another function will fail with an
|
||||
`InterpretorError`.
|
||||
"""
|
||||
if isinstance(expression, ast.Assign):
|
||||
# Assignement -> we evaluate the assignement which should update the state
|
||||
# We return the variable assigned as it may be used to determine the final result.
|
||||
return evaluate_assign(expression, state, tools)
|
||||
elif isinstance(expression, ast.Call):
|
||||
# Function call -> we return the value of the function call
|
||||
return evaluate_call(expression, state, tools)
|
||||
elif isinstance(expression, ast.Constant):
|
||||
# Constant -> just return the value
|
||||
return expression.value
|
||||
elif isinstance(expression, ast.Dict):
|
||||
# Dict -> evaluate all keys and values
|
||||
keys = [evaluate_ast(k, state, tools) for k in expression.keys]
|
||||
values = [evaluate_ast(v, state, tools) for v in expression.values]
|
||||
return dict(zip(keys, values))
|
||||
elif isinstance(expression, ast.Expr):
|
||||
# Expression -> evaluate the content
|
||||
return evaluate_ast(expression.value, state, tools)
|
||||
elif isinstance(expression, ast.FormattedValue):
|
||||
# Formatted value (part of f-string) -> evaluate the content and return
|
||||
return evaluate_ast(expression.value, state, tools)
|
||||
elif isinstance(expression, ast.If):
|
||||
# If -> execute the right branch
|
||||
return evaluate_if(expression, state, tools)
|
||||
elif hasattr(ast, "Index") and isinstance(expression, ast.Index):
|
||||
return evaluate_ast(expression.value, state, tools)
|
||||
elif isinstance(expression, ast.JoinedStr):
|
||||
return "".join([str(evaluate_ast(v, state, tools)) for v in expression.values])
|
||||
elif isinstance(expression, ast.List):
|
||||
# List -> evaluate all elements
|
||||
return [evaluate_ast(elt, state, tools) for elt in expression.elts]
|
||||
elif isinstance(expression, ast.Name):
|
||||
# Name -> pick up the value in the state
|
||||
return evaluate_name(expression, state, tools)
|
||||
elif isinstance(expression, ast.Subscript):
|
||||
# Subscript -> return the value of the indexing
|
||||
return evaluate_subscript(expression, state, tools)
|
||||
else:
|
||||
# For now we refuse anything else. Let's add things as we need them.
|
||||
raise InterpretorError(f"{expression.__class__.__name__} is not supported.")
|
||||
|
||||
|
||||
def evaluate_assign(assign, state, tools):
|
||||
var_names = assign.targets
|
||||
result = evaluate_ast(assign.value, state, tools)
|
||||
|
||||
if len(var_names) == 1:
|
||||
state[var_names[0].id] = result
|
||||
else:
|
||||
if len(result) != len(var_names):
|
||||
raise InterpretorError(f"Expected {len(var_names)} values but got {len(result)}.")
|
||||
for var_name, r in zip(var_names, result):
|
||||
state[var_name.id] = r
|
||||
return result
|
||||
|
||||
|
||||
def evaluate_call(call, state, tools):
|
||||
if not isinstance(call.func, ast.Name):
|
||||
raise InterpretorError(
|
||||
f"It is not permitted to evaluate other functions than the provided tools (tried to execute {call.func} of "
|
||||
f"type {type(call.func)}."
|
||||
)
|
||||
func_name = call.func.id
|
||||
if func_name not in tools:
|
||||
raise InterpretorError(
|
||||
f"It is not permitted to evaluate other functions than the provided tools (tried to execute {call.func.id})."
|
||||
)
|
||||
|
||||
func = tools[func_name]
|
||||
# Todo deal with args
|
||||
args = [evaluate_ast(arg, state, tools) for arg in call.args]
|
||||
kwargs = {keyword.arg: evaluate_ast(keyword.value, state, tools) for keyword in call.keywords}
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
def evaluate_subscript(subscript, state, tools):
|
||||
index = evaluate_ast(subscript.slice, state, tools)
|
||||
value = evaluate_ast(subscript.value, state, tools)
|
||||
if isinstance(value, (list, tuple)):
|
||||
return value[int(index)]
|
||||
if index in value:
|
||||
return value[index]
|
||||
if isinstance(index, str) and isinstance(value, Mapping):
|
||||
close_matches = difflib.get_close_matches(index, list(value.keys()))
|
||||
if len(close_matches) > 0:
|
||||
return value[close_matches[0]]
|
||||
|
||||
raise InterpretorError(f"Could not index {value} with '{index}'.")
|
||||
|
||||
|
||||
def evaluate_name(name, state, tools):
|
||||
if name.id in state:
|
||||
return state[name.id]
|
||||
close_matches = difflib.get_close_matches(name.id, list(state.keys()))
|
||||
if len(close_matches) > 0:
|
||||
return state[close_matches[0]]
|
||||
raise InterpretorError(f"The variable `{name.id}` is not defined.")
|
||||
|
||||
|
||||
def evaluate_condition(condition, state, tools):
|
||||
if len(condition.ops) > 1:
|
||||
raise InterpretorError("Cannot evaluate conditions with multiple operators")
|
||||
|
||||
left = evaluate_ast(condition.left, state, tools)
|
||||
comparator = condition.ops[0]
|
||||
right = evaluate_ast(condition.comparators[0], state, tools)
|
||||
|
||||
if isinstance(comparator, ast.Eq):
|
||||
return left == right
|
||||
elif isinstance(comparator, ast.NotEq):
|
||||
return left != right
|
||||
elif isinstance(comparator, ast.Lt):
|
||||
return left < right
|
||||
elif isinstance(comparator, ast.LtE):
|
||||
return left <= right
|
||||
elif isinstance(comparator, ast.Gt):
|
||||
return left > right
|
||||
elif isinstance(comparator, ast.GtE):
|
||||
return left >= right
|
||||
elif isinstance(comparator, ast.Is):
|
||||
return left is right
|
||||
elif isinstance(comparator, ast.IsNot):
|
||||
return left is not right
|
||||
elif isinstance(comparator, ast.In):
|
||||
return left in right
|
||||
elif isinstance(comparator, ast.NotIn):
|
||||
return left not in right
|
||||
else:
|
||||
raise InterpretorError(f"Operator not supported: {comparator}")
|
||||
|
||||
|
||||
def evaluate_if(if_statement, state, tools):
|
||||
result = None
|
||||
if evaluate_condition(if_statement.test, state, tools):
|
||||
for line in if_statement.body:
|
||||
line_result = evaluate_ast(line, state, tools)
|
||||
if line_result is not None:
|
||||
result = line_result
|
||||
else:
|
||||
for line in if_statement.orelse:
|
||||
line_result = evaluate_ast(line, state, tools)
|
||||
if line_result is not None:
|
||||
result = line_result
|
||||
return result
|
|
@ -0,0 +1,41 @@
|
|||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from ..models.whisper import WhisperForConditionalGeneration, WhisperProcessor
|
||||
from .base import PipelineTool
|
||||
|
||||
|
||||
class SpeechToTextTool(PipelineTool):
|
||||
default_checkpoint = "openai/whisper-base"
|
||||
description = (
|
||||
"This is a tool that transcribes an audio into text. It takes an input named `audio` and returns the "
|
||||
"transcribed text."
|
||||
)
|
||||
name = "transcriber"
|
||||
pre_processor_class = WhisperProcessor
|
||||
model_class = WhisperForConditionalGeneration
|
||||
|
||||
inputs = ["audio"]
|
||||
outputs = ["text"]
|
||||
|
||||
def encode(self, audio):
|
||||
return self.pre_processor(audio, return_tensors="pt").input_features
|
||||
|
||||
def forward(self, inputs):
|
||||
return self.model.generate(inputs=inputs)
|
||||
|
||||
def decode(self, outputs):
|
||||
return self.pre_processor.batch_decode(outputs, skip_special_tokens=True)[0]
|
|
@ -0,0 +1,70 @@
|
|||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import torch
|
||||
|
||||
from ..models.auto import AutoModelForSequenceClassification, AutoTokenizer
|
||||
from .base import PipelineTool
|
||||
|
||||
|
||||
class TextClassificationTool(PipelineTool):
|
||||
"""
|
||||
Example:
|
||||
|
||||
```py
|
||||
from transformers.tools import TextClassificationTool
|
||||
|
||||
classifier = TextClassificationTool()
|
||||
classifier("This is a super nice API!", labels=["positive", "negative"])
|
||||
```
|
||||
"""
|
||||
|
||||
default_checkpoint = "facebook/bart-large-mnli"
|
||||
description = (
|
||||
"This is a tool that classifies an English text using provided labels. It takes two inputs: `text`, which "
|
||||
"should be the text to classify, and `labels`, which should be the list of labels to use for classification. "
|
||||
"It returns the most likely label in the list of provided `labels` for the input text."
|
||||
)
|
||||
name = "text_classifier"
|
||||
pre_processor_class = AutoTokenizer
|
||||
model_class = AutoModelForSequenceClassification
|
||||
|
||||
inputs = ["text", ["text"]]
|
||||
outputs = ["text"]
|
||||
|
||||
def setup(self):
|
||||
super().setup()
|
||||
config = self.model.config
|
||||
self.entailment_id = -1
|
||||
for idx, label in config.id2label.items():
|
||||
if label.lower().startswith("entail"):
|
||||
self.entailment_id = int(idx)
|
||||
if self.entailment_id == -1:
|
||||
raise ValueError("Could not determine the entailment ID from the model config, please pass it at init.")
|
||||
|
||||
def encode(self, text, labels):
|
||||
self._labels = labels
|
||||
return self.pre_processor(
|
||||
[text] * len(labels),
|
||||
[f"This example is {label}" for label in labels],
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
)
|
||||
|
||||
def decode(self, outputs):
|
||||
logits = outputs.logits
|
||||
label_id = torch.argmax(logits[:, 2]).item()
|
||||
return self._labels[label_id]
|
|
@ -0,0 +1,52 @@
|
|||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from ..models.auto import AutoModelForSeq2SeqLM, AutoTokenizer
|
||||
from .base import PipelineTool
|
||||
|
||||
|
||||
QA_PROMPT = """Here is a text containing a lot of information: '''{text}'''.
|
||||
|
||||
Can you answer this question about the text: '{question}'"""
|
||||
|
||||
|
||||
class TextQuestionAnsweringTool(PipelineTool):
|
||||
default_checkpoint = "google/flan-t5-base"
|
||||
description = (
|
||||
"This is a tool that answers questions related to a text. It takes two arguments named `text`, which is the "
|
||||
"text where to find the answer, and `question`, which is the question, and returns the answer to the question."
|
||||
)
|
||||
name = "text_qa"
|
||||
pre_processor_class = AutoTokenizer
|
||||
model_class = AutoModelForSeq2SeqLM
|
||||
|
||||
inputs = ["text", "text"]
|
||||
outputs = ["text"]
|
||||
|
||||
def encode(self, text: str, question: str):
|
||||
prompt = QA_PROMPT.format(text=text, question=question)
|
||||
return self.pre_processor(prompt, return_tensors="pt")
|
||||
|
||||
def forward(self, inputs):
|
||||
output_ids = self.model.generate(**inputs)
|
||||
|
||||
in_b, _ = inputs["input_ids"].shape
|
||||
out_b = output_ids.shape[0]
|
||||
|
||||
return output_ids.reshape(in_b, out_b // in_b, *output_ids.shape[1:])[0][0]
|
||||
|
||||
def decode(self, outputs):
|
||||
return self.pre_processor.decode(outputs, skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
|
@ -0,0 +1,52 @@
|
|||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from ..models.auto import AutoModelForSeq2SeqLM, AutoTokenizer
|
||||
from .base import PipelineTool
|
||||
|
||||
|
||||
class TextSummarizationTool(PipelineTool):
|
||||
"""
|
||||
Example:
|
||||
|
||||
```py
|
||||
from transformers.tools import TextSummarizationTool
|
||||
|
||||
summarizer = TextSummarizationTool()
|
||||
summarizer(long_text)
|
||||
```
|
||||
"""
|
||||
|
||||
default_checkpoint = "philschmid/bart-large-cnn-samsum"
|
||||
description = (
|
||||
"This is a tool that summarizes an English text. It takes an input `text` containing the text to summarize, "
|
||||
"and returns a summary of the text."
|
||||
)
|
||||
name = "summarizer"
|
||||
pre_processor_class = AutoTokenizer
|
||||
model_class = AutoModelForSeq2SeqLM
|
||||
|
||||
inputs = ["text"]
|
||||
outputs = ["text"]
|
||||
|
||||
def encode(self, text):
|
||||
return self.pre_processor(text, return_tensors="pt", truncation=True)
|
||||
|
||||
def forward(self, inputs):
|
||||
return self.model.generate(**inputs)[0]
|
||||
|
||||
def decode(self, outputs):
|
||||
return self.pre_processor.decode(outputs, skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
|
@ -0,0 +1,65 @@
|
|||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import torch
|
||||
|
||||
from ..models.speecht5 import SpeechT5ForTextToSpeech, SpeechT5HifiGan, SpeechT5Processor
|
||||
from ..utils import is_datasets_available
|
||||
from .base import PipelineTool
|
||||
|
||||
|
||||
if is_datasets_available():
|
||||
from datasets import load_dataset
|
||||
|
||||
|
||||
class TextToSpeechTool(PipelineTool):
|
||||
default_checkpoint = "microsoft/speecht5_tts"
|
||||
description = (
|
||||
"This is a tool that reads an English text out loud. It takes an input named `text` which should contain the "
|
||||
"text to read (in English) and returns a waveform object containing the sound."
|
||||
)
|
||||
name = "text_reader"
|
||||
pre_processor_class = SpeechT5Processor
|
||||
model_class = SpeechT5ForTextToSpeech
|
||||
post_processor_class = SpeechT5HifiGan
|
||||
|
||||
inputs = ["text"]
|
||||
outputs = ["audio"]
|
||||
|
||||
def setup(self):
|
||||
if self.post_processor is None:
|
||||
self.post_processor = "microsoft/speecht5_hifigan"
|
||||
super().setup()
|
||||
|
||||
def encode(self, text, speaker_embeddings=None):
|
||||
inputs = self.pre_processor(text=text, return_tensors="pt", truncation=True)
|
||||
|
||||
if speaker_embeddings is None:
|
||||
if not is_datasets_available():
|
||||
raise ImportError("Datasets needs to be installed if not passing speaker embeddings.")
|
||||
|
||||
embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
|
||||
speaker_embeddings = torch.tensor(embeddings_dataset[7305]["xvector"]).unsqueeze(0)
|
||||
|
||||
return {"input_ids": inputs["input_ids"], "speaker_embeddings": speaker_embeddings}
|
||||
|
||||
def forward(self, inputs):
|
||||
with torch.no_grad():
|
||||
return self.model.generate_speech(**inputs)
|
||||
|
||||
def decode(self, outputs):
|
||||
with torch.no_grad():
|
||||
return self.post_processor(outputs).cpu().detach()
|
|
@ -0,0 +1,271 @@
|
|||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from ..models.auto import AutoModelForSeq2SeqLM, AutoTokenizer
|
||||
from .base import PipelineTool
|
||||
|
||||
|
||||
LANGUAGE_CODES = {
|
||||
"Acehnese Arabic": "ace_Arab",
|
||||
"Acehnese Latin": "ace_Latn",
|
||||
"Mesopotamian Arabic": "acm_Arab",
|
||||
"Ta'izzi-Adeni Arabic": "acq_Arab",
|
||||
"Tunisian Arabic": "aeb_Arab",
|
||||
"Afrikaans": "afr_Latn",
|
||||
"South Levantine Arabic": "ajp_Arab",
|
||||
"Akan": "aka_Latn",
|
||||
"Amharic": "amh_Ethi",
|
||||
"North Levantine Arabic": "apc_Arab",
|
||||
"Modern Standard Arabic": "arb_Arab",
|
||||
"Modern Standard Arabic Romanized": "arb_Latn",
|
||||
"Najdi Arabic": "ars_Arab",
|
||||
"Moroccan Arabic": "ary_Arab",
|
||||
"Egyptian Arabic": "arz_Arab",
|
||||
"Assamese": "asm_Beng",
|
||||
"Asturian": "ast_Latn",
|
||||
"Awadhi": "awa_Deva",
|
||||
"Central Aymara": "ayr_Latn",
|
||||
"South Azerbaijani": "azb_Arab",
|
||||
"North Azerbaijani": "azj_Latn",
|
||||
"Bashkir": "bak_Cyrl",
|
||||
"Bambara": "bam_Latn",
|
||||
"Balinese": "ban_Latn",
|
||||
"Belarusian": "bel_Cyrl",
|
||||
"Bemba": "bem_Latn",
|
||||
"Bengali": "ben_Beng",
|
||||
"Bhojpuri": "bho_Deva",
|
||||
"Banjar Arabic": "bjn_Arab",
|
||||
"Banjar Latin": "bjn_Latn",
|
||||
"Standard Tibetan": "bod_Tibt",
|
||||
"Bosnian": "bos_Latn",
|
||||
"Buginese": "bug_Latn",
|
||||
"Bulgarian": "bul_Cyrl",
|
||||
"Catalan": "cat_Latn",
|
||||
"Cebuano": "ceb_Latn",
|
||||
"Czech": "ces_Latn",
|
||||
"Chokwe": "cjk_Latn",
|
||||
"Central Kurdish": "ckb_Arab",
|
||||
"Crimean Tatar": "crh_Latn",
|
||||
"Welsh": "cym_Latn",
|
||||
"Danish": "dan_Latn",
|
||||
"German": "deu_Latn",
|
||||
"Southwestern Dinka": "dik_Latn",
|
||||
"Dyula": "dyu_Latn",
|
||||
"Dzongkha": "dzo_Tibt",
|
||||
"Greek": "ell_Grek",
|
||||
"English": "eng_Latn",
|
||||
"Esperanto": "epo_Latn",
|
||||
"Estonian": "est_Latn",
|
||||
"Basque": "eus_Latn",
|
||||
"Ewe": "ewe_Latn",
|
||||
"Faroese": "fao_Latn",
|
||||
"Fijian": "fij_Latn",
|
||||
"Finnish": "fin_Latn",
|
||||
"Fon": "fon_Latn",
|
||||
"French": "fra_Latn",
|
||||
"Friulian": "fur_Latn",
|
||||
"Nigerian Fulfulde": "fuv_Latn",
|
||||
"Scottish Gaelic": "gla_Latn",
|
||||
"Irish": "gle_Latn",
|
||||
"Galician": "glg_Latn",
|
||||
"Guarani": "grn_Latn",
|
||||
"Gujarati": "guj_Gujr",
|
||||
"Haitian Creole": "hat_Latn",
|
||||
"Hausa": "hau_Latn",
|
||||
"Hebrew": "heb_Hebr",
|
||||
"Hindi": "hin_Deva",
|
||||
"Chhattisgarhi": "hne_Deva",
|
||||
"Croatian": "hrv_Latn",
|
||||
"Hungarian": "hun_Latn",
|
||||
"Armenian": "hye_Armn",
|
||||
"Igbo": "ibo_Latn",
|
||||
"Ilocano": "ilo_Latn",
|
||||
"Indonesian": "ind_Latn",
|
||||
"Icelandic": "isl_Latn",
|
||||
"Italian": "ita_Latn",
|
||||
"Javanese": "jav_Latn",
|
||||
"Japanese": "jpn_Jpan",
|
||||
"Kabyle": "kab_Latn",
|
||||
"Jingpho": "kac_Latn",
|
||||
"Kamba": "kam_Latn",
|
||||
"Kannada": "kan_Knda",
|
||||
"Kashmiri Arabic": "kas_Arab",
|
||||
"Kashmiri Devanagari": "kas_Deva",
|
||||
"Georgian": "kat_Geor",
|
||||
"Central Kanuri Arabic": "knc_Arab",
|
||||
"Central Kanuri Latin": "knc_Latn",
|
||||
"Kazakh": "kaz_Cyrl",
|
||||
"Kabiyè": "kbp_Latn",
|
||||
"Kabuverdianu": "kea_Latn",
|
||||
"Khmer": "khm_Khmr",
|
||||
"Kikuyu": "kik_Latn",
|
||||
"Kinyarwanda": "kin_Latn",
|
||||
"Kyrgyz": "kir_Cyrl",
|
||||
"Kimbundu": "kmb_Latn",
|
||||
"Northern Kurdish": "kmr_Latn",
|
||||
"Kikongo": "kon_Latn",
|
||||
"Korean": "kor_Hang",
|
||||
"Lao": "lao_Laoo",
|
||||
"Ligurian": "lij_Latn",
|
||||
"Limburgish": "lim_Latn",
|
||||
"Lingala": "lin_Latn",
|
||||
"Lithuanian": "lit_Latn",
|
||||
"Lombard": "lmo_Latn",
|
||||
"Latgalian": "ltg_Latn",
|
||||
"Luxembourgish": "ltz_Latn",
|
||||
"Luba-Kasai": "lua_Latn",
|
||||
"Ganda": "lug_Latn",
|
||||
"Luo": "luo_Latn",
|
||||
"Mizo": "lus_Latn",
|
||||
"Standard Latvian": "lvs_Latn",
|
||||
"Magahi": "mag_Deva",
|
||||
"Maithili": "mai_Deva",
|
||||
"Malayalam": "mal_Mlym",
|
||||
"Marathi": "mar_Deva",
|
||||
"Minangkabau Arabic ": "min_Arab",
|
||||
"Minangkabau Latin": "min_Latn",
|
||||
"Macedonian": "mkd_Cyrl",
|
||||
"Plateau Malagasy": "plt_Latn",
|
||||
"Maltese": "mlt_Latn",
|
||||
"Meitei Bengali": "mni_Beng",
|
||||
"Halh Mongolian": "khk_Cyrl",
|
||||
"Mossi": "mos_Latn",
|
||||
"Maori": "mri_Latn",
|
||||
"Burmese": "mya_Mymr",
|
||||
"Dutch": "nld_Latn",
|
||||
"Norwegian Nynorsk": "nno_Latn",
|
||||
"Norwegian Bokmål": "nob_Latn",
|
||||
"Nepali": "npi_Deva",
|
||||
"Northern Sotho": "nso_Latn",
|
||||
"Nuer": "nus_Latn",
|
||||
"Nyanja": "nya_Latn",
|
||||
"Occitan": "oci_Latn",
|
||||
"West Central Oromo": "gaz_Latn",
|
||||
"Odia": "ory_Orya",
|
||||
"Pangasinan": "pag_Latn",
|
||||
"Eastern Panjabi": "pan_Guru",
|
||||
"Papiamento": "pap_Latn",
|
||||
"Western Persian": "pes_Arab",
|
||||
"Polish": "pol_Latn",
|
||||
"Portuguese": "por_Latn",
|
||||
"Dari": "prs_Arab",
|
||||
"Southern Pashto": "pbt_Arab",
|
||||
"Ayacucho Quechua": "quy_Latn",
|
||||
"Romanian": "ron_Latn",
|
||||
"Rundi": "run_Latn",
|
||||
"Russian": "rus_Cyrl",
|
||||
"Sango": "sag_Latn",
|
||||
"Sanskrit": "san_Deva",
|
||||
"Santali": "sat_Olck",
|
||||
"Sicilian": "scn_Latn",
|
||||
"Shan": "shn_Mymr",
|
||||
"Sinhala": "sin_Sinh",
|
||||
"Slovak": "slk_Latn",
|
||||
"Slovenian": "slv_Latn",
|
||||
"Samoan": "smo_Latn",
|
||||
"Shona": "sna_Latn",
|
||||
"Sindhi": "snd_Arab",
|
||||
"Somali": "som_Latn",
|
||||
"Southern Sotho": "sot_Latn",
|
||||
"Spanish": "spa_Latn",
|
||||
"Tosk Albanian": "als_Latn",
|
||||
"Sardinian": "srd_Latn",
|
||||
"Serbian": "srp_Cyrl",
|
||||
"Swati": "ssw_Latn",
|
||||
"Sundanese": "sun_Latn",
|
||||
"Swedish": "swe_Latn",
|
||||
"Swahili": "swh_Latn",
|
||||
"Silesian": "szl_Latn",
|
||||
"Tamil": "tam_Taml",
|
||||
"Tatar": "tat_Cyrl",
|
||||
"Telugu": "tel_Telu",
|
||||
"Tajik": "tgk_Cyrl",
|
||||
"Tagalog": "tgl_Latn",
|
||||
"Thai": "tha_Thai",
|
||||
"Tigrinya": "tir_Ethi",
|
||||
"Tamasheq Latin": "taq_Latn",
|
||||
"Tamasheq Tifinagh": "taq_Tfng",
|
||||
"Tok Pisin": "tpi_Latn",
|
||||
"Tswana": "tsn_Latn",
|
||||
"Tsonga": "tso_Latn",
|
||||
"Turkmen": "tuk_Latn",
|
||||
"Tumbuka": "tum_Latn",
|
||||
"Turkish": "tur_Latn",
|
||||
"Twi": "twi_Latn",
|
||||
"Central Atlas Tamazight": "tzm_Tfng",
|
||||
"Uyghur": "uig_Arab",
|
||||
"Ukrainian": "ukr_Cyrl",
|
||||
"Umbundu": "umb_Latn",
|
||||
"Urdu": "urd_Arab",
|
||||
"Northern Uzbek": "uzn_Latn",
|
||||
"Venetian": "vec_Latn",
|
||||
"Vietnamese": "vie_Latn",
|
||||
"Waray": "war_Latn",
|
||||
"Wolof": "wol_Latn",
|
||||
"Xhosa": "xho_Latn",
|
||||
"Eastern Yiddish": "ydd_Hebr",
|
||||
"Yoruba": "yor_Latn",
|
||||
"Yue Chinese": "yue_Hant",
|
||||
"Chinese Simplified": "zho_Hans",
|
||||
"Chinese Traditional": "zho_Hant",
|
||||
"Standard Malay": "zsm_Latn",
|
||||
"Zulu": "zul_Latn",
|
||||
}
|
||||
|
||||
|
||||
class TranslationTool(PipelineTool):
|
||||
"""
|
||||
Example:
|
||||
|
||||
```py
|
||||
from transformers.tools import TranslationTool
|
||||
|
||||
translator = TranslationTool()
|
||||
translator("This is a super nice API!", src_lang="English", tgt_lang="French")
|
||||
```
|
||||
"""
|
||||
|
||||
default_checkpoint = "facebook/nllb-200-distilled-600M"
|
||||
description = (
|
||||
"This is a tool that translates text from a language to another. It takes three inputs: `text`, which should "
|
||||
"be the text to translate, `src_lang`, which should be the language of the text to translate and `tgt_lang`, "
|
||||
"which should be the language for the desired ouput language. Both `src_lang` and `tgt_lang` are written in "
|
||||
"plain English, such as 'Romanian', or 'Albanian'. It returns the text translated in `tgt_lang`."
|
||||
)
|
||||
name = "translator"
|
||||
pre_processor_class = AutoTokenizer
|
||||
model_class = AutoModelForSeq2SeqLM
|
||||
lang_to_code = LANGUAGE_CODES
|
||||
|
||||
inputs = ["text", "text", "text"]
|
||||
outputs = ["text"]
|
||||
|
||||
def encode(self, text, src_lang, tgt_lang):
|
||||
if src_lang not in self.lang_to_code:
|
||||
raise ValueError(f"{src_lang} is not a supported language.")
|
||||
if tgt_lang not in self.lang_to_code:
|
||||
raise ValueError(f"{tgt_lang} is not a supported language.")
|
||||
src_lang = self.lang_to_code[src_lang]
|
||||
tgt_lang = self.lang_to_code[tgt_lang]
|
||||
return self.pre_processor._build_translation_inputs(
|
||||
text, return_tensors="pt", src_lang=src_lang, tgt_lang=tgt_lang
|
||||
)
|
||||
|
||||
def forward(self, inputs):
|
||||
return self.model.generate(**inputs)
|
||||
|
||||
def decode(self, outputs):
|
||||
return self.post_processor.decode(outputs[0].tolist(), skip_special_tokens=True)
|
|
@ -121,6 +121,7 @@ from .import_utils import (
|
|||
is_natten_available,
|
||||
is_ninja_available,
|
||||
is_onnx_available,
|
||||
is_openai_available,
|
||||
is_optimum_available,
|
||||
is_pandas_available,
|
||||
is_peft_available,
|
||||
|
|
|
@ -235,6 +235,7 @@ def try_to_load_from_cache(
|
|||
filename: str,
|
||||
cache_dir: Union[str, Path, None] = None,
|
||||
revision: Optional[str] = None,
|
||||
repo_type: Optional[str] = None,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Explores the cache to return the latest cached file for a given revision if found.
|
||||
|
@ -251,6 +252,8 @@ def try_to_load_from_cache(
|
|||
revision (`str`, *optional*):
|
||||
The specific model version to use. Will default to `"main"` if it's not provided and no `commit_hash` is
|
||||
provided either.
|
||||
repo_type (`str`, *optional*):
|
||||
The type of the repo.
|
||||
|
||||
Returns:
|
||||
`Optional[str]` or `_CACHED_NO_EXIST`:
|
||||
|
@ -266,7 +269,9 @@ def try_to_load_from_cache(
|
|||
cache_dir = TRANSFORMERS_CACHE
|
||||
|
||||
object_id = repo_id.replace("/", "--")
|
||||
repo_cache = os.path.join(cache_dir, f"models--{object_id}")
|
||||
if repo_type is None:
|
||||
repo_type = "model"
|
||||
repo_cache = os.path.join(cache_dir, f"{repo_type}s--{object_id}")
|
||||
if not os.path.isdir(repo_cache):
|
||||
# No cache for this model
|
||||
return None
|
||||
|
@ -303,6 +308,7 @@ def cached_file(
|
|||
revision: Optional[str] = None,
|
||||
local_files_only: bool = False,
|
||||
subfolder: str = "",
|
||||
repo_type: Optional[str] = None,
|
||||
user_agent: Optional[Union[str, Dict[str, str]]] = None,
|
||||
_raise_exceptions_for_missing_entries: bool = True,
|
||||
_raise_exceptions_for_connection_errors: bool = True,
|
||||
|
@ -342,6 +348,8 @@ def cached_file(
|
|||
subfolder (`str`, *optional*, defaults to `""`):
|
||||
In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
|
||||
specify the folder name here.
|
||||
repo_type (`str`, *optional*):
|
||||
Specify the repo type (useful when downloading from a space for instance).
|
||||
|
||||
<Tip>
|
||||
|
||||
|
@ -393,7 +401,7 @@ def cached_file(
|
|||
if _commit_hash is not None and not force_download:
|
||||
# If the file is cached under that commit hash, we return it directly.
|
||||
resolved_file = try_to_load_from_cache(
|
||||
path_or_repo_id, full_filename, cache_dir=cache_dir, revision=_commit_hash
|
||||
path_or_repo_id, full_filename, cache_dir=cache_dir, revision=_commit_hash, repo_type=repo_type
|
||||
)
|
||||
if resolved_file is not None:
|
||||
if resolved_file is not _CACHED_NO_EXIST:
|
||||
|
@ -410,6 +418,7 @@ def cached_file(
|
|||
path_or_repo_id,
|
||||
filename,
|
||||
subfolder=None if len(subfolder) == 0 else subfolder,
|
||||
repo_type=repo_type,
|
||||
revision=revision,
|
||||
cache_dir=cache_dir,
|
||||
user_agent=user_agent,
|
||||
|
|
|
@ -125,6 +125,14 @@ except importlib_metadata.PackageNotFoundError:
|
|||
_datasets_available = False
|
||||
|
||||
|
||||
_diffusers_available = importlib.util.find_spec("diffusers") is not None
|
||||
try:
|
||||
_diffusers_version = importlib_metadata.version("diffusers")
|
||||
logger.debug(f"Successfully imported diffusers version {_diffusers_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_diffusers_available = False
|
||||
|
||||
|
||||
_detectron2_available = importlib.util.find_spec("detectron2") is not None
|
||||
try:
|
||||
_detectron2_version = importlib_metadata.version("detectron2")
|
||||
|
@ -185,6 +193,9 @@ except importlib_metadata.PackageNotFoundError:
|
|||
_onnx_available = False
|
||||
|
||||
|
||||
_opencv_available = importlib.util.find_spec("cv2") is not None
|
||||
|
||||
|
||||
_pytorch_quantization_available = importlib.util.find_spec("pytorch_quantization") is not None
|
||||
try:
|
||||
_pytorch_quantization_version = importlib_metadata.version("pytorch_quantization")
|
||||
|
@ -431,6 +442,10 @@ def is_onnx_available():
|
|||
return _onnx_available
|
||||
|
||||
|
||||
def is_openai_available():
|
||||
return importlib.util.find_spec("openai") is not None
|
||||
|
||||
|
||||
def is_flax_available():
|
||||
return _flax_available
|
||||
|
||||
|
|
|
@ -0,0 +1,57 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from datasets import load_dataset
|
||||
|
||||
from transformers import load_tool
|
||||
|
||||
from .test_tools_common import ToolTesterMixin
|
||||
|
||||
|
||||
class DocumentQuestionAnsweringToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
def setUp(self):
|
||||
self.tool = load_tool("document-question-answering")
|
||||
self.tool.setup()
|
||||
self.remote_tool = load_tool("document-question-answering", remote=True)
|
||||
|
||||
def test_exact_match_arg(self):
|
||||
dataset = load_dataset("hf-internal-testing/example-documents", split="test")
|
||||
image = dataset[0]["image"]
|
||||
|
||||
result = self.tool(image, "When is the coffee break?")
|
||||
self.assertEqual(result, "11-14 to 11:39 a.m.")
|
||||
|
||||
def test_exact_match_arg_remote(self):
|
||||
dataset = load_dataset("hf-internal-testing/example-documents", split="test")
|
||||
image = dataset[0]["image"]
|
||||
|
||||
result = self.remote_tool(image, "When is the coffee break?")
|
||||
self.assertEqual(result, "11-14 to 11:39 a.m.")
|
||||
|
||||
def test_exact_match_kwarg(self):
|
||||
dataset = load_dataset("hf-internal-testing/example-documents", split="test")
|
||||
image = dataset[0]["image"]
|
||||
|
||||
result = self.tool(image=image, question="When is the coffee break?")
|
||||
self.assertEqual(result, "11-14 to 11:39 a.m.")
|
||||
|
||||
def test_exact_match_kwarg_remote(self):
|
||||
dataset = load_dataset("hf-internal-testing/example-documents", split="test")
|
||||
image = dataset[0]["image"]
|
||||
|
||||
result = self.remote_tool(image=image, question="When is the coffee break?")
|
||||
self.assertEqual(result, "11-14 to 11:39 a.m.")
|
|
@ -0,0 +1,53 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from transformers import is_vision_available, load_tool
|
||||
from transformers.testing_utils import get_tests_dir
|
||||
|
||||
from .test_tools_common import ToolTesterMixin
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class ImageCaptioningToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
def setUp(self):
|
||||
self.tool = load_tool("image-captioning")
|
||||
self.tool.setup()
|
||||
self.remote_tool = load_tool("image-captioning", remote=True)
|
||||
|
||||
def test_exact_match_arg(self):
|
||||
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png")
|
||||
result = self.tool(image)
|
||||
self.assertEqual(result, "two cats sleeping on a couch")
|
||||
|
||||
def test_exact_match_arg_remote(self):
|
||||
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png")
|
||||
result = self.remote_tool(image)
|
||||
self.assertEqual(result, "two cats sleeping on a couch")
|
||||
|
||||
def test_exact_match_kwarg(self):
|
||||
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png")
|
||||
result = self.tool(image=image)
|
||||
self.assertEqual(result, "two cats sleeping on a couch")
|
||||
|
||||
def test_exact_match_kwarg_remote(self):
|
||||
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png")
|
||||
result = self.remote_tool(image=image)
|
||||
self.assertEqual(result, "two cats sleeping on a couch")
|
|
@ -0,0 +1,53 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from transformers import is_vision_available, load_tool
|
||||
from transformers.testing_utils import get_tests_dir
|
||||
|
||||
from .test_tools_common import ToolTesterMixin
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class ImageQuestionAnsweringToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
def setUp(self):
|
||||
self.tool = load_tool("image-question-answering")
|
||||
self.tool.setup()
|
||||
self.remote_tool = load_tool("image-question-answering", remote=True)
|
||||
|
||||
def test_exact_match_arg(self):
|
||||
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png")
|
||||
result = self.tool(image, "How many cats are sleeping on the couch?")
|
||||
self.assertEqual(result, "2")
|
||||
|
||||
def test_exact_match_arg_remote(self):
|
||||
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png")
|
||||
result = self.remote_tool(image, "How many cats are sleeping on the couch?")
|
||||
self.assertEqual(result, "2")
|
||||
|
||||
def test_exact_match_kwarg(self):
|
||||
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png")
|
||||
result = self.tool(image=image, question="How many cats are sleeping on the couch?")
|
||||
self.assertEqual(result, "2")
|
||||
|
||||
def test_exact_match_kwarg_remote(self):
|
||||
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png")
|
||||
result = self.remote_tool(image=image, question="How many cats are sleeping on the couch?")
|
||||
self.assertEqual(result, "2")
|
|
@ -0,0 +1,53 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from transformers import is_vision_available, load_tool
|
||||
from transformers.testing_utils import get_tests_dir
|
||||
|
||||
from .test_tools_common import ToolTesterMixin
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class ImageSegmentationToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
def setUp(self):
|
||||
self.tool = load_tool("image-segmentation")
|
||||
self.tool.setup()
|
||||
self.remote_tool = load_tool("image-segmentation", remote=True)
|
||||
|
||||
def test_exact_match_arg(self):
|
||||
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png").resize((512, 512))
|
||||
result = self.tool(image, "cat")
|
||||
self.assertTrue(isinstance(result, Image.Image))
|
||||
|
||||
def test_exact_match_arg_remote(self):
|
||||
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png").resize((512, 512))
|
||||
result = self.remote_tool(image, "cat")
|
||||
self.assertTrue(isinstance(result, Image.Image))
|
||||
|
||||
def test_exact_match_kwarg(self):
|
||||
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png").resize((512, 512))
|
||||
result = self.tool(image=image, prompt="cat")
|
||||
self.assertTrue(isinstance(result, Image.Image))
|
||||
|
||||
def test_exact_match_kwarg_remote(self):
|
||||
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png").resize((512, 512))
|
||||
result = self.remote_tool(image=image, prompt="cat")
|
||||
self.assertTrue(isinstance(result, Image.Image))
|
|
@ -0,0 +1,124 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers.testing_utils import CaptureStdout
|
||||
from transformers.tools.python_interpreter import evaluate
|
||||
|
||||
|
||||
# Fake function we will use as tool
|
||||
def add_two(x):
|
||||
return x + 2
|
||||
|
||||
|
||||
class PythonInterpreterTester(unittest.TestCase):
|
||||
def test_evaluate_assign(self):
|
||||
code = "x = 3"
|
||||
state = {}
|
||||
result = evaluate(code, {}, state=state)
|
||||
assert result == 3
|
||||
self.assertDictEqual(state, {"x": 3})
|
||||
|
||||
code = "x = y"
|
||||
state = {"y": 5}
|
||||
result = evaluate(code, {}, state=state)
|
||||
# evaluate returns the value of the last assignment.
|
||||
assert result == 5
|
||||
self.assertDictEqual(state, {"x": 5, "y": 5})
|
||||
|
||||
def test_evaluate_call(self):
|
||||
code = "y = add_two(x)"
|
||||
state = {"x": 3}
|
||||
result = evaluate(code, {"add_two": add_two}, state=state)
|
||||
assert result == 5
|
||||
self.assertDictEqual(state, {"x": 3, "y": 5})
|
||||
|
||||
# Won't work without the tool
|
||||
with CaptureStdout() as out:
|
||||
result = evaluate(code, {}, state=state)
|
||||
assert result is None
|
||||
assert "tried to execute add_two" in out.out
|
||||
|
||||
def test_evaluate_constant(self):
|
||||
code = "x = 3"
|
||||
state = {}
|
||||
result = evaluate(code, {}, state=state)
|
||||
assert result == 3
|
||||
self.assertDictEqual(state, {"x": 3})
|
||||
|
||||
def test_evaluate_dict(self):
|
||||
code = "test_dict = {'x': x, 'y': add_two(x)}"
|
||||
state = {"x": 3}
|
||||
result = evaluate(code, {"add_two": add_two}, state=state)
|
||||
self.assertDictEqual(result, {"x": 3, "y": 5})
|
||||
self.assertDictEqual(state, {"x": 3, "test_dict": {"x": 3, "y": 5}})
|
||||
|
||||
def test_evaluate_expression(self):
|
||||
code = "x = 3\ny = 5"
|
||||
state = {}
|
||||
result = evaluate(code, {}, state=state)
|
||||
# evaluate returns the value of the last assignment.
|
||||
assert result == 5
|
||||
self.assertDictEqual(state, {"x": 3, "y": 5})
|
||||
|
||||
def test_evaluate_f_string(self):
|
||||
code = "text = f'This is x: {x}.'"
|
||||
state = {"x": 3}
|
||||
result = evaluate(code, {}, state=state)
|
||||
# evaluate returns the value of the last assignment.
|
||||
assert result == "This is x: 3."
|
||||
self.assertDictEqual(state, {"x": 3, "text": "This is x: 3."})
|
||||
|
||||
def test_evaluate_if(self):
|
||||
code = "if x <= 3:\n y = 2\nelse:\n y = 5"
|
||||
state = {"x": 3}
|
||||
result = evaluate(code, {}, state=state)
|
||||
# evaluate returns the value of the last assignment.
|
||||
assert result == 2
|
||||
self.assertDictEqual(state, {"x": 3, "y": 2})
|
||||
|
||||
state = {"x": 8}
|
||||
result = evaluate(code, {}, state=state)
|
||||
# evaluate returns the value of the last assignment.
|
||||
assert result == 5
|
||||
self.assertDictEqual(state, {"x": 8, "y": 5})
|
||||
|
||||
def test_evaluate_list(self):
|
||||
code = "test_list = [x, add_two(x)]"
|
||||
state = {"x": 3}
|
||||
result = evaluate(code, {"add_two": add_two}, state=state)
|
||||
self.assertListEqual(result, [3, 5])
|
||||
self.assertDictEqual(state, {"x": 3, "test_list": [3, 5]})
|
||||
|
||||
def test_evaluate_name(self):
|
||||
code = "y = x"
|
||||
state = {"x": 3}
|
||||
result = evaluate(code, {}, state=state)
|
||||
assert result == 3
|
||||
self.assertDictEqual(state, {"x": 3, "y": 3})
|
||||
|
||||
def test_evaluate_subscript(self):
|
||||
code = "test_list = [x, add_two(x)]\ntest_list[1]"
|
||||
state = {"x": 3}
|
||||
result = evaluate(code, {"add_two": add_two}, state=state)
|
||||
assert result == 5
|
||||
self.assertDictEqual(state, {"x": 3, "test_list": [3, 5]})
|
||||
|
||||
code = "test_dict = {'x': x, 'y': add_two(x)}\ntest_dict['y']"
|
||||
state = {"x": 3}
|
||||
result = evaluate(code, {"add_two": add_two}, state=state)
|
||||
assert result == 5
|
||||
self.assertDictEqual(state, {"x": 3, "test_dict": {"x": 3, "y": 5}})
|
|
@ -0,0 +1,38 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import is_torch_available, load_tool
|
||||
|
||||
from .test_tools_common import ToolTesterMixin
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
class SpeechToTextToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
def setUp(self):
|
||||
self.tool = load_tool("speech-to-text")
|
||||
self.tool.setup()
|
||||
|
||||
def test_exact_match_arg(self):
|
||||
result = self.tool(torch.ones(3000))
|
||||
self.assertEqual(result, " you")
|
||||
|
||||
def test_exact_match_kwarg(self):
|
||||
result = self.tool(audio=torch.ones(3000))
|
||||
self.assertEqual(result, " you")
|
|
@ -0,0 +1,43 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import load_tool
|
||||
|
||||
from .test_tools_common import ToolTesterMixin
|
||||
|
||||
|
||||
class TextClassificationToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
def setUp(self):
|
||||
self.tool = load_tool("text-classification")
|
||||
self.tool.setup()
|
||||
self.remote_tool = load_tool("text-classification", remote=True)
|
||||
|
||||
def test_exact_match_arg(self):
|
||||
result = self.tool("That's quite cool", ["positive", "negative"])
|
||||
self.assertEqual(result, "positive")
|
||||
|
||||
def test_exact_match_arg_remote(self):
|
||||
result = self.remote_tool("That's quite cool", ["positive", "negative"])
|
||||
self.assertEqual(result, "positive")
|
||||
|
||||
def test_exact_match_kwarg(self):
|
||||
result = self.tool(text="That's quite cool", labels=["positive", "negative"])
|
||||
self.assertEqual(result, "positive")
|
||||
|
||||
def test_exact_match_kwarg_remote(self):
|
||||
result = self.remote_tool(text="That's quite cool", labels=["positive", "negative"])
|
||||
self.assertEqual(result, "positive")
|
|
@ -0,0 +1,52 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import load_tool
|
||||
|
||||
from .test_tools_common import ToolTesterMixin
|
||||
|
||||
|
||||
TEXT = """
|
||||
Hugging Face was founded in 2016 by French entrepreneurs Clément Delangue, Julien Chaumond, and Thomas Wolf originally as a company that developed a chatbot app targeted at teenagers.[2] After open-sourcing the model behind the chatbot, the company pivoted to focus on being a platform for machine learning.
|
||||
|
||||
In March 2021, Hugging Face raised $40 million in a Series B funding round.[3]
|
||||
|
||||
On April 28, 2021, the company launched the BigScience Research Workshop in collaboration with several other research groups to release an open large language model.[4] In 2022, the workshop concluded with the announcement of BLOOM, a multilingual large language model with 176 billion parameters.[5]
|
||||
"""
|
||||
|
||||
|
||||
class TextQuestionAnsweringToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
def setUp(self):
|
||||
self.tool = load_tool("text-question-answering")
|
||||
self.tool.setup()
|
||||
self.remote_tool = load_tool("text-question-answering", remote=True)
|
||||
|
||||
def test_exact_match_arg(self):
|
||||
result = self.tool(TEXT, "What did Hugging Face do in April 2021?")
|
||||
self.assertEqual(result, "launched the BigScience Research Workshop")
|
||||
|
||||
def test_exact_match_arg_remote(self):
|
||||
result = self.remote_tool(TEXT, "What did Hugging Face do in April 2021?")
|
||||
self.assertEqual(result, "launched the BigScience Research Workshop")
|
||||
|
||||
def test_exact_match_kwarg(self):
|
||||
result = self.tool(text=TEXT, question="What did Hugging Face do in April 2021?")
|
||||
self.assertEqual(result, "launched the BigScience Research Workshop")
|
||||
|
||||
def test_exact_match_kwarg_remote(self):
|
||||
result = self.remote_tool(text=TEXT, question="What did Hugging Face do in April 2021?")
|
||||
self.assertEqual(result, "launched the BigScience Research Workshop")
|
|
@ -0,0 +1,64 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import load_tool
|
||||
|
||||
from .test_tools_common import ToolTesterMixin
|
||||
|
||||
|
||||
TEXT = """
|
||||
Hugging Face was founded in 2016 by French entrepreneurs Clément Delangue, Julien Chaumond, and Thomas Wolf originally as a company that developed a chatbot app targeted at teenagers.[2] After open-sourcing the model behind the chatbot, the company pivoted to focus on being a platform for machine learning.
|
||||
|
||||
In March 2021, Hugging Face raised $40 million in a Series B funding round.[3]
|
||||
|
||||
On April 28, 2021, the company launched the BigScience Research Workshop in collaboration with several other research groups to release an open large language model.[4] In 2022, the workshop concluded with the announcement of BLOOM, a multilingual large language model with 176 billion parameters.[5]
|
||||
"""
|
||||
|
||||
|
||||
class TextSummarizationToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
def setUp(self):
|
||||
self.tool = load_tool("summarization")
|
||||
self.tool.setup()
|
||||
self.remote_tool = load_tool("summarization", remote=True)
|
||||
|
||||
def test_exact_match_arg(self):
|
||||
result = self.tool(TEXT)
|
||||
self.assertEqual(
|
||||
result,
|
||||
"Hugging Face was founded in 2016 by French entrepreneurs Clément Delangue, Julien Chaumond, and Thomas Wolf. In March 2021, Hugging Face raised $40 million in a Series B funding round. On April 28, 2021, the company launched the BigScience Research Workshop in collaboration with several other research groups to release an open large language model. In 2022, the workshop concluded with the announcement of BLOOM.",
|
||||
)
|
||||
|
||||
def test_exact_match_arg_remote(self):
|
||||
result = self.remote_tool(TEXT)
|
||||
self.assertEqual(
|
||||
result,
|
||||
"Hugging Face was founded in 2016 by French entrepreneurs Clément Delangue, Julien Chaumond, and Thomas Wolf. In March 2021, Hugging Face raised $40 million in a Series B funding round. On April 28, 2021, the company launched the BigScience Research Workshop in collaboration with several other research groups to release an open large language model. In 2022, the workshop concluded with the announcement of BLOOM.",
|
||||
)
|
||||
|
||||
def test_exact_match_kwarg(self):
|
||||
result = self.tool(text=TEXT)
|
||||
self.assertEqual(
|
||||
result,
|
||||
"Hugging Face was founded in 2016 by French entrepreneurs Clément Delangue, Julien Chaumond, and Thomas Wolf. In March 2021, Hugging Face raised $40 million in a Series B funding round. On April 28, 2021, the company launched the BigScience Research Workshop in collaboration with several other research groups to release an open large language model. In 2022, the workshop concluded with the announcement of BLOOM.",
|
||||
)
|
||||
|
||||
def test_exact_match_kwarg_remote(self):
|
||||
result = self.remote_tool(text=TEXT)
|
||||
self.assertEqual(
|
||||
result,
|
||||
"Hugging Face was founded in 2016 by French entrepreneurs Clément Delangue, Julien Chaumond, and Thomas Wolf. In March 2021, Hugging Face raised $40 million in a Series B funding round. On April 28, 2021, the company launched the BigScience Research Workshop in collaboration with several other research groups to release an open large language model. In 2022, the workshop concluded with the announcement of BLOOM.",
|
||||
)
|
|
@ -0,0 +1,54 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import load_tool
|
||||
from transformers.utils import is_torch_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers.testing_utils import require_torch
|
||||
|
||||
from .test_tools_common import ToolTesterMixin
|
||||
|
||||
|
||||
@require_torch
|
||||
class TextToSpeechToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
def setUp(self):
|
||||
self.tool = load_tool("text-to-speech")
|
||||
self.tool.setup()
|
||||
|
||||
def test_exact_match_arg(self):
|
||||
# SpeechT5 isn't deterministic
|
||||
torch.manual_seed(0)
|
||||
result = self.tool("hey")
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
result[:3], torch.tensor([-0.00040140701457858086, -0.0002551682700868696, -0.00010294507956132293])
|
||||
)
|
||||
)
|
||||
|
||||
def test_exact_match_kwarg(self):
|
||||
# SpeechT5 isn't deterministic
|
||||
torch.manual_seed(0)
|
||||
result = self.tool("hey")
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
result[:3], torch.tensor([-0.00040140701457858086, -0.0002551682700868696, -0.00010294507956132293])
|
||||
)
|
||||
)
|
|
@ -0,0 +1,100 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from transformers import is_torch_available, is_vision_available
|
||||
from transformers.testing_utils import get_tests_dir, is_tool_test
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
|
||||
authorized_types = ["text", "image", "audio"]
|
||||
|
||||
|
||||
def create_inputs(input_types: List[str]):
|
||||
inputs = []
|
||||
|
||||
for input_type in input_types:
|
||||
if input_type == "text":
|
||||
inputs.append("Text input")
|
||||
elif input_type == "image":
|
||||
inputs.append(
|
||||
Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png").resize((512, 512))
|
||||
)
|
||||
elif input_type == "audio":
|
||||
inputs.append(torch.ones(3000))
|
||||
elif isinstance(input_type, list):
|
||||
inputs.append(create_inputs(input_type))
|
||||
else:
|
||||
raise ValueError(f"Invalid type requested: {input_type}")
|
||||
|
||||
return inputs
|
||||
|
||||
|
||||
def output_types(outputs: List):
|
||||
output_types = []
|
||||
|
||||
for output in outputs:
|
||||
if isinstance(output, str):
|
||||
output_types.append("text")
|
||||
elif isinstance(output, Image.Image):
|
||||
output_types.append("image")
|
||||
elif isinstance(output, torch.Tensor):
|
||||
output_types.append("audio")
|
||||
else:
|
||||
raise ValueError(f"Invalid output: {output}")
|
||||
|
||||
return output_types
|
||||
|
||||
|
||||
@is_tool_test
|
||||
class ToolTesterMixin:
|
||||
def test_inputs_outputs(self):
|
||||
self.assertTrue(hasattr(self.tool, "inputs"))
|
||||
self.assertTrue(hasattr(self.tool, "outputs"))
|
||||
|
||||
inputs = self.tool.inputs
|
||||
for _input in inputs:
|
||||
if isinstance(_input, list):
|
||||
for __input in _input:
|
||||
self.assertTrue(__input in authorized_types)
|
||||
else:
|
||||
self.assertTrue(_input in authorized_types)
|
||||
|
||||
outputs = self.tool.outputs
|
||||
for _output in outputs:
|
||||
self.assertTrue(_output in authorized_types)
|
||||
|
||||
def test_call(self):
|
||||
inputs = create_inputs(self.tool.inputs)
|
||||
outputs = self.tool(*inputs)
|
||||
|
||||
# There is a single output
|
||||
if len(self.tool.outputs) == 1:
|
||||
outputs = [outputs]
|
||||
|
||||
self.assertListEqual(output_types(outputs), self.tool.outputs)
|
||||
|
||||
def test_common_attributes(self):
|
||||
self.assertTrue(hasattr(self.tool, "description"))
|
||||
self.assertTrue(hasattr(self.tool, "default_checkpoint"))
|
||||
self.assertTrue(self.tool.description.startswith("This is a tool that"))
|
|
@ -0,0 +1,53 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import load_tool
|
||||
|
||||
from .test_tools_common import ToolTesterMixin, output_types
|
||||
|
||||
|
||||
class TranslationToolTester(unittest.TestCase, ToolTesterMixin):
|
||||
def setUp(self):
|
||||
self.tool = load_tool("translation")
|
||||
self.tool.setup()
|
||||
self.remote_tool = load_tool("translation", remote=True)
|
||||
|
||||
def test_exact_match_arg(self):
|
||||
result = self.tool("Hey, what's up?", src_lang="English", tgt_lang="French")
|
||||
self.assertEqual(result, "- Hé, comment ça va?")
|
||||
|
||||
def test_exact_match_arg_remote(self):
|
||||
result = self.remote_tool("Hey, what's up?", src_lang="English", tgt_lang="French")
|
||||
self.assertEqual(result, "- Hé, comment ça va?")
|
||||
|
||||
def test_exact_match_kwarg(self):
|
||||
result = self.tool(text="Hey, what's up?", src_lang="English", tgt_lang="French")
|
||||
self.assertEqual(result, "- Hé, comment ça va?")
|
||||
|
||||
def test_exact_match_kwarg_remote(self):
|
||||
result = self.remote_tool(text="Hey, what's up?", src_lang="English", tgt_lang="French")
|
||||
self.assertEqual(result, "- Hé, comment ça va?")
|
||||
|
||||
def test_call(self):
|
||||
inputs = ["Hey, what's up?", "English", "Spanish"]
|
||||
outputs = self.tool(*inputs)
|
||||
|
||||
# There is a single output
|
||||
if len(self.tool.outputs) == 1:
|
||||
outputs = [outputs]
|
||||
|
||||
self.assertListEqual(output_types(outputs), self.tool.outputs)
|
Loading…
Reference in New Issue