Test composition (#23214)

* Remove nestedness in tool config

* Really do it

* Use remote tools descriptions

* Work

* Clean up eval

* Changes

* Tools

* Tools

* tool

* Fix everything

* Use last result/assign for evaluation

* Prompt

* Remove hardcoded selection

* Evaluation for chat agents

* correct some spelling

* Small fixes

* Change summarization model (#23172)

* Fix link displayed

* Update description of the tool

* Fixes in chat prompt

* Custom tools, custom prompt

* Tool clean up

* save_pretrained and push_to_hub for tool

* Fix init

* Tests

* Fix tests

* Tool save/from_hub/push_to_hub and tool->load_tool

* Clean push_to_hub and add app file

* Custom inference API for endpoints too

* Clean up

* old remote tool and new remote tool

* Make a requirements

* return_code adds tool creation

* Avoid redundancy between global variables

* Remote tools can be loaded

* Tests

* Text summarization tests

* Quality

* Properly mark tests

* Test the python interpreter

* And the CI shall be green.

* fix loading of additional tools

* Work on RemoteTool and fix tests

* General clean up

* Guard imports

* Fix tools

* docs: Fix broken link in 'How to add a model...'  (#23216)

fix link

* Get default endpoint from the Hub

* Add guide

* Simplify tool config

* Docs

* Some fixes

* Docs

* Docs

* Docs

* Fix code returned by agent

* Try this

* Match args with signature in remote tool

* Should fix python interpreter for Python 3.8

* Fix push_to_hub for tools

* Other fixes to push_to_hub

* Add API doc page

* Docs

* Docs

* Custom tools

* Pin tensorflow-probability (#23220)

* Pin tensorflow-probability

* [all-test]

* [all-test] Fix syntax for bash

* PoC for some chaining API

* Text to speech

* J'ai pris des libertés

* Rename

* Basic python interpreter

* Add agents

* Quality

* Add translation tool

* temp

* GenQA + LID + S2T

* Quality + word missing in translation

* Add open assistance, support f-strings in evaluate

* captioning + s2t fixes

* Style

* Refactor descriptions and remove chain

* Support errors and rename OpenAssistantAgent

* Add setup

* Deal with typos + example of inference API

* Some rename + README

* Fixes

* Update prompt

* Unwanted change

* Make sure everyone has a default

* One prompt to rule them all.

* SD

* Description

* Clean up remote tools

* More remote tools

* Add option to return code and update doc

* Image segmentation

* ControlNet

* Gradio demo

* Diffusers protection

* Lib protection

* ControlNet description

* Cleanup

* Style

* Remove accelerate and try to be reproducible

* No randomness

* Male Basic optional in token

* Clean description

* Better prompts

* Fix args eval in interpreter

* Add tool wrapper

* Tool on the Hub

* Style post-rebase

* Big refactor of descriptions, batch generation and evaluation for agents

* Make problems easier - interface to debug

* More problems, add python primitives

* Back to one prompt

* Remove dict for translation

* Be consistent

* Add prompts

* New version of the agent

* Evaluate new agents

* New endpoints agents

* Make all tools a dict variable

* Typo

* Add problems

* Add to big prompt

* Harmonize

* Add tools

* New evaluation

* Add more tools

* Build prompt with tools descriptions

* Tools on the Hub

* Let's chat!

* Cleanup

* Temporary bs4 safeguard

* Cache agents and clean up

* Blank init

* Fix evaluation for agents

* New format for tools on the Hub

* Add method to reset state

* Remove nestedness in tool config

* Really do it

* Use remote tools descriptions

* Work

* Clean up eval

* Changes

* Tools

* Tools

* tool

* Fix everything

* Use last result/assign for evaluation

* Prompt

* Remove hardcoded selection

* Evaluation for chat agents

* correct some spelling

* Small fixes

* Change summarization model (#23172)

* Fix link displayed

* Update description of the tool

* Fixes in chat prompt

* Custom tools, custom prompt

* Tool clean up

* save_pretrained and push_to_hub for tool

* Fix init

* Tests

* Fix tests

* Tool save/from_hub/push_to_hub and tool->load_tool

* Clean push_to_hub and add app file

* Custom inference API for endpoints too

* Clean up

* old remote tool and new remote tool

* Make a requirements

* return_code adds tool creation

* Avoid redundancy between global variables

* Remote tools can be loaded

* Tests

* Text summarization tests

* Quality

* Properly mark tests

* Test the python interpreter

* And the CI shall be green.

* Work on RemoteTool and fix tests

* fix loading of additional tools

* General clean up

* Guard imports

* Fix tools

* Get default endpoint from the Hub

* Simplify tool config

* Add guide

* Docs

* Some fixes

* Docs

* Docs

* Fix code returned by agent

* Try this

* Docs

* Match args with signature in remote tool

* Should fix python interpreter for Python 3.8

* Fix push_to_hub for tools

* Other fixes to push_to_hub

* Add API doc page

* Fixes

* Doc fixes

* Docs

* Fix audio

* Custom tools

* Audio fix

* Improve custom tools docstring

* Docstrings

* Trigger CI

* Mode docstrings

* More docstrings

* Improve custom tools

* Fix for remote tools

* Style

* Fix repo consistency

* Quality

* Tip

* Cleanup on doc

* Cleanup toc

* Add disclaimer for starcoder vs openai

* Remove disclaimer

* Small fixed in the prompts

* 4.29

* Update src/transformers/tools/agents.py

Co-authored-by: Lysandre Debut <lysandre.debut@reseau.eseo.fr>

* Complete documentation

* Small fixes

* Agent evaluation

* Note about gradio-tools & LC

* Clean up agents and prompt

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Note about gradio-tools & LC

* Add copyrights and address review comments

* Quality

* Add all language codes

* Add remote tool tests

* Move custom prompts to other docs

* Apply suggestions from code review

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* TTS tests

* Quality

---------

Co-authored-by: Lysandre <hi@lyand.re>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Philipp Schmid <32632186+philschmid@users.noreply.github.com>
Co-authored-by: Connor Henderson <connor.henderson@talkiatry.com>
Co-authored-by: Lysandre <lysandre.debut@reseau.eseo.fr>
Co-authored-by: Lysandre <lysandre@huggingface.co>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
Sylvain Gugger 2023-05-09 20:37:57 -04:00 committed by GitHub
parent 366a8ca09e
commit 3335724376
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
41 changed files with 4933 additions and 8 deletions

View File

@ -43,6 +43,7 @@ def pytest_configure(config):
)
config.addinivalue_line("markers", "is_staging_test: mark test to run only in the staging environment")
config.addinivalue_line("markers", "accelerate_tests: mark test that require accelerate")
config.addinivalue_line("markers", "tool_tests: mark the tool tests that are run on their specific schedule")
def pytest_addoption(parser):

View File

@ -21,6 +21,8 @@
title: Set up distributed training with 🤗 Accelerate
- local: model_sharing
title: Share your model
- local: transformers_agents
title: Agents
title: Tutorials
- sections:
- sections:
@ -99,6 +101,8 @@
title: Notebooks with examples
- local: community
title: Community resources
- local: custom_tools
title: Custom Tools
- local: troubleshooting
title: Troubleshoot
title: Developer guides
@ -179,6 +183,8 @@
title: Conceptual guides
- sections:
- sections:
- local: main_classes/agent
title: Agents and Tools
- local: model_doc/auto
title: Auto Classes
- local: main_classes/callback

View File

@ -0,0 +1,503 @@
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Custom Tools and Prompts
<Tip>
If you are not aware of what tools and agents are in the context of transformers, we recommend you read the
[Transformers Agents](transformers_agents) page first.
</Tip>
<Tip warning={true}>
Transformers Agent is an experimental API which is subject to change at any time. Results returned by the agents
can vary as the APIs or underlying models are prone to change.
</Tip>
Creating and using custom tools and prompts is paramount to empowering the agent and having it perform new tasks.
In this guide we'll take a look at:
- How to customize the prompt
- How to use custom tools
- How to create custom tools
## Customizing the prompt
As explained in [Transformers Agents](transformers_agents) agents can run in [`~Agent.run`] and [`~Agent.chat`] mode.
Both the run and chat mode underlie the same logic. The language model powering the agent is conditioned on a long prompt
and simply asked to complete the prompt by generating next tokens until the stop token is reached.
The only difference between the `run` and `chat` mode is that during the `chat` mode the prompt is extended with
previous user inputs and model generations, which seemingly gives the agent a memory and allows it to refer to
past interactions.
Let's take a closer look into how the prompt is structured to understand how it can be best customized.
The prompt is structured broadly into four parts.
- 1. Introduction: how the agent should behave, explanation of the concept of tools.
- 2. Description of all the tools. This is defined by a `<<all_tools>>` token that is dynamically replaced at runtime with the tools defined/chosen by the user.
- 3. A set of examples of tasks and their solution
- 4. Current example, and request for solution.
To better understand each part, let's look at a shortened version of how such a prompt can look like in practice.
```
I will ask you to perform a task, your job is to come up with a series of simple commands in Python that will perform the task.
[...]
You can print intermediate results if it makes sense to do so.
Tools:
- document_qa: This is a tool that answers a question about an document (pdf). It takes an input named `document` which should be the document containing the information, as well as a `question` that is the question about the document. It returns a text that contains the answer to the question.
- image_captioner: This is a tool that generates a description of an image. It takes an input named `image` which should be the image to caption, and returns a text that contains the description in English.
[...]
Task: "Answer the question in the variable `question` about the image stored in the variable `image`. The question is in French."
I will use the following tools: `translator` to translate the question into English and then `image_qa` to answer the question on the input image.
Answer:
```py
translated_question = translator(question=question, src_lang="French", tgt_lang="English")
print(f"The translated question is {translated_question}.")
answer = image_qa(image=image, question=translated_question)
print(f"The answer is {answer}")
```
Task: "Identify the oldest person in the `document` and create an image showcasing the result as a banner."
I will use the following tools: `document_qa` to find the oldest person in the document, then `image_generator` to generate an image according to the answer.
Answer:
```py
answer = document_qa(document, question="What is the oldest person?")
print(f"The answer is {answer}.")
image = image_generator("A banner showing " + answer)
```
[...]
Task: "Draw me a picture of rivers and lakes"
I will use the following
```
The first part explains precisely how the model shall behave and what it should do. This part
most likely does not need to be customized.
TODO(PVP) - explain better how the .description and .name influence the prompt
### Customizing the tool descriptions
The performance of the agent is directly linked to the prompt itself. We structure the prompt so that it works well
with what we intend for the agent to do; but for maximum customization we also offer the ability to specify a different prompt when instantiating the agent.
### Customizing the single-execution prompt
In order to specify a custom single-execution prompt, one would so the following:
```py
template = """ [...] """
agent = HfAgent(your_endpoint, run_prompt_template=template)
```
<Tip>
Please make sure to have the `<<all_tools>>` string defined somewhere in the `template` so that the agent can be aware
of the tools it has available to it.
</Tip>
#### Chat-execution prompt
In order to specify a custom single-execution prompt, one would so the following:
```
template = """ [...] """
agent = HfAgent(
url_endpoint=your_endpoint,
token=your_hf_token,
chat_prompt_template=template
)
```
<Tip>
Please make sure to have the `<<all_tools>>` string defined somewhere in the `template` so that the agent can be
aware of the tools it has available to it.
</Tip>
## Using custom tools
In this section, we'll be leveraging two existing custom tools that are specific to image generation:
- We replace [huggingface-tools/image-transformation](https://huggingface.co/spaces/huggingface-tools/image-transformation),
with [diffusers/controlnet-canny-tool](https://huggingface.co/spaces/diffusers/controlnet-canny-tool)
to allow for more image modifications.
- We add a new tool for image upscaling to the default toolbox:
[diffusers/latent-upscaler-tool](https://huggingface.co/spaces/diffusers/latent-upscaler-tool) replace the existing image-transformation tool.
We'll start by loading the custom tools with the convenient [`load_tool`] function:
```py
from transformers import load_tool
controlnet_transformer = load_tool("diffusers/controlnet-canny-tool")
upscaler = load_tool("diffusers/latent-upscaler-tool")
```
Upon adding custom tools to an agent, the tools' descriptions and names are automatically
included in the agents' prompts. Thus, it is imperative that custom tools have
a well-written description and name in order for the agent to understand how to use them.
Let's take a look at the description and name of `controlnet_transformer`:
```py
print(f"Description: '{controlnet_transformer.description}'")
print(f"Name: '{controlnet_transformer.name}'")
```
gives
```
Description: 'This is a tool that transforms an image with ControlNet according to a prompt.
It takes two inputs: `image`, which should be the image to transform, and `prompt`, which should be the prompt to use to change it. It returns the modified image.'
Name: 'image_transformer'
```
The name and description is accurate and fits the style of the [curated set of tools](./transformers_agents#a-curated-set-of-tools).
Next, let's instantiate an agent with `controlnet_transformer` and `upscaler`:
```py
tools = [controlnet_transformer, upscaler]
agent = HfAgent("https://api-inference.huggingface.co/models/bigcode/starcoder", additional_tools=tools)
```
This command should give you the following info:
```
image_transformer has been replaced by <transformers_modules.diffusers.controlnet-canny-tool.bd76182c7777eba9612fc03c0
8718a60c0aa6312.image_transformation.ControlNetTransformationTool object at 0x7f1d3bfa3a00> as provided in `additional_tools`
```
The set of curated tools already has a `image_transformer` tool which is hereby replaced with our custom tool.
<Tip>
Overwriting existing tools can be beneficial if we want to use a custom tool exactly for the same task as an existing tool
because the agent is well-versed in using the specific task. Beware that the custom tool should follow the exact same API
as the overwritten tool in this case.
</Tip>
The upscaler tool was given the name `image_upscaler` which is not yet present in the default toolbox and is therefore is simply added to the list of tools.
You can always have a look at the toolbox that is currently available to the agent via the `agent.toolbox` attribute:
```py
print("\n".join([f"- {a}" for a in agent.toolbox.keys()]))
```
```
- document_qa
- image_captioner
- image_qa
- image_segmenter
- transcriber
- summarizer
- text_classifier
- text_qa
- text_reader
- translator
- image_transformer
- text_downloader
- image_generator
- video_generator
- image_upscaler
```
Note how `image_upscaler` is now part of the agents' toolbox.
Let's now try out the new tools! We will re-use the image we generated in (Transformers Agents Quickstart)[./transformers_agents#single-execution-run].
```py
from diffusers.utils import load_image
image = load_image(
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rivers_and_lakes.png"
)
```
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rivers_and_lakes.png" width=200>
Let's transform the image into a beautiful winter landscape:
```py
image = agent.run("Transform the image: 'A frozen lake and snowy forest'", image=image)
```
```
==Explanation from the agent==
I will use the following tool: `image_transformer` to transform the image.
==Code generated by the agent==
image = image_transformer(image, prompt="A frozen lake and snowy forest")
```
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rivers_and_lakes_winter.png" width=200>
The new image processing tool is based on ControlNet which is can make very strong modifications to the image.
By default the image processing tool returns an image of size 512x512 pixels. Let's see if we can upscale it.
```py
image = agent.run("Upscale the image", image)
```
```
==Explanation from the agent==
I will use the following tool: `image_upscaler` to upscale the image.
==Code generated by the agent==
upscaled_image = image_upscaler(image)
```
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rivers_and_lakes_winter_upscale.png" width=400>
The agent automatically mapped our prompt "Upscale the image" to the just added upscaler tool purely based on the description and name of the upscaler tool
and was able to correctly run it.
Next, let's have a look into how you can create a new custom tool.
### Adding new tools
In this section we show how to create a new tool that can be added to the agent.
#### Creating a new tool
We'll first start by creating a tool. We'll add the not-so-useful yet fun task of fetching the model on the Hugging Face
Hub with the most downloads for a given task.
We can do that with the following code:
```python
from huggingface_hub import list_models
task = "text-classification"
model = next(iter(list_models(filter=task, sort="downloads", direction=-1)))
print(model.id)
```
For the task `text-classification`, this returns `'facebook/bart-large-mnli'`, for `translation` it returns `'t5-base`.
How do we convert this to a tool that the agent can leverage? All tools depend on the superclass `Tool` that holds the
main attributes necessary. We'll create a class that inherits from it:
```python
from transformers import Tool
class HFModelDownloadsTool(Tool):
pass
```
This class has a few needs:
- An attribute `name`, which corresponds to the name of the tool itself. To be in tune with other tools which have a
performative name, we'll name it `model_download_counter`.
- An attribute `description`, which will be used to populate the prompt of the agent.
- `inputs` and `outputs` attributes. Defining this will help the python interpreter make educated choices about types,
and will allow for a gradio-demo to be spawned when we push our tool to the Hub. They're both a list of expected
values, which can be `text`, `image`, or `audio`.
- A `__call__` method which contains the inference code. This is the code we've played with above!
Here's what our class looks like now:
```python
from transformers import Tool
from huggingface_hub import list_models
class HFModelDownloadsTool(Tool):
name = "model_download_counter"
description = (
"This is a tool that returns the most downloaded model of a given task on the Hugging Face Hub. "
"It takes the name of the category (such as text-classification, depth-estimation, etc), and "
"returns the name of the checkpoint."
)
inputs = ["text"]
outputs = ["text"]
def __call__(self, task: str):
model = next(iter(list_models(filter=task, sort="downloads", direction=-1)))
return model.id
```
We now have our tool handy. Save it in a file and import it from your main script. Let's name this file
`model_downloads.py`, so the resulting import code looks like this:
```python
from model_downloads import HFModelDownloadsTool
tool = HFModelDownloadsTool()
```
In order to let others benefit from it and for simpler initialization, we recommend pushing it to the Hub under your
namespace. To do so, just call `push_to_hub` on the `tool` variable:
```python
tool.push_to_hub("lysandre/hf-model-downloads")
```
You now have your code on the Hub! Let's take a look at the final step, which is to have the agent use it.
#### Having the agent use the tool
We now have our tool that lives on the Hub which can be instantiated as such:
```python
from transformers import load_tool
tool = load_tool("lysandre/hf-model-downloads")
```
In order to use it in the agent, simply pass it in the `additional_tools` parameter of the agent initialization method:
```python
from transformers import HfAgent
agent = HfAgent("https://api-inference.huggingface.co/models/bigcode/starcoder", additional_tools=[tool])
agent.run(
"Can you read out loud the name of the model that has the most downloads in the 'text-to-video' task on the Hugging Face Hub?"
)
```
which outputs the following:
```
==Code generated by the agent==
model = model_download_counter(task="text-to-video")
print(f"The model with the most downloads is {model}.")
audio_model = text_reader(model)
==Result==
The model with the most downloads is damo-vilab/text-to-video-ms-1.7b.
```
and generates the following audio.
| **Audio** |
|------------------------------------------------------------------------------------------------------------------------------------------------------|
| <audio controls><source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/damo.wav" type="audio/wav"/> |
<Tip>
Depending on the LLM, some are quite brittle and require very exact prompts in order to work well. Having a well-defined
description of the tool is paramount to having it be leveraged by the agent.
</Tip>
### Replacing existing tools
Replacing existing tools can be done simply by assigning a new item to the agent's toolbox. Here's how one would do so:
```python
from transformers import HfAgent, load_tool
agent = HfAgent("https://api-inference.huggingface.co/models/bigcode/starcoder")
agent.toolbox["image-transformation"] = load_tool("diffusers/controlnet-canny-tool")
```
<Tip>
Beware when replacing tools with others! This will also adjust the agent's prompt. This can be good if you have a better
prompt suited for the task, but it can also result in your tool being selected way more than others or for other
tools to be selected instead of the one you have defined.
</Tip>
## Leveraging gradio-tools
[gradio-tools](https://github.com/freddyaboulton/gradio-tools) is a powerful library that allows using Hugging
Face Spaces as tools. It supports many existing Spaces as well as custom Spaces to be designed with it.
We offer support for `gradio_tools` by using the `Tool.from_gradio` method. For example, we want to take
advantage of the `StableDiffusionPromptGeneratorTool` tool offered in the `gradio-tools` toolkit so as to
improve our prompts and generate better images.
We first import the tool from `gradio_tools` and instantiate it:
```python
from gradio_tools import StableDiffusionPromptGeneratorTool
gradio_tool = StableDiffusionPromptGeneratorTool()
```
We pass that instance to the `Tool.from_gradio` method:
```python
from transformers import Tool
tool = Tool.from_gradio(gradio_tools)
```
Now we can manage it exactly as we would a usual custom tool. We leverage it to improve our prompt
` a rabbit wearing a space suit`:
```python
from transformers import HfAgent
agent = HfAgent("https://api-inference.huggingface.co/models/bigcode/starcoder", additional_tools=[tool])
agent.run("Generate an image of the `prompt` after improving it.", prompt="A rabbit wearing a space suit")
```
The model adequately leverages the tool:
```
==Explanation from the agent==
I will use the following tools: `StableDiffusionPromptGenerator` to improve the prompt, then `image_generator` to generate an image according to the improved prompt.
==Code generated by the agent==
improved_prompt = StableDiffusionPromptGenerator(prompt)
print(f"The improved prompt is {improved_prompt}.")
image = image_generator(improved_prompt)
```
Before finally generating the image:
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png">
<Tip warning={true}>
gradio-tools requires *textual* inputs and outputs, even when working with different modalities. This implementation
works with image and audio objects. The two are currently incompatible, but will rapidly become compatible as we
work to improve the support.
</Tip>
## Future compatibility with Langchain
We love Langchain and think it has a very compelling suite of tools. In order to handle these tools,
Langchain requires *textual* inputs and outputs, even when working with different modalities.
This is often the serialized version (i.e., saved to disk) of the objects.
This difference means that multi-modality isn't handled between transformers-agents and langchain.
We aim for this limitation to be resolved in future versions, and welcome any help from avid langchain
users to help us achieve this compatibility.
We would love to have better support. If you would like to help, please
[open an issue](https://github.com/huggingface/transformers/issues/new) and share what you have in mind.

View File

@ -0,0 +1,64 @@
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Agents & Tools
<Tip warning={true}>
Transformers Agent is an experimental API which is subject to change at any time. Results returned by the agents
can vary as the APIs or underlying models are prone to change.
</Tip>
To learn more about agents and tools make sure to read the [introductory guide](../agents_and_tools). This page
contains the API docs for the underlying classes.
## Agents
We provide two types of agents: [`HfAgent`] uses inference endpoints for opensource models and [`OpenAiAgent`] uses OpenAI closed models.
### HfAgent
[[autodoc]] HfAgent
### OpenAiAgent
[[autodoc]] OpenAiAgent
### Agent
[[autodoc]] Agent
- chat
- run
- prepare_for_new_chat
## Tools
### load_tool
[[autodoc]] load_tool
### Tool
[[autodoc]] Tool
### PipelineTool
[[autodoc]] PipelineTool
### RemoteTool
[[autodoc]] RemoteTool
### launch_gradio_demo
[[autodoc]] launch_gradio_demo

View File

@ -0,0 +1,329 @@
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Transformers Agent
<Tip warning={true}>
Transformers Agent is an experimental API which is subject to change at any time. Results returned by the agents
can vary as the APIs or underlying models are prone to change.
</Tip>
Transformers version v4.29.0, building on the concept of *tools* and *agents*.
In short, it provides a natural language API on top of transformers: we define a set of curated tools, and design an
agent to interpret natural language and to use these tools. It is extensible by design; we curated some relevant tools,
but we'll show you how the system can be extended easily to use any tool developed by the community.
Let's start with a few examples of what can be achieved with this new API. It is particularly powerful when it comes
to multimodal tasks, so let's take it for a spin to generate images and read text out loud.
```py
agent.run("Caption the following image", image=image)
```
| **Input** | **Output** |
|-----------------------------------------------------------------------------------------------------------------------------|-----------------------------------|
| <img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/beaver.png" width=200> | A beaver is swimming in the water |
---
```py
agent.run("Read the following text out loud", text=text)
```
| **Input** | **Output** |
|-------------------------------------------------------------------------------------------------------------------------|----------------------------------------------|
| A beaver is swimming in the water | <audio controls><source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tts_example.wav" type="audio/wav"> your browser does not support the audio element. </audio>
---
```py
agent.run(
"In the following `document`, where will the TRRF Scientific Advisory Council Meeting take place?",
document=document,
)
```
| **Input** | **Output** |
|-----------------------------------------------------------------------------------------------------------------------------|----------------|
| <img src="https://datasets-server.huggingface.co/assets/hf-internal-testing/example-documents/--/hf-internal-testing--example-documents/test/0/image/image.jpg" width=200> | ballroom foyer |
## Quickstart
Before being able to use `agent.run`, you will need to instantiate an agent, which is a large language model (LLM).
We recommend using the [bigcode/starcoder](https://huggingface.co/bigcode/starcoder) checkpoint as it works very well
for the task at hand and is open-source, but please find other examples below.
Start by logging-in to have access to the Inference API:
```py
from huggingface_hub import login
login("<YOUR_TOKEN>")
```
Then, instantiate the agent
```py
from transformers import HfAgent
agent = HfAgent("https://api-inference.huggingface.co/models/bigcode/starcoder")
```
This is using the inference API that Hugging Face provides for free at the moment, if you have your own inference
endpoint for this model (or another one) you can replace the url above by your url endpoint.
<Tip>
We're showcasing StarCoder as the default in the documentation as the model is free to use and performs admirably well
on simple tasks. However, the checkpoint doesn't hold up when handling more complex prompts. If you're facing such an
issue, we recommend trying out the OpenAI model which, while sadly not open-source, performs better at this given time.
</Tip>
You're now good to go! Let's dive into the two APIs that you now have at your disposal.
### Single execution (run)
The single execution method is when using the [`~Agent.run`] method of the agent:
```py
agent.run("Draw me a picture of rivers and lakes")
```
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rivers_and_lakes.png" width=200>
It automatically select the tool (or tools) appropriate for the task you want to perform and run them appropriately. It
can perform one or several tasks in the same instruction (though the more complex your instruction, the more likely
the agent is to fail).
```py
agent.chat("Draw me a picture of the sea then transform the picture to add an island.")
```
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/sea_and_island.png" width=200>
<br/>
Every [`~Agent.run`] operation is independent, so you can run it several times in a row with different tasks.
Note that your `agent` is just a large-language model, so small variations in your prompt might yield completely
different results. It's important to explain as clearly as possible the task you want to perform.
If you'd like to keep a state across executions or to pass non-text objects to the agent, you can do so by specifying
variables that you would like the agent to use. For example you could generate the first image of rivers and lakes,
and ask the model to update that picture to add an island by doing the following:
```python
picture = agent.run("Draw me a picture of rivers and lakes")
updated_picture = agent.chat("Take that `picture` and add an island to it", picture=picture)
```
<Tip>
This can be helpful when the model is unable to understand your request and mixes tools. An example would be:
```python
agent.run("Draw me the picture of a capybara swimming in the sea")
```
Here, the model could interpret it two ways:
- Have the `text-to-image` generate a capybara swimming in the sea
- Or, have the `text-to-image` generate capybara, then use the `image-transformation` tool to have it swim in the sea
In case you would like to force the first scenario, you could do so by passing it the prompt as an argument:
```python
agent.run("Draw me a picture of the `prompt`", prompt="a capybara swimming in the sea")
```
</Tip>
### Chat-based execution (chat)
The agent also has a chat-based approach, using the [`~Agent.chat`] method:
```py
agent.chat("Draw me a picture of rivers and lakes")
```
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rivers_and_lakes.png" width=200>
```py
agent.chat("Transform the picture so that there is a rock in there")
```
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rivers_and_lakes_and_beaver.png" width=200>
<br/>
This is an interesting approach when you want to keep the state across instructions. It's better for experimentation,
but will tend to be much better at single instructions rather than complex instructions (which the [`~Agent.run`]
method is better at handling).
This method can also take arguments if you would like to pass non-text types or specific prompts.
### ⚠️ Remote execution
For demonstration purposes and so that this can be used with all setups, we have created remote executors for several
of the default tools the agent has access to. These are created using
[inference endpoints](https://huggingface.co/inference-endpoints). To see how to setup remote executors tools yourself,
we recommend reading the custom tool guide [TODO LINK].
In order to run with remote tools, specifying `remote=True` to either [`~Agent.run`] or [`~Agent.chat`] is sufficient.
For example, the following command could be run on any device efficiently, without needing significant RAM or GPU:
```python
agent.run("Draw me a picture of rivers and lakes", remote=True)
```
The same can be said for [`~Agent.chat`]:
```py
agent.chat("Draw me a picture of rivers and lakes", remote=True)
```
### What's happening here? What are tools, and what are agents?
#### Agents
The "agent" here is a large language model, and we're prompting it so that it has access to a specific set of tools.
LLMs are pretty good at generating small samples of code, so this API takes advantage of that by prompting the
LLM to give a small sample of code performing a task with a set of tools. This prompt is then completed by the
task you give your agent and the description of the tools you give it. This way it gets access to the doc of the
tools you are using, especially their expected inputs and outputs and can generate the relevant code.
#### Tools
Tools are very simple: they're a single function, with a name, and a description. We then use these tools description
to prompt the agent. Through the prompt, we show the agent how it would leverage tools in order to perform what was
requests in the query.
This is using brand-new tools and not pipelines, because the agent writes better code with very atomic tools.
Pipelines are more refactored and often combine several tasks in one. Tools are really meant to be focused on
one very simple task only.
#### Code-execution?!
This code is then executed with our small Python interpreter on the set of inputs passed along with your tools.
We hear you screaming "Arbitrary code execution!" in the back, but let us explain why that is not the case.
The only functions that can be called are the tools you provided and the print function, so you're already
limited in what can be executed. You should be safe if it's limited to Hugging Face tools.
Then, we don't allow any attribute lookup or imports (which shouldn't be needed anyway for passing along
inputs/outputs to a small set of functions) so all the most obvious attacks (and you'd need to prompt the LLM
to output them anyway) shouldn't be an issue. If you want to be on the super safe side, you can execute the
run() method with the additional argument return_code=True, in which case the agent will just return the code
to execute and you can decide whether to do it or not.
The execution will stop at any line trying to perform an illegal operation or if there is a regular Python error
with the code generated by the agent.
### A curated set of tools
We identify a set of tools that can empower such agents. Here is an updated list of the tools we have integrated
in `transformers`:
- **Document question answering**: given a document (such as a PDF) in image format, answer a question on this document ([Donut](../model_doc/donut))
- **Text question answering**: given a long text and a question, answer the question in the text ([Flan-T5](../model_doc/flan-t5))
- **Unconditional image captioning**: Caption the image! ([BLIP](../model_doc/blip))
- **Image question answering**: given an image, answer a question on this image ([VILT](../model_doc/vilt))
- **Image segmentation**: given an image and a prompt, output the segmentation mask of that prompt ([CLIPSeg](../model_doc/clipseg))
- **Speech to text**: given an audio recording of a person talking, transcribe the speech into text ([Whisper](../model_doc/whisper))
- **Text to speech**: convert text to speech ([SpeechT5](../model_doc/speecht5))
- **Zero-shot text classification**: given a text and a list of labels, identify to which label the text corresponds the most ([BART](../model_doc/bart))
- **Text summarization**: summarize a long text in one or a few sentences ([BART](../model_doc/bart))
- **Translation**: translate the text into a given language ([NLLB](../model_doc/nllb))
These tools have an integration in transformers, and can be used manually as well, for example:
```py
from transformers import load_tool
tool = load_tool("text-to-speech")
audio = tool("This is a text to speech tool")
```
### Custom tools
While we identify a curated set of tools, we strongly believe that the main value provided by this implementation is
the ability to quickly create and share custom tools.
By pushing the code of a tool to a Hugging Face Space or a model repository, you're then able to leverage the tool
directly with the agent. We've added a few
**transformers-agnostic** tools to the `huggingface-tools` organization:
- **Text downloader**: to download a text from a web URL
- **Text to image**: generate an image according to a prompt, leveraging stable diffusion
- **Image transformation**: modify an image given an initial image and a prompt, leveraging instruct pix2pix stable diffusion
The text-to-image tool we have been using since the beginning is actually a remote tool that lives in
[*huggingface-tools/text-to-image*](https://huggingface.co/spaces/huggingface-tools/text-to-image)! We will
continue releasing such tools on this and other organization, to further supercharge this implementation.
The agents have by default access to tools that reside on `huggingface-tools`.
We explain how to you can write and share your own tools as well as leverage any custom tool that resides on the Hub in [following guide](custom_tools).
[following guide](custom_tools).
### Leveraging different agents
We showcase here how to use the [bigcode/starcoder](https://huggingface.co/bigcode/starcoder) model as an LLM, but
it isn't the only model available. We also support the OpenAssistant model and OpenAI's davinci models (3.5 and 4).
We're planning on supporting local language models in an ulterior version.
The tools defined in this implementation are agnostic to the agent used; we are showcasing the agents that work with
our prompts below, but the tools can also be used with Langchain, Minichain, or any other Agent-based library.
#### Example code for the OpenAssistant model
```py
from transformers import HfAgent
agent = HfAgent(url_endpoint="https://OpenAssistant/oasst-sft-1-pythia-12b", token="<HF_TOKEN>")
```
#### Example code for OpenAI models
```py
from transformers import OpenAiAgent
agent = OpenAiAgent(model="text-davinci-003", api_key="<API_KEY>")
```
### Code generation
So far we have shown how to use the agents to perform actions for you. However, the agent is really only generating code
that we then execute using a very restricted Python interpreter. In case you would like to use the code generated in
a different setting, the agent can be prompted to return the code, along with tool definition and accurate imports.
For example, the following instruction
```python
agent.run("Draw me a picture of rivers and lakes", return_code=True)
```
returns the following code
```python
from transformers import load_tool
image_generator = load_tool("huggingface-tools/text-to-image")
image = image_generator(prompt="rivers and lakes")
```
that you can then modify and execute yourself.

View File

@ -610,6 +610,16 @@ _import_structure = {
"SpecialTokensMixin",
"TokenSpan",
],
"tools": [
"Agent",
"HfAgent",
"OpenAiAgent",
"PipelineTool",
"RemoteTool",
"Tool",
"launch_gradio_demo",
"load_tool",
],
"trainer_callback": [
"DefaultFlowCallback",
"EarlyStoppingCallback",
@ -4340,6 +4350,9 @@ if TYPE_CHECKING:
TokenSpan,
)
# Tools
from .tools import Agent, HfAgent, OpenAiAgent, PipelineTool, RemoteTool, Tool, launch_gradio_demo, load_tool
# Trainer
from .trainer_callback import (
DefaultFlowCallback,

View File

@ -115,9 +115,9 @@ def get_relative_import_files(module_file):
return all_relative_imports
def check_imports(filename):
def get_imports(filename):
"""
Check if the current Python environment contains all the libraries that are imported in a file.
Extracts all the libraries that are imported in a file.
"""
with open(filename, "r", encoding="utf-8") as f:
content = f.read()
@ -131,9 +131,14 @@ def check_imports(filename):
imports += re.findall(r"^\s*from\s+(\S+)\s+import", content, flags=re.MULTILINE)
# Only keep the top-level module
imports = [imp.split(".")[0] for imp in imports if not imp.startswith(".")]
return list(set(imports))
# Unique-ify and test we got them all
imports = list(set(imports))
def check_imports(filename):
"""
Check if the current Python environment contains all the libraries that are imported in a file.
"""
imports = get_imports(filename)
missing_packages = []
for imp in imports:
try:
@ -169,6 +174,7 @@ def get_cached_module_file(
use_auth_token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None,
local_files_only: bool = False,
repo_type: Optional[str] = None,
_commit_hash: Optional[str] = None,
):
"""
@ -207,6 +213,8 @@ def get_cached_module_file(
identifier allowed by git.
local_files_only (`bool`, *optional*, defaults to `False`):
If `True`, will only try to load the tokenizer configuration from local files.
repo_type (`str`, *optional*):
Specify the repo type (useful when downloading from a space for instance).
<Tip>
@ -229,7 +237,7 @@ def get_cached_module_file(
else:
submodule = pretrained_model_name_or_path.replace("/", os.path.sep)
cached_module = try_to_load_from_cache(
pretrained_model_name_or_path, module_file, cache_dir=cache_dir, revision=_commit_hash
pretrained_model_name_or_path, module_file, cache_dir=cache_dir, revision=_commit_hash, repo_type=repo_type
)
new_files = []
@ -245,6 +253,7 @@ def get_cached_module_file(
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
repo_type=repo_type,
_commit_hash=_commit_hash,
)
if not is_local and cached_module != resolved_module_file:
@ -309,8 +318,10 @@ def get_cached_module_file(
if len(new_files) > 0:
new_files = "\n".join([f"- {f}" for f in new_files])
repo_type_str = "" if repo_type is None else f"{repo_type}/"
url = f"https://huggingface.co/{repo_type_str}{pretrained_model_name_or_path}"
logger.warning(
f"A new version of the following files was downloaded from {pretrained_model_name_or_path}:\n{new_files}"
f"A new version of the following files was downloaded from {url}:\n{new_files}"
"\n. Make sure to double-check they do not contain any added malicious code. To avoid downloading new "
"versions of the code file, you can pin a revision."
)
@ -328,6 +339,7 @@ def get_class_from_dynamic_module(
use_auth_token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None,
local_files_only: bool = False,
repo_type: Optional[str] = None,
**kwargs,
):
"""
@ -377,6 +389,8 @@ def get_class_from_dynamic_module(
identifier allowed by git.
local_files_only (`bool`, *optional*, defaults to `False`):
If `True`, will only try to load the tokenizer configuration from local files.
repo_type (`str`, *optional*):
Specify the repo type (useful when downloading from a space for instance).
<Tip>
@ -418,6 +432,7 @@ def get_class_from_dynamic_module(
use_auth_token=use_auth_token,
revision=revision,
local_files_only=local_files_only,
repo_type=repo_type,
)
return get_class_in_module(class_name, final_module.replace(".py", ""))
@ -439,6 +454,7 @@ def custom_object_save(obj, folder, config=None):
"this code in a separate module so we can include it in the saved folder and make it easier to share via "
"the Hub."
)
return
def _set_auto_map_in_config(_config):
module_name = obj.__class__.__module__
@ -478,12 +494,17 @@ def custom_object_save(obj, folder, config=None):
elif config is not None:
_set_auto_map_in_config(config)
result = []
# Copy module file to the output folder.
object_file = sys.modules[obj.__module__].__file__
dest_file = Path(folder) / (Path(object_file).name)
shutil.copy(object_file, dest_file)
result.append(dest_file)
# Gather all relative imports recursively and make sure they are copied as well.
for needed_file in get_relative_import_files(object_file):
dest_file = Path(folder) / (Path(needed_file).name)
shutil.copy(needed_file, dest_file)
result.append(dest_file)
return result

View File

@ -64,6 +64,10 @@ class ChannelDimension(ExplicitEnum):
LAST = "channels_last"
def is_pil_image(img):
return is_vision_available() and isinstance(img, PIL.Image.Image)
def is_valid_image(img):
return (
(is_vision_available() and isinstance(img, PIL.Image.Image))

View File

@ -148,6 +148,7 @@ _run_custom_tokenizers = parse_flag_from_env("RUN_CUSTOM_TOKENIZERS", default=Fa
_run_staging = parse_flag_from_env("HUGGINGFACE_CO_STAGING", default=False)
_tf_gpu_memory_limit = parse_int_from_env("TF_GPU_MEMORY_LIMIT", default=None)
_run_pipeline_tests = parse_flag_from_env("RUN_PIPELINE_TESTS", default=True)
_run_tool_tests = parse_flag_from_env("RUN_TOOL_TESTS", default=False)
def is_pt_tf_cross_test(test_case):
@ -221,6 +222,21 @@ def is_pipeline_test(test_case):
return pytest.mark.is_pipeline_test()(test_case)
def is_tool_test(test_case):
"""
Decorator marking a test as a tool test. If RUN_TOOL_TESTS is set to a falsy value, those tests will be skipped.
"""
if not _run_tool_tests:
return unittest.skip("test is a tool test")(test_case)
else:
try:
import pytest # We don't need a hard dependency on pytest in the main library
except ImportError:
return test_case
else:
return pytest.mark.is_tool_test()(test_case)
def slow(test_case):
"""
Decorator marking a test as slow.

View File

@ -0,0 +1,73 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ..utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_torch_available,
)
_import_structure = {
"agents": ["Agent", "HfAgent", "OpenAiAgent"],
"base": ["PipelineTool", "RemoteTool", "Tool", "launch_gradio_demo", "load_tool"],
}
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["document_question_answering"] = ["DocumentQuestionAnsweringTool"]
_import_structure["image_captioning"] = ["ImageCaptioningTool"]
_import_structure["image_question_answering"] = ["ImageQuestionAnsweringTool"]
_import_structure["image_segmentation"] = ["ImageSegmentationTool"]
_import_structure["language_identifier"] = ["LanguageIdentificationTool"]
_import_structure["speech_to_text"] = ["SpeechToTextTool"]
_import_structure["text_classification"] = ["TextClassificationTool"]
_import_structure["text_question_answering"] = ["TextQuestionAnsweringTool"]
_import_structure["text_summarization"] = ["TextSummarizationTool"]
_import_structure["text_to_speech"] = ["TextToSpeechTool"]
_import_structure["translation"] = ["TranslationTool"]
if TYPE_CHECKING:
from .agents import Agent, HfAgent, OpenAiAgent
from .base import PipelineTool, RemoteTool, Tool, launch_gradio_demo, load_tool
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .document_question_answering import DocumentQuestionAnsweringTool
from .image_captioning import ImageCaptioningTool
from .image_question_answering import ImageQuestionAnsweringTool
from .image_segmentation import ImageSegmentationTool
from .language_identifier import LanguageIdentificationTool
from .speech_to_text import SpeechToTextTool
from .text_classification import TextClassificationTool
from .text_question_answering import TextQuestionAnsweringTool
from .text_summarization import TextSummarizationTool
from .text_to_speech import TextToSpeechTool
from .translation import TranslationTool
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)

View File

@ -0,0 +1,489 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib.util
import json
import os
import time
from dataclasses import dataclass
import requests
from huggingface_hub import HfFolder, hf_hub_download, list_spaces
from ..utils import is_openai_available, logging
from .base import TASK_MAPPING, TOOL_CONFIG_FILE, Tool, load_tool, supports_remote
from .prompts import CHAT_MESSAGE_PROMPT, CHAT_PROMPT_TEMPLATE, RUN_PROMPT_TEMPLATE
from .python_interpreter import evaluate
logger = logging.get_logger(__name__)
if is_openai_available():
import openai
_tools_are_initialized = False
BASE_PYTHON_TOOLS = {
"print": print,
"float": float,
"int": int,
"bool": bool,
"str": str,
}
@dataclass
class PreTool:
task: str
description: str
repo_id: str
HUGGINGFACE_DEFAULT_TOOLS = {}
HUGGINGFACE_DEFAULT_TOOLS_FROM_HUB = [
"image-transformation",
"text-download",
"text-to-image",
"text-to-video",
]
def get_remote_tools(organization="huggingface-tools"):
spaces = list_spaces(author=organization)
tools = {}
for space_info in spaces:
repo_id = space_info.id
resolved_config_file = hf_hub_download(repo_id, TOOL_CONFIG_FILE, repo_type="space")
with open(resolved_config_file, encoding="utf-8") as reader:
config = json.load(reader)
task = repo_id.split("/")[-1]
tools[config["name"]] = PreTool(task=task, description=config["description"], repo_id=repo_id)
return tools
def _setup_default_tools():
global HUGGINGFACE_DEFAULT_TOOLS
global _tools_are_initialized
if _tools_are_initialized:
return
main_module = importlib.import_module("transformers")
tools_module = main_module.tools
remote_tools = get_remote_tools()
for task_name in TASK_MAPPING:
tool_class_name = TASK_MAPPING.get(task_name)
tool_class = getattr(tools_module, tool_class_name)
description = tool_class.description
HUGGINGFACE_DEFAULT_TOOLS[tool_class.name] = PreTool(task=task_name, description=description, repo_id=None)
for task_name in HUGGINGFACE_DEFAULT_TOOLS_FROM_HUB:
found = False
for tool_name, tool in remote_tools.items():
if tool.task == task_name:
HUGGINGFACE_DEFAULT_TOOLS[tool_name] = tool
found = True
break
if not found:
raise ValueError(f"{task_name} is not implemented on the Hub.")
_tools_are_initialized = True
def resolve_tools(code, toolbox, remote=False, cached_tools=None):
if cached_tools is None:
resolved_tools = BASE_PYTHON_TOOLS.copy()
else:
resolved_tools = cached_tools
for name, tool in toolbox.items():
if name not in code or name in resolved_tools:
continue
if isinstance(tool, Tool):
resolved_tools[name] = tool
else:
task_or_repo_id = tool.task if tool.repo_id is None else tool.repo_id
_remote = remote and supports_remote(task_or_repo_id)
resolved_tools[name] = load_tool(task_or_repo_id, remote=_remote)
return resolved_tools
def get_tool_creation_code(code, toolbox, remote=False):
code_lines = ["from transformers import load_tool", ""]
for name, tool in toolbox.items():
if name not in code or isinstance(tool, Tool):
continue
task_or_repo_id = tool.task if tool.repo_id is None else tool.repo_id
line = f'{name} = load_tool("{task_or_repo_id}"'
if remote:
line += ", remote=True"
line += ")"
code_lines.append(line)
return "\n".join(code_lines) + "\n"
def clean_code_for_chat(result):
lines = result.split("\n")
idx = 0
while idx < len(lines) and not lines[idx].lstrip().startswith("```"):
idx += 1
explanation = "\n".join(lines[:idx]).strip()
if idx == len(lines):
return explanation, None
idx += 1
start_idx = idx
while not lines[idx].lstrip().startswith("```"):
idx += 1
code = "\n".join(lines[start_idx:idx]).strip()
return explanation, code
def clean_code_for_run(result):
result = f"I will use the following {result}"
explanation, code = result.split("Answer:")
explanation = explanation.strip()
code = code.strip()
code_lines = code.split("\n")
if code_lines[0] in ["```", "```py", "```python"]:
code_lines = code_lines[1:]
if code_lines[-1] == "```":
code_lines = code_lines[:-1]
code = "\n".join(code_lines)
return explanation, code
class Agent:
"""
Base class for all agents which contains the main API methods.
Args:
chat_prompt_template (`str`, *optional*):
Pass along your own prompt if you want to override the default template for the `chat` method.
run_prompt_template (`str`, *optional*):
Pass along your own prompt if you want to override the default template for the `run` method.
additional_tools ([`Tool`], list of tools or dictionary with tool values, *optional*):
Any additional tools to include on top of the default ones. If you pass along a tool with the same name as
one of the default tools, that default tool will be overridden.
"""
def __init__(self, chat_prompt_template=None, run_prompt_template=None, additional_tools=None):
_setup_default_tools()
self.chat_prompt_template = CHAT_MESSAGE_PROMPT if chat_prompt_template is None else chat_prompt_template
self.run_prompt_template = RUN_PROMPT_TEMPLATE if run_prompt_template is None else run_prompt_template
self.toolbox = HUGGINGFACE_DEFAULT_TOOLS.copy()
if additional_tools is not None:
if isinstance(additional_tools, (list, tuple)):
additional_tools = {t.name: t for t in additional_tools}
elif not isinstance(additional_tools, dict):
additional_tools = {additional_tools.name: additional_tools}
replacements = {name: tool for name, tool in additional_tools.items() if name in HUGGINGFACE_DEFAULT_TOOLS}
self.toolbox.update(additional_tools)
if len(replacements) > 1:
names = "\n".join([f"- {n}: {t}" for n, t in replacements.items()])
logger.warn(
f"The following tools have been replaced by the ones provided in `additional_tools`:\n{names}."
)
elif len(replacements) == 1:
name = list(replacements.keys())[0]
logger.warn(f"{name} has been replaced by {replacements[name]} as provided in `additional_tools`.")
self.prepare_for_new_chat()
def format_prompt(self, task, chat_mode=False):
description = "\n".join([f"- {name}: {tool.description}" for name, tool in self.toolbox.items()])
if chat_mode:
if self.chat_history is None:
prompt = CHAT_PROMPT_TEMPLATE.replace("<<all_tools>>", description)
else:
prompt = self.chat_history
prompt += CHAT_MESSAGE_PROMPT.replace("<<task>>", task)
else:
prompt = self.run_prompt_template.replace("<<all_tools>>", description)
prompt = prompt.replace("<<prompt>>", task)
return prompt
def chat(self, task, *, return_code=False, remote=False, **kwargs):
"""
Sends a new request to the agent in a chat. Will use the previous ones in its history.
Args:
task (`str`): The task to perform
return_code (`bool`, *optional*, defaults to `False`):
Whether to just return code and not evaluate it.
remote (`bool`, *optional*, defaults to `False`):
Whether or not to use remote tools (inference endpoints) instead of local ones.
kwargs:
Any keyword argument to send to the agent when evaluating the code.
Example:
```py
from transformers import HfAgent
agent = HfAgent("https://api-inference.huggingface.co/models/bigcode/starcoder")
agent.chat("Draw me a picture of rivers and lakes")
agent.chat("Transform the picture so that there is a rock in there")
```
"""
prompt = self.format_prompt(task, chat_mode=True)
result = self.generate_one(prompt, stop=["Human:", "====="])
self.chat_history = prompt + result + "\n"
explanation, code = clean_code_for_chat(result)
print(f"==Explanation from the agent==\n{explanation}")
if code is not None:
print(f"\n\n==Code generated by the agent==\n{code}")
if not return_code:
print("\n\n==Result==")
self.cached_tools = resolve_tools(code, self.toolbox, remote=remote, cached_tools=self.cached_tools)
self.chat_state.update(kwargs)
return evaluate(code, self.cached_tools, self.chat_state, chat_mode=True)
else:
tool_code = get_tool_creation_code(code, self.toolbox, remote=remote)
return f"{tool_code}\n{code}"
def prepare_for_new_chat(self):
"""
Clears the history of prior calls to [`~Agent.chat`].
"""
self.chat_history = None
self.chat_state = {}
self.cached_tools = None
def run(self, task, *, return_code=False, remote=False, **kwargs):
"""
Sends a request to the agent.
Args:
task (`str`): The task to perform
return_code (`bool`, *optional*, defaults to `False`):
Whether to just return code and not evaluate it.
remote (`bool`, *optional*, defaults to `False`):
Whether or not to use remote tools (inference endpoints) instead of local ones.
kwargs:
Any keyword argument to send to the agent when evaluating the code.
Example:
```py
from transformers import HfAgent
agent = HfAgent("https://api-inference.huggingface.co/models/bigcode/starcoder")
agent.run("Draw me a picture of rivers and lakes")
```
"""
prompt = self.format_prompt(task)
result = self.generate_one(prompt, stop=["Task:"])
explanation, code = clean_code_for_run(result)
print(f"==Explanation from the agent==\n{explanation}")
print(f"\n\n==Code generated by the agent==\n{code}")
if not return_code:
print("\n\n==Result==")
self.cached_tools = resolve_tools(code, self.toolbox, remote=remote, cached_tools=self.cached_tools)
return evaluate(code, self.cached_tools, state=kwargs.copy())
else:
tool_code = get_tool_creation_code(code, self.toolbox, remote=remote)
return f"{tool_code}\n{code}"
def generate_one(self, prompt, stop):
# This is the method to implement in your custom agent.
raise NotImplementedError
def generate_many(self, prompts, stop):
# Override if you have a way to do batch generation faster than one by one
return [self.generate_one(prompt, stop) for prompt in prompts]
class OpenAiAgent(Agent):
"""
Agent that uses the openai API to generate code.
<Tip warning={true}>
The openAI models are used in generation mode, so even for the `chat()` API, it's better to use models like
`"text-davinci-003"` over the chat-GPT variant. Proper support for chat-GPT models will come in a next version.
</Tip>
Args:
model (`str`, *optional*, defaults to `"text-davinci-003"`):
The name of the OpenAI model to use.
api_key (`str`, *optional*):
The API key to use. If unset, will look for the environment variable `"OPENAI_API_KEY"`.
chat_prompt_template (`str`, *optional*):
Pass along your own prompt if you want to override the default template for the `chat` method.
run_prompt_template (`str`, *optional*):
Pass along your own prompt if you want to override the default template for the `run` method.
additional_tools ([`Tool`], list of tools or dictionary with tool values, *optional*):
Any additional tools to include on top of the default ones. If you pass along a tool with the same name as
one of the default tools, that default tool will be overridden.
Example:
```py
from transformers import OpenAiAgent
agent = OpenAiAgent(model="text-davinci-003", api_key=xxx)
agent.run("Is the following `text` (in Spanish) positive or negative?", text="¡Este es un API muy agradable!")
```
"""
def __init__(
self,
model="text-davinci-003",
api_key=None,
chat_prompt_template=None,
run_prompt_template=None,
additional_tools=None,
):
if not is_openai_available():
raise ImportError("Using `OpenAiAgent` requires `openai`: `pip install openai`.")
if api_key is None:
api_key = os.environ.get("OPENAI_API_KEY", None)
if api_key is None:
raise ValueError(
"You need an openai key to use `OpenAIAgent`. You can get one here: Get one here "
"https://openai.com/api/`. If you have one, set it in your env with `os.environ['OPENAI_API_KEY'] = "
"xxx."
)
else:
openai.api_key = api_key
self.model = model
super().__init__(
chat_prompt_template=chat_prompt_template,
run_prompt_template=run_prompt_template,
additional_tools=additional_tools,
)
def generate_many(self, prompts, stop):
if "gpt" in self.model:
return [self._chat_generate(prompt, stop) for prompt in prompts]
else:
return self._completion_generate(prompts, stop)
def generate_one(self, prompt, stop):
if "gpt" in self.model:
return self._chat_generate(prompt, stop)
else:
return self._completion_generate([prompt], stop)[0]
def _chat_generate(self, prompt, stop):
result = openai.ChatCompletion.create(
model=self.model,
messages=[{"role": "user", "content": prompt}],
temperature=0,
stop=stop,
)
return result["choices"][0]["message"]["content"]
def _completion_generate(self, prompts, stop):
result = openai.Completion.create(
model=self.model,
prompt=prompts,
temperature=0,
stop=stop,
max_tokens=200,
)
return [answer["text"] for answer in result["choices"]]
class HfAgent(Agent):
"""
Agent that uses and inference endpoint to generate code.
Args:
url_endpoint (`str`):
The name of the url endpoint to use.
token (`str`, *optional*):
The token to use as HTTP bearer authorization for remote files. If unset, will use the token generated when
running `huggingface-cli login` (stored in `~/.huggingface`).
chat_prompt_template (`str`, *optional*):
Pass along your own prompt if you want to override the default template for the `chat` method.
run_prompt_template (`str`, *optional*):
Pass along your own prompt if you want to override the default template for the `run` method.
additional_tools ([`Tool`], list of tools or dictionary with tool values, *optional*):
Any additional tools to include on top of the default ones. If you pass along a tool with the same name as
one of the default tools, that default tool will be overridden.
Example:
```py
from transformers import HfAgent
agent = HfAgent("https://api-inference.huggingface.co/models/bigcode/starcoder")
agent.run("Is the following `text` (in Spanish) positive or negative?", text="¡Este es un API muy agradable!")
```
"""
def __init__(
self, url_endpoint, token=None, chat_prompt_template=None, run_prompt_template=None, additional_tools=None
):
self.url_endpoint = url_endpoint
if token is None:
self.token = f"Bearer {HfFolder().get_token()}"
elif token.startswith("Bearer") or token.startswith("Basic"):
self.token = token
else:
self.token = f"Bearer {token}"
super().__init__(
chat_prompt_template=chat_prompt_template,
run_prompt_template=run_prompt_template,
additional_tools=additional_tools,
)
def generate_one(self, prompt, stop):
headers = {"Authorization": self.token}
inputs = {
"inputs": prompt,
"parameters": {"max_new_tokens": 200, "return_full_text": False, "stop": stop},
}
response = requests.post(self.url_endpoint, json=inputs, headers=headers)
if response.status_code == 429:
print("Getting rate-limited, waiting a tiny bit before trying again.")
time.sleep(1)
return self._generate_one(prompt)
elif response.status_code != 200:
raise ValueError(f"Error {response.status_code}: {response.json()}")
result = response.json()[0]["generated_text"]
# Inference API returns the stop sequence
for stop_seq in stop:
if result.endswith(stop_seq):
result = result[: -len(stop_seq)]
return result

View File

@ -0,0 +1,722 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import base64
import importlib
import inspect
import io
import json
import os
import tempfile
from typing import Any, Dict, List, Optional, Union
from huggingface_hub import CommitOperationAdd, HfFolder, create_commit, create_repo, hf_hub_download
from huggingface_hub.utils import RepositoryNotFoundError, get_session
from ..dynamic_module_utils import custom_object_save, get_class_from_dynamic_module, get_imports
from ..image_utils import is_pil_image
from ..models.auto import AutoProcessor
from ..utils import (
CONFIG_NAME,
cached_file,
is_accelerate_available,
is_torch_available,
is_vision_available,
logging,
)
logger = logging.get_logger(__name__)
if is_torch_available():
import torch
if is_accelerate_available():
from accelerate.utils import send_to_device
TOOL_CONFIG_FILE = "tool_config.json"
def get_repo_type(repo_id, repo_type=None, **hub_kwargs):
if repo_type is not None:
return repo_type
try:
hf_hub_download(repo_id, TOOL_CONFIG_FILE, repo_type="space", **hub_kwargs)
return "space"
except RepositoryNotFoundError:
try:
hf_hub_download(repo_id, TOOL_CONFIG_FILE, repo_type="model", **hub_kwargs)
return "model"
except RepositoryNotFoundError:
raise EnvironmentError(f"`{repo_id}` does not seem to be a valid repo identifier on the Hub.")
except Exception:
return "model"
except Exception:
return "space"
# docstyle-ignore
APP_FILE_TEMPLATE = """from transformers import launch_gradio_demo
from {module_name} import {class_name}
launch_gradio_demo({class_name})
"""
class Tool:
"""
A base class for the functions used by the agent. Subclass this and implement the `__call__` method as well as the
following class attributes:
- **description** (`str`) -- A short description of what your tool does, the inputs it expects and the output(s) it
will return. For instance 'This is a tool that downloads a file from a `url`. It takes the `url` as input, and
returns the text contained in the file'.
- **name** (`str`) -- A performative name that will be used for your tool in the prompt to the agent. For instance
`"text-classifier"` or `"image_generator"`.
- **inputs** (`List[str]`) -- The list of modalities expected for the inputs (in the same order as in the call).
Modalitiies should be `"text"`, `"image"` or `"audio"`. This is only used by `launch_gradio_demo` or to make a
nice space from your tool.
- **outputs** (`List[str]`) -- The list of modalities returned but the tool (in the same order as the return of the
call method). Modalitiies should be `"text"`, `"image"` or `"audio"`. This is only used by `launch_gradio_demo`
or to make a nice space from your tool.
You can also override the method [`~Tool.setup`] if your tool as an expensive operation to perform before being
usable (such as loading a model). [`~Tool.setup`] will be called the first time you use your tool, but not at
instantiation.
"""
description: str = "This is a tool that ..."
name: str = ""
inputs: List[str]
outputs: List[str]
def __init__(self, *args, **kwargs):
self.is_initialized = False
def __call__(self, *args, **kwargs):
return NotImplemented("Write this method in your subclass of `Tool`.")
def setup(self):
"""
Overwrite this method here for any operation that is expensive and needs to be executed before you start using
your tool. Such as loading a big model.
"""
self.is_initialized = True
def save(self, output_dir):
"""
Saves the relevant code files for your tool so it can be pushed to the Hub. This will copy the code of your
tool in `output_dir` as well as autogenerate:
- a config file named `tool_config.json`
- an `app.py` file so that your tool can be converted to a space
- a `requirements.txt` containing the names of the module used by your tool (as detected when inspecting its
code)
You should only use this method to save tools that are defined in a separate module (not `__main__`).
Args:
output_dir (`str`): The folder in which you want to save your tool.
"""
os.makedirs(output_dir, exist_ok=True)
# Save module file
if self.__module__ == "__main__":
raise ValueError(
f"We can't save the code defining {self} in {output_dir} as it's been defined in __main__. You "
"have to put this code in a separate module so we can include it in the saved folder."
)
module_files = custom_object_save(self, output_dir)
module_name = self.__class__.__module__
last_module = module_name.split(".")[-1]
full_name = f"{last_module}.{self.__class__.__name__}"
# Save config file
config_file = os.path.join(output_dir, "tool_config.json")
if os.path.isfile(config_file):
with open(config_file, "r", encoding="utf-8") as f:
tool_config = json.load(f)
else:
tool_config = {}
tool_config = {"tool_class": full_name, "description": self.description, "name": self.name}
with open(config_file, "w", encoding="utf-8") as f:
f.write(json.dumps(tool_config, indent=2, sort_keys=True) + "\n")
# Save app file
app_file = os.path.join(output_dir, "app.py")
with open(app_file, "w", encoding="utf-8") as f:
f.write(APP_FILE_TEMPLATE.format(module_name=last_module, class_name=self.__class__.__name__))
# Save requirements file
requirements_file = os.path.join(output_dir, "requirements.txt")
imports = []
for module in module_files:
imports.extend(get_imports(module))
imports = list(set(imports))
with open(requirements_file, "w", encoding="utf-8") as f:
f.write("\n".join(imports) + "\n")
@classmethod
def from_hub(cls, repo_id, model_repo_id=None, token=None, remote=False, **kwargs):
"""
Loads a tool defined on the Hub.
Args:
repo_id (`str`):
The name of the repo on the Hub where your tool is defined.
model_repo_id (`str`, *optional*):
If your tool uses a model and you want to use a different model than the default, you can pass a second
repo ID or an endpoint url to this argument.
token (`str`, *optional*):
The token to identify you on hf.co. If unset, will use the token generated when running
`huggingface-cli login` (stored in `~/.huggingface`).
remote (`bool`, *optional*, defaults to `False`):
Whether to use your tool by downloading the model or (if it is available) with an inference endpoint.
kwargs:
Additional keyword arguments that will be split in two: all arguments relevant to the Hub (such as
`cache_dir`, `revision`, `subfolder`) will be used when downloading the files for your tool, and the
others will be passed along to its init.
"""
if remote and model_repo_id is None:
endpoints = get_default_endpoints()
if repo_id not in endpoints:
raise ValueError(
f"Could not infer a default endpoint for {repo_id}, you need to pass one using the "
"`model_repo_id` argument."
)
model_repo_id = endpoints[repo_id]
hub_kwargs_names = [
"cache_dir",
"force_download",
"resume_download",
"proxies",
"revision",
"repo_type",
"subfolder",
"local_files_only",
]
hub_kwargs = {k: v for k, v in kwargs.items() if k in hub_kwargs_names}
# Try to get the tool config first.
hub_kwargs["repo_type"] = get_repo_type(repo_id, **hub_kwargs)
resolved_config_file = cached_file(
repo_id,
TOOL_CONFIG_FILE,
use_auth_token=token,
**hub_kwargs,
_raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False,
)
is_tool_config = resolved_config_file is not None
if resolved_config_file is None:
resolved_config_file = cached_file(
repo_id,
CONFIG_NAME,
use_auth_token=token,
**hub_kwargs,
_raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False,
)
if resolved_config_file is None:
raise EnvironmentError(
f"{repo_id} does not appear to provide a valid configuration in `tool_config.json` or `config.json`."
)
with open(resolved_config_file, encoding="utf-8") as reader:
config = json.load(reader)
if not is_tool_config:
if "custom_tool" not in config:
raise EnvironmentError(
f"{repo_id} does not provide a mapping to custom tools in its configuration `config.json`."
)
custom_tool = config["custom_tool"]
else:
custom_tool = config
tool_class = custom_tool["tool_class"]
tool_class = get_class_from_dynamic_module(tool_class, repo_id, use_auth_token=token, **hub_kwargs)
if remote:
return RemoteTool(model_repo_id, token=token, tool_class=tool_class)
return tool_class(model_repo_id, token=token, **kwargs)
def push_to_hub(
self,
repo_id: str,
commit_message: str = "Upload tool",
private: Optional[bool] = None,
token: Optional[Union[bool, str]] = None,
create_pr: bool = False,
) -> str:
"""
Upload the tool to the Hub.
Parameters:
repo_id (`str`):
The name of the repository you want to push your tool to. It should contain your organization name when
pushing to a given organization.
commit_message (`str`, *optional*, defaults to `"Upload tool"`):
Message to commit while pushing.
private (`bool`, *optional*):
Whether or not the repository created should be private.
token (`bool` or `str`, *optional*):
The token to use as HTTP bearer authorization for remote files. If unset, will use the token generated
when running `huggingface-cli login` (stored in `~/.huggingface`).
create_pr (`bool`, *optional*, defaults to `False`):
Whether or not to create a PR with the uploaded files or directly commit.
"""
repo_url = create_repo(
repo_id=repo_id, token=token, private=private, exist_ok=True, repo_type="space", space_sdk="gradio"
)
repo_id = repo_url.repo_id
with tempfile.TemporaryDirectory() as work_dir:
# Save all files.
self.save(work_dir)
os.listdir(work_dir)
operations = [
CommitOperationAdd(path_or_fileobj=os.path.join(work_dir, f), path_in_repo=f)
for f in os.listdir(work_dir)
]
logger.info(f"Uploading the following files to {repo_id}: {','.join(os.listdir(work_dir))}")
return create_commit(
repo_id=repo_id,
operations=operations,
commit_message=commit_message,
token=token,
create_pr=create_pr,
repo_type="space",
)
@staticmethod
def from_gradio(gradio_tool):
"""
Creates a [`Tool`] from a gradio tool.
"""
class GradioToolWrapper(Tool):
def __init__(self, _gradio_tool):
super().__init__()
self.name = _gradio_tool.name
self.description = _gradio_tool.description
GradioToolWrapper.__call__ = gradio_tool.run
return GradioToolWrapper(gradio_tool)
class RemoteTool(Tool):
"""
A [`Tool`] that will make requests to an inference endpoint.
Args:
endpoint_url (`str`):
The url of the endpoint to use.
token (`str`, *optional*):
The token to use as HTTP bearer authorization for remote files. If unset, will use the token generated when
running `huggingface-cli login` (stored in `~/.huggingface`).
tool_class (`type`, *optional*):
The corresponding `tool_class` if this is a remote version of an existing tool. Will help determine when
the output should be converted to another type (like images).
"""
def __init__(self, endpoint_url=None, token=None, tool_class=None):
self.endpoint_url = endpoint_url
self.client = EndpointClient(endpoint_url, token=token)
self.tool_class = tool_class
def prepare_inputs(self, *args, **kwargs):
"""
Prepare the inputs received for the HTTP client sending data to the endpoint. Positional arguments will be
matched with the signature of the `tool_class` if it was provided at instantation. Images will be encoded into
bytes.
You can override this method in your custom class of [`RemoteTool`].
"""
inputs = kwargs.copy()
if len(args) > 0:
if self.tool_class is not None:
# Match args with the signature
if issubclass(self.tool_class, PipelineTool):
call_method = self.tool_class.encode
else:
call_method = self.tool_class.__call__
signature = inspect.signature(call_method).parameters
parameters = [
k
for k, p in signature.items()
if p.kind not in [inspect._ParameterKind.VAR_POSITIONAL, inspect._ParameterKind.VAR_KEYWORD]
]
if parameters[0] == "self":
parameters = parameters[1:]
if len(args) > len(parameters):
raise ValueError(
f"{self.tool_class} only accepts {len(parameters)} arguments but {len(args)} were given."
)
for arg, name in zip(args, parameters):
inputs[name] = arg
elif len(args) > 1:
raise ValueError("A `RemoteTool` can only accept one positional input.")
elif len(args) == 1:
if is_pil_image(args[0]):
return {"inputs": self.client.encode_image(args[0])}
return {"inputs": args[0]}
for key, value in inputs.items():
if is_pil_image(value):
inputs[key] = self.client.encode_image(value)
return {"inputs": inputs}
def extract_outputs(self, outputs):
"""
You can override this method in your custom class of [`RemoteTool`] to apply some custom post-processing of the
outputs of the endpoint.
"""
return outputs
def __call__(self, *args, **kwargs):
output_image = self.tool_class is not None and self.tool_class.outputs == ["image"]
inputs = self.prepare_inputs(*args, **kwargs)
if isinstance(inputs, dict):
outputs = self.client(**inputs, output_image=output_image)
else:
outputs = self.client(inputs, output_image=output_image)
if isinstance(outputs, list) and len(outputs) == 1 and isinstance(outputs[0], list):
outputs = outputs[0]
return self.extract_outputs(outputs)
class PipelineTool(Tool):
"""
A [`Tool`] tailored towards Transformer models. On top of the class attributes of the base class [`Tool`], you will
need to specify:
- **model_class** (`type`) -- The class to use to load the model in this tool.
- **default_checkpoint** (`str`) -- The default checkpoint that should be used when the user doesn't specify one.
- **pre_processor_class** (`type`, *optional*, defaults to [`AutoProcessor`]) -- The class to use to load the
pre-processor
- **post_processor_class** (`type`, *optional*, defaults to [`AutoProcessor`]) -- The class to use to load the
post-processor (when different from the pre-processor).
Args:
model (`str` or [`PreTrainedModel`], *optional*):
The name of the checkpoint to use for the model, or the instantiated model. If unset, will default to the
value of the class attribute `default_checkpoint`.
pre_processor (`str` or `Any`, *optional*):
The name of the checkpoint to use for the pre-processor, or the instantiated pre-processor (can be a
tokenizer, an image processor, a feature extractor or a processor). Will default to the value of `model` if
unset.
post_processor (`str` or `Any`, *optional*):
The name of the checkpoint to use for the post-processor, or the instantiated pre-processor (can be a
tokenizer, an image processor, a feature extractor or a processor). Will default to the `pre_processor` if
unset.
device (`int`, `str` or `torch.device`, *optional*):
The device on which to execute the model. Will default to any accelerator available (GPU, MPS etc...), the
CPU otherwise.
device_map (`str` or `dict`, *optional*):
If passed along, will be used to instantiate the model.
model_kwargs (`dict`, *optional*):
Any keyword argument to send to the model instantiation.
token (`str`, *optional*):
The token to use as HTTP bearer authorization for remote files. If unset, will use the token generated when
running `huggingface-cli login` (stored in `~/.huggingface`).
hub_kwargs:
Any additional keyword argument to send to the methods that will load the data from the Hub.
"""
pre_processor_class = AutoProcessor
model_class = None
post_processor_class = AutoProcessor
default_checkpoint = None
def __init__(
self,
model=None,
pre_processor=None,
post_processor=None,
device=None,
device_map=None,
model_kwargs=None,
token=None,
**hub_kwargs,
):
if not is_torch_available():
raise ImportError("Please install torch in order to use this tool.")
if not is_accelerate_available():
raise ImportError("Please install accelerate in order to use this tool.")
if model is None:
if self.default_checkpoint is None:
raise ValueError("This tool does not implement a default checkpoint, you need to pass one.")
model = self.default_checkpoint
if pre_processor is None:
pre_processor = model
self.model = model
self.pre_processor = pre_processor
self.post_processor = post_processor
self.device = device
self.device_map = device_map
self.model_kwargs = {} if model_kwargs is None else model_kwargs
if device_map is not None:
self.model_kwargs["device_map"] = device_map
self.hub_kwargs = hub_kwargs
self.hub_kwargs["use_auth_token"] = token
self.is_initialized = False
def setup(self):
"""
Instantiates the `pre_processor`, `model` and `post_processor` if necessary.
"""
if isinstance(self.pre_processor, str):
self.pre_processor = self.pre_processor_class.from_pretrained(self.pre_processor, **self.hub_kwargs)
if isinstance(self.model, str):
self.model = self.model_class.from_pretrained(self.model, **self.model_kwargs, **self.hub_kwargs)
if self.post_processor is None:
self.post_processor = self.pre_processor
elif isinstance(self.post_processor, str):
self.post_processor = self.post_processor_class.from_pretrained(self.post_processor, **self.hub_kwargs)
if self.device is None:
if self.device_map is not None:
self.device = list(self.model.hf_device_map.values())[0]
else:
self.device = get_default_device()
if self.device_map is None:
self.model.to(self.device)
def encode(self, raw_inputs):
"""
Uses the `pre_processor` to prepare the inputs for the `model`.
"""
return self.pre_processor(raw_inputs)
def forward(self, inputs):
"""
Sends the inputs through the `model`.
"""
with torch.no_grad():
return self.model(**inputs)
def decode(self, outputs):
"""
Uses the `post_processor` to decode the model output.
"""
return self.post_processor(outputs)
def __call__(self, *args, **kwargs):
if not self.is_initialized:
self.setup()
encoded_inputs = self.encode(*args, **kwargs)
encoded_inputs = send_to_device(encoded_inputs, self.device)
outputs = self.forward(encoded_inputs)
outputs = send_to_device(outputs, "cpu")
return self.decode(outputs)
def launch_gradio_demo(tool_class: Tool):
"""
Launches a gradio demo for a tool. The corresponding tool class needs to properly implement the class attributes
`inputs` and `outputs`.
Args:
tool_class (`type`): The class of the tool for which to launch the demo.
"""
try:
import gradio as gr
except ImportError:
raise ImportError("Gradio should be installed in order to launch a gradio demo.")
tool = tool_class()
def fn(*args, **kwargs):
return tool(*args, **kwargs)
gr.Interface(
fn=fn,
inputs=tool_class.inputs,
outputs=tool_class.outputs,
title=tool_class.__name__,
article=tool.description,
).launch()
# TODO: Migrate to Accelerate for this once `PartialState.default_device` makes its way into a release.
def get_default_device():
if not is_torch_available():
raise ImportError("Please install torch in order to use this tool.")
if torch.backends.mps.is_available() and torch.backends.mps.is_built():
return torch.device("mps")
elif torch.cuda.is_available():
return torch.device("cuda")
else:
return torch.device("cpu")
TASK_MAPPING = {
"document-question-answering": "DocumentQuestionAnsweringTool",
"image-captioning": "ImageCaptioningTool",
"image-question-answering": "ImageQuestionAnsweringTool",
"image-segmentation": "ImageSegmentationTool",
"speech-to-text": "SpeechToTextTool",
"summarization": "TextSummarizationTool",
"text-classification": "TextClassificationTool",
"text-question-answering": "TextQuestionAnsweringTool",
"text-to-speech": "TextToSpeechTool",
"translation": "TranslationTool",
}
def get_default_endpoints():
endpoints_file = cached_file("huggingface-tools/default-endpoints", "default_endpoints.json", repo_type="dataset")
with open(endpoints_file, "r", encoding="utf-8") as f:
endpoints = json.load(f)
return endpoints
def supports_remote(task_or_repo_id):
endpoints = get_default_endpoints()
return task_or_repo_id in endpoints
def load_tool(task_or_repo_id, model_repo_id=None, remote=False, token=None, **kwargs):
"""
Main function to quickly load a tool, be it on the Hub or in the Transformers library.
Args:
task_or_repo_id (`str`):
The task for which to load the tool or a repo ID of a tool on the Hub. Tasks implemented in Transformers
are:
- `"document-question-answering"`
- `"image-captioning"`
- `"image-question-answering"`
- `"image-segmentation"`
- `"speech-to-text"`
- `"summarization"`
- `"text-classification"`
- `"text-question-answering"`
- `"text-to-speech"`
- `"translation"`
model_repo_id (`str`, *optional*):
Use this argument to use a different model than the default one for the tool you selected.
remote (`bool`, *optional*, defaults to `False`):
Whether to use your tool by downloading the model or (if it is available) with an inference endpoint.
token (`str`, *optional*):
The token to identify you on hf.co. If unset, will use the token generated when running `huggingface-cli
login` (stored in `~/.huggingface`).
kwargs:
Additional keyword arguments that will be split in two: all arguments relevant to the Hub (such as
`cache_dir`, `revision`, `subfolder`) will be used when downloading the files for your tool, and the others
will be passed along to its init.
"""
if task_or_repo_id in TASK_MAPPING:
tool_class_name = TASK_MAPPING[task_or_repo_id]
main_module = importlib.import_module("transformers")
tools_module = main_module.tools
tool_class = getattr(tools_module, tool_class_name)
if remote:
if model_repo_id is None:
endpoints = get_default_endpoints()
if task_or_repo_id not in endpoints:
raise ValueError(
f"Could not infer a default endpoint for {task_or_repo_id}, you need to pass one using the "
"`model_repo_id` argument."
)
model_repo_id = endpoints[task_or_repo_id]
return RemoteTool(model_repo_id, token=token, tool_class=tool_class)
else:
return tool_class(model_repo_id, token=token, **kwargs)
else:
return Tool.from_hub(task_or_repo_id, model_repo_id=model_repo_id, token=token, remote=remote, **kwargs)
def add_description(description):
"""
A decorator that adds a description to a function.
"""
def inner(func):
func.description = description
func.name = func.__name__
return func
return inner
## Will move to the Hub
class EndpointClient:
def __init__(self, endpoint_url: str, token: Optional[str] = None):
if token is None:
token = HfFolder().get_token()
self.headers = {"authorization": f"Bearer {token}", "Content-Type": "application/json"}
self.endpoint_url = endpoint_url
@staticmethod
def encode_image(image):
_bytes = io.BytesIO()
image.save(_bytes, format="PNG")
b64 = base64.b64encode(_bytes.getvalue())
return b64.decode("utf-8")
@staticmethod
def decode_image(raw_image):
if not is_vision_available():
raise ImportError(
"This tool returned an image but Pillow is not installed. Please install it (`pip install Pillow`)."
)
from PIL import Image
b64 = base64.b64decode(raw_image)
_bytes = io.BytesIO(b64)
return Image.open(_bytes)
def __call__(
self,
inputs: Optional[Union[str, Dict, List[str], List[List[str]]]] = None,
params: Optional[Dict] = None,
data: Optional[bytes] = None,
output_image: bool = False,
) -> Any:
# Build payload
payload = {}
if inputs:
payload["inputs"] = inputs
if params:
payload["parameters"] = params
# Make API call
response = get_session().post(self.endpoint_url, headers=self.headers, json=payload, data=data)
# By default, parse the response for the user.
if output_image:
return self.decode_image(response.content)
else:
return response.json()

View File

@ -0,0 +1,80 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from ..models.auto import AutoProcessor
from ..models.vision_encoder_decoder import VisionEncoderDecoderModel
from ..utils import is_vision_available
from .base import PipelineTool
if is_vision_available():
from PIL import Image
class DocumentQuestionAnsweringTool(PipelineTool):
default_checkpoint = "naver-clova-ix/donut-base-finetuned-docvqa"
description = (
"This is a tool that answers a question about an document (pdf). It takes an input named `document` which "
"should be the document containing the information, as well as a `question` that is the question about the "
"document. It returns a text that contains the answer to the question."
)
name = "document_qa"
pre_processor_class = AutoProcessor
model_class = VisionEncoderDecoderModel
inputs = ["image", "text"]
outputs = ["text"]
def __init__(self, *args, **kwargs):
if not is_vision_available():
raise ValueError("Pillow must be installed to use the DocumentQuestionAnsweringTool.")
super().__init__(*args, **kwargs)
def encode(self, image: "Image", question: str):
task_prompt = "<s_docvqa><s_question>{user_input}</s_question><s_answer>"
prompt = task_prompt.replace("{user_input}", question)
decoder_input_ids = self.pre_processor.tokenizer(
prompt, add_special_tokens=False, return_tensors="pt"
).input_ids
pixel_values = self.pre_processor(image, return_tensors="pt").pixel_values
return {"decoder_input_ids": decoder_input_ids, "pixel_values": pixel_values}
def forward(self, inputs):
return self.model.generate(
inputs["pixel_values"].to(self.device),
decoder_input_ids=inputs["decoder_input_ids"].to(self.device),
max_length=self.model.decoder.config.max_position_embeddings,
early_stopping=True,
pad_token_id=self.pre_processor.tokenizer.pad_token_id,
eos_token_id=self.pre_processor.tokenizer.eos_token_id,
use_cache=True,
num_beams=1,
bad_words_ids=[[self.pre_processor.tokenizer.unk_token_id]],
return_dict_in_generate=True,
).sequences
def decode(self, outputs):
sequence = self.pre_processor.batch_decode(outputs)[0]
sequence = sequence.replace(self.pre_processor.tokenizer.eos_token, "")
sequence = sequence.replace(self.pre_processor.tokenizer.pad_token, "")
sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
sequence = self.pre_processor.token2json(sequence)
return sequence["answer"]

View File

@ -0,0 +1,692 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .agents import BASE_PYTHON_TOOLS, clean_code_for_chat, clean_code_for_run
from .python_interpreter import InterpretorError, evaluate
### Fake tools for test
def classifier(text, labels):
return f"This is the classification of {text} along {labels}."
def translator(text, src_lang, tgt_lang):
return f"This is the translation of {text} from {src_lang} to {tgt_lang}."
def speaker(text):
return f"This is actually a sound reading {text}."
def transcriber(audio):
if "sound" not in audio:
raise ValueError(f"`audio` ({audio}) is not a sound.")
return f"This is the transcribed text from {audio}."
def image_generator(prompt):
return f"This is actually an image representing {prompt}."
def image_captioner(image):
if "image" not in image:
raise ValueError(f"`image` ({image}) is not an image.")
return f"This is a description of {image}."
def image_transformer(image, prompt):
if "image" not in image:
raise ValueError(f"`image` ({image}) is not an image.")
return f"This is a transformation of {image} according to {prompt}."
def question_answerer(text, question):
return f"This is the answer to {question} from {text}."
def image_qa(image, question):
if "image" not in image:
raise ValueError(f"`image` ({image}) is not an image.")
return f"This is the answer to {question} from {image}."
def text_downloader(url):
return f"This is the content of {url}."
def summarizer(text):
return f"This is a summary of {text}."
def video_generator(prompt, seconds=2):
return f"A video of {prompt}"
def document_qa(image, question):
return f"This is the answer to {question} from the document {image}."
def image_segmenter(image, prompt):
return f"This is the mask of {prompt} in {image}"
TEST_TOOLS = {
"text_classifier": classifier,
"translator": translator,
"text_reader": speaker,
"summarizer": summarizer,
"transcriber": transcriber,
"image_generator": image_generator,
"image_captioner": image_captioner,
"image_transformer": image_transformer,
"text_qa": question_answerer,
"text_downloader": text_downloader,
"image_qa": image_qa,
"video_generator": video_generator,
"document_qa": document_qa,
"image_segmenter": image_segmenter,
}
class Problem:
"""
A class regrouping all the information to solve a problem on which we will evaluate agents.
Args:
task (`str` ou `list[str]`):
One or several descriptions of the task to perform. If a list, it should contain variations on the
phrasing, but for the same task.
inputs (`list[str]` or `dict[str, str]`):
The inputs that will be fed to the tools. For this testing environment, only strings are accepted as
values. Pass along a dictionary when you want to specify the values of each inputs, or just the list of
inputs expected (the value used will be `<<input_name>>` in this case).
answer (`str` or `list[str`]):
The theoretical answer (or list of possible valid answers) to the problem, as code.
"""
def __init__(self, task, inputs, answer):
self.task = task
self.inputs = inputs
self.answer = answer
### The list of problems the agent will be evaluated on.
EVALUATION_TASKS = [
Problem(
task=[
"Is the following `text` (in Spanish) positive or negative?",
"Is the text in the variable `text` (in Spanish) positive or negative?",
"Translate the following `text` from Spanish to English then tell me if its positive or negative.",
],
inputs=["text"],
answer="""text_classifier(translator(text, src_lang="Spanish", tgt_lang="English"), labels=["positive", "negative"])""",
),
Problem(
task=[
"Tell me out loud what the `image` contains.",
"Describe the following `image` out loud.",
"Find what is in the picture stored in `image` then read it out loud.",
],
inputs=["image"],
answer=[
"text_reader(image_captioner(image))",
"text_reader(image_qa(image, question='What is in the image?'))",
],
),
Problem(
task=[
"Generate an image from the text given in `text_input`. Then transform it according to the text in `prompt`.",
"Use the following `text_input` to generate an image, then transform it by using the text in `prompt`.",
],
inputs=["text_input", "prompt"],
answer="image_transformer(image_generator(text_input), prompt)",
),
Problem(
task=[
"Download the content of `url`, summarize it then generate an image from its content.",
"Use a summary of the web page at `url` to generate an image.",
"Summarize the content of the web page at `url`, and use the result to generate an image.",
],
inputs=["url"],
answer="image_generator(summarizer(text_downloader(url)))",
),
Problem(
task=[
"Transform the following `image` using the prompt in `text`. The prompt is in Spanish.",
"Use the text prompt in `text` (in Spanish) to transform the following `image`.",
"Translate the `text` from Spanish to English then use it to transform the picture in `image`.",
],
inputs=["text", "image"],
answer="image_transformer(image, translator(text, src_lang='Spanish', tgt_lang='English'))",
),
Problem(
task=[
"Download the content of `url`, summarize it then read it out loud to me.",
"Read me a summary of the web page at `url`.",
],
inputs=["url"],
answer="text_reader(summarizer(text_downloader(url)))",
),
Problem(
task=[
"Generate an image from the text given in `text_input`.",
],
inputs=["text_input"],
answer="image_generator(text_input)",
),
Problem(
task=[
"Replace the beaver in the `image` by the `prompt`.",
"Transform the `image` so that it contains the `prompt`.",
"Use `prompt` to transform this `image`.",
],
inputs=["image", "prompt"],
answer="image_transformer(image, prompt)",
),
Problem(
task=[
"Provide me the summary of the `text`, then read it to me before transcribing it and translating it in French.",
"Summarize `text`, read it out loud then transcribe the audio and translate it in French.",
"Read me a summary of the the `text` out loud. Transcribe this and translate it in French.",
],
inputs=["text"],
answer="translator(transcriber(text_reader(summarizer(text))), src_lang='English', tgt_lang='French')",
),
Problem(
task=["Generate a video of the `prompt`", "Animate a `prompt`", "Make me a short video using `prompt`."],
inputs={"prompt": "A lobster swimming"},
answer="video_generator('A lobster swimming')",
),
Problem(
task=[
"Download the following file `url`, summarize it in a few words and generate a video from it."
"Fetch the file at this `url`, summarize it, and create an animation out of it."
],
inputs=["url"],
answer="video_generator(summarizer(text_downloader(url)))",
),
]
EVALUATION_CHATS = [
[
Problem(
task=[
"Translate the following `text` from Spanish to English.",
"Translate the following `text` from Spanish to English.",
],
inputs=["text"],
answer="translated_text=translator(text, src_lang='Spanish', tgt_lang='English')",
),
Problem(
task=[
"Is it positive or negative?",
"Tell me if its positive or negative.",
],
inputs=[],
answer="text_classifier(translated_text, labels=['positive', 'negative'])",
),
],
[
Problem(
task=[
"What does this `image` contain?",
"Describe the following `image`.",
"Find what is in the picture stored in `image`",
],
inputs=["image"],
answer=[
"description=image_captioner(image)",
"description=image_qa(image, question='What is in the image?')",
],
),
Problem(
task=["Now, read the description out loud.", "Great! Can you read it out loud?", "Read it out loud."],
inputs=[],
answer=["audio=text_reader(description)", "audio=text_reader(description)"],
),
],
[
Problem(
task=[
"Generate an image from the text given in `text_input`.",
"Use the following `text_input` to generate an image",
],
inputs=["text_input"],
answer="image = image_generator(text_input)",
),
Problem(
task=[
"Transform it according to the text in `prompt`.",
"Transform it by using the text in `prompt`.",
],
inputs=["prompt"],
answer="image_transformer(image, prompt)",
),
],
[
Problem(
task=[
"Download the content of `url` and summarize it.",
"Summarize the content of the web page at `url`.",
],
inputs=["url"],
answer="summary = summarizer(text_downloader(url))",
),
Problem(
task=[
"Generate an image from its content.",
"Use the previous result to generate an image.",
],
inputs=[],
answer="image_generator(summary)",
),
],
[
Problem(
task=[
"Translate this Spanish `text` in English.",
"Translate the `text` from Spanish to English.",
],
inputs=["text"],
answer="translated_text = translator(text, src_lang='Spanish', tgt_lang='English')",
),
Problem(
task=[
"Transform the following `image` using the translated `text`.",
"Use the previous result to transform the following `image`.",
],
inputs=["image"],
answer="image_transformer(image, translated_text)",
),
],
[
Problem(
task=["Download the content of `url`.", "Get me the text on the weg page `url`."],
inputs=["url"],
answer="text = text_downloader(url)",
),
Problem(
task=["Summarize this text.", "Summarize this text."],
inputs=[],
answer="summary = summarizer(text)",
),
Problem(
task=["Read it out loud to me.", "Read me the previous result."],
inputs=[],
answer="text_reader(summary)",
),
],
[
Problem(
task=[
"Generate an image from the text given in `text_input`.",
],
inputs=["text_input"],
answer="image_generator(text_input)",
),
],
[
Problem(
task=[
"Replace the beaver in the `image` by the `prompt`.",
"Transform the `image` so that it contains the `prompt`.",
"Use `prompt` to transform this `image`.",
],
inputs=["image", "prompt"],
answer="image_transformer(image, prompt)",
),
],
[
Problem(
task=["Provide me the summary of the `text`.", "Summarize `text`."],
inputs=["text"],
answer="summary = summarizer(text)",
),
Problem(
task=["Read this summary to me.", "Read it out loud."],
inputs=[],
answer="audio = text_reader(summarizer(text))",
),
Problem(
task=["Transcribing the previous result back in text.", "Transcribe the audio."],
inputs=[],
answer="text = transcriber(audio)",
),
Problem(
task=["Translating the last result in French.", "Translate this in French."],
inputs=[],
answer="translator(text, src_lang='English', tgt_lang='French')",
),
],
[
Problem(
task=["Generate a video of the `prompt`", "Animate a `prompt`", "Make me a short video using `prompt`."],
inputs={"prompt": "A lobster swimming"},
answer="video_generator('A lobster swimming')",
),
],
[
Problem(
task=[
"Download the content of `url` and summarize it.",
"Summarize the content of the web page at `url`.",
],
inputs=["url"],
answer="summary = summarizer(text_downloader(url))",
),
Problem(
task=["generate a video from it.", "Create an animation from the last result."],
inputs=[],
answer="video_generator(summary)",
),
],
]
def get_theoretical_tools(agent_answer, theoretical_answer, code_answer):
if not isinstance(theoretical_answer, list):
return {name for name in TEST_TOOLS if name in code_answer}
if isinstance(agent_answer, dict):
for one_answer, one_code in zip(theoretical_answer, code_answer):
if one_answer in agent_answer.values():
return {name for name in TEST_TOOLS if name in one_code}
for one_answer, one_code in zip(theoretical_answer, code_answer):
if agent_answer == one_answer:
return {name for name in TEST_TOOLS if name in one_code}
return {name for name in TEST_TOOLS if name in code_answer[0]}
def evaluate_code(code, inputs=None, state=None, verbose=False, return_interpretor_error=False):
tools = BASE_PYTHON_TOOLS.copy()
for name, tool in TEST_TOOLS.items():
if name not in code:
continue
tools[name] = tool
if isinstance(inputs, dict):
inputs = inputs.copy()
elif inputs is not None:
inputs = {inp: f"<<{inp}>>" for inp in inputs}
if state is not None:
state.update(inputs)
else:
state = inputs
try:
return evaluate(code, tools, state)
except InterpretorError as e:
return str(e)
except Exception as e:
if verbose:
print(e)
return None
def score_code(agent_answer, theoretical_answer, verbose: bool = False):
if verbose:
print(agent_answer, theoretical_answer)
theoretical_answer = theoretical_answer if isinstance(theoretical_answer, list) else [theoretical_answer]
if agent_answer in theoretical_answer:
if verbose:
print("Perfect!")
return 1
elif isinstance(agent_answer, dict) and any(v in theoretical_answer for v in agent_answer.values()):
if verbose:
print("Almsot perfect, result in state!")
return 0.75
else:
if verbose:
print("Result is not the right one but code executed.")
return 0.3
def evaluate_one_result(explanation, code, agent_answer, theoretical_answer, answer, verbose=False):
tools_in_explanation = {name for name in TEST_TOOLS if f"`{name}`" in explanation}
theoretical_tools = get_theoretical_tools(agent_answer, theoretical_answer, answer)
if tools_in_explanation == theoretical_tools:
tool_selection_score = 1.0
tool_selection_errors = None
else:
missing_tools = len(theoretical_tools - tools_in_explanation)
unexpected_tools = len(tools_in_explanation - theoretical_tools)
tool_selection_score = max(0, 1.0 - 0.25 * missing_tools - 0.25 * unexpected_tools)
tool_selection_errors = {
"selected_tools": tools_in_explanation,
"theoretical_tools": theoretical_tools,
}
tools_in_code = {name for name in TEST_TOOLS if name in code}
if tools_in_code == theoretical_tools:
tool_used_score = 1.0
tool_used_errors = None
else:
missing_tools = len(theoretical_tools - tools_in_code)
unexpected_tools = len(tools_in_code - theoretical_tools)
tool_used_score = max(0, 1.0 - 0.25 * missing_tools - 0.25 * unexpected_tools)
tool_used_errors = {
"selected_tools": tools_in_explanation,
"theoretical_tools": theoretical_tools,
}
score = score_code(agent_answer, theoretical_answer, verbose=verbose)
if score < 1.0:
code_errors = {
"code_produced": code,
"evaluation": agent_answer,
"theoretical_answer": theoretical_answer,
}
else:
code_errors = None
return (tool_selection_score, tool_used_score, score), (tool_selection_errors, tool_used_errors, code_errors)
def evaluate_agent(agent, batch_size=8, verbose=False, return_errors=False):
"""
Evaluates a new agent on all `EVALUATION_TASKS`.
Example:
```py
agent = NewOpenAiAgent(model="text-davinci-003", api_key=your_api_key)
bads = new_evaluate_agent(agent)
for bad in bads:
print(bad)
```
"""
# Sanity check
agent_tools = set(agent.toolbox.keys())
if agent_tools != set(TEST_TOOLS):
missing_tools = set(TEST_TOOLS) - agent_tools
unexpected_tools = set(agent_tools) - TEST_TOOLS
raise ValueError(
f"Fix the test tools in the evaluate_agent module. Tools mising: {missing_tools}. Extra tools: {unexpected_tools}."
)
eval_tasks = []
eval_idx = []
for idx, pb in enumerate(EVALUATION_TASKS):
if isinstance(pb.task, list):
eval_tasks.extend(pb.task)
eval_idx.extend([idx] * len(pb.task))
else:
eval_tasks.append(pb.task)
eval_idx.append(idx)
tool_selection_score = 0
tool_used_score = 0
code_score = 0
if return_errors:
tool_selection_errors = {}
tool_used_errors = {}
code_errors = {}
for start_idx in range(0, len(eval_tasks), batch_size):
end_idx = min(start_idx + batch_size, len(eval_tasks))
batch_tasks = eval_tasks[start_idx:end_idx]
prompts = [agent.format_prompt(task) for task in batch_tasks]
results = agent.generate_many(prompts, stop=["Task:"])
for idx, result in enumerate(results):
problem = EVALUATION_TASKS[eval_idx[start_idx + idx]]
if verbose:
print(f"====Task {start_idx + idx}====\n{batch_tasks[idx]}\n")
explanation, code = clean_code_for_run(result)
# Evaluate agent answer and code answer
agent_answer = evaluate_code(code, problem.inputs, verbose=verbose)
if isinstance(problem.answer, list):
theoretical_answer = [evaluate_code(answer, problem.inputs) for answer in problem.answer]
else:
theoretical_answer = evaluate_code(problem.answer, problem.inputs)
scores, errors = evaluate_one_result(
explanation, code, agent_answer, theoretical_answer, problem.answer, verbose=verbose
)
tool_selection_score += scores[0]
tool_used_score += scores[1]
code_score += scores[2]
if return_errors:
if errors[0] is not None:
tool_selection_errors[batch_tasks[idx]] = errors[0]
if errors[1] is not None:
tool_used_errors[batch_tasks[idx]] = errors[1]
if errors[2] is not None:
code_errors[batch_tasks[idx]] = errors[2]
scores = {
"tool selection score": 100 * (tool_selection_score / len(eval_tasks)),
"tool used score": 100 * (tool_used_score / len(eval_tasks)),
"code score": 100 * (code_score / len(eval_tasks)),
}
if return_errors:
return scores, tool_selection_errors, tool_used_errors, code_errors
else:
return scores
def evaluate_chat_agent(agent, verbose=False, return_errors=False):
"""
Evaluates a new agent on all `EVALUATION_CHATS`.
Example:
```py
agent = NewOpenAiAgent(model="text-davinci-003", api_key=your_api_key)
bads = new_evaluate_agent(agent)
for bad in bads:
print(bad)
```
"""
# Sanity check
agent_tools = set(agent.toolbox.keys())
if agent_tools != set(TEST_TOOLS):
missing_tools = set(TEST_TOOLS) - agent_tools
unexpected_tools = agent_tools - set(TEST_TOOLS)
raise ValueError(
f"Fix the test tools in the evaluate_agent module. Tools mising: {missing_tools}. Extra tools: {unexpected_tools}."
)
tool_selection_score = 0
tool_used_score = 0
code_score = 0
total_steps = 0
if return_errors:
tool_selection_errors = {}
tool_used_errors = {}
code_errors = {}
for chat_problem in EVALUATION_CHATS:
if isinstance(chat_problem[0].task, str):
resolved_problems = [chat_problem]
else:
resolved_problems = [
[Problem(task=pb.task[i], inputs=pb.inputs, answer=pb.answer) for pb in chat_problem]
for i in range(len(chat_problem[0].task))
]
for problem in resolved_problems:
agent.prepare_for_new_chat()
agent_state = {}
theoretical_state = (
[{} for _ in range(len(problem[0].answer))] if isinstance(problem[0].answer, list) else {}
)
for step, step_problem in enumerate(problem):
if verbose:
print(step_problem.task)
total_steps += 1
prompt = agent.format_prompt(step_problem.task, chat_mode=True)
result = agent.generate_one(prompt, stop=["Human:", "====="])
agent.chat_history = prompt + result + "\n"
explanation, code = clean_code_for_chat(result)
if verbose:
print(f"==Explanation from the agent==\n{explanation}")
print(f"\n==Code generated by the agent==\n{code}")
# Evaluate agent answer and code answer
agent_answer = evaluate_code(code, step_problem.inputs, state=agent_state, verbose=verbose)
answer = step_problem.answer
if isinstance(answer, list):
theoretical_answer = [
evaluate_code(a, step_problem.inputs, state=state)
for a, state in zip(answer, theoretical_state)
]
else:
theoretical_answer = evaluate_code(answer, step_problem.inputs, state=theoretical_state)
scores, errors = evaluate_one_result(
explanation, code, agent_answer, theoretical_answer, answer, verbose=verbose
)
tool_selection_score += scores[0]
tool_used_score += scores[1]
code_score += scores[2]
if return_errors:
if errors[0] is not None:
tool_selection_errors[step_problem.task] = errors[0]
if errors[1] is not None:
tool_used_errors[step_problem.task] = errors[1]
if errors[2] is not None:
code_errors[step_problem.task] = errors[2]
scores = {
"tool selection score": 100 * (tool_selection_score / total_steps),
"tool used score": 100 * (tool_used_score / total_steps),
"code score": 100 * (code_score / total_steps),
}
if return_errors:
return scores, tool_selection_errors, tool_used_errors, code_errors
else:
return scores

View File

@ -0,0 +1,51 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ..models.auto import AutoModelForVision2Seq
from ..utils import requires_backends
from .base import PipelineTool
if TYPE_CHECKING:
from PIL import Image
class ImageCaptioningTool(PipelineTool):
default_checkpoint = "Salesforce/blip-image-captioning-base"
description = (
"This is a tool that generates a description of an image. It takes an input named `image` which should be the "
"image to caption, and returns a text that contains the description in English."
)
name = "image_captioner"
model_class = AutoModelForVision2Seq
inputs = ["image"]
outputs = ["text"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["vision"])
super().__init__(*args, **kwargs)
def encode(self, image: "Image"):
return self.pre_processor(images=image, return_tensors="pt")
def forward(self, inputs):
return self.model.generate(**inputs)
def decode(self, outputs):
return self.pre_processor.batch_decode(outputs, skip_special_tokens=True)[0].strip()

View File

@ -0,0 +1,57 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
import torch
from ..models.auto import AutoModelForVisualQuestionAnswering, AutoProcessor
from ..utils import requires_backends
from .base import PipelineTool
if TYPE_CHECKING:
from PIL import Image
class ImageQuestionAnsweringTool(PipelineTool):
default_checkpoint = "dandelin/vilt-b32-finetuned-vqa"
description = (
"This is a tool that answers a question about an image. It takes an input named `image` which should be the "
"image containing the information, as well as a `question` which should be the question in English. It "
"returns a text that is the answer to the question."
)
name = "image_qa"
pre_processor_class = AutoProcessor
model_class = AutoModelForVisualQuestionAnswering
inputs = ["image", "text"]
outputs = ["text"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["vision"])
super().__init__(*args, **kwargs)
def encode(self, image: "Image", question: str):
return self.pre_processor(image, question, return_tensors="pt")
def forward(self, inputs):
with torch.no_grad():
return self.model(**inputs).logits
def decode(self, outputs):
idx = outputs.argmax(-1).item()
return self.model.config.id2label[idx]

View File

@ -0,0 +1,60 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import torch
from ..models.clipseg import CLIPSegForImageSegmentation
from ..utils import is_vision_available, requires_backends
from .base import PipelineTool
if is_vision_available():
from PIL import Image
class ImageSegmentationTool(PipelineTool):
description = (
"This is a tool that creates a segmentation mask identifiying elements inside an image according to a prompt. "
"It takes two arguments named `image` which should be the original image, and `prompt` which should be a text "
"describing the elements what should be identified in the segmentation mask. The tool returns the mask as a "
"black-and-white image."
)
default_checkpoint = "CIDAS/clipseg-rd64-refined"
name = "image_segmenter"
model_class = CLIPSegForImageSegmentation
inputs = ["image", "text"]
outputs = ["image"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["vision"])
super().__init__(*args, **kwargs)
def encode(self, image: "Image", prompt: str):
self.pre_processor.image_processor.size = {"width": image.size[0], "height": image.size[1]}
return self.pre_processor(text=[prompt], images=[image], padding=True, return_tensors="pt")
def forward(self, inputs):
with torch.no_grad():
logits = self.model(**inputs).logits
return logits
def decode(self, outputs):
array = outputs.cpu().detach().numpy()
array[array <= 0] = 0
array[array > 0] = 1
return Image.fromarray((array * 255).astype(np.uint8))

View File

@ -0,0 +1,186 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# docstyle-ignore
RUN_PROMPT_TEMPLATE = """I will ask you to perform a task, your job is to come up with a series of simple commands in Python that will perform the task.
To help you, I will give you access to a set of tools that you can use. Each tool is a Python function and has a description explaining the task it performs, the inputs it expects and the outputs it returns.
You should first explain which tool you will use to perform the task and for what reason, then write the code in Python.
Each instruction in Python should be a simple assignment. You can print intermediate results if it makes sense to do so.
Tools:
<<all_tools>>
Task: "Answer the question in the variable `question` about the image stored in the variable `image`. The question is in French."
I will use the following tools: `translator` to translate the question into English and then `image_qa` to answer the question on the input image.
Answer:
```py
translated_question = translator(question=question, src_lang="French", tgt_lang="English")
print(f"The translated question is {translated_question}.")
answer = image_qa(image=image, question=translated_question)
print(f"The answer is {answer}")
```
Task: "Identify the oldest person in the `document` and create an image showcasing the result."
I will use the following tools: `document_qa` to find the oldest person in the document, then `image_generator` to generate an image according to the answer.
Answer:
```py
answer = document_qa(document, question="What is the oldest person?")
print(f"The answer is {answer}.")
image = image_generator(answer)
```
Task: "Generate an image using the text given in the variable `caption`."
I will use the following tool: `image_generator` to generate an image.
Answer:
```py
image = image_generator(prompt=caption)
```
Task: "Summarize the text given in the variable `text` and read it out loud."
I will use the following tools: `summarizer` to create a summary of the input text, then `text_reader` to read it out loud.
Answer:
```py
summarized_text = summarizer(text)
print(f"Summary: {summarized_text}")
audio_summary = text_reader(summarized_text)
```
Task: "Answer the question in the variable `question` about the text in the variable `text`. Use the answer to generate an image."
I will use the following tools: `text_qa` to create the answer, then `image_generator` to generate an image according to the answer.
Answer:
```py
answer = text_qa(text=text, question=question)
print(f"The answer is {answer}.")
image = image_generator(answer)
```
Task: "Caption the following `image`."
I will use the following tool: `image_captioner` to generate a caption for the image.
Answer:
```py
caption = image_captioner(image)
```
Task: "<<prompt>>"
I will use the following"""
# docstyle-ignore
CHAT_PROMPT_TEMPLATE = """Below are a series of dialogues between various people and an AI assistant specialized in coding. The AI assistant tries to be helpful, polite, honest, and humble-but-knowledgeable.
The job of the AI assistant is to come up with a series of simple commands in Python that will perform the task the human wants to perform.
To help with that, the AI assistant has access to a set of tools. Each tool is a Python function and has a description explaining the task it performs, the inputs it expects and the outputs it returns.
The AI assistant should first explain the tools it will use to perform the task and for what reason, then write the code in Python.
Each instruction in Python should be a simple assignment. The AI assistant can print intermediate results if it makes sense to do so.
Tools:
<<all_tools>>
=====
Human: Answer the question in the variable `question` about the image stored in the variable `image`.
Assistant: I will use the tool `image_qa` to answer the question on the input image.
```py
answer = image_qa(text=question, image=image)
print(f"The answer is {answer}")
```
Human: I tried this code, it worked but didn't give me a good result. The question is in French
Assistant: In this case, the question needs to be translated first. I will use the tool `translator` to do this.
```py
translated_question = translator(question=question, src_lang="French", tgt_lang="English")
print(f"The translated question is {translated_question}.")
answer = image_qa(text=translated_question, image=image)
print(f"The answer is {answer}")
```
=====
Human: Identify the oldest person in the `document`.
Assistant: I will use the tool `document_qa` to find the oldest person in the document.
```py
answer = document_qa(document, question="What is the oldest person?")
print(f"The answer is {answer}.")
```
Human: Can you generate an image with the result?
Assistant: I will use the tool `image_generator` to do that.
```py
image = image_generator(answer)
```
=====
Human: Summarize the text given in the variable `text` and read it out loud.
Assistant: I will use the tool `summarizer` to create a summary of the input text, then the tool `text_reader` to read it out loud.
```py
summarized_text = summarizer(text)
print(f"Summary: {summarized_text}")
audio_summary = text_reader(text=summary)
```
Human: I got the following error: "The variable `summary` is not defined."
Assistant: My bad! Let's try this code instead.
```py
summarized_text = summarizer(text)
print(f"Summary: {summarized_text}")
audio_summary = text_reader(text=summarized_text)
```
Human: It worked! Can you translate the summary in German?
Assistant: I will use the tool `translator` to translate the text in German.
```py
translated_summary = translator(summarized_text, src_lang="English", tgt_lang="German)
```
====
"""
# docstyle-ignore
CHAT_MESSAGE_PROMPT = """
Human: <<task>>
Assistant: """

View File

@ -0,0 +1,238 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import ast
import difflib
from collections.abc import Mapping
from typing import Any, Callable, Dict
class InterpretorError(ValueError):
"""
An error raised when the interpretor cannot evaluate a Python expression, due to syntax error or unsupported
operations.
"""
pass
def evaluate(code: str, tools: Dict[str, Callable], state=None, chat_mode=False):
"""
Evaluate a python expression using the content of the variables stored in a state and only evaluating a given set
of functions.
This function will recurse through the nodes of the tree provided.
Args:
code (`str`):
The code to evaluate.
tools (`Dict[str, Callable]`):
The functions that may be called during the evaluation. Any call to another function will fail with an
`InterpretorError`.
state (`Dict[str, Any]`):
A dictionary mapping variable names to values. The `state` should contain the initial inputs but will be
updated by this function to contain all variables as they are evaluated.
chat_mode (`bool`, *optional*, defaults to `False`):
Whether or not the function is called from `Agent.chat`.
"""
try:
expression = ast.parse(code)
except SyntaxError as e:
print("The code generated by the agent is not valid.\n", e)
return
if state is None:
state = {}
result = None
for idx, node in enumerate(expression.body):
try:
line_result = evaluate_ast(node, state, tools)
except InterpretorError as e:
msg = f"Evaluation of the code stopped at line {idx} before the end because of the following error"
if chat_mode:
msg += (
f". Copy paste the following error message and send it back to the agent:\nI get an error: '{e}'"
)
else:
msg += f":\n{e}"
print(msg)
break
if line_result is not None:
result = line_result
return result
def evaluate_ast(expression: ast.AST, state: Dict[str, Any], tools: Dict[str, Callable]):
"""
Evaluate an absract syntax tree using the content of the variables stored in a state and only evaluating a given
set of functions.
This function will recurse trough the nodes of the tree provided.
Args:
expression (`ast.AST`):
The code to evaluate, as an abastract syntax tree.
state (`Dict[str, Any]`):
A dictionary mapping variable names to values. The `state` is updated if need be when the evaluation
encounters assignements.
tools (`Dict[str, Callable]`):
The functions that may be called during the evaluation. Any call to another function will fail with an
`InterpretorError`.
"""
if isinstance(expression, ast.Assign):
# Assignement -> we evaluate the assignement which should update the state
# We return the variable assigned as it may be used to determine the final result.
return evaluate_assign(expression, state, tools)
elif isinstance(expression, ast.Call):
# Function call -> we return the value of the function call
return evaluate_call(expression, state, tools)
elif isinstance(expression, ast.Constant):
# Constant -> just return the value
return expression.value
elif isinstance(expression, ast.Dict):
# Dict -> evaluate all keys and values
keys = [evaluate_ast(k, state, tools) for k in expression.keys]
values = [evaluate_ast(v, state, tools) for v in expression.values]
return dict(zip(keys, values))
elif isinstance(expression, ast.Expr):
# Expression -> evaluate the content
return evaluate_ast(expression.value, state, tools)
elif isinstance(expression, ast.FormattedValue):
# Formatted value (part of f-string) -> evaluate the content and return
return evaluate_ast(expression.value, state, tools)
elif isinstance(expression, ast.If):
# If -> execute the right branch
return evaluate_if(expression, state, tools)
elif hasattr(ast, "Index") and isinstance(expression, ast.Index):
return evaluate_ast(expression.value, state, tools)
elif isinstance(expression, ast.JoinedStr):
return "".join([str(evaluate_ast(v, state, tools)) for v in expression.values])
elif isinstance(expression, ast.List):
# List -> evaluate all elements
return [evaluate_ast(elt, state, tools) for elt in expression.elts]
elif isinstance(expression, ast.Name):
# Name -> pick up the value in the state
return evaluate_name(expression, state, tools)
elif isinstance(expression, ast.Subscript):
# Subscript -> return the value of the indexing
return evaluate_subscript(expression, state, tools)
else:
# For now we refuse anything else. Let's add things as we need them.
raise InterpretorError(f"{expression.__class__.__name__} is not supported.")
def evaluate_assign(assign, state, tools):
var_names = assign.targets
result = evaluate_ast(assign.value, state, tools)
if len(var_names) == 1:
state[var_names[0].id] = result
else:
if len(result) != len(var_names):
raise InterpretorError(f"Expected {len(var_names)} values but got {len(result)}.")
for var_name, r in zip(var_names, result):
state[var_name.id] = r
return result
def evaluate_call(call, state, tools):
if not isinstance(call.func, ast.Name):
raise InterpretorError(
f"It is not permitted to evaluate other functions than the provided tools (tried to execute {call.func} of "
f"type {type(call.func)}."
)
func_name = call.func.id
if func_name not in tools:
raise InterpretorError(
f"It is not permitted to evaluate other functions than the provided tools (tried to execute {call.func.id})."
)
func = tools[func_name]
# Todo deal with args
args = [evaluate_ast(arg, state, tools) for arg in call.args]
kwargs = {keyword.arg: evaluate_ast(keyword.value, state, tools) for keyword in call.keywords}
return func(*args, **kwargs)
def evaluate_subscript(subscript, state, tools):
index = evaluate_ast(subscript.slice, state, tools)
value = evaluate_ast(subscript.value, state, tools)
if isinstance(value, (list, tuple)):
return value[int(index)]
if index in value:
return value[index]
if isinstance(index, str) and isinstance(value, Mapping):
close_matches = difflib.get_close_matches(index, list(value.keys()))
if len(close_matches) > 0:
return value[close_matches[0]]
raise InterpretorError(f"Could not index {value} with '{index}'.")
def evaluate_name(name, state, tools):
if name.id in state:
return state[name.id]
close_matches = difflib.get_close_matches(name.id, list(state.keys()))
if len(close_matches) > 0:
return state[close_matches[0]]
raise InterpretorError(f"The variable `{name.id}` is not defined.")
def evaluate_condition(condition, state, tools):
if len(condition.ops) > 1:
raise InterpretorError("Cannot evaluate conditions with multiple operators")
left = evaluate_ast(condition.left, state, tools)
comparator = condition.ops[0]
right = evaluate_ast(condition.comparators[0], state, tools)
if isinstance(comparator, ast.Eq):
return left == right
elif isinstance(comparator, ast.NotEq):
return left != right
elif isinstance(comparator, ast.Lt):
return left < right
elif isinstance(comparator, ast.LtE):
return left <= right
elif isinstance(comparator, ast.Gt):
return left > right
elif isinstance(comparator, ast.GtE):
return left >= right
elif isinstance(comparator, ast.Is):
return left is right
elif isinstance(comparator, ast.IsNot):
return left is not right
elif isinstance(comparator, ast.In):
return left in right
elif isinstance(comparator, ast.NotIn):
return left not in right
else:
raise InterpretorError(f"Operator not supported: {comparator}")
def evaluate_if(if_statement, state, tools):
result = None
if evaluate_condition(if_statement.test, state, tools):
for line in if_statement.body:
line_result = evaluate_ast(line, state, tools)
if line_result is not None:
result = line_result
else:
for line in if_statement.orelse:
line_result = evaluate_ast(line, state, tools)
if line_result is not None:
result = line_result
return result

View File

@ -0,0 +1,41 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ..models.whisper import WhisperForConditionalGeneration, WhisperProcessor
from .base import PipelineTool
class SpeechToTextTool(PipelineTool):
default_checkpoint = "openai/whisper-base"
description = (
"This is a tool that transcribes an audio into text. It takes an input named `audio` and returns the "
"transcribed text."
)
name = "transcriber"
pre_processor_class = WhisperProcessor
model_class = WhisperForConditionalGeneration
inputs = ["audio"]
outputs = ["text"]
def encode(self, audio):
return self.pre_processor(audio, return_tensors="pt").input_features
def forward(self, inputs):
return self.model.generate(inputs=inputs)
def decode(self, outputs):
return self.pre_processor.batch_decode(outputs, skip_special_tokens=True)[0]

View File

@ -0,0 +1,70 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from ..models.auto import AutoModelForSequenceClassification, AutoTokenizer
from .base import PipelineTool
class TextClassificationTool(PipelineTool):
"""
Example:
```py
from transformers.tools import TextClassificationTool
classifier = TextClassificationTool()
classifier("This is a super nice API!", labels=["positive", "negative"])
```
"""
default_checkpoint = "facebook/bart-large-mnli"
description = (
"This is a tool that classifies an English text using provided labels. It takes two inputs: `text`, which "
"should be the text to classify, and `labels`, which should be the list of labels to use for classification. "
"It returns the most likely label in the list of provided `labels` for the input text."
)
name = "text_classifier"
pre_processor_class = AutoTokenizer
model_class = AutoModelForSequenceClassification
inputs = ["text", ["text"]]
outputs = ["text"]
def setup(self):
super().setup()
config = self.model.config
self.entailment_id = -1
for idx, label in config.id2label.items():
if label.lower().startswith("entail"):
self.entailment_id = int(idx)
if self.entailment_id == -1:
raise ValueError("Could not determine the entailment ID from the model config, please pass it at init.")
def encode(self, text, labels):
self._labels = labels
return self.pre_processor(
[text] * len(labels),
[f"This example is {label}" for label in labels],
return_tensors="pt",
padding="max_length",
)
def decode(self, outputs):
logits = outputs.logits
label_id = torch.argmax(logits[:, 2]).item()
return self._labels[label_id]

View File

@ -0,0 +1,52 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ..models.auto import AutoModelForSeq2SeqLM, AutoTokenizer
from .base import PipelineTool
QA_PROMPT = """Here is a text containing a lot of information: '''{text}'''.
Can you answer this question about the text: '{question}'"""
class TextQuestionAnsweringTool(PipelineTool):
default_checkpoint = "google/flan-t5-base"
description = (
"This is a tool that answers questions related to a text. It takes two arguments named `text`, which is the "
"text where to find the answer, and `question`, which is the question, and returns the answer to the question."
)
name = "text_qa"
pre_processor_class = AutoTokenizer
model_class = AutoModelForSeq2SeqLM
inputs = ["text", "text"]
outputs = ["text"]
def encode(self, text: str, question: str):
prompt = QA_PROMPT.format(text=text, question=question)
return self.pre_processor(prompt, return_tensors="pt")
def forward(self, inputs):
output_ids = self.model.generate(**inputs)
in_b, _ = inputs["input_ids"].shape
out_b = output_ids.shape[0]
return output_ids.reshape(in_b, out_b // in_b, *output_ids.shape[1:])[0][0]
def decode(self, outputs):
return self.pre_processor.decode(outputs, skip_special_tokens=True, clean_up_tokenization_spaces=True)

View File

@ -0,0 +1,52 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ..models.auto import AutoModelForSeq2SeqLM, AutoTokenizer
from .base import PipelineTool
class TextSummarizationTool(PipelineTool):
"""
Example:
```py
from transformers.tools import TextSummarizationTool
summarizer = TextSummarizationTool()
summarizer(long_text)
```
"""
default_checkpoint = "philschmid/bart-large-cnn-samsum"
description = (
"This is a tool that summarizes an English text. It takes an input `text` containing the text to summarize, "
"and returns a summary of the text."
)
name = "summarizer"
pre_processor_class = AutoTokenizer
model_class = AutoModelForSeq2SeqLM
inputs = ["text"]
outputs = ["text"]
def encode(self, text):
return self.pre_processor(text, return_tensors="pt", truncation=True)
def forward(self, inputs):
return self.model.generate(**inputs)[0]
def decode(self, outputs):
return self.pre_processor.decode(outputs, skip_special_tokens=True, clean_up_tokenization_spaces=True)

View File

@ -0,0 +1,65 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from ..models.speecht5 import SpeechT5ForTextToSpeech, SpeechT5HifiGan, SpeechT5Processor
from ..utils import is_datasets_available
from .base import PipelineTool
if is_datasets_available():
from datasets import load_dataset
class TextToSpeechTool(PipelineTool):
default_checkpoint = "microsoft/speecht5_tts"
description = (
"This is a tool that reads an English text out loud. It takes an input named `text` which should contain the "
"text to read (in English) and returns a waveform object containing the sound."
)
name = "text_reader"
pre_processor_class = SpeechT5Processor
model_class = SpeechT5ForTextToSpeech
post_processor_class = SpeechT5HifiGan
inputs = ["text"]
outputs = ["audio"]
def setup(self):
if self.post_processor is None:
self.post_processor = "microsoft/speecht5_hifigan"
super().setup()
def encode(self, text, speaker_embeddings=None):
inputs = self.pre_processor(text=text, return_tensors="pt", truncation=True)
if speaker_embeddings is None:
if not is_datasets_available():
raise ImportError("Datasets needs to be installed if not passing speaker embeddings.")
embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
speaker_embeddings = torch.tensor(embeddings_dataset[7305]["xvector"]).unsqueeze(0)
return {"input_ids": inputs["input_ids"], "speaker_embeddings": speaker_embeddings}
def forward(self, inputs):
with torch.no_grad():
return self.model.generate_speech(**inputs)
def decode(self, outputs):
with torch.no_grad():
return self.post_processor(outputs).cpu().detach()

View File

@ -0,0 +1,271 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ..models.auto import AutoModelForSeq2SeqLM, AutoTokenizer
from .base import PipelineTool
LANGUAGE_CODES = {
"Acehnese Arabic": "ace_Arab",
"Acehnese Latin": "ace_Latn",
"Mesopotamian Arabic": "acm_Arab",
"Ta'izzi-Adeni Arabic": "acq_Arab",
"Tunisian Arabic": "aeb_Arab",
"Afrikaans": "afr_Latn",
"South Levantine Arabic": "ajp_Arab",
"Akan": "aka_Latn",
"Amharic": "amh_Ethi",
"North Levantine Arabic": "apc_Arab",
"Modern Standard Arabic": "arb_Arab",
"Modern Standard Arabic Romanized": "arb_Latn",
"Najdi Arabic": "ars_Arab",
"Moroccan Arabic": "ary_Arab",
"Egyptian Arabic": "arz_Arab",
"Assamese": "asm_Beng",
"Asturian": "ast_Latn",
"Awadhi": "awa_Deva",
"Central Aymara": "ayr_Latn",
"South Azerbaijani": "azb_Arab",
"North Azerbaijani": "azj_Latn",
"Bashkir": "bak_Cyrl",
"Bambara": "bam_Latn",
"Balinese": "ban_Latn",
"Belarusian": "bel_Cyrl",
"Bemba": "bem_Latn",
"Bengali": "ben_Beng",
"Bhojpuri": "bho_Deva",
"Banjar Arabic": "bjn_Arab",
"Banjar Latin": "bjn_Latn",
"Standard Tibetan": "bod_Tibt",
"Bosnian": "bos_Latn",
"Buginese": "bug_Latn",
"Bulgarian": "bul_Cyrl",
"Catalan": "cat_Latn",
"Cebuano": "ceb_Latn",
"Czech": "ces_Latn",
"Chokwe": "cjk_Latn",
"Central Kurdish": "ckb_Arab",
"Crimean Tatar": "crh_Latn",
"Welsh": "cym_Latn",
"Danish": "dan_Latn",
"German": "deu_Latn",
"Southwestern Dinka": "dik_Latn",
"Dyula": "dyu_Latn",
"Dzongkha": "dzo_Tibt",
"Greek": "ell_Grek",
"English": "eng_Latn",
"Esperanto": "epo_Latn",
"Estonian": "est_Latn",
"Basque": "eus_Latn",
"Ewe": "ewe_Latn",
"Faroese": "fao_Latn",
"Fijian": "fij_Latn",
"Finnish": "fin_Latn",
"Fon": "fon_Latn",
"French": "fra_Latn",
"Friulian": "fur_Latn",
"Nigerian Fulfulde": "fuv_Latn",
"Scottish Gaelic": "gla_Latn",
"Irish": "gle_Latn",
"Galician": "glg_Latn",
"Guarani": "grn_Latn",
"Gujarati": "guj_Gujr",
"Haitian Creole": "hat_Latn",
"Hausa": "hau_Latn",
"Hebrew": "heb_Hebr",
"Hindi": "hin_Deva",
"Chhattisgarhi": "hne_Deva",
"Croatian": "hrv_Latn",
"Hungarian": "hun_Latn",
"Armenian": "hye_Armn",
"Igbo": "ibo_Latn",
"Ilocano": "ilo_Latn",
"Indonesian": "ind_Latn",
"Icelandic": "isl_Latn",
"Italian": "ita_Latn",
"Javanese": "jav_Latn",
"Japanese": "jpn_Jpan",
"Kabyle": "kab_Latn",
"Jingpho": "kac_Latn",
"Kamba": "kam_Latn",
"Kannada": "kan_Knda",
"Kashmiri Arabic": "kas_Arab",
"Kashmiri Devanagari": "kas_Deva",
"Georgian": "kat_Geor",
"Central Kanuri Arabic": "knc_Arab",
"Central Kanuri Latin": "knc_Latn",
"Kazakh": "kaz_Cyrl",
"Kabiyè": "kbp_Latn",
"Kabuverdianu": "kea_Latn",
"Khmer": "khm_Khmr",
"Kikuyu": "kik_Latn",
"Kinyarwanda": "kin_Latn",
"Kyrgyz": "kir_Cyrl",
"Kimbundu": "kmb_Latn",
"Northern Kurdish": "kmr_Latn",
"Kikongo": "kon_Latn",
"Korean": "kor_Hang",
"Lao": "lao_Laoo",
"Ligurian": "lij_Latn",
"Limburgish": "lim_Latn",
"Lingala": "lin_Latn",
"Lithuanian": "lit_Latn",
"Lombard": "lmo_Latn",
"Latgalian": "ltg_Latn",
"Luxembourgish": "ltz_Latn",
"Luba-Kasai": "lua_Latn",
"Ganda": "lug_Latn",
"Luo": "luo_Latn",
"Mizo": "lus_Latn",
"Standard Latvian": "lvs_Latn",
"Magahi": "mag_Deva",
"Maithili": "mai_Deva",
"Malayalam": "mal_Mlym",
"Marathi": "mar_Deva",
"Minangkabau Arabic ": "min_Arab",
"Minangkabau Latin": "min_Latn",
"Macedonian": "mkd_Cyrl",
"Plateau Malagasy": "plt_Latn",
"Maltese": "mlt_Latn",
"Meitei Bengali": "mni_Beng",
"Halh Mongolian": "khk_Cyrl",
"Mossi": "mos_Latn",
"Maori": "mri_Latn",
"Burmese": "mya_Mymr",
"Dutch": "nld_Latn",
"Norwegian Nynorsk": "nno_Latn",
"Norwegian Bokmål": "nob_Latn",
"Nepali": "npi_Deva",
"Northern Sotho": "nso_Latn",
"Nuer": "nus_Latn",
"Nyanja": "nya_Latn",
"Occitan": "oci_Latn",
"West Central Oromo": "gaz_Latn",
"Odia": "ory_Orya",
"Pangasinan": "pag_Latn",
"Eastern Panjabi": "pan_Guru",
"Papiamento": "pap_Latn",
"Western Persian": "pes_Arab",
"Polish": "pol_Latn",
"Portuguese": "por_Latn",
"Dari": "prs_Arab",
"Southern Pashto": "pbt_Arab",
"Ayacucho Quechua": "quy_Latn",
"Romanian": "ron_Latn",
"Rundi": "run_Latn",
"Russian": "rus_Cyrl",
"Sango": "sag_Latn",
"Sanskrit": "san_Deva",
"Santali": "sat_Olck",
"Sicilian": "scn_Latn",
"Shan": "shn_Mymr",
"Sinhala": "sin_Sinh",
"Slovak": "slk_Latn",
"Slovenian": "slv_Latn",
"Samoan": "smo_Latn",
"Shona": "sna_Latn",
"Sindhi": "snd_Arab",
"Somali": "som_Latn",
"Southern Sotho": "sot_Latn",
"Spanish": "spa_Latn",
"Tosk Albanian": "als_Latn",
"Sardinian": "srd_Latn",
"Serbian": "srp_Cyrl",
"Swati": "ssw_Latn",
"Sundanese": "sun_Latn",
"Swedish": "swe_Latn",
"Swahili": "swh_Latn",
"Silesian": "szl_Latn",
"Tamil": "tam_Taml",
"Tatar": "tat_Cyrl",
"Telugu": "tel_Telu",
"Tajik": "tgk_Cyrl",
"Tagalog": "tgl_Latn",
"Thai": "tha_Thai",
"Tigrinya": "tir_Ethi",
"Tamasheq Latin": "taq_Latn",
"Tamasheq Tifinagh": "taq_Tfng",
"Tok Pisin": "tpi_Latn",
"Tswana": "tsn_Latn",
"Tsonga": "tso_Latn",
"Turkmen": "tuk_Latn",
"Tumbuka": "tum_Latn",
"Turkish": "tur_Latn",
"Twi": "twi_Latn",
"Central Atlas Tamazight": "tzm_Tfng",
"Uyghur": "uig_Arab",
"Ukrainian": "ukr_Cyrl",
"Umbundu": "umb_Latn",
"Urdu": "urd_Arab",
"Northern Uzbek": "uzn_Latn",
"Venetian": "vec_Latn",
"Vietnamese": "vie_Latn",
"Waray": "war_Latn",
"Wolof": "wol_Latn",
"Xhosa": "xho_Latn",
"Eastern Yiddish": "ydd_Hebr",
"Yoruba": "yor_Latn",
"Yue Chinese": "yue_Hant",
"Chinese Simplified": "zho_Hans",
"Chinese Traditional": "zho_Hant",
"Standard Malay": "zsm_Latn",
"Zulu": "zul_Latn",
}
class TranslationTool(PipelineTool):
"""
Example:
```py
from transformers.tools import TranslationTool
translator = TranslationTool()
translator("This is a super nice API!", src_lang="English", tgt_lang="French")
```
"""
default_checkpoint = "facebook/nllb-200-distilled-600M"
description = (
"This is a tool that translates text from a language to another. It takes three inputs: `text`, which should "
"be the text to translate, `src_lang`, which should be the language of the text to translate and `tgt_lang`, "
"which should be the language for the desired ouput language. Both `src_lang` and `tgt_lang` are written in "
"plain English, such as 'Romanian', or 'Albanian'. It returns the text translated in `tgt_lang`."
)
name = "translator"
pre_processor_class = AutoTokenizer
model_class = AutoModelForSeq2SeqLM
lang_to_code = LANGUAGE_CODES
inputs = ["text", "text", "text"]
outputs = ["text"]
def encode(self, text, src_lang, tgt_lang):
if src_lang not in self.lang_to_code:
raise ValueError(f"{src_lang} is not a supported language.")
if tgt_lang not in self.lang_to_code:
raise ValueError(f"{tgt_lang} is not a supported language.")
src_lang = self.lang_to_code[src_lang]
tgt_lang = self.lang_to_code[tgt_lang]
return self.pre_processor._build_translation_inputs(
text, return_tensors="pt", src_lang=src_lang, tgt_lang=tgt_lang
)
def forward(self, inputs):
return self.model.generate(**inputs)
def decode(self, outputs):
return self.post_processor.decode(outputs[0].tolist(), skip_special_tokens=True)

View File

@ -121,6 +121,7 @@ from .import_utils import (
is_natten_available,
is_ninja_available,
is_onnx_available,
is_openai_available,
is_optimum_available,
is_pandas_available,
is_peft_available,

View File

@ -235,6 +235,7 @@ def try_to_load_from_cache(
filename: str,
cache_dir: Union[str, Path, None] = None,
revision: Optional[str] = None,
repo_type: Optional[str] = None,
) -> Optional[str]:
"""
Explores the cache to return the latest cached file for a given revision if found.
@ -251,6 +252,8 @@ def try_to_load_from_cache(
revision (`str`, *optional*):
The specific model version to use. Will default to `"main"` if it's not provided and no `commit_hash` is
provided either.
repo_type (`str`, *optional*):
The type of the repo.
Returns:
`Optional[str]` or `_CACHED_NO_EXIST`:
@ -266,7 +269,9 @@ def try_to_load_from_cache(
cache_dir = TRANSFORMERS_CACHE
object_id = repo_id.replace("/", "--")
repo_cache = os.path.join(cache_dir, f"models--{object_id}")
if repo_type is None:
repo_type = "model"
repo_cache = os.path.join(cache_dir, f"{repo_type}s--{object_id}")
if not os.path.isdir(repo_cache):
# No cache for this model
return None
@ -303,6 +308,7 @@ def cached_file(
revision: Optional[str] = None,
local_files_only: bool = False,
subfolder: str = "",
repo_type: Optional[str] = None,
user_agent: Optional[Union[str, Dict[str, str]]] = None,
_raise_exceptions_for_missing_entries: bool = True,
_raise_exceptions_for_connection_errors: bool = True,
@ -342,6 +348,8 @@ def cached_file(
subfolder (`str`, *optional*, defaults to `""`):
In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
specify the folder name here.
repo_type (`str`, *optional*):
Specify the repo type (useful when downloading from a space for instance).
<Tip>
@ -393,7 +401,7 @@ def cached_file(
if _commit_hash is not None and not force_download:
# If the file is cached under that commit hash, we return it directly.
resolved_file = try_to_load_from_cache(
path_or_repo_id, full_filename, cache_dir=cache_dir, revision=_commit_hash
path_or_repo_id, full_filename, cache_dir=cache_dir, revision=_commit_hash, repo_type=repo_type
)
if resolved_file is not None:
if resolved_file is not _CACHED_NO_EXIST:
@ -410,6 +418,7 @@ def cached_file(
path_or_repo_id,
filename,
subfolder=None if len(subfolder) == 0 else subfolder,
repo_type=repo_type,
revision=revision,
cache_dir=cache_dir,
user_agent=user_agent,

View File

@ -125,6 +125,14 @@ except importlib_metadata.PackageNotFoundError:
_datasets_available = False
_diffusers_available = importlib.util.find_spec("diffusers") is not None
try:
_diffusers_version = importlib_metadata.version("diffusers")
logger.debug(f"Successfully imported diffusers version {_diffusers_version}")
except importlib_metadata.PackageNotFoundError:
_diffusers_available = False
_detectron2_available = importlib.util.find_spec("detectron2") is not None
try:
_detectron2_version = importlib_metadata.version("detectron2")
@ -185,6 +193,9 @@ except importlib_metadata.PackageNotFoundError:
_onnx_available = False
_opencv_available = importlib.util.find_spec("cv2") is not None
_pytorch_quantization_available = importlib.util.find_spec("pytorch_quantization") is not None
try:
_pytorch_quantization_version = importlib_metadata.version("pytorch_quantization")
@ -431,6 +442,10 @@ def is_onnx_available():
return _onnx_available
def is_openai_available():
return importlib.util.find_spec("openai") is not None
def is_flax_available():
return _flax_available

0
tests/tools/__init__.py Normal file
View File

View File

@ -0,0 +1,57 @@
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from datasets import load_dataset
from transformers import load_tool
from .test_tools_common import ToolTesterMixin
class DocumentQuestionAnsweringToolTester(unittest.TestCase, ToolTesterMixin):
def setUp(self):
self.tool = load_tool("document-question-answering")
self.tool.setup()
self.remote_tool = load_tool("document-question-answering", remote=True)
def test_exact_match_arg(self):
dataset = load_dataset("hf-internal-testing/example-documents", split="test")
image = dataset[0]["image"]
result = self.tool(image, "When is the coffee break?")
self.assertEqual(result, "11-14 to 11:39 a.m.")
def test_exact_match_arg_remote(self):
dataset = load_dataset("hf-internal-testing/example-documents", split="test")
image = dataset[0]["image"]
result = self.remote_tool(image, "When is the coffee break?")
self.assertEqual(result, "11-14 to 11:39 a.m.")
def test_exact_match_kwarg(self):
dataset = load_dataset("hf-internal-testing/example-documents", split="test")
image = dataset[0]["image"]
result = self.tool(image=image, question="When is the coffee break?")
self.assertEqual(result, "11-14 to 11:39 a.m.")
def test_exact_match_kwarg_remote(self):
dataset = load_dataset("hf-internal-testing/example-documents", split="test")
image = dataset[0]["image"]
result = self.remote_tool(image=image, question="When is the coffee break?")
self.assertEqual(result, "11-14 to 11:39 a.m.")

View File

@ -0,0 +1,53 @@
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from pathlib import Path
from transformers import is_vision_available, load_tool
from transformers.testing_utils import get_tests_dir
from .test_tools_common import ToolTesterMixin
if is_vision_available():
from PIL import Image
class ImageCaptioningToolTester(unittest.TestCase, ToolTesterMixin):
def setUp(self):
self.tool = load_tool("image-captioning")
self.tool.setup()
self.remote_tool = load_tool("image-captioning", remote=True)
def test_exact_match_arg(self):
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png")
result = self.tool(image)
self.assertEqual(result, "two cats sleeping on a couch")
def test_exact_match_arg_remote(self):
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png")
result = self.remote_tool(image)
self.assertEqual(result, "two cats sleeping on a couch")
def test_exact_match_kwarg(self):
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png")
result = self.tool(image=image)
self.assertEqual(result, "two cats sleeping on a couch")
def test_exact_match_kwarg_remote(self):
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png")
result = self.remote_tool(image=image)
self.assertEqual(result, "two cats sleeping on a couch")

View File

@ -0,0 +1,53 @@
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from pathlib import Path
from transformers import is_vision_available, load_tool
from transformers.testing_utils import get_tests_dir
from .test_tools_common import ToolTesterMixin
if is_vision_available():
from PIL import Image
class ImageQuestionAnsweringToolTester(unittest.TestCase, ToolTesterMixin):
def setUp(self):
self.tool = load_tool("image-question-answering")
self.tool.setup()
self.remote_tool = load_tool("image-question-answering", remote=True)
def test_exact_match_arg(self):
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png")
result = self.tool(image, "How many cats are sleeping on the couch?")
self.assertEqual(result, "2")
def test_exact_match_arg_remote(self):
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png")
result = self.remote_tool(image, "How many cats are sleeping on the couch?")
self.assertEqual(result, "2")
def test_exact_match_kwarg(self):
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png")
result = self.tool(image=image, question="How many cats are sleeping on the couch?")
self.assertEqual(result, "2")
def test_exact_match_kwarg_remote(self):
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png")
result = self.remote_tool(image=image, question="How many cats are sleeping on the couch?")
self.assertEqual(result, "2")

View File

@ -0,0 +1,53 @@
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from pathlib import Path
from transformers import is_vision_available, load_tool
from transformers.testing_utils import get_tests_dir
from .test_tools_common import ToolTesterMixin
if is_vision_available():
from PIL import Image
class ImageSegmentationToolTester(unittest.TestCase, ToolTesterMixin):
def setUp(self):
self.tool = load_tool("image-segmentation")
self.tool.setup()
self.remote_tool = load_tool("image-segmentation", remote=True)
def test_exact_match_arg(self):
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png").resize((512, 512))
result = self.tool(image, "cat")
self.assertTrue(isinstance(result, Image.Image))
def test_exact_match_arg_remote(self):
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png").resize((512, 512))
result = self.remote_tool(image, "cat")
self.assertTrue(isinstance(result, Image.Image))
def test_exact_match_kwarg(self):
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png").resize((512, 512))
result = self.tool(image=image, prompt="cat")
self.assertTrue(isinstance(result, Image.Image))
def test_exact_match_kwarg_remote(self):
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png").resize((512, 512))
result = self.remote_tool(image=image, prompt="cat")
self.assertTrue(isinstance(result, Image.Image))

View File

@ -0,0 +1,124 @@
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from transformers.testing_utils import CaptureStdout
from transformers.tools.python_interpreter import evaluate
# Fake function we will use as tool
def add_two(x):
return x + 2
class PythonInterpreterTester(unittest.TestCase):
def test_evaluate_assign(self):
code = "x = 3"
state = {}
result = evaluate(code, {}, state=state)
assert result == 3
self.assertDictEqual(state, {"x": 3})
code = "x = y"
state = {"y": 5}
result = evaluate(code, {}, state=state)
# evaluate returns the value of the last assignment.
assert result == 5
self.assertDictEqual(state, {"x": 5, "y": 5})
def test_evaluate_call(self):
code = "y = add_two(x)"
state = {"x": 3}
result = evaluate(code, {"add_two": add_two}, state=state)
assert result == 5
self.assertDictEqual(state, {"x": 3, "y": 5})
# Won't work without the tool
with CaptureStdout() as out:
result = evaluate(code, {}, state=state)
assert result is None
assert "tried to execute add_two" in out.out
def test_evaluate_constant(self):
code = "x = 3"
state = {}
result = evaluate(code, {}, state=state)
assert result == 3
self.assertDictEqual(state, {"x": 3})
def test_evaluate_dict(self):
code = "test_dict = {'x': x, 'y': add_two(x)}"
state = {"x": 3}
result = evaluate(code, {"add_two": add_two}, state=state)
self.assertDictEqual(result, {"x": 3, "y": 5})
self.assertDictEqual(state, {"x": 3, "test_dict": {"x": 3, "y": 5}})
def test_evaluate_expression(self):
code = "x = 3\ny = 5"
state = {}
result = evaluate(code, {}, state=state)
# evaluate returns the value of the last assignment.
assert result == 5
self.assertDictEqual(state, {"x": 3, "y": 5})
def test_evaluate_f_string(self):
code = "text = f'This is x: {x}.'"
state = {"x": 3}
result = evaluate(code, {}, state=state)
# evaluate returns the value of the last assignment.
assert result == "This is x: 3."
self.assertDictEqual(state, {"x": 3, "text": "This is x: 3."})
def test_evaluate_if(self):
code = "if x <= 3:\n y = 2\nelse:\n y = 5"
state = {"x": 3}
result = evaluate(code, {}, state=state)
# evaluate returns the value of the last assignment.
assert result == 2
self.assertDictEqual(state, {"x": 3, "y": 2})
state = {"x": 8}
result = evaluate(code, {}, state=state)
# evaluate returns the value of the last assignment.
assert result == 5
self.assertDictEqual(state, {"x": 8, "y": 5})
def test_evaluate_list(self):
code = "test_list = [x, add_two(x)]"
state = {"x": 3}
result = evaluate(code, {"add_two": add_two}, state=state)
self.assertListEqual(result, [3, 5])
self.assertDictEqual(state, {"x": 3, "test_list": [3, 5]})
def test_evaluate_name(self):
code = "y = x"
state = {"x": 3}
result = evaluate(code, {}, state=state)
assert result == 3
self.assertDictEqual(state, {"x": 3, "y": 3})
def test_evaluate_subscript(self):
code = "test_list = [x, add_two(x)]\ntest_list[1]"
state = {"x": 3}
result = evaluate(code, {"add_two": add_two}, state=state)
assert result == 5
self.assertDictEqual(state, {"x": 3, "test_list": [3, 5]})
code = "test_dict = {'x': x, 'y': add_two(x)}\ntest_dict['y']"
state = {"x": 3}
result = evaluate(code, {"add_two": add_two}, state=state)
assert result == 5
self.assertDictEqual(state, {"x": 3, "test_dict": {"x": 3, "y": 5}})

View File

@ -0,0 +1,38 @@
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from transformers import is_torch_available, load_tool
from .test_tools_common import ToolTesterMixin
if is_torch_available():
import torch
class SpeechToTextToolTester(unittest.TestCase, ToolTesterMixin):
def setUp(self):
self.tool = load_tool("speech-to-text")
self.tool.setup()
def test_exact_match_arg(self):
result = self.tool(torch.ones(3000))
self.assertEqual(result, " you")
def test_exact_match_kwarg(self):
result = self.tool(audio=torch.ones(3000))
self.assertEqual(result, " you")

View File

@ -0,0 +1,43 @@
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from transformers import load_tool
from .test_tools_common import ToolTesterMixin
class TextClassificationToolTester(unittest.TestCase, ToolTesterMixin):
def setUp(self):
self.tool = load_tool("text-classification")
self.tool.setup()
self.remote_tool = load_tool("text-classification", remote=True)
def test_exact_match_arg(self):
result = self.tool("That's quite cool", ["positive", "negative"])
self.assertEqual(result, "positive")
def test_exact_match_arg_remote(self):
result = self.remote_tool("That's quite cool", ["positive", "negative"])
self.assertEqual(result, "positive")
def test_exact_match_kwarg(self):
result = self.tool(text="That's quite cool", labels=["positive", "negative"])
self.assertEqual(result, "positive")
def test_exact_match_kwarg_remote(self):
result = self.remote_tool(text="That's quite cool", labels=["positive", "negative"])
self.assertEqual(result, "positive")

View File

@ -0,0 +1,52 @@
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from transformers import load_tool
from .test_tools_common import ToolTesterMixin
TEXT = """
Hugging Face was founded in 2016 by French entrepreneurs Clément Delangue, Julien Chaumond, and Thomas Wolf originally as a company that developed a chatbot app targeted at teenagers.[2] After open-sourcing the model behind the chatbot, the company pivoted to focus on being a platform for machine learning.
In March 2021, Hugging Face raised $40 million in a Series B funding round.[3]
On April 28, 2021, the company launched the BigScience Research Workshop in collaboration with several other research groups to release an open large language model.[4] In 2022, the workshop concluded with the announcement of BLOOM, a multilingual large language model with 176 billion parameters.[5]
"""
class TextQuestionAnsweringToolTester(unittest.TestCase, ToolTesterMixin):
def setUp(self):
self.tool = load_tool("text-question-answering")
self.tool.setup()
self.remote_tool = load_tool("text-question-answering", remote=True)
def test_exact_match_arg(self):
result = self.tool(TEXT, "What did Hugging Face do in April 2021?")
self.assertEqual(result, "launched the BigScience Research Workshop")
def test_exact_match_arg_remote(self):
result = self.remote_tool(TEXT, "What did Hugging Face do in April 2021?")
self.assertEqual(result, "launched the BigScience Research Workshop")
def test_exact_match_kwarg(self):
result = self.tool(text=TEXT, question="What did Hugging Face do in April 2021?")
self.assertEqual(result, "launched the BigScience Research Workshop")
def test_exact_match_kwarg_remote(self):
result = self.remote_tool(text=TEXT, question="What did Hugging Face do in April 2021?")
self.assertEqual(result, "launched the BigScience Research Workshop")

View File

@ -0,0 +1,64 @@
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from transformers import load_tool
from .test_tools_common import ToolTesterMixin
TEXT = """
Hugging Face was founded in 2016 by French entrepreneurs Clément Delangue, Julien Chaumond, and Thomas Wolf originally as a company that developed a chatbot app targeted at teenagers.[2] After open-sourcing the model behind the chatbot, the company pivoted to focus on being a platform for machine learning.
In March 2021, Hugging Face raised $40 million in a Series B funding round.[3]
On April 28, 2021, the company launched the BigScience Research Workshop in collaboration with several other research groups to release an open large language model.[4] In 2022, the workshop concluded with the announcement of BLOOM, a multilingual large language model with 176 billion parameters.[5]
"""
class TextSummarizationToolTester(unittest.TestCase, ToolTesterMixin):
def setUp(self):
self.tool = load_tool("summarization")
self.tool.setup()
self.remote_tool = load_tool("summarization", remote=True)
def test_exact_match_arg(self):
result = self.tool(TEXT)
self.assertEqual(
result,
"Hugging Face was founded in 2016 by French entrepreneurs Clément Delangue, Julien Chaumond, and Thomas Wolf. In March 2021, Hugging Face raised $40 million in a Series B funding round. On April 28, 2021, the company launched the BigScience Research Workshop in collaboration with several other research groups to release an open large language model. In 2022, the workshop concluded with the announcement of BLOOM.",
)
def test_exact_match_arg_remote(self):
result = self.remote_tool(TEXT)
self.assertEqual(
result,
"Hugging Face was founded in 2016 by French entrepreneurs Clément Delangue, Julien Chaumond, and Thomas Wolf. In March 2021, Hugging Face raised $40 million in a Series B funding round. On April 28, 2021, the company launched the BigScience Research Workshop in collaboration with several other research groups to release an open large language model. In 2022, the workshop concluded with the announcement of BLOOM.",
)
def test_exact_match_kwarg(self):
result = self.tool(text=TEXT)
self.assertEqual(
result,
"Hugging Face was founded in 2016 by French entrepreneurs Clément Delangue, Julien Chaumond, and Thomas Wolf. In March 2021, Hugging Face raised $40 million in a Series B funding round. On April 28, 2021, the company launched the BigScience Research Workshop in collaboration with several other research groups to release an open large language model. In 2022, the workshop concluded with the announcement of BLOOM.",
)
def test_exact_match_kwarg_remote(self):
result = self.remote_tool(text=TEXT)
self.assertEqual(
result,
"Hugging Face was founded in 2016 by French entrepreneurs Clément Delangue, Julien Chaumond, and Thomas Wolf. In March 2021, Hugging Face raised $40 million in a Series B funding round. On April 28, 2021, the company launched the BigScience Research Workshop in collaboration with several other research groups to release an open large language model. In 2022, the workshop concluded with the announcement of BLOOM.",
)

View File

@ -0,0 +1,54 @@
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from transformers import load_tool
from transformers.utils import is_torch_available
if is_torch_available():
import torch
from transformers.testing_utils import require_torch
from .test_tools_common import ToolTesterMixin
@require_torch
class TextToSpeechToolTester(unittest.TestCase, ToolTesterMixin):
def setUp(self):
self.tool = load_tool("text-to-speech")
self.tool.setup()
def test_exact_match_arg(self):
# SpeechT5 isn't deterministic
torch.manual_seed(0)
result = self.tool("hey")
self.assertTrue(
torch.allclose(
result[:3], torch.tensor([-0.00040140701457858086, -0.0002551682700868696, -0.00010294507956132293])
)
)
def test_exact_match_kwarg(self):
# SpeechT5 isn't deterministic
torch.manual_seed(0)
result = self.tool("hey")
self.assertTrue(
torch.allclose(
result[:3], torch.tensor([-0.00040140701457858086, -0.0002551682700868696, -0.00010294507956132293])
)
)

View File

@ -0,0 +1,100 @@
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pathlib import Path
from typing import List
from transformers import is_torch_available, is_vision_available
from transformers.testing_utils import get_tests_dir, is_tool_test
if is_torch_available():
import torch
if is_vision_available():
from PIL import Image
authorized_types = ["text", "image", "audio"]
def create_inputs(input_types: List[str]):
inputs = []
for input_type in input_types:
if input_type == "text":
inputs.append("Text input")
elif input_type == "image":
inputs.append(
Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png").resize((512, 512))
)
elif input_type == "audio":
inputs.append(torch.ones(3000))
elif isinstance(input_type, list):
inputs.append(create_inputs(input_type))
else:
raise ValueError(f"Invalid type requested: {input_type}")
return inputs
def output_types(outputs: List):
output_types = []
for output in outputs:
if isinstance(output, str):
output_types.append("text")
elif isinstance(output, Image.Image):
output_types.append("image")
elif isinstance(output, torch.Tensor):
output_types.append("audio")
else:
raise ValueError(f"Invalid output: {output}")
return output_types
@is_tool_test
class ToolTesterMixin:
def test_inputs_outputs(self):
self.assertTrue(hasattr(self.tool, "inputs"))
self.assertTrue(hasattr(self.tool, "outputs"))
inputs = self.tool.inputs
for _input in inputs:
if isinstance(_input, list):
for __input in _input:
self.assertTrue(__input in authorized_types)
else:
self.assertTrue(_input in authorized_types)
outputs = self.tool.outputs
for _output in outputs:
self.assertTrue(_output in authorized_types)
def test_call(self):
inputs = create_inputs(self.tool.inputs)
outputs = self.tool(*inputs)
# There is a single output
if len(self.tool.outputs) == 1:
outputs = [outputs]
self.assertListEqual(output_types(outputs), self.tool.outputs)
def test_common_attributes(self):
self.assertTrue(hasattr(self.tool, "description"))
self.assertTrue(hasattr(self.tool, "default_checkpoint"))
self.assertTrue(self.tool.description.startswith("This is a tool that"))

View File

@ -0,0 +1,53 @@
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from transformers import load_tool
from .test_tools_common import ToolTesterMixin, output_types
class TranslationToolTester(unittest.TestCase, ToolTesterMixin):
def setUp(self):
self.tool = load_tool("translation")
self.tool.setup()
self.remote_tool = load_tool("translation", remote=True)
def test_exact_match_arg(self):
result = self.tool("Hey, what's up?", src_lang="English", tgt_lang="French")
self.assertEqual(result, "- Hé, comment ça va?")
def test_exact_match_arg_remote(self):
result = self.remote_tool("Hey, what's up?", src_lang="English", tgt_lang="French")
self.assertEqual(result, "- Hé, comment ça va?")
def test_exact_match_kwarg(self):
result = self.tool(text="Hey, what's up?", src_lang="English", tgt_lang="French")
self.assertEqual(result, "- Hé, comment ça va?")
def test_exact_match_kwarg_remote(self):
result = self.remote_tool(text="Hey, what's up?", src_lang="English", tgt_lang="French")
self.assertEqual(result, "- Hé, comment ça va?")
def test_call(self):
inputs = ["Hey, what's up?", "English", "Spanish"]
outputs = self.tool(*inputs)
# There is a single output
if len(self.tool.outputs) == 1:
outputs = [outputs]
self.assertListEqual(output_types(outputs), self.tool.outputs)