Compare commits
6 Commits
fix-llama-
...
main
Author | SHA1 | Date |
---|---|---|
Pavel Iakubovskii | cdc813113a | |
Aymeric Roucher | 9837a25481 | |
Marc Sun | f8e6ba454c | |
Younes Belkada | fc5d3e112a | |
Asif Ajrof | bd9d1ddf41 | |
Marc Sun | 48cada87c3 |
|
@ -28,8 +28,8 @@ An agent is a system that uses an LLM as its engine, and it has access to functi
|
||||||
These *tools* are functions for performing a task, and they contain all necessary description for the agent to properly use them.
|
These *tools* are functions for performing a task, and they contain all necessary description for the agent to properly use them.
|
||||||
|
|
||||||
The agent can be programmed to:
|
The agent can be programmed to:
|
||||||
- devise a series of actions/tools and run them all at once like the `CodeAgent` for example
|
- devise a series of actions/tools and run them all at once like the [`CodeAgent`] for example
|
||||||
- plan and execute actions/tools one by one and wait for the outcome of each action before launching the next one like the `ReactJsonAgent` for example
|
- plan and execute actions/tools one by one and wait for the outcome of each action before launching the next one like the [`ReactJsonAgent`] for example
|
||||||
|
|
||||||
### Types of agents
|
### Types of agents
|
||||||
|
|
||||||
|
@ -42,8 +42,8 @@ This agent has a planning step, then generates python code to execute all its ac
|
||||||
This is the go-to agent to solve reasoning tasks, since the ReAct framework ([Yao et al., 2022](https://huggingface.co/papers/2210.03629)) makes it really efficient to think on the basis of its previous observations.
|
This is the go-to agent to solve reasoning tasks, since the ReAct framework ([Yao et al., 2022](https://huggingface.co/papers/2210.03629)) makes it really efficient to think on the basis of its previous observations.
|
||||||
|
|
||||||
We implement two versions of ReactJsonAgent:
|
We implement two versions of ReactJsonAgent:
|
||||||
- [`~ReactJsonAgent`] generates tool calls as a JSON in its output.
|
- [`ReactJsonAgent`] generates tool calls as a JSON in its output.
|
||||||
- [`~ReactCodeAgent`] is a new type of ReactJsonAgent that generates its tool calls as blobs of code, which works really well for LLMs that have strong coding performance.
|
- [`ReactCodeAgent`] is a new type of ReactJsonAgent that generates its tool calls as blobs of code, which works really well for LLMs that have strong coding performance.
|
||||||
|
|
||||||
> [!TIP]
|
> [!TIP]
|
||||||
> Read [Open-source LLMs as LangChain Agents](https://huggingface.co/blog/open-source-llms-as-agents) blog post to learn more the ReAct agent.
|
> Read [Open-source LLMs as LangChain Agents](https://huggingface.co/blog/open-source-llms-as-agents) blog post to learn more the ReAct agent.
|
||||||
|
@ -124,7 +124,7 @@ You could use any `llm_engine` method as long as:
|
||||||
|
|
||||||
You also need a `tools` argument which accepts a list of `Tools`. You can provide an empty list for `tools`, but use the default toolbox with the optional argument `add_base_tools=True`.
|
You also need a `tools` argument which accepts a list of `Tools`. You can provide an empty list for `tools`, but use the default toolbox with the optional argument `add_base_tools=True`.
|
||||||
|
|
||||||
Now you can create an agent, like `CodeAgent`, and run it. For convenience, we also provide the `HfEngine` class that uses `huggingface_hub.InferenceClient` under the hood.
|
Now you can create an agent, like [`CodeAgent`], and run it. For convenience, we also provide the [`HfEngine`] class that uses `huggingface_hub.InferenceClient` under the hood.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from transformers import CodeAgent, HfEngine
|
from transformers import CodeAgent, HfEngine
|
||||||
|
@ -139,7 +139,7 @@ agent.run(
|
||||||
```
|
```
|
||||||
|
|
||||||
This will be handy in case of emergency baguette need!
|
This will be handy in case of emergency baguette need!
|
||||||
You can even leave the argument `llm_engine` undefined, and an [~HfEngine] will be created by default.
|
You can even leave the argument `llm_engine` undefined, and an [`HfEngine`] will be created by default.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from transformers import CodeAgent
|
from transformers import CodeAgent
|
||||||
|
@ -181,13 +181,27 @@ You can also run an agent consecutively for different tasks: each time the attri
|
||||||
A Python interpreter executes the code on a set of inputs passed along with your tools.
|
A Python interpreter executes the code on a set of inputs passed along with your tools.
|
||||||
This should be safe because the only functions that can be called are the tools you provided (especially if it's only tools by Hugging Face) and the print function, so you're already limited in what can be executed.
|
This should be safe because the only functions that can be called are the tools you provided (especially if it's only tools by Hugging Face) and the print function, so you're already limited in what can be executed.
|
||||||
|
|
||||||
The Python interpreter also doesn't allow any attribute lookup or imports (which shouldn't be needed for passing inputs/outputs to a small set of functions) so all the most obvious attacks shouldn't be an issue.
|
The Python interpreter also doesn't allow imports by default outside of a safe list, so all the most obvious attacks shouldn't be an issue.
|
||||||
|
You can still authorize additional imports by passing the authorized modules as a list of strings in argument `additional_authorized_imports` upon initialization of your [`ReactCodeAgent`] or [`CodeAgent`]:
|
||||||
|
|
||||||
|
```py
|
||||||
|
>>> from transformers import ReactCodeAgent
|
||||||
|
|
||||||
|
>>> agent = ReactCodeAgent(tools=[], additional_authorized_imports=['requests', 'bs4'])
|
||||||
|
>>>agent.run("Could you get me the title of the page at url 'https://huggingface.co/blog'?")
|
||||||
|
|
||||||
|
(...)
|
||||||
|
'Hugging Face – Blog'
|
||||||
|
```
|
||||||
|
|
||||||
The execution will stop at any code trying to perform an illegal operation or if there is a regular Python error with the code generated by the agent.
|
The execution will stop at any code trying to perform an illegal operation or if there is a regular Python error with the code generated by the agent.
|
||||||
|
|
||||||
|
> [!WARNING]
|
||||||
|
> The LLM can generate arbitrary code that will then be executed: do not add any unsafe imports!
|
||||||
|
|
||||||
### The system prompt
|
### The system prompt
|
||||||
|
|
||||||
An agent, or rather the LLM that drives the agent, generates an output based on the system prompt. The system prompt can be customized and tailored to the intended task. For example, check the system prompt for the `ReactCodeAgent` (below version is slightly simplified).
|
An agent, or rather the LLM that drives the agent, generates an output based on the system prompt. The system prompt can be customized and tailored to the intended task. For example, check the system prompt for the [`ReactCodeAgent`] (below version is slightly simplified).
|
||||||
|
|
||||||
```text
|
```text
|
||||||
You will be given a task to solve as best you can.
|
You will be given a task to solve as best you can.
|
||||||
|
@ -246,7 +260,7 @@ of the available tools.
|
||||||
|
|
||||||
A tool is an atomic function to be used by an agent.
|
A tool is an atomic function to be used by an agent.
|
||||||
|
|
||||||
You can for instance check the [~PythonInterpreterTool]: it has a name, a description, input descriptions, an output type, and a `__call__` method to perform the action.
|
You can for instance check the [`PythonInterpreterTool`]: it has a name, a description, input descriptions, an output type, and a `__call__` method to perform the action.
|
||||||
|
|
||||||
When the agent is initialized, the tool attributes are used to generate a tool description which is baked into the agent's system prompt. This lets the agent know which tools it can use and why.
|
When the agent is initialized, the tool attributes are used to generate a tool description which is baked into the agent's system prompt. This lets the agent know which tools it can use and why.
|
||||||
|
|
||||||
|
@ -259,7 +273,7 @@ Transformers comes with a default toolbox for empowering agents, that you can ad
|
||||||
- **Speech to text**: given an audio recording of a person talking, transcribe the speech into text ([Whisper](./model_doc/whisper))
|
- **Speech to text**: given an audio recording of a person talking, transcribe the speech into text ([Whisper](./model_doc/whisper))
|
||||||
- **Text to speech**: convert text to speech ([SpeechT5](./model_doc/speecht5))
|
- **Text to speech**: convert text to speech ([SpeechT5](./model_doc/speecht5))
|
||||||
- **Translation**: translates a given sentence from source language to target language.
|
- **Translation**: translates a given sentence from source language to target language.
|
||||||
- **Python code interpreter**: runs your the LLM generated Python code in a secure environment. This tool will only be added to [~ReactJsonAgent] if you use `add_base_tools=True`, since code-based tools can already execute Python code
|
- **Python code interpreter**: runs your the LLM generated Python code in a secure environment. This tool will only be added to [`ReactJsonAgent`] if you use `add_base_tools=True`, since code-based tools can already execute Python code
|
||||||
|
|
||||||
|
|
||||||
You can manually use a tool by calling the [`load_tool`] function and a task to perform.
|
You can manually use a tool by calling the [`load_tool`] function and a task to perform.
|
||||||
|
|
|
@ -41,6 +41,7 @@ This model was contributed by [Shivalika Singh](https://huggingface.co/shivi) an
|
||||||
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with Mask2Former.
|
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with Mask2Former.
|
||||||
|
|
||||||
- Demo notebooks regarding inference + fine-tuning Mask2Former on custom data can be found [here](https://github.com/NielsRogge/Transformers-Tutorials/tree/master/Mask2Former).
|
- Demo notebooks regarding inference + fine-tuning Mask2Former on custom data can be found [here](https://github.com/NielsRogge/Transformers-Tutorials/tree/master/Mask2Former).
|
||||||
|
- Scripts for finetuning [`Mask2Former`] with [`Trainer`] or [Accelerate](https://huggingface.co/docs/accelerate/index) can be found [here](https://github.com/huggingface/transformers/tree/main/examples/pytorch/instance-segmentation).
|
||||||
|
|
||||||
If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we will review it.
|
If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we will review it.
|
||||||
The resource should ideally demonstrate something new instead of duplicating an existing resource.
|
The resource should ideally demonstrate something new instead of duplicating an existing resource.
|
||||||
|
|
|
@ -51,6 +51,7 @@ This model was contributed by [francesco](https://huggingface.co/francesco). The
|
||||||
<PipelineTag pipeline="image-segmentation"/>
|
<PipelineTag pipeline="image-segmentation"/>
|
||||||
|
|
||||||
- All notebooks that illustrate inference as well as fine-tuning on custom data with MaskFormer can be found [here](https://github.com/NielsRogge/Transformers-Tutorials/tree/master/MaskFormer).
|
- All notebooks that illustrate inference as well as fine-tuning on custom data with MaskFormer can be found [here](https://github.com/NielsRogge/Transformers-Tutorials/tree/master/MaskFormer).
|
||||||
|
- Scripts for finetuning [`MaskFormer`] with [`Trainer`] or [Accelerate](https://huggingface.co/docs/accelerate/index) can be found [here](https://github.com/huggingface/transformers/tree/main/examples/pytorch/instance-segmentation).
|
||||||
|
|
||||||
## MaskFormer specific outputs
|
## MaskFormer specific outputs
|
||||||
|
|
||||||
|
|
|
@ -81,10 +81,10 @@ processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
|
||||||
img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
|
img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
|
||||||
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
|
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
|
||||||
mask_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
|
mask_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
|
||||||
segmentation_map = Image.open(requests.get(mask_url, stream=True).raw).convert("RGB")
|
segmentation_map = Image.open(requests.get(mask_url, stream=True).raw).convert("1")
|
||||||
input_points = [[[450, 600]]] # 2D location of a window in the image
|
input_points = [[[450, 600]]] # 2D location of a window in the image
|
||||||
|
|
||||||
inputs = processor(raw_image, input_points=input_points, segmentation_maps=mask, return_tensors="pt").to(device)
|
inputs = processor(raw_image, input_points=input_points, segmentation_maps=segmentation_map, return_tensors="pt").to(device)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
outputs = model(**inputs)
|
outputs = model(**inputs)
|
||||||
|
|
||||||
|
|
|
@ -47,6 +47,7 @@ Coming soon!
|
||||||
| [**`image-classification`**](https://github.com/huggingface/transformers/tree/main/examples/pytorch/image-classification) | [CIFAR-10](https://huggingface.co/datasets/cifar10) | ✅ | ✅ |✅ | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/image_classification.ipynb)
|
| [**`image-classification`**](https://github.com/huggingface/transformers/tree/main/examples/pytorch/image-classification) | [CIFAR-10](https://huggingface.co/datasets/cifar10) | ✅ | ✅ |✅ | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/image_classification.ipynb)
|
||||||
| [**`semantic-segmentation`**](https://github.com/huggingface/transformers/tree/main/examples/pytorch/semantic-segmentation) | [SCENE_PARSE_150](https://huggingface.co/datasets/scene_parse_150) | ✅ | ✅ |✅ | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/semantic_segmentation.ipynb)
|
| [**`semantic-segmentation`**](https://github.com/huggingface/transformers/tree/main/examples/pytorch/semantic-segmentation) | [SCENE_PARSE_150](https://huggingface.co/datasets/scene_parse_150) | ✅ | ✅ |✅ | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/semantic_segmentation.ipynb)
|
||||||
| [**`object-detection`**](https://github.com/huggingface/transformers/tree/main/examples/pytorch/object-detection) | [CPPE-5](https://huggingface.co/datasets/cppe-5) | ✅ | ✅ |✅ | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/transformers_doc/en/pytorch/object_detection.ipynb)
|
| [**`object-detection`**](https://github.com/huggingface/transformers/tree/main/examples/pytorch/object-detection) | [CPPE-5](https://huggingface.co/datasets/cppe-5) | ✅ | ✅ |✅ | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/transformers_doc/en/pytorch/object_detection.ipynb)
|
||||||
|
| [**`instance-segmentation`**](https://github.com/huggingface/transformers/tree/main/examples/pytorch/instance-segmentation) | [ADE20K sample](https://huggingface.co/datasets/qubvel-hf/ade20k-mini) | ✅ | ✅ |✅ |
|
||||||
|
|
||||||
|
|
||||||
## Running quick tests
|
## Running quick tests
|
||||||
|
|
|
@ -0,0 +1,235 @@
|
||||||
|
<!---
|
||||||
|
Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
-->
|
||||||
|
|
||||||
|
# Instance Segmentation Examples
|
||||||
|
|
||||||
|
This directory contains two scripts that demonstrate how to fine-tune [MaskFormer](https://huggingface.co/docs/transformers/model_doc/maskformer) and [Mask2Former](https://huggingface.co/docs/transformers/model_doc/mask2former) for instance segmentation using PyTorch.
|
||||||
|
For other instance segmentation models, such as [DETR](https://huggingface.co/docs/transformers/model_doc/detr) and [Conditional DETR](https://huggingface.co/docs/transformers/model_doc/conditional_detr), the scripts need to be adjusted to properly handle input and output data.
|
||||||
|
|
||||||
|
Content:
|
||||||
|
- [PyTorch Version with Trainer](#pytorch-version-with-trainer)
|
||||||
|
- [PyTorch Version with Accelerate](#pytorch-version-with-accelerate)
|
||||||
|
- [Reload and Perform Inference](#reload-and-perform-inference)
|
||||||
|
- [Note on Custom Data](#note-on-custom-data)
|
||||||
|
|
||||||
|
## PyTorch Version with Trainer
|
||||||
|
|
||||||
|
This example is based on the script [`run_instance_segmentation.py`](https://github.com/huggingface/transformers/blob/main/examples/pytorch/instance-segmentation/run_instance_segmentation.py).
|
||||||
|
|
||||||
|
The script uses the [🤗 Trainer API](https://huggingface.co/docs/transformers/main_classes/trainer) to manage training automatically, including distributed environments.
|
||||||
|
|
||||||
|
Here, we show how to fine-tune a [Mask2Former](https://huggingface.co/docs/transformers/model_doc/mask2former) model on a subsample of the [ADE20K](https://huggingface.co/datasets/zhoubolei/scene_parse_150) dataset. We created a [small dataset](https://huggingface.co/datasets/qubvel-hf/ade20k-mini) with approximately 2,000 images containing only "person" and "car" annotations; all other pixels are marked as "background."
|
||||||
|
|
||||||
|
Here is the `label2id` mapping for this dataset:
|
||||||
|
|
||||||
|
```python
|
||||||
|
label2id = {
|
||||||
|
"background": 0,
|
||||||
|
"person": 1,
|
||||||
|
"car": 2,
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Since the `background` label is not an instance and we don't want to predict it, we will use `do_reduce_labels` to remove it from the data.
|
||||||
|
|
||||||
|
Run the training with the following command:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python run_instance_segmentation.py \
|
||||||
|
--model_name_or_path facebook/mask2former-swin-tiny-coco-instance \
|
||||||
|
--output_dir finetune-instance-segmentation-ade20k-mini-mask2former \
|
||||||
|
--dataset_name qubvel-hf/ade20k-mini \
|
||||||
|
--do_reduce_labels \
|
||||||
|
--image_height 256 \
|
||||||
|
--image_width 256 \
|
||||||
|
--do_train \
|
||||||
|
--fp16 \
|
||||||
|
--num_train_epochs 40 \
|
||||||
|
--learning_rate 1e-5 \
|
||||||
|
--lr_scheduler_type constant \
|
||||||
|
--per_device_train_batch_size 8 \
|
||||||
|
--gradient_accumulation_steps 2 \
|
||||||
|
--dataloader_num_workers 8 \
|
||||||
|
--dataloader_persistent_workers \
|
||||||
|
--dataloader_prefetch_factor 4 \
|
||||||
|
--do_eval \
|
||||||
|
--evaluation_strategy epoch \
|
||||||
|
--logging_strategy epoch \
|
||||||
|
--save_strategy epoch \
|
||||||
|
--save_total_limit 2 \
|
||||||
|
--push_to_hub
|
||||||
|
```
|
||||||
|
|
||||||
|
The resulting model can be viewed [here](https://huggingface.co/qubvel-hf/finetune-instance-segmentation-ade20k-mini-mask2former). Always refer to the original paper for details on training hyperparameters. To improve model quality, consider:
|
||||||
|
- Changing image size parameters (`--image_height`/`--image_width`)
|
||||||
|
- Adjusting training parameters such as learning rate, batch size, warmup, optimizer, and more (see [TrainingArguments](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments))
|
||||||
|
- Adding more image augmentations (we created a helpful [HF Space](https://huggingface.co/spaces/qubvel-hf/albumentations-demo) to choose some)
|
||||||
|
|
||||||
|
You can also replace the model [checkpoint](https://huggingface.co/models?search=maskformer).
|
||||||
|
|
||||||
|
## PyTorch Version with Accelerate
|
||||||
|
|
||||||
|
This example is based on the script [`run_instance_segmentation_no_trainer.py`](https://github.com/huggingface/transformers/blob/main/examples/pytorch/instance-segmentation/run_instance_segmentation_no_trainer.py).
|
||||||
|
|
||||||
|
The script uses [🤗 Accelerate](https://github.com/huggingface/accelerate) to write your own training loop in PyTorch and run it on various environments, including CPU, multi-CPU, GPU, multi-GPU, and TPU, with support for mixed precision.
|
||||||
|
|
||||||
|
First, configure the environment:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
accelerate config
|
||||||
|
```
|
||||||
|
|
||||||
|
Answer the questions regarding your training environment. Then, run:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
accelerate test
|
||||||
|
```
|
||||||
|
|
||||||
|
This command ensures everything is ready for training. Finally, launch training with:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
accelerate launch run_instance_segmentation_no_trainer.py \
|
||||||
|
--model_name_or_path facebook/mask2former-swin-tiny-coco-instance \
|
||||||
|
--output_dir finetune-instance-segmentation-ade20k-mini-mask2former-no-trainer \
|
||||||
|
--dataset_name qubvel-hf/ade20k-mini \
|
||||||
|
--do_reduce_labels \
|
||||||
|
--image_height 256 \
|
||||||
|
--image_width 256 \
|
||||||
|
--num_train_epochs 40 \
|
||||||
|
--learning_rate 1e-5 \
|
||||||
|
--lr_scheduler_type constant \
|
||||||
|
--per_device_train_batch_size 8 \
|
||||||
|
--gradient_accumulation_steps 2 \
|
||||||
|
--dataloader_num_workers 8 \
|
||||||
|
--push_to_hub
|
||||||
|
```
|
||||||
|
|
||||||
|
With this setup, you can train on multiple GPUs, log everything to trackers (like Weights and Biases, Tensorboard), and regularly push your model to the hub (with the repo name set to `args.output_dir` under your HF username).
|
||||||
|
With the default settings, the script fine-tunes a [Mask2Former](https://huggingface.co/docs/transformers/model_doc/mask2former) model on the sample of [ADE20K](https://huggingface.co/datasets/qubvel-hf/ade20k-mini) dataset. The resulting model can be viewed [here](https://huggingface.co/qubvel-hf/finetune-instance-segmentation-ade20k-mini-mask2former-no-trainer).
|
||||||
|
|
||||||
|
## Reload and Perform Inference
|
||||||
|
|
||||||
|
After training, you can easily load your trained model and perform inference as follows:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
import requests
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
from transformers import Mask2FormerForUniversalSegmentation, Mask2FormerImageProcessor
|
||||||
|
|
||||||
|
# Load image
|
||||||
|
image = Image.open(requests.get("http://farm4.staticflickr.com/3017/3071497290_31f0393363_z.jpg", stream=True).raw)
|
||||||
|
|
||||||
|
# Load model and image processor
|
||||||
|
device = "cuda"
|
||||||
|
checkpoint = "qubvel-hf/finetune-instance-segmentation-ade20k-mini-mask2former"
|
||||||
|
|
||||||
|
model = Mask2FormerForUniversalSegmentation.from_pretrained(checkpoint, device_map=device)
|
||||||
|
image_processor = Mask2FormerImageProcessor.from_pretrained(checkpoint)
|
||||||
|
|
||||||
|
# Run inference on image
|
||||||
|
inputs = image_processor(images=[image], return_tensors="pt").to(device)
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model(**inputs)
|
||||||
|
|
||||||
|
# Post-process outputs
|
||||||
|
outputs = image_processor.post_process_instance_segmentation(outputs, target_sizes=[image.size[::-1]])
|
||||||
|
|
||||||
|
print("Mask shape: ", outputs[0]["segmentation"].shape)
|
||||||
|
print("Mask values: ", outputs[0]["segmentation"].unique())
|
||||||
|
for segment in outputs[0]["segments_info"]:
|
||||||
|
print("Segment: ", segment)
|
||||||
|
```
|
||||||
|
|
||||||
|
```
|
||||||
|
Mask shape: torch.Size([427, 640])
|
||||||
|
Mask values: tensor([-1., 0., 1., 2., 3., 4., 5., 6.])
|
||||||
|
Segment: {'id': 0, 'label_id': 0, 'was_fused': False, 'score': 0.946127}
|
||||||
|
Segment: {'id': 1, 'label_id': 1, 'was_fused': False, 'score': 0.961582}
|
||||||
|
Segment: {'id': 2, 'label_id': 1, 'was_fused': False, 'score': 0.968367}
|
||||||
|
Segment: {'id': 3, 'label_id': 1, 'was_fused': False, 'score': 0.819527}
|
||||||
|
Segment: {'id': 4, 'label_id': 1, 'was_fused': False, 'score': 0.655761}
|
||||||
|
Segment: {'id': 5, 'label_id': 1, 'was_fused': False, 'score': 0.531299}
|
||||||
|
Segment: {'id': 6, 'label_id': 1, 'was_fused': False, 'score': 0.929477}
|
||||||
|
```
|
||||||
|
|
||||||
|
Use the following code to visualize the results:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import numpy as np
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
segmentation = outputs[0]["segmentation"].numpy()
|
||||||
|
|
||||||
|
plt.figure(figsize=(10, 10))
|
||||||
|
plt.subplot(1, 2, 1)
|
||||||
|
plt.imshow(np.array(image))
|
||||||
|
plt.axis("off")
|
||||||
|
plt.subplot(1, 2, 2)
|
||||||
|
plt.imshow(segmentation)
|
||||||
|
plt.axis("off")
|
||||||
|
plt.show()
|
||||||
|
```
|
||||||
|
|
||||||
|
![Result](https://i.imgur.com/rZmaRjD.png)
|
||||||
|
|
||||||
|
## Note on Custom Data
|
||||||
|
|
||||||
|
Here is a short script demonstrating how to create your own dataset for instance segmentation and push it to the hub:
|
||||||
|
|
||||||
|
> Note: Annotations should be represented as 3-channel images (similar to the [scene_parsing_150](https://huggingface.co/datasets/zhoubolei/scene_parse_150#instance_segmentation-1) dataset). The first channel is a semantic-segmentation map with values corresponding to `label2id`, the second is an instance-segmentation map where each instance has a unique value, and the third channel should be empty (filled with zeros).
|
||||||
|
|
||||||
|
```python
|
||||||
|
from datasets import Dataset, DatasetDict
|
||||||
|
from datasets import Image as DatasetImage
|
||||||
|
|
||||||
|
label2id = {
|
||||||
|
"background": 0,
|
||||||
|
"person": 1,
|
||||||
|
"car": 2,
|
||||||
|
}
|
||||||
|
|
||||||
|
train_split = {
|
||||||
|
"image": [<PIL Image 1>, <PIL Image 2>, <PIL Image 3>, ...],
|
||||||
|
"annotation": [<PIL Image ann 1>, <PIL Image ann 2>, <PIL Image ann 3>, ...],
|
||||||
|
}
|
||||||
|
|
||||||
|
validation_split = {
|
||||||
|
"image": [<PIL Image 101>, <PIL Image 102>, <PIL Image 103>, ...],
|
||||||
|
"annotation": [<PIL Image ann 101>, <PIL Image ann 102>, <PIL Image ann 103>, ...],
|
||||||
|
}
|
||||||
|
|
||||||
|
def create_instance_segmentation_dataset(label2id, **splits):
|
||||||
|
dataset_dict = {}
|
||||||
|
for split_name, split in splits.items():
|
||||||
|
split["semantic_class_to_id"] = [label2id] * len(split["image"])
|
||||||
|
dataset_split = (
|
||||||
|
Dataset.from_dict(split)
|
||||||
|
.cast_column("image", DatasetImage())
|
||||||
|
.cast_column("annotation", DatasetImage())
|
||||||
|
)
|
||||||
|
dataset_dict[split_name] = dataset_split
|
||||||
|
return DatasetDict(dataset_dict)
|
||||||
|
|
||||||
|
dataset = create_instance_segmentation_dataset(label2id, train=train_split, validation=validation_split)
|
||||||
|
dataset.push_to_hub("qubvel-hf/ade20k-nano")
|
||||||
|
```
|
||||||
|
|
||||||
|
Use this dataset for fine-tuning by specifying its name with `--dataset_name <your_dataset_repo>`.
|
||||||
|
|
||||||
|
See also: [Dataset Creation Guide](https://huggingface.co/docs/datasets/image_dataset#create-an-image-dataset)
|
|
@ -0,0 +1,5 @@
|
||||||
|
albumentations >= 1.4.5
|
||||||
|
timm
|
||||||
|
datasets
|
||||||
|
torchmetrics
|
||||||
|
pycocotools
|
|
@ -0,0 +1,469 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
|
||||||
|
"""Finetuning 🤗 Transformers model for instance segmentation leveraging the Trainer API."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from functools import partial
|
||||||
|
from typing import Any, Dict, List, Mapping, Optional
|
||||||
|
|
||||||
|
import albumentations as A
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from datasets import load_dataset
|
||||||
|
from torchmetrics.detection.mean_ap import MeanAveragePrecision
|
||||||
|
|
||||||
|
import transformers
|
||||||
|
from transformers import (
|
||||||
|
AutoImageProcessor,
|
||||||
|
AutoModelForUniversalSegmentation,
|
||||||
|
HfArgumentParser,
|
||||||
|
Trainer,
|
||||||
|
TrainingArguments,
|
||||||
|
)
|
||||||
|
from transformers.image_processing_utils import BatchFeature
|
||||||
|
from transformers.trainer import EvalPrediction
|
||||||
|
from transformers.trainer_utils import get_last_checkpoint
|
||||||
|
from transformers.utils import check_min_version, send_example_telemetry
|
||||||
|
from transformers.utils.versions import require_version
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||||
|
check_min_version("4.42.0.dev0")
|
||||||
|
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/instance-segmentation/requirements.txt")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Arguments:
|
||||||
|
"""
|
||||||
|
Arguments pertaining to what data we are going to input our model for training and eval.
|
||||||
|
Using `HfArgumentParser` we can turn this class into argparse arguments to be able to specify
|
||||||
|
them on the command line.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_name_or_path: str = field(
|
||||||
|
default="facebook/mask2former-swin-tiny-coco-instance",
|
||||||
|
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"},
|
||||||
|
)
|
||||||
|
dataset_name: str = field(
|
||||||
|
default="qubvel-hf/ade20k-mini",
|
||||||
|
metadata={
|
||||||
|
"help": "Name of a dataset from the hub (could be your own, possibly private dataset hosted on the hub)."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
image_height: Optional[int] = field(default=512, metadata={"help": "Image height after resizing."})
|
||||||
|
image_width: Optional[int] = field(default=512, metadata={"help": "Image width after resizing."})
|
||||||
|
token: str = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": (
|
||||||
|
"The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "
|
||||||
|
"generated when running `huggingface-cli login` (stored in `~/.huggingface`)."
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
do_reduce_labels: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={
|
||||||
|
"help": (
|
||||||
|
"If background class is labeled as 0 and you want to remove it from the labels, set this flag to True."
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def augment_and_transform_batch(
|
||||||
|
examples: Mapping[str, Any], transform: A.Compose, image_processor: AutoImageProcessor
|
||||||
|
) -> BatchFeature:
|
||||||
|
batch = {
|
||||||
|
"pixel_values": [],
|
||||||
|
"mask_labels": [],
|
||||||
|
"class_labels": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
for pil_image, pil_annotation in zip(examples["image"], examples["annotation"]):
|
||||||
|
image = np.array(pil_image)
|
||||||
|
semantic_and_instance_masks = np.array(pil_annotation)[..., :2]
|
||||||
|
|
||||||
|
# Apply augmentations
|
||||||
|
output = transform(image=image, mask=semantic_and_instance_masks)
|
||||||
|
|
||||||
|
aug_image = output["image"]
|
||||||
|
aug_semantic_and_instance_masks = output["mask"]
|
||||||
|
aug_instance_mask = aug_semantic_and_instance_masks[..., 1]
|
||||||
|
|
||||||
|
# Create mapping from instance id to semantic id
|
||||||
|
unique_semantic_id_instance_id_pairs = np.unique(aug_semantic_and_instance_masks.reshape(-1, 2), axis=0)
|
||||||
|
instance_id_to_semantic_id = {
|
||||||
|
instance_id: semantic_id for semantic_id, instance_id in unique_semantic_id_instance_id_pairs
|
||||||
|
}
|
||||||
|
|
||||||
|
# Apply the image processor transformations: resizing, rescaling, normalization
|
||||||
|
model_inputs = image_processor(
|
||||||
|
images=[aug_image],
|
||||||
|
segmentation_maps=[aug_instance_mask],
|
||||||
|
instance_id_to_semantic_id=instance_id_to_semantic_id,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
|
||||||
|
batch["pixel_values"].append(model_inputs.pixel_values[0])
|
||||||
|
batch["mask_labels"].append(model_inputs.mask_labels[0])
|
||||||
|
batch["class_labels"].append(model_inputs.class_labels[0])
|
||||||
|
|
||||||
|
return batch
|
||||||
|
|
||||||
|
|
||||||
|
def collate_fn(examples):
|
||||||
|
batch = {}
|
||||||
|
batch["pixel_values"] = torch.stack([example["pixel_values"] for example in examples])
|
||||||
|
batch["class_labels"] = [example["class_labels"] for example in examples]
|
||||||
|
batch["mask_labels"] = [example["mask_labels"] for example in examples]
|
||||||
|
if "pixel_mask" in examples[0]:
|
||||||
|
batch["pixel_mask"] = torch.stack([example["pixel_mask"] for example in examples])
|
||||||
|
return batch
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelOutput:
|
||||||
|
class_queries_logits: torch.Tensor
|
||||||
|
masks_queries_logits: torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
|
def nested_cpu(tensors):
|
||||||
|
if isinstance(tensors, (list, tuple)):
|
||||||
|
return type(tensors)(nested_cpu(t) for t in tensors)
|
||||||
|
elif isinstance(tensors, Mapping):
|
||||||
|
return type(tensors)({k: nested_cpu(t) for k, t in tensors.items()})
|
||||||
|
elif isinstance(tensors, torch.Tensor):
|
||||||
|
return tensors.cpu().detach()
|
||||||
|
else:
|
||||||
|
return tensors
|
||||||
|
|
||||||
|
|
||||||
|
class Evaluator:
|
||||||
|
"""
|
||||||
|
Compute metrics for the instance segmentation task.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
image_processor: AutoImageProcessor,
|
||||||
|
id2label: Mapping[int, str],
|
||||||
|
threshold: float = 0.0,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize evaluator with image processor, id2label mapping and threshold for filtering predictions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_processor (AutoImageProcessor): Image processor for
|
||||||
|
`post_process_instance_segmentation` method.
|
||||||
|
id2label (Mapping[int, str]): Mapping from class id to class name.
|
||||||
|
threshold (float): Threshold to filter predicted boxes by confidence. Defaults to 0.0.
|
||||||
|
"""
|
||||||
|
self.image_processor = image_processor
|
||||||
|
self.id2label = id2label
|
||||||
|
self.threshold = threshold
|
||||||
|
self.metric = self.get_metric()
|
||||||
|
|
||||||
|
def get_metric(self):
|
||||||
|
metric = MeanAveragePrecision(iou_type="segm", class_metrics=True)
|
||||||
|
return metric
|
||||||
|
|
||||||
|
def reset_metric(self):
|
||||||
|
self.metric.reset()
|
||||||
|
|
||||||
|
def postprocess_target_batch(self, target_batch) -> List[Dict[str, torch.Tensor]]:
|
||||||
|
"""Collect targets in a form of list of dictionaries with keys "masks", "labels"."""
|
||||||
|
batch_masks = target_batch[0]
|
||||||
|
batch_labels = target_batch[1]
|
||||||
|
post_processed_targets = []
|
||||||
|
for masks, labels in zip(batch_masks, batch_labels):
|
||||||
|
post_processed_targets.append(
|
||||||
|
{
|
||||||
|
"masks": masks.to(dtype=torch.bool),
|
||||||
|
"labels": labels,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return post_processed_targets
|
||||||
|
|
||||||
|
def get_target_sizes(self, post_processed_targets) -> List[List[int]]:
|
||||||
|
target_sizes = []
|
||||||
|
for target in post_processed_targets:
|
||||||
|
target_sizes.append(target["masks"].shape[-2:])
|
||||||
|
return target_sizes
|
||||||
|
|
||||||
|
def postprocess_prediction_batch(self, prediction_batch, target_sizes) -> List[Dict[str, torch.Tensor]]:
|
||||||
|
"""Collect predictions in a form of list of dictionaries with keys "masks", "labels", "scores"."""
|
||||||
|
|
||||||
|
model_output = ModelOutput(class_queries_logits=prediction_batch[0], masks_queries_logits=prediction_batch[1])
|
||||||
|
post_processed_output = self.image_processor.post_process_instance_segmentation(
|
||||||
|
model_output,
|
||||||
|
threshold=self.threshold,
|
||||||
|
target_sizes=target_sizes,
|
||||||
|
return_binary_maps=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
post_processed_predictions = []
|
||||||
|
for image_predictions, target_size in zip(post_processed_output, target_sizes):
|
||||||
|
if image_predictions["segments_info"]:
|
||||||
|
post_processed_image_prediction = {
|
||||||
|
"masks": image_predictions["segmentation"].to(dtype=torch.bool),
|
||||||
|
"labels": torch.tensor([x["label_id"] for x in image_predictions["segments_info"]]),
|
||||||
|
"scores": torch.tensor([x["score"] for x in image_predictions["segments_info"]]),
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
# for void predictions, we need to provide empty tensors
|
||||||
|
post_processed_image_prediction = {
|
||||||
|
"masks": torch.zeros([0, *target_size], dtype=torch.bool),
|
||||||
|
"labels": torch.tensor([]),
|
||||||
|
"scores": torch.tensor([]),
|
||||||
|
}
|
||||||
|
post_processed_predictions.append(post_processed_image_prediction)
|
||||||
|
|
||||||
|
return post_processed_predictions
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(self, evaluation_results: EvalPrediction, compute_result: bool = False) -> Mapping[str, float]:
|
||||||
|
"""
|
||||||
|
Update metrics with current evaluation results and return metrics if `compute_result` is True.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
evaluation_results (EvalPrediction): Predictions and targets from evaluation.
|
||||||
|
compute_result (bool): Whether to compute and return metrics.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Mapping[str, float]: Metrics in a form of dictionary {<metric_name>: <metric_value>}
|
||||||
|
"""
|
||||||
|
prediction_batch = nested_cpu(evaluation_results.predictions)
|
||||||
|
target_batch = nested_cpu(evaluation_results.label_ids)
|
||||||
|
|
||||||
|
# For metric computation we need to provide:
|
||||||
|
# - targets in a form of list of dictionaries with keys "masks", "labels"
|
||||||
|
# - predictions in a form of list of dictionaries with keys "masks", "labels", "scores"
|
||||||
|
post_processed_targets = self.postprocess_target_batch(target_batch)
|
||||||
|
target_sizes = self.get_target_sizes(post_processed_targets)
|
||||||
|
post_processed_predictions = self.postprocess_prediction_batch(prediction_batch, target_sizes)
|
||||||
|
|
||||||
|
# Compute metrics
|
||||||
|
self.metric.update(post_processed_predictions, post_processed_targets)
|
||||||
|
|
||||||
|
if not compute_result:
|
||||||
|
return
|
||||||
|
|
||||||
|
metrics = self.metric.compute()
|
||||||
|
|
||||||
|
# Replace list of per class metrics with separate metric for each class
|
||||||
|
classes = metrics.pop("classes")
|
||||||
|
map_per_class = metrics.pop("map_per_class")
|
||||||
|
mar_100_per_class = metrics.pop("mar_100_per_class")
|
||||||
|
for class_id, class_map, class_mar in zip(classes, map_per_class, mar_100_per_class):
|
||||||
|
class_name = self.id2label[class_id.item()] if self.id2label is not None else class_id.item()
|
||||||
|
metrics[f"map_{class_name}"] = class_map
|
||||||
|
metrics[f"mar_100_{class_name}"] = class_mar
|
||||||
|
|
||||||
|
metrics = {k: round(v.item(), 4) for k, v in metrics.items()}
|
||||||
|
|
||||||
|
# Reset metric for next evaluation
|
||||||
|
self.reset_metric()
|
||||||
|
|
||||||
|
return metrics
|
||||||
|
|
||||||
|
|
||||||
|
def setup_logging(training_args: TrainingArguments) -> None:
|
||||||
|
"""Setup logging according to `training_args`."""
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||||
|
datefmt="%m/%d/%Y %H:%M:%S",
|
||||||
|
handlers=[logging.StreamHandler(sys.stdout)],
|
||||||
|
)
|
||||||
|
|
||||||
|
if training_args.should_log:
|
||||||
|
# The default of training_args.log_level is passive, so we set log level at info here to have that default.
|
||||||
|
transformers.utils.logging.set_verbosity_info()
|
||||||
|
|
||||||
|
log_level = training_args.get_process_log_level()
|
||||||
|
logger.setLevel(log_level)
|
||||||
|
transformers.utils.logging.set_verbosity(log_level)
|
||||||
|
transformers.utils.logging.enable_default_handler()
|
||||||
|
transformers.utils.logging.enable_explicit_format()
|
||||||
|
|
||||||
|
|
||||||
|
def find_last_checkpoint(training_args: TrainingArguments) -> Optional[str]:
|
||||||
|
"""Find the last checkpoint in the output directory according to parameters specified in `training_args`."""
|
||||||
|
|
||||||
|
checkpoint = None
|
||||||
|
if training_args.resume_from_checkpoint is not None:
|
||||||
|
checkpoint = training_args.resume_from_checkpoint
|
||||||
|
elif os.path.isdir(training_args.output_dir) and not training_args.overwrite_output_dir:
|
||||||
|
checkpoint = get_last_checkpoint(training_args.output_dir)
|
||||||
|
if checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"Output directory ({training_args.output_dir}) already exists and is not empty. "
|
||||||
|
"Use --overwrite_output_dir to overcome."
|
||||||
|
)
|
||||||
|
elif checkpoint is not None and training_args.resume_from_checkpoint is None:
|
||||||
|
logger.info(
|
||||||
|
f"Checkpoint detected, resuming training at {checkpoint}. To avoid this behavior, change "
|
||||||
|
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
|
||||||
|
)
|
||||||
|
|
||||||
|
return checkpoint
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# See all possible arguments in https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments
|
||||||
|
# or by passing the --help flag to this script.
|
||||||
|
|
||||||
|
parser = HfArgumentParser([Arguments, TrainingArguments])
|
||||||
|
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
||||||
|
# If we pass only one argument to the script and it's the path to a json file,
|
||||||
|
# let's parse it to get our arguments.
|
||||||
|
args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
|
||||||
|
else:
|
||||||
|
args, training_args = parser.parse_args_into_dataclasses()
|
||||||
|
|
||||||
|
# Set default training arguments for instance segmentation
|
||||||
|
training_args.eval_do_concat_batches = False
|
||||||
|
training_args.batch_eval_metrics = True
|
||||||
|
training_args.remove_unused_columns = False
|
||||||
|
|
||||||
|
# # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
|
||||||
|
# # information sent is the one passed as arguments along with your Python/PyTorch versions.
|
||||||
|
send_example_telemetry("run_instance_segmentation", args)
|
||||||
|
|
||||||
|
# Setup logging and log on each process the small summary:
|
||||||
|
setup_logging(training_args)
|
||||||
|
logger.warning(
|
||||||
|
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, "
|
||||||
|
+ f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}"
|
||||||
|
)
|
||||||
|
logger.info(f"Training/evaluation parameters {training_args}")
|
||||||
|
|
||||||
|
# Load last checkpoint from output_dir if it exists (and we are not overwriting it)
|
||||||
|
checkpoint = find_last_checkpoint(training_args)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------------------------------------
|
||||||
|
# Load dataset, prepare splits
|
||||||
|
# ------------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
dataset = load_dataset(args.dataset_name)
|
||||||
|
|
||||||
|
# We need to specify the label2id mapping for the model
|
||||||
|
# it is a mapping from semantic class name to class index.
|
||||||
|
# In case your dataset does not provide it, you can create it manually:
|
||||||
|
# label2id = {"background": 0, "cat": 1, "dog": 2}
|
||||||
|
label2id = dataset["train"][0]["semantic_class_to_id"]
|
||||||
|
|
||||||
|
if args.do_reduce_labels:
|
||||||
|
label2id = {name: idx for name, idx in label2id.items() if idx != 0} # remove background class
|
||||||
|
label2id = {name: idx - 1 for name, idx in label2id.items()} # shift class indices by -1
|
||||||
|
|
||||||
|
id2label = {v: k for k, v in label2id.items()}
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------------------------------------
|
||||||
|
# Load pretrained config, model and image processor
|
||||||
|
# ------------------------------------------------------------------------------------------------
|
||||||
|
model = AutoModelForUniversalSegmentation.from_pretrained(
|
||||||
|
args.model_name_or_path,
|
||||||
|
label2id=label2id,
|
||||||
|
id2label=id2label,
|
||||||
|
ignore_mismatched_sizes=True,
|
||||||
|
token=args.token,
|
||||||
|
)
|
||||||
|
|
||||||
|
image_processor = AutoImageProcessor.from_pretrained(
|
||||||
|
args.model_name_or_path,
|
||||||
|
do_resize=True,
|
||||||
|
size={"height": args.image_height, "width": args.image_width},
|
||||||
|
do_reduce_labels=args.do_reduce_labels,
|
||||||
|
reduce_labels=args.do_reduce_labels, # TODO: remove when mask2former support `do_reduce_labels`
|
||||||
|
token=args.token,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------------------------------------
|
||||||
|
# Define image augmentations and dataset transforms
|
||||||
|
# ------------------------------------------------------------------------------------------------
|
||||||
|
train_augment_and_transform = A.Compose(
|
||||||
|
[
|
||||||
|
A.HorizontalFlip(p=0.5),
|
||||||
|
A.RandomBrightnessContrast(p=0.5),
|
||||||
|
A.HueSaturationValue(p=0.1),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
validation_transform = A.Compose(
|
||||||
|
[A.NoOp()],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Make transform functions for batch and apply for dataset splits
|
||||||
|
train_transform_batch = partial(
|
||||||
|
augment_and_transform_batch, transform=train_augment_and_transform, image_processor=image_processor
|
||||||
|
)
|
||||||
|
validation_transform_batch = partial(
|
||||||
|
augment_and_transform_batch, transform=validation_transform, image_processor=image_processor
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset["train"] = dataset["train"].with_transform(train_transform_batch)
|
||||||
|
dataset["validation"] = dataset["validation"].with_transform(validation_transform_batch)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------------------------------------
|
||||||
|
# Model training and evaluation with Trainer API
|
||||||
|
# ------------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
compute_metrics = Evaluator(image_processor=image_processor, id2label=id2label, threshold=0.0)
|
||||||
|
|
||||||
|
trainer = Trainer(
|
||||||
|
model=model,
|
||||||
|
args=training_args,
|
||||||
|
train_dataset=dataset["train"] if training_args.do_train else None,
|
||||||
|
eval_dataset=dataset["validation"] if training_args.do_eval else None,
|
||||||
|
tokenizer=image_processor,
|
||||||
|
data_collator=collate_fn,
|
||||||
|
compute_metrics=compute_metrics,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Training
|
||||||
|
if training_args.do_train:
|
||||||
|
train_result = trainer.train(resume_from_checkpoint=checkpoint)
|
||||||
|
trainer.save_model()
|
||||||
|
trainer.log_metrics("train", train_result.metrics)
|
||||||
|
trainer.save_metrics("train", train_result.metrics)
|
||||||
|
trainer.save_state()
|
||||||
|
|
||||||
|
# Final evaluation
|
||||||
|
if training_args.do_eval:
|
||||||
|
metrics = trainer.evaluate(eval_dataset=dataset["validation"], metric_key_prefix="test")
|
||||||
|
trainer.log_metrics("test", metrics)
|
||||||
|
trainer.save_metrics("test", metrics)
|
||||||
|
|
||||||
|
# Write model card and (optionally) push to hub
|
||||||
|
kwargs = {
|
||||||
|
"finetuned_from": args.model_name_or_path,
|
||||||
|
"dataset": args.dataset_name,
|
||||||
|
"tags": ["image-segmentation", "instance-segmentation", "vision"],
|
||||||
|
}
|
||||||
|
if training_args.push_to_hub:
|
||||||
|
trainer.push_to_hub(**kwargs)
|
||||||
|
else:
|
||||||
|
trainer.create_model_card(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
|
@ -0,0 +1,734 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
|
||||||
|
"""Finetuning 🤗 Transformers model for instance segmentation with Accelerate 🚀."""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from functools import partial
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Mapping
|
||||||
|
|
||||||
|
import albumentations as A
|
||||||
|
import datasets
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from accelerate import Accelerator
|
||||||
|
from accelerate.utils import set_seed
|
||||||
|
from datasets import load_dataset
|
||||||
|
from huggingface_hub import HfApi
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from torchmetrics.detection.mean_ap import MeanAveragePrecision
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
import transformers
|
||||||
|
from transformers import (
|
||||||
|
AutoImageProcessor,
|
||||||
|
AutoModelForUniversalSegmentation,
|
||||||
|
SchedulerType,
|
||||||
|
get_scheduler,
|
||||||
|
)
|
||||||
|
from transformers.image_processing_utils import BatchFeature
|
||||||
|
from transformers.utils import check_min_version, send_example_telemetry
|
||||||
|
from transformers.utils.versions import require_version
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||||
|
check_min_version("4.42.0.dev0")
|
||||||
|
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/instance-segmentation/requirements.txt")
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser(description="Finetune a transformers model for instance segmentation task")
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--model_name_or_path",
|
||||||
|
type=str,
|
||||||
|
help="Path to a pretrained model or model identifier from huggingface.co/models.",
|
||||||
|
default="facebook/mask2former-swin-tiny-coco-instance",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dataset_name",
|
||||||
|
type=str,
|
||||||
|
help="Name of the dataset on the hub.",
|
||||||
|
default="qubvel-hf/ade20k-mini",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--image_height",
|
||||||
|
type=int,
|
||||||
|
default=384,
|
||||||
|
help="The height of the images to feed the model.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--image_width",
|
||||||
|
type=int,
|
||||||
|
default=384,
|
||||||
|
help="The width of the images to feed the model.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--do_reduce_labels",
|
||||||
|
action="store_true",
|
||||||
|
help="Whether to reduce the number of labels by removing the background class.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--cache_dir",
|
||||||
|
type=str,
|
||||||
|
help="Path to a folder in which the model and dataset will be cached.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--per_device_train_batch_size",
|
||||||
|
type=int,
|
||||||
|
default=8,
|
||||||
|
help="Batch size (per device) for the training dataloader.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--per_device_eval_batch_size",
|
||||||
|
type=int,
|
||||||
|
default=8,
|
||||||
|
help="Batch size (per device) for the evaluation dataloader.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dataloader_num_workers",
|
||||||
|
type=int,
|
||||||
|
default=4,
|
||||||
|
help="Number of workers to use for the dataloaders.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--learning_rate",
|
||||||
|
type=float,
|
||||||
|
default=5e-5,
|
||||||
|
help="Initial learning rate (after the potential warmup period) to use.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--adam_beta1",
|
||||||
|
type=float,
|
||||||
|
default=0.9,
|
||||||
|
help="Beta1 for AdamW optimizer",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--adam_beta2",
|
||||||
|
type=float,
|
||||||
|
default=0.999,
|
||||||
|
help="Beta2 for AdamW optimizer",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--adam_epsilon",
|
||||||
|
type=float,
|
||||||
|
default=1e-8,
|
||||||
|
help="Epsilon for AdamW optimizer",
|
||||||
|
)
|
||||||
|
parser.add_argument("--num_train_epochs", type=int, default=3, help="Total number of training epochs to perform.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--max_train_steps",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--gradient_accumulation_steps",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--lr_scheduler_type",
|
||||||
|
type=SchedulerType,
|
||||||
|
default="linear",
|
||||||
|
help="The scheduler type to use.",
|
||||||
|
choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"],
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--num_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler."
|
||||||
|
)
|
||||||
|
parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.")
|
||||||
|
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
|
||||||
|
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--hub_model_id", type=str, help="The name of the repository to keep in sync with the local `output_dir`."
|
||||||
|
)
|
||||||
|
parser.add_argument("--hub_token", type=str, help="The token to use to push to the Model Hub.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--checkpointing_steps",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--resume_from_checkpoint",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="If the training should continue from a checkpoint folder.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--with_tracking",
|
||||||
|
required=False,
|
||||||
|
action="store_true",
|
||||||
|
help="Whether to enable experiment trackers for logging.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--report_to",
|
||||||
|
type=str,
|
||||||
|
default="all",
|
||||||
|
help=(
|
||||||
|
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
|
||||||
|
' `"wandb"`, `"comet_ml"` and `"clearml"`. Use `"all"` (default) to report to all integrations. '
|
||||||
|
"Only applicable when `--with_tracking` is passed."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Sanity checks
|
||||||
|
if args.push_to_hub or args.with_tracking:
|
||||||
|
if args.output_dir is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Need an `output_dir` to create a repo when `--push_to_hub` or `with_tracking` is specified."
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.output_dir is not None:
|
||||||
|
os.makedirs(args.output_dir, exist_ok=True)
|
||||||
|
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def augment_and_transform_batch(
|
||||||
|
examples: Mapping[str, Any], transform: A.Compose, image_processor: AutoImageProcessor
|
||||||
|
) -> BatchFeature:
|
||||||
|
batch = {
|
||||||
|
"pixel_values": [],
|
||||||
|
"mask_labels": [],
|
||||||
|
"class_labels": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
for pil_image, pil_annotation in zip(examples["image"], examples["annotation"]):
|
||||||
|
image = np.array(pil_image)
|
||||||
|
semantic_and_instance_masks = np.array(pil_annotation)[..., :2]
|
||||||
|
|
||||||
|
# Apply augmentations
|
||||||
|
output = transform(image=image, mask=semantic_and_instance_masks)
|
||||||
|
|
||||||
|
aug_image = output["image"]
|
||||||
|
aug_semantic_and_instance_masks = output["mask"]
|
||||||
|
aug_instance_mask = aug_semantic_and_instance_masks[..., 1]
|
||||||
|
|
||||||
|
# Create mapping from instance id to semantic id
|
||||||
|
unique_semantic_id_instance_id_pairs = np.unique(aug_semantic_and_instance_masks.reshape(-1, 2), axis=0)
|
||||||
|
instance_id_to_semantic_id = {
|
||||||
|
instance_id: semantic_id for semantic_id, instance_id in unique_semantic_id_instance_id_pairs
|
||||||
|
}
|
||||||
|
|
||||||
|
# Apply the image processor transformations: resizing, rescaling, normalization
|
||||||
|
model_inputs = image_processor(
|
||||||
|
images=[aug_image],
|
||||||
|
segmentation_maps=[aug_instance_mask],
|
||||||
|
instance_id_to_semantic_id=instance_id_to_semantic_id,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
|
||||||
|
batch["pixel_values"].append(model_inputs.pixel_values[0])
|
||||||
|
batch["mask_labels"].append(model_inputs.mask_labels[0])
|
||||||
|
batch["class_labels"].append(model_inputs.class_labels[0])
|
||||||
|
|
||||||
|
return batch
|
||||||
|
|
||||||
|
|
||||||
|
def collate_fn(examples):
|
||||||
|
batch = {}
|
||||||
|
batch["pixel_values"] = torch.stack([example["pixel_values"] for example in examples])
|
||||||
|
batch["class_labels"] = [example["class_labels"] for example in examples]
|
||||||
|
batch["mask_labels"] = [example["mask_labels"] for example in examples]
|
||||||
|
if "pixel_mask" in examples[0]:
|
||||||
|
batch["pixel_mask"] = torch.stack([example["pixel_mask"] for example in examples])
|
||||||
|
return batch
|
||||||
|
|
||||||
|
|
||||||
|
def nested_cpu(tensors):
|
||||||
|
if isinstance(tensors, (list, tuple)):
|
||||||
|
return type(tensors)(nested_cpu(t) for t in tensors)
|
||||||
|
elif isinstance(tensors, Mapping):
|
||||||
|
return type(tensors)({k: nested_cpu(t) for k, t in tensors.items()})
|
||||||
|
elif isinstance(tensors, torch.Tensor):
|
||||||
|
return tensors.cpu().detach()
|
||||||
|
else:
|
||||||
|
return tensors
|
||||||
|
|
||||||
|
|
||||||
|
def evaluation_loop(model, image_processor, accelerator: Accelerator, dataloader, id2label):
|
||||||
|
metric = MeanAveragePrecision(iou_type="segm", class_metrics=True)
|
||||||
|
|
||||||
|
for inputs in tqdm(dataloader, total=len(dataloader), disable=not accelerator.is_local_main_process):
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model(**inputs)
|
||||||
|
|
||||||
|
inputs = accelerator.gather_for_metrics(inputs)
|
||||||
|
inputs = nested_cpu(inputs)
|
||||||
|
|
||||||
|
outputs = accelerator.gather_for_metrics(outputs)
|
||||||
|
outputs = nested_cpu(outputs)
|
||||||
|
|
||||||
|
# For metric computation we need to provide:
|
||||||
|
# - targets in a form of list of dictionaries with keys "masks", "labels"
|
||||||
|
# - predictions in a form of list of dictionaries with keys "masks", "labels", "scores"
|
||||||
|
|
||||||
|
post_processed_targets = []
|
||||||
|
post_processed_predictions = []
|
||||||
|
target_sizes = []
|
||||||
|
|
||||||
|
# Collect targets
|
||||||
|
for masks, labels in zip(inputs["mask_labels"], inputs["class_labels"]):
|
||||||
|
post_processed_targets.append(
|
||||||
|
{
|
||||||
|
"masks": masks.to(dtype=torch.bool),
|
||||||
|
"labels": labels,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
target_sizes.append(masks.shape[-2:])
|
||||||
|
|
||||||
|
# Collect predictions
|
||||||
|
post_processed_output = image_processor.post_process_instance_segmentation(
|
||||||
|
outputs,
|
||||||
|
threshold=0.0,
|
||||||
|
target_sizes=target_sizes,
|
||||||
|
return_binary_maps=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
for image_predictions, target_size in zip(post_processed_output, target_sizes):
|
||||||
|
if image_predictions["segments_info"]:
|
||||||
|
post_processed_image_prediction = {
|
||||||
|
"masks": image_predictions["segmentation"].to(dtype=torch.bool),
|
||||||
|
"labels": torch.tensor([x["label_id"] for x in image_predictions["segments_info"]]),
|
||||||
|
"scores": torch.tensor([x["score"] for x in image_predictions["segments_info"]]),
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
# for void predictions, we need to provide empty tensors
|
||||||
|
post_processed_image_prediction = {
|
||||||
|
"masks": torch.zeros([0, *target_size], dtype=torch.bool),
|
||||||
|
"labels": torch.tensor([]),
|
||||||
|
"scores": torch.tensor([]),
|
||||||
|
}
|
||||||
|
post_processed_predictions.append(post_processed_image_prediction)
|
||||||
|
|
||||||
|
# Update metric for batch targets and predictions
|
||||||
|
metric.update(post_processed_predictions, post_processed_targets)
|
||||||
|
|
||||||
|
# Compute metrics
|
||||||
|
metrics = metric.compute()
|
||||||
|
|
||||||
|
# Replace list of per class metrics with separate metric for each class
|
||||||
|
classes = metrics.pop("classes")
|
||||||
|
map_per_class = metrics.pop("map_per_class")
|
||||||
|
mar_100_per_class = metrics.pop("mar_100_per_class")
|
||||||
|
for class_id, class_map, class_mar in zip(classes, map_per_class, mar_100_per_class):
|
||||||
|
class_name = id2label[class_id.item()] if id2label is not None else class_id.item()
|
||||||
|
metrics[f"map_{class_name}"] = class_map
|
||||||
|
metrics[f"mar_100_{class_name}"] = class_mar
|
||||||
|
|
||||||
|
metrics = {k: round(v.item(), 4) for k, v in metrics.items()}
|
||||||
|
|
||||||
|
return metrics
|
||||||
|
|
||||||
|
|
||||||
|
def setup_logging(accelerator: Accelerator) -> None:
|
||||||
|
"""Setup logging according to `training_args`."""
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||||
|
datefmt="%m/%d/%Y %H:%M:%S",
|
||||||
|
handlers=[logging.StreamHandler(sys.stdout)],
|
||||||
|
)
|
||||||
|
|
||||||
|
if accelerator.is_local_main_process:
|
||||||
|
datasets.utils.logging.set_verbosity_warning()
|
||||||
|
transformers.utils.logging.set_verbosity_info()
|
||||||
|
logger.setLevel(logging.INFO)
|
||||||
|
else:
|
||||||
|
datasets.utils.logging.set_verbosity_error()
|
||||||
|
transformers.utils.logging.set_verbosity_error()
|
||||||
|
|
||||||
|
|
||||||
|
def handle_repository_creation(accelerator: Accelerator, args: argparse.Namespace):
|
||||||
|
"""Create a repository for the model and dataset if `args.push_to_hub` is set."""
|
||||||
|
|
||||||
|
repo_id = None
|
||||||
|
if accelerator.is_main_process:
|
||||||
|
if args.push_to_hub:
|
||||||
|
# Retrieve of infer repo_name
|
||||||
|
repo_name = args.hub_model_id
|
||||||
|
if repo_name is None:
|
||||||
|
repo_name = Path(args.output_dir).absolute().name
|
||||||
|
# Create repo and retrieve repo_id
|
||||||
|
api = HfApi()
|
||||||
|
repo_id = api.create_repo(repo_name, exist_ok=True, token=args.hub_token).repo_id
|
||||||
|
|
||||||
|
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
|
||||||
|
if "step_*" not in gitignore:
|
||||||
|
gitignore.write("step_*\n")
|
||||||
|
if "epoch_*" not in gitignore:
|
||||||
|
gitignore.write("epoch_*\n")
|
||||||
|
elif args.output_dir is not None:
|
||||||
|
os.makedirs(args.output_dir, exist_ok=True)
|
||||||
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
|
return repo_id
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = parse_args()
|
||||||
|
|
||||||
|
# Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
|
||||||
|
# information sent is the one passed as arguments along with your Python/PyTorch versions.
|
||||||
|
send_example_telemetry("run_instance_segmentation_no_trainer", args)
|
||||||
|
|
||||||
|
# Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
|
||||||
|
# If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers
|
||||||
|
# in the environment
|
||||||
|
accelerator_log_kwargs = {}
|
||||||
|
|
||||||
|
if args.with_tracking:
|
||||||
|
accelerator_log_kwargs["log_with"] = args.report_to
|
||||||
|
accelerator_log_kwargs["project_dir"] = args.output_dir
|
||||||
|
|
||||||
|
accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, **accelerator_log_kwargs)
|
||||||
|
setup_logging(accelerator)
|
||||||
|
|
||||||
|
# If passed along, set the training seed now.
|
||||||
|
# We set device_specific to True as we want different data augmentation per device.
|
||||||
|
if args.seed is not None:
|
||||||
|
set_seed(args.seed, device_specific=True)
|
||||||
|
|
||||||
|
# Create repository if push ot hub is specified
|
||||||
|
repo_id = handle_repository_creation(accelerator, args)
|
||||||
|
|
||||||
|
if args.push_to_hub:
|
||||||
|
api = HfApi()
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------------------------------------
|
||||||
|
# Load dataset, prepare splits
|
||||||
|
# ------------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
# In distributed training, the load_dataset function guarantees that only one local process can concurrently
|
||||||
|
# download the dataset.
|
||||||
|
dataset = load_dataset(args.dataset_name, cache_dir=args.cache_dir)
|
||||||
|
|
||||||
|
# We need to specify the label2id mapping for the model
|
||||||
|
# it is a mapping from semantic class name to class index.
|
||||||
|
# In case your dataset does not provide it, you can create it manually:
|
||||||
|
# label2id = {"background": 0, "cat": 1, "dog": 2}
|
||||||
|
label2id = dataset["train"][0]["semantic_class_to_id"]
|
||||||
|
|
||||||
|
if args.do_reduce_labels:
|
||||||
|
label2id = {name: idx for name, idx in label2id.items() if idx != 0} # remove background class
|
||||||
|
label2id = {name: idx - 1 for name, idx in label2id.items()} # shift class indices by -1
|
||||||
|
|
||||||
|
id2label = {v: k for k, v in label2id.items()}
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------------------------------------
|
||||||
|
# Load pretrained model and image processor
|
||||||
|
# ------------------------------------------------------------------------------------------------
|
||||||
|
model = AutoModelForUniversalSegmentation.from_pretrained(
|
||||||
|
args.model_name_or_path,
|
||||||
|
label2id=label2id,
|
||||||
|
id2label=id2label,
|
||||||
|
ignore_mismatched_sizes=True,
|
||||||
|
token=args.hub_token,
|
||||||
|
)
|
||||||
|
|
||||||
|
image_processor = AutoImageProcessor.from_pretrained(
|
||||||
|
args.model_name_or_path,
|
||||||
|
do_resize=True,
|
||||||
|
size={"height": args.image_height, "width": args.image_width},
|
||||||
|
do_reduce_labels=args.do_reduce_labels,
|
||||||
|
reduce_labels=args.do_reduce_labels, # TODO: remove when mask2former support `do_reduce_labels`
|
||||||
|
token=args.hub_token,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------------------------------------
|
||||||
|
# Define image augmentations and dataset transforms
|
||||||
|
# ------------------------------------------------------------------------------------------------
|
||||||
|
train_augment_and_transform = A.Compose(
|
||||||
|
[
|
||||||
|
A.HorizontalFlip(p=0.5),
|
||||||
|
A.RandomBrightnessContrast(p=0.5),
|
||||||
|
A.HueSaturationValue(p=0.1),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
validation_transform = A.Compose(
|
||||||
|
[A.NoOp()],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Make transform functions for batch and apply for dataset splits
|
||||||
|
train_transform_batch = partial(
|
||||||
|
augment_and_transform_batch, transform=train_augment_and_transform, image_processor=image_processor
|
||||||
|
)
|
||||||
|
validation_transform_batch = partial(
|
||||||
|
augment_and_transform_batch, transform=validation_transform, image_processor=image_processor
|
||||||
|
)
|
||||||
|
|
||||||
|
with accelerator.main_process_first():
|
||||||
|
dataset["train"] = dataset["train"].with_transform(train_transform_batch)
|
||||||
|
dataset["validation"] = dataset["validation"].with_transform(validation_transform_batch)
|
||||||
|
|
||||||
|
dataloader_common_args = {
|
||||||
|
"num_workers": args.dataloader_num_workers,
|
||||||
|
"persistent_workers": True,
|
||||||
|
"collate_fn": collate_fn,
|
||||||
|
}
|
||||||
|
train_dataloader = DataLoader(
|
||||||
|
dataset["train"], shuffle=True, batch_size=args.per_device_train_batch_size, **dataloader_common_args
|
||||||
|
)
|
||||||
|
valid_dataloader = DataLoader(
|
||||||
|
dataset["validation"], shuffle=False, batch_size=args.per_device_eval_batch_size, **dataloader_common_args
|
||||||
|
)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------------------------------------
|
||||||
|
# Define optimizer, scheduler and prepare everything with the accelerator
|
||||||
|
# ------------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
# Optimizer
|
||||||
|
optimizer = torch.optim.AdamW(
|
||||||
|
list(model.parameters()),
|
||||||
|
lr=args.learning_rate,
|
||||||
|
betas=[args.adam_beta1, args.adam_beta2],
|
||||||
|
eps=args.adam_epsilon,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Figure out how many steps we should save the Accelerator states
|
||||||
|
checkpointing_steps = args.checkpointing_steps
|
||||||
|
if checkpointing_steps is not None and checkpointing_steps.isdigit():
|
||||||
|
checkpointing_steps = int(checkpointing_steps)
|
||||||
|
|
||||||
|
# Scheduler and math around the number of training steps.
|
||||||
|
overrode_max_train_steps = False
|
||||||
|
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||||
|
if args.max_train_steps is None:
|
||||||
|
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||||
|
overrode_max_train_steps = True
|
||||||
|
|
||||||
|
lr_scheduler = get_scheduler(
|
||||||
|
name=args.lr_scheduler_type,
|
||||||
|
optimizer=optimizer,
|
||||||
|
num_warmup_steps=args.num_warmup_steps * accelerator.num_processes,
|
||||||
|
num_training_steps=args.max_train_steps
|
||||||
|
if overrode_max_train_steps
|
||||||
|
else args.max_train_steps * accelerator.num_processes,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prepare everything with our `accelerator`.
|
||||||
|
model, optimizer, train_dataloader, valid_dataloader, lr_scheduler = accelerator.prepare(
|
||||||
|
model, optimizer, train_dataloader, valid_dataloader, lr_scheduler
|
||||||
|
)
|
||||||
|
|
||||||
|
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||||
|
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||||
|
if overrode_max_train_steps:
|
||||||
|
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||||
|
# Afterwards we recalculate our number of training epochs
|
||||||
|
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||||
|
|
||||||
|
# We need to initialize the trackers we use, and also store our configuration.
|
||||||
|
# The trackers initializes automatically on the main process.
|
||||||
|
if args.with_tracking:
|
||||||
|
experiment_config = vars(args)
|
||||||
|
# TensorBoard cannot log Enums, need the raw value
|
||||||
|
experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
|
||||||
|
accelerator.init_trackers("instance_segmentation_no_trainer", experiment_config)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------------------------------------
|
||||||
|
# Run training with evaluation on each epoch
|
||||||
|
# ------------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||||
|
|
||||||
|
logger.info("***** Running training *****")
|
||||||
|
logger.info(f" Num examples = {len(dataset['train'])}")
|
||||||
|
logger.info(f" Num Epochs = {args.num_train_epochs}")
|
||||||
|
logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}")
|
||||||
|
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
||||||
|
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
||||||
|
logger.info(f" Total optimization steps = {args.max_train_steps}")
|
||||||
|
|
||||||
|
# Only show the progress bar once on each machine.
|
||||||
|
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
|
||||||
|
completed_steps = 0
|
||||||
|
starting_epoch = 0
|
||||||
|
|
||||||
|
# Potentially load in the weights and states from a previous save
|
||||||
|
if args.resume_from_checkpoint:
|
||||||
|
if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "":
|
||||||
|
checkpoint_path = args.resume_from_checkpoint
|
||||||
|
path = os.path.basename(args.resume_from_checkpoint)
|
||||||
|
else:
|
||||||
|
# Get the most recent checkpoint
|
||||||
|
dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()]
|
||||||
|
dirs.sort(key=os.path.getctime)
|
||||||
|
path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last
|
||||||
|
checkpoint_path = path
|
||||||
|
path = os.path.basename(checkpoint_path)
|
||||||
|
|
||||||
|
accelerator.print(f"Resumed from checkpoint: {checkpoint_path}")
|
||||||
|
accelerator.load_state(checkpoint_path)
|
||||||
|
# Extract `epoch_{i}` or `step_{i}`
|
||||||
|
training_difference = os.path.splitext(path)[0]
|
||||||
|
|
||||||
|
if "epoch" in training_difference:
|
||||||
|
starting_epoch = int(training_difference.replace("epoch_", "")) + 1
|
||||||
|
resume_step = None
|
||||||
|
completed_steps = starting_epoch * num_update_steps_per_epoch
|
||||||
|
else:
|
||||||
|
# need to multiply `gradient_accumulation_steps` to reflect real steps
|
||||||
|
resume_step = int(training_difference.replace("step_", "")) * args.gradient_accumulation_steps
|
||||||
|
starting_epoch = resume_step // len(train_dataloader)
|
||||||
|
completed_steps = resume_step // args.gradient_accumulation_steps
|
||||||
|
resume_step -= starting_epoch * len(train_dataloader)
|
||||||
|
|
||||||
|
# update the progress_bar if load from checkpoint
|
||||||
|
progress_bar.update(completed_steps)
|
||||||
|
|
||||||
|
for epoch in range(starting_epoch, args.num_train_epochs):
|
||||||
|
model.train()
|
||||||
|
if args.with_tracking:
|
||||||
|
total_loss = 0
|
||||||
|
if args.resume_from_checkpoint and epoch == starting_epoch and resume_step is not None:
|
||||||
|
# We skip the first `n` batches in the dataloader when resuming from a checkpoint
|
||||||
|
active_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step)
|
||||||
|
else:
|
||||||
|
active_dataloader = train_dataloader
|
||||||
|
|
||||||
|
for step, batch in enumerate(active_dataloader):
|
||||||
|
with accelerator.accumulate(model):
|
||||||
|
outputs = model(**batch)
|
||||||
|
loss = outputs.loss
|
||||||
|
# We keep track of the loss at each epoch
|
||||||
|
if args.with_tracking:
|
||||||
|
total_loss += loss.detach().float()
|
||||||
|
accelerator.backward(loss)
|
||||||
|
optimizer.step()
|
||||||
|
lr_scheduler.step()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||||
|
if accelerator.sync_gradients:
|
||||||
|
progress_bar.update(1)
|
||||||
|
completed_steps += 1
|
||||||
|
|
||||||
|
if isinstance(checkpointing_steps, int):
|
||||||
|
if completed_steps % checkpointing_steps == 0:
|
||||||
|
output_dir = f"step_{completed_steps}"
|
||||||
|
if args.output_dir is not None:
|
||||||
|
output_dir = os.path.join(args.output_dir, output_dir)
|
||||||
|
accelerator.save_state(output_dir)
|
||||||
|
|
||||||
|
if args.push_to_hub and epoch < args.num_train_epochs - 1:
|
||||||
|
accelerator.wait_for_everyone()
|
||||||
|
unwrapped_model = accelerator.unwrap_model(model)
|
||||||
|
unwrapped_model.save_pretrained(
|
||||||
|
args.output_dir,
|
||||||
|
is_main_process=accelerator.is_main_process,
|
||||||
|
save_function=accelerator.save,
|
||||||
|
)
|
||||||
|
if accelerator.is_main_process:
|
||||||
|
image_processor.save_pretrained(args.output_dir)
|
||||||
|
api.upload_folder(
|
||||||
|
repo_id=repo_id,
|
||||||
|
commit_message=f"Training in progress epoch {epoch}",
|
||||||
|
folder_path=args.output_dir,
|
||||||
|
repo_type="model",
|
||||||
|
token=args.hub_token,
|
||||||
|
)
|
||||||
|
|
||||||
|
if completed_steps >= args.max_train_steps:
|
||||||
|
break
|
||||||
|
|
||||||
|
logger.info("***** Running evaluation *****")
|
||||||
|
metrics = evaluation_loop(model, image_processor, accelerator, valid_dataloader, id2label)
|
||||||
|
|
||||||
|
logger.info(f"epoch {epoch}: {metrics}")
|
||||||
|
|
||||||
|
if args.with_tracking:
|
||||||
|
accelerator.log(
|
||||||
|
{
|
||||||
|
"train_loss": total_loss.item() / len(train_dataloader),
|
||||||
|
**metrics,
|
||||||
|
"epoch": epoch,
|
||||||
|
"step": completed_steps,
|
||||||
|
},
|
||||||
|
step=completed_steps,
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.push_to_hub and epoch < args.num_train_epochs - 1:
|
||||||
|
accelerator.wait_for_everyone()
|
||||||
|
unwrapped_model = accelerator.unwrap_model(model)
|
||||||
|
unwrapped_model.save_pretrained(
|
||||||
|
args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save
|
||||||
|
)
|
||||||
|
if accelerator.is_main_process:
|
||||||
|
image_processor.save_pretrained(args.output_dir)
|
||||||
|
api.upload_folder(
|
||||||
|
commit_message=f"Training in progress epoch {epoch}",
|
||||||
|
folder_path=args.output_dir,
|
||||||
|
repo_id=repo_id,
|
||||||
|
repo_type="model",
|
||||||
|
token=args.hub_token,
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.checkpointing_steps == "epoch":
|
||||||
|
output_dir = f"epoch_{epoch}"
|
||||||
|
if args.output_dir is not None:
|
||||||
|
output_dir = os.path.join(args.output_dir, output_dir)
|
||||||
|
accelerator.save_state(output_dir)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------------------------------------
|
||||||
|
# Run evaluation on test dataset and save the model
|
||||||
|
# ------------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
logger.info("***** Running evaluation on test dataset *****")
|
||||||
|
metrics = evaluation_loop(model, image_processor, accelerator, valid_dataloader, id2label)
|
||||||
|
metrics = {f"test_{k}": v for k, v in metrics.items()}
|
||||||
|
|
||||||
|
logger.info(f"Test metrics: {metrics}")
|
||||||
|
|
||||||
|
if args.with_tracking:
|
||||||
|
accelerator.end_training()
|
||||||
|
|
||||||
|
if args.output_dir is not None:
|
||||||
|
accelerator.wait_for_everyone()
|
||||||
|
unwrapped_model = accelerator.unwrap_model(model)
|
||||||
|
unwrapped_model.save_pretrained(
|
||||||
|
args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save
|
||||||
|
)
|
||||||
|
if accelerator.is_main_process:
|
||||||
|
with open(os.path.join(args.output_dir, "all_results.json"), "w") as f:
|
||||||
|
json.dump(metrics, f, indent=2)
|
||||||
|
|
||||||
|
image_processor.save_pretrained(args.output_dir)
|
||||||
|
|
||||||
|
if args.push_to_hub:
|
||||||
|
api.upload_folder(
|
||||||
|
commit_message="End of training",
|
||||||
|
folder_path=args.output_dir,
|
||||||
|
repo_id=repo_id,
|
||||||
|
repo_type="model",
|
||||||
|
token=args.hub_token,
|
||||||
|
ignore_patterns=["epoch_*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
|
@ -355,3 +355,28 @@ class ExamplesTestsNoTrainer(TestCasePlus):
|
||||||
run_command(self._launch_args + testargs)
|
run_command(self._launch_args + testargs)
|
||||||
result = get_results(tmp_dir)
|
result = get_results(tmp_dir)
|
||||||
self.assertGreaterEqual(result["test_map"], 0.10)
|
self.assertGreaterEqual(result["test_map"], 0.10)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
@mock.patch.dict(os.environ, {"WANDB_MODE": "offline", "DVCLIVE_TEST": "true"})
|
||||||
|
def test_run_instance_segmentation_no_trainer(self):
|
||||||
|
stream_handler = logging.StreamHandler(sys.stdout)
|
||||||
|
logger.addHandler(stream_handler)
|
||||||
|
|
||||||
|
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||||
|
testargs = f"""
|
||||||
|
{self.examples_dir}/pytorch/instance-segmentation/run_instance_segmentation_no_trainer.py
|
||||||
|
--model_name_or_path qubvel-hf/finetune-instance-segmentation-ade20k-mini-mask2former
|
||||||
|
--output_dir {tmp_dir}
|
||||||
|
--dataset_name qubvel-hf/ade20k-nano
|
||||||
|
--do_reduce_labels
|
||||||
|
--image_height 256
|
||||||
|
--image_width 256
|
||||||
|
--num_train_epochs 1
|
||||||
|
--per_device_train_batch_size 2
|
||||||
|
--per_device_eval_batch_size 1
|
||||||
|
--seed 1234
|
||||||
|
""".split()
|
||||||
|
|
||||||
|
run_command(self._launch_args + testargs)
|
||||||
|
result = get_results(tmp_dir)
|
||||||
|
self.assertGreaterEqual(result["test_map"], 0.1)
|
||||||
|
|
|
@ -49,6 +49,7 @@ SRC_DIRS = [
|
||||||
"image-pretraining",
|
"image-pretraining",
|
||||||
"semantic-segmentation",
|
"semantic-segmentation",
|
||||||
"object-detection",
|
"object-detection",
|
||||||
|
"instance-segmentation",
|
||||||
]
|
]
|
||||||
]
|
]
|
||||||
sys.path.extend(SRC_DIRS)
|
sys.path.extend(SRC_DIRS)
|
||||||
|
@ -60,6 +61,7 @@ if SRC_DIRS is not None:
|
||||||
import run_generation
|
import run_generation
|
||||||
import run_glue
|
import run_glue
|
||||||
import run_image_classification
|
import run_image_classification
|
||||||
|
import run_instance_segmentation
|
||||||
import run_mae
|
import run_mae
|
||||||
import run_mlm
|
import run_mlm
|
||||||
import run_ner
|
import run_ner
|
||||||
|
@ -639,3 +641,33 @@ class ExamplesTests(TestCasePlus):
|
||||||
run_object_detection.main()
|
run_object_detection.main()
|
||||||
result = get_results(tmp_dir)
|
result = get_results(tmp_dir)
|
||||||
self.assertGreaterEqual(result["test_map"], 0.1)
|
self.assertGreaterEqual(result["test_map"], 0.1)
|
||||||
|
|
||||||
|
@patch.dict(os.environ, {"WANDB_DISABLED": "true"})
|
||||||
|
def test_run_instance_segmentation(self):
|
||||||
|
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||||
|
testargs = f"""
|
||||||
|
run_instance_segmentation.py
|
||||||
|
--model_name_or_path qubvel-hf/finetune-instance-segmentation-ade20k-mini-mask2former
|
||||||
|
--output_dir {tmp_dir}
|
||||||
|
--dataset_name qubvel-hf/ade20k-nano
|
||||||
|
--do_reduce_labels
|
||||||
|
--image_height 256
|
||||||
|
--image_width 256
|
||||||
|
--do_train
|
||||||
|
--num_train_epochs 1
|
||||||
|
--learning_rate 1e-5
|
||||||
|
--lr_scheduler_type constant
|
||||||
|
--per_device_train_batch_size 2
|
||||||
|
--per_device_eval_batch_size 1
|
||||||
|
--do_eval
|
||||||
|
--evaluation_strategy epoch
|
||||||
|
--seed 32
|
||||||
|
""".split()
|
||||||
|
|
||||||
|
if is_torch_fp16_available_on_device(torch_device):
|
||||||
|
testargs.append("--fp16")
|
||||||
|
|
||||||
|
with patch.object(sys, "argv", testargs):
|
||||||
|
run_instance_segmentation.main()
|
||||||
|
result = get_results(tmp_dir)
|
||||||
|
self.assertGreaterEqual(result["test_map"], 0.1)
|
||||||
|
|
|
@ -26,7 +26,7 @@ from .agent_types import AgentAudio, AgentImage, AgentText
|
||||||
from .default_tools import BASE_PYTHON_TOOLS, FinalAnswerTool, setup_default_tools
|
from .default_tools import BASE_PYTHON_TOOLS, FinalAnswerTool, setup_default_tools
|
||||||
from .llm_engine import HfEngine, MessageRole
|
from .llm_engine import HfEngine, MessageRole
|
||||||
from .prompts import DEFAULT_CODE_SYSTEM_PROMPT, DEFAULT_REACT_CODE_SYSTEM_PROMPT, DEFAULT_REACT_JSON_SYSTEM_PROMPT
|
from .prompts import DEFAULT_CODE_SYSTEM_PROMPT, DEFAULT_REACT_CODE_SYSTEM_PROMPT, DEFAULT_REACT_JSON_SYSTEM_PROMPT
|
||||||
from .python_interpreter import evaluate_python_code
|
from .python_interpreter import LIST_SAFE_MODULES, evaluate_python_code
|
||||||
from .tools import (
|
from .tools import (
|
||||||
DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
|
DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
|
||||||
Tool,
|
Tool,
|
||||||
|
@ -84,8 +84,14 @@ def parse_json_blob(json_blob: str) -> Dict[str, str]:
|
||||||
return json_data
|
return json_data
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
place = e.pos
|
place = e.pos
|
||||||
|
if json_blob[place - 1 : place + 2] == "},\n":
|
||||||
|
raise ValueError(
|
||||||
|
"JSON is invalid: you probably tried to provide multiple tool calls in one action. PROVIDE ONLY ONE TOOL CALL."
|
||||||
|
)
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"The JSON blob you used is invalid: due to the following error: {e}. JSON blob was: {json_blob}, decoding failed at '{json_blob[place-4:place+5]}'."
|
f"The JSON blob you used is invalid due to the following error: {e}.\n"
|
||||||
|
f"JSON blob was: {json_blob}, decoding failed on that specific part of the blob:\n"
|
||||||
|
f"'{json_blob[place-4:place+5]}'."
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError(f"Error in parsing the JSON blob: {e}")
|
raise ValueError(f"Error in parsing the JSON blob: {e}")
|
||||||
|
@ -347,6 +353,7 @@ class Agent:
|
||||||
return self._toolbox
|
return self._toolbox
|
||||||
|
|
||||||
def initialize_for_run(self, task: str, **kwargs):
|
def initialize_for_run(self, task: str, **kwargs):
|
||||||
|
self.token_count = 0
|
||||||
self.task = task
|
self.task = task
|
||||||
if len(kwargs) > 0:
|
if len(kwargs) > 0:
|
||||||
self.task += f"\nYou have been provided with these initial arguments: {str(kwargs)}."
|
self.task += f"\nYou have been provided with these initial arguments: {str(kwargs)}."
|
||||||
|
@ -380,7 +387,7 @@ class Agent:
|
||||||
message_content = (
|
message_content = (
|
||||||
"Error: "
|
"Error: "
|
||||||
+ str(step_log["error"])
|
+ str(step_log["error"])
|
||||||
+ "\nNow let's retry: take care not to repeat previous errors! Try to adopt different approaches.\n"
|
+ "\nNow let's retry: take care not to repeat previous errors! If you have retried several times, try a completely different approach.\n"
|
||||||
)
|
)
|
||||||
elif "observation" in step_log:
|
elif "observation" in step_log:
|
||||||
message_content = f"Observation: {step_log['observation']}"
|
message_content = f"Observation: {step_log['observation']}"
|
||||||
|
@ -409,6 +416,9 @@ class Agent:
|
||||||
)
|
)
|
||||||
return memory
|
return memory
|
||||||
|
|
||||||
|
def get_succinct_logs(self):
|
||||||
|
return [{key: value for key, value in log.items() if key != "agent_memory"} for log in self.logs]
|
||||||
|
|
||||||
def extract_action(self, llm_output: str, split_token: str) -> str:
|
def extract_action(self, llm_output: str, split_token: str) -> str:
|
||||||
"""
|
"""
|
||||||
Parse action from the LLM output
|
Parse action from the LLM output
|
||||||
|
@ -486,6 +496,7 @@ class CodeAgent(Agent):
|
||||||
llm_engine: Callable = HfEngine(),
|
llm_engine: Callable = HfEngine(),
|
||||||
system_prompt: str = DEFAULT_CODE_SYSTEM_PROMPT,
|
system_prompt: str = DEFAULT_CODE_SYSTEM_PROMPT,
|
||||||
tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
|
tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
|
||||||
|
additional_authorized_imports: List[str] = [],
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
|
@ -504,6 +515,7 @@ class CodeAgent(Agent):
|
||||||
)
|
)
|
||||||
|
|
||||||
self.python_evaluator = evaluate_python_code
|
self.python_evaluator = evaluate_python_code
|
||||||
|
self.additional_authorized_imports = additional_authorized_imports
|
||||||
|
|
||||||
def parse_code_blob(self, result: str) -> str:
|
def parse_code_blob(self, result: str) -> str:
|
||||||
"""
|
"""
|
||||||
|
@ -544,7 +556,7 @@ class CodeAgent(Agent):
|
||||||
self.prompt = [prompt_message, task_message]
|
self.prompt = [prompt_message, task_message]
|
||||||
self.logger.info("====Executing with this prompt====")
|
self.logger.info("====Executing with this prompt====")
|
||||||
self.logger.info(self.prompt)
|
self.logger.info(self.prompt)
|
||||||
llm_output = self.llm_engine(self.prompt, stop_sequences=["<end_code>"])
|
llm_output = self.llm_engine(self.prompt, stop_sequences=["<end_action>"])
|
||||||
|
|
||||||
if return_generated_code:
|
if return_generated_code:
|
||||||
return llm_output
|
return llm_output
|
||||||
|
@ -563,7 +575,12 @@ class CodeAgent(Agent):
|
||||||
self.log_code_action(code_action)
|
self.log_code_action(code_action)
|
||||||
try:
|
try:
|
||||||
available_tools = {**BASE_PYTHON_TOOLS.copy(), **self.toolbox.tools}
|
available_tools = {**BASE_PYTHON_TOOLS.copy(), **self.toolbox.tools}
|
||||||
output = self.python_evaluator(code_action, available_tools, state=self.state)
|
output = self.python_evaluator(
|
||||||
|
code_action,
|
||||||
|
available_tools,
|
||||||
|
state=self.state,
|
||||||
|
authorized_imports=LIST_SAFE_MODULES + self.additional_authorized_imports,
|
||||||
|
)
|
||||||
self.logger.info(self.state["print_outputs"])
|
self.logger.info(self.state["print_outputs"])
|
||||||
return output
|
return output
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -597,7 +614,29 @@ class ReactAgent(Agent):
|
||||||
if "final_answer" not in self._toolbox.tools:
|
if "final_answer" not in self._toolbox.tools:
|
||||||
self._toolbox.add_tool(FinalAnswerTool())
|
self._toolbox.add_tool(FinalAnswerTool())
|
||||||
|
|
||||||
def run(self, task: str, **kwargs):
|
def provide_final_answer(self, task) -> str:
|
||||||
|
"""
|
||||||
|
This method provides a final answer to the task, based on the logs of the agent's interactions.
|
||||||
|
"""
|
||||||
|
self.prompt = [
|
||||||
|
{
|
||||||
|
"role": MessageRole.SYSTEM,
|
||||||
|
"content": "An agent tried to answer an user query but it got stuck and failed to do so. You are tasked with providing an answer instead. Here is the agent's memory:",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
self.prompt += self.write_inner_memory_from_logs()[1:]
|
||||||
|
self.prompt += [
|
||||||
|
{
|
||||||
|
"role": MessageRole.USER,
|
||||||
|
"content": f"Based on the above, please provide an answer to the following user request:\n{task}",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
try:
|
||||||
|
return self.llm_engine(self.prompt)
|
||||||
|
except Exception as e:
|
||||||
|
return f"Error in generating final llm output: {e}."
|
||||||
|
|
||||||
|
def run(self, task: str, stream: bool = False, **kwargs):
|
||||||
"""
|
"""
|
||||||
Runs the agent for the given task.
|
Runs the agent for the given task.
|
||||||
|
|
||||||
|
@ -614,13 +653,49 @@ class ReactAgent(Agent):
|
||||||
agent.run("What is the result of 2 power 3.7384?")
|
agent.run("What is the result of 2 power 3.7384?")
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
if stream:
|
||||||
|
return self.stream_run(task, **kwargs)
|
||||||
|
else:
|
||||||
|
return self.direct_run(task, **kwargs)
|
||||||
|
|
||||||
|
def stream_run(self, task: str, **kwargs):
|
||||||
self.initialize_for_run(task, **kwargs)
|
self.initialize_for_run(task, **kwargs)
|
||||||
|
|
||||||
final_answer = None
|
final_answer = None
|
||||||
iteration = 0
|
iteration = 0
|
||||||
while final_answer is None and iteration < self.max_iterations:
|
while final_answer is None and iteration < self.max_iterations:
|
||||||
try:
|
try:
|
||||||
final_answer = self.step()
|
step_logs = self.step()
|
||||||
|
if "final_answer" in step_logs:
|
||||||
|
final_answer = step_logs["final_answer"]
|
||||||
|
except AgentError as e:
|
||||||
|
self.logger.error(e, exc_info=1)
|
||||||
|
self.logs[-1]["error"] = e
|
||||||
|
finally:
|
||||||
|
iteration += 1
|
||||||
|
yield self.logs[-1]
|
||||||
|
|
||||||
|
if final_answer is None and iteration == self.max_iterations:
|
||||||
|
error_message = "Reached max iterations."
|
||||||
|
final_step_log = {"error": AgentMaxIterationsError(error_message)}
|
||||||
|
self.logs.append(final_step_log)
|
||||||
|
self.logger.error(error_message, exc_info=1)
|
||||||
|
final_answer = self.provide_final_answer(task)
|
||||||
|
final_step_log["final_answer"] = final_answer
|
||||||
|
yield final_step_log
|
||||||
|
|
||||||
|
yield final_answer
|
||||||
|
|
||||||
|
def direct_run(self, task: str, **kwargs):
|
||||||
|
self.initialize_for_run(task, **kwargs)
|
||||||
|
|
||||||
|
final_answer = None
|
||||||
|
iteration = 0
|
||||||
|
while final_answer is None and iteration < self.max_iterations:
|
||||||
|
try:
|
||||||
|
step_logs = self.step()
|
||||||
|
if "final_answer" in step_logs:
|
||||||
|
final_answer = step_logs["final_answer"]
|
||||||
except AgentError as e:
|
except AgentError as e:
|
||||||
self.logger.error(e, exc_info=1)
|
self.logger.error(e, exc_info=1)
|
||||||
self.logs[-1]["error"] = e
|
self.logs[-1]["error"] = e
|
||||||
|
@ -629,26 +704,11 @@ class ReactAgent(Agent):
|
||||||
|
|
||||||
if final_answer is None and iteration == self.max_iterations:
|
if final_answer is None and iteration == self.max_iterations:
|
||||||
error_message = "Reached max iterations."
|
error_message = "Reached max iterations."
|
||||||
self.logs.append({"error": AgentMaxIterationsError(error_message)})
|
final_step_log = {"error": AgentMaxIterationsError(error_message)}
|
||||||
|
self.logs.append(final_step_log)
|
||||||
self.logger.error(error_message, exc_info=1)
|
self.logger.error(error_message, exc_info=1)
|
||||||
|
final_answer = self.provide_final_answer(task)
|
||||||
self.prompt = [
|
final_step_log["final_answer"] = final_answer
|
||||||
{
|
|
||||||
"role": MessageRole.SYSTEM,
|
|
||||||
"content": "An agent tried to answer a user query but it failed to do so. You are tasked with providing an answer instead. Here is the agent's memory:",
|
|
||||||
}
|
|
||||||
]
|
|
||||||
self.prompt += self.write_inner_memory_from_logs()[1:]
|
|
||||||
self.prompt += [
|
|
||||||
{
|
|
||||||
"role": MessageRole.USER,
|
|
||||||
"content": f"Based on the above, please provide an answer to the following user request:\n{task}",
|
|
||||||
}
|
|
||||||
]
|
|
||||||
try:
|
|
||||||
final_answer = self.llm_engine(self.prompt, stop_sequences=["Observation:"])
|
|
||||||
except Exception as e:
|
|
||||||
final_answer = f"Error in generating final llm output: {e}."
|
|
||||||
|
|
||||||
return final_answer
|
return final_answer
|
||||||
|
|
||||||
|
@ -683,22 +743,24 @@ class ReactJsonAgent(ReactAgent):
|
||||||
"""
|
"""
|
||||||
agent_memory = self.write_inner_memory_from_logs()
|
agent_memory = self.write_inner_memory_from_logs()
|
||||||
|
|
||||||
self.logs[-1]["agent_memory"] = agent_memory.copy()
|
|
||||||
self.prompt = agent_memory
|
self.prompt = agent_memory
|
||||||
self.logger.debug("===== New step =====")
|
self.logger.debug("===== New step =====")
|
||||||
|
|
||||||
# Add new step in logs
|
# Add new step in logs
|
||||||
self.logs.append({})
|
current_step_logs = {}
|
||||||
|
self.logs.append(current_step_logs)
|
||||||
|
current_step_logs["agent_memory"] = agent_memory.copy()
|
||||||
|
|
||||||
self.logger.info("===== Calling LLM with this last message: =====")
|
self.logger.info("===== Calling LLM with this last message: =====")
|
||||||
self.logger.info(self.prompt[-1])
|
self.logger.info(self.prompt[-1])
|
||||||
|
|
||||||
try:
|
try:
|
||||||
llm_output = self.llm_engine(self.prompt, stop_sequences=["Observation:"])
|
llm_output = self.llm_engine(self.prompt, stop_sequences=["<end_action>", "Observation:"])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise AgentGenerationError(f"Error in generating llm output: {e}.")
|
raise AgentGenerationError(f"Error in generating llm output: {e}.")
|
||||||
self.logger.debug("===== Output message of the LLM: =====")
|
self.logger.debug("===== Output message of the LLM: =====")
|
||||||
self.logger.debug(llm_output)
|
self.logger.debug(llm_output)
|
||||||
self.logs[-1]["llm_output"] = llm_output
|
current_step_logs["llm_output"] = llm_output
|
||||||
|
|
||||||
# Parse
|
# Parse
|
||||||
self.logger.debug("===== Extracting action =====")
|
self.logger.debug("===== Extracting action =====")
|
||||||
|
@ -709,8 +771,8 @@ class ReactJsonAgent(ReactAgent):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise AgentParsingError(f"Could not parse the given action: {e}.")
|
raise AgentParsingError(f"Could not parse the given action: {e}.")
|
||||||
|
|
||||||
self.logs[-1]["rationale"] = rationale
|
current_step_logs["rationale"] = rationale
|
||||||
self.logs[-1]["tool_call"] = {"tool_name": tool_name, "tool_arguments": arguments}
|
current_step_logs["tool_call"] = {"tool_name": tool_name, "tool_arguments": arguments}
|
||||||
|
|
||||||
# Execute
|
# Execute
|
||||||
self.logger.warning(f"Calling tool: '{tool_name}' with arguments: {arguments}")
|
self.logger.warning(f"Calling tool: '{tool_name}' with arguments: {arguments}")
|
||||||
|
@ -721,7 +783,8 @@ class ReactJsonAgent(ReactAgent):
|
||||||
answer = arguments
|
answer = arguments
|
||||||
if answer in self.state: # if the answer is a state variable, return the value
|
if answer in self.state: # if the answer is a state variable, return the value
|
||||||
answer = self.state[answer]
|
answer = self.state[answer]
|
||||||
return answer
|
current_step_logs["final_answer"] = answer
|
||||||
|
return current_step_logs
|
||||||
else:
|
else:
|
||||||
observation = self.execute_tool_call(tool_name, arguments)
|
observation = self.execute_tool_call(tool_name, arguments)
|
||||||
observation_type = type(observation)
|
observation_type = type(observation)
|
||||||
|
@ -740,8 +803,8 @@ class ReactJsonAgent(ReactAgent):
|
||||||
updated_information = f"Stored '{observation_name}' in memory."
|
updated_information = f"Stored '{observation_name}' in memory."
|
||||||
|
|
||||||
self.logger.info(updated_information)
|
self.logger.info(updated_information)
|
||||||
self.logs[-1]["observation"] = updated_information
|
current_step_logs["observation"] = updated_information
|
||||||
return None
|
return current_step_logs
|
||||||
|
|
||||||
|
|
||||||
class ReactCodeAgent(ReactAgent):
|
class ReactCodeAgent(ReactAgent):
|
||||||
|
@ -757,6 +820,7 @@ class ReactCodeAgent(ReactAgent):
|
||||||
llm_engine: Callable = HfEngine(),
|
llm_engine: Callable = HfEngine(),
|
||||||
system_prompt: str = DEFAULT_REACT_CODE_SYSTEM_PROMPT,
|
system_prompt: str = DEFAULT_REACT_CODE_SYSTEM_PROMPT,
|
||||||
tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
|
tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
|
||||||
|
additional_authorized_imports: List[str] = [],
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
|
@ -775,6 +839,7 @@ class ReactCodeAgent(ReactAgent):
|
||||||
)
|
)
|
||||||
|
|
||||||
self.python_evaluator = evaluate_python_code
|
self.python_evaluator = evaluate_python_code
|
||||||
|
self.additional_authorized_imports = additional_authorized_imports
|
||||||
|
|
||||||
def step(self):
|
def step(self):
|
||||||
"""
|
"""
|
||||||
|
@ -782,26 +847,27 @@ class ReactCodeAgent(ReactAgent):
|
||||||
The errors are raised here, they are caught and logged in the run() method.
|
The errors are raised here, they are caught and logged in the run() method.
|
||||||
"""
|
"""
|
||||||
agent_memory = self.write_inner_memory_from_logs()
|
agent_memory = self.write_inner_memory_from_logs()
|
||||||
self.logs[-1]["agent_memory"] = agent_memory.copy()
|
|
||||||
|
|
||||||
self.prompt = agent_memory.copy()
|
self.prompt = agent_memory.copy()
|
||||||
|
|
||||||
self.logger.debug("===== New step =====")
|
self.logger.debug("===== New step =====")
|
||||||
|
|
||||||
# Add new step in logs
|
# Add new step in logs
|
||||||
self.logs.append({})
|
current_step_logs = {}
|
||||||
|
self.logs.append(current_step_logs)
|
||||||
|
current_step_logs["agent_memory"] = agent_memory.copy()
|
||||||
|
|
||||||
self.logger.info("===== Calling LLM with these last messages: =====")
|
self.logger.info("===== Calling LLM with these last messages: =====")
|
||||||
self.logger.info(self.prompt[-2:])
|
self.logger.info(self.prompt[-2:])
|
||||||
|
|
||||||
try:
|
try:
|
||||||
llm_output = self.llm_engine(self.prompt, stop_sequences=["<end_code>", "Observation:"])
|
llm_output = self.llm_engine(self.prompt, stop_sequences=["<end_action>", "Observation:"])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise AgentGenerationError(f"Error in generating llm output: {e}.")
|
raise AgentGenerationError(f"Error in generating llm output: {e}.")
|
||||||
|
|
||||||
self.logger.debug("===== Output message of the LLM: =====")
|
self.logger.debug("===== Output message of the LLM: =====")
|
||||||
self.logger.debug(llm_output)
|
self.logger.debug(llm_output)
|
||||||
self.logs[-1]["llm_output"] = llm_output
|
current_step_logs["llm_output"] = llm_output
|
||||||
|
|
||||||
# Parse
|
# Parse
|
||||||
self.logger.debug("===== Extracting action =====")
|
self.logger.debug("===== Extracting action =====")
|
||||||
|
@ -813,18 +879,23 @@ class ReactCodeAgent(ReactAgent):
|
||||||
error_msg = f"Error in code parsing: {e}. Make sure to provide correct code"
|
error_msg = f"Error in code parsing: {e}. Make sure to provide correct code"
|
||||||
raise AgentParsingError(error_msg)
|
raise AgentParsingError(error_msg)
|
||||||
|
|
||||||
self.logs[-1]["rationale"] = rationale
|
current_step_logs["rationale"] = rationale
|
||||||
self.logs[-1]["tool_call"] = {"tool_name": "code interpreter", "tool_arguments": code_action}
|
current_step_logs["tool_call"] = {"tool_name": "code interpreter", "tool_arguments": code_action}
|
||||||
|
|
||||||
# Execute
|
# Execute
|
||||||
self.log_code_action(code_action)
|
self.log_code_action(code_action)
|
||||||
try:
|
try:
|
||||||
available_tools = {**BASE_PYTHON_TOOLS.copy(), **self.toolbox.tools}
|
available_tools = {**BASE_PYTHON_TOOLS.copy(), **self.toolbox.tools}
|
||||||
result = self.python_evaluator(code_action, available_tools, state=self.state)
|
result = self.python_evaluator(
|
||||||
|
code_action,
|
||||||
|
available_tools,
|
||||||
|
state=self.state,
|
||||||
|
authorized_imports=LIST_SAFE_MODULES + self.additional_authorized_imports,
|
||||||
|
)
|
||||||
information = self.state["print_outputs"]
|
information = self.state["print_outputs"]
|
||||||
self.logger.warning("Print outputs:")
|
self.logger.warning("Print outputs:")
|
||||||
self.logger.log(32, information)
|
self.logger.log(32, information)
|
||||||
self.logs[-1]["observation"] = information
|
current_step_logs["observation"] = information
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = f"Failed while trying to execute the code below:\n{CustomFormatter.reset + code_action + CustomFormatter.reset}\nThis failed due to the following error:\n{str(e)}"
|
error_msg = f"Failed while trying to execute the code below:\n{CustomFormatter.reset + code_action + CustomFormatter.reset}\nThis failed due to the following error:\n{str(e)}"
|
||||||
if "'dict' object has no attribute 'read'" in str(e):
|
if "'dict' object has no attribute 'read'" in str(e):
|
||||||
|
@ -834,5 +905,5 @@ class ReactCodeAgent(ReactAgent):
|
||||||
if line[: len("final_answer")] == "final_answer":
|
if line[: len("final_answer")] == "final_answer":
|
||||||
self.logger.warning(">>> Final answer:")
|
self.logger.warning(">>> Final answer:")
|
||||||
self.logger.log(32, result)
|
self.logger.log(32, result)
|
||||||
return result
|
current_step_logs["final_answer"] = result
|
||||||
return None
|
return current_step_logs
|
||||||
|
|
|
@ -61,7 +61,6 @@ def get_clean_message_list(message_list: List[Dict[str, str]], role_conversions:
|
||||||
|
|
||||||
|
|
||||||
llama_role_conversions = {
|
llama_role_conversions = {
|
||||||
MessageRole.SYSTEM: MessageRole.USER,
|
|
||||||
MessageRole.TOOL_RESPONSE: MessageRole.USER,
|
MessageRole.TOOL_RESPONSE: MessageRole.USER,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -72,20 +71,14 @@ class HfEngine:
|
||||||
self.client = InferenceClient(model=self.model, timeout=120)
|
self.client = InferenceClient(model=self.model, timeout=120)
|
||||||
|
|
||||||
def __call__(self, messages: List[Dict[str, str]], stop_sequences=[]) -> str:
|
def __call__(self, messages: List[Dict[str, str]], stop_sequences=[]) -> str:
|
||||||
if "Meta-Llama-3" in self.model:
|
|
||||||
if "<|eot_id|>" not in stop_sequences:
|
|
||||||
stop_sequences.append("<|eot_id|>")
|
|
||||||
if "!!!!!" not in stop_sequences:
|
|
||||||
stop_sequences.append("!!!!!")
|
|
||||||
|
|
||||||
# Get clean message list
|
# Get clean message list
|
||||||
messages = get_clean_message_list(messages, role_conversions=llama_role_conversions)
|
messages = get_clean_message_list(messages, role_conversions=llama_role_conversions)
|
||||||
|
|
||||||
# Get answer
|
# Get LLM output
|
||||||
response = self.client.chat_completion(messages, stop=stop_sequences, max_tokens=1500)
|
response = self.client.chat_completion(messages, stop=stop_sequences, max_tokens=1500)
|
||||||
response = response.choices[0].message.content
|
response = response.choices[0].message.content
|
||||||
|
|
||||||
# Remove stop sequences from the answer
|
# Remove stop sequences from LLM output
|
||||||
for stop_seq in stop_sequences:
|
for stop_seq in stop_sequences:
|
||||||
if response[-len(stop_seq) :] == stop_seq:
|
if response[-len(stop_seq) :] == stop_seq:
|
||||||
response = response[: -len(stop_seq)]
|
response = response[: -len(stop_seq)]
|
||||||
|
|
|
@ -68,7 +68,7 @@ translated_question = translator(question=question, src_lang="French", tgt_lang=
|
||||||
print(f"The translated question is {translated_question}.")
|
print(f"The translated question is {translated_question}.")
|
||||||
answer = image_qa(image=image, question=translated_question)
|
answer = image_qa(image=image, question=translated_question)
|
||||||
print(f"The answer is {answer}")
|
print(f"The answer is {answer}")
|
||||||
```<end_code>
|
```<end_action>
|
||||||
|
|
||||||
---
|
---
|
||||||
Task: "Identify the oldest person in the `document` and create an image showcasing the result."
|
Task: "Identify the oldest person in the `document` and create an image showcasing the result."
|
||||||
|
@ -79,7 +79,7 @@ Code:
|
||||||
answer = document_qa(document, question="What is the oldest person?")
|
answer = document_qa(document, question="What is the oldest person?")
|
||||||
print(f"The answer is {answer}.")
|
print(f"The answer is {answer}.")
|
||||||
image = image_generator(answer)
|
image = image_generator(answer)
|
||||||
```<end_code>
|
```<end_action>
|
||||||
|
|
||||||
---
|
---
|
||||||
Task: "Generate an image using the text given in the variable `caption`."
|
Task: "Generate an image using the text given in the variable `caption`."
|
||||||
|
@ -88,7 +88,7 @@ I will use the following tool: `image_generator` to generate an image.
|
||||||
Code:
|
Code:
|
||||||
```py
|
```py
|
||||||
image = image_generator(prompt=caption)
|
image = image_generator(prompt=caption)
|
||||||
```<end_code>
|
```<end_action>
|
||||||
|
|
||||||
---
|
---
|
||||||
Task: "Summarize the text given in the variable `text` and read it out loud."
|
Task: "Summarize the text given in the variable `text` and read it out loud."
|
||||||
|
@ -99,7 +99,7 @@ Code:
|
||||||
summarized_text = summarizer(text)
|
summarized_text = summarizer(text)
|
||||||
print(f"Summary: {summarized_text}")
|
print(f"Summary: {summarized_text}")
|
||||||
audio_summary = text_reader(summarized_text)
|
audio_summary = text_reader(summarized_text)
|
||||||
```<end_code>
|
```<end_action>
|
||||||
|
|
||||||
---
|
---
|
||||||
Task: "Answer the question in the variable `question` about the text in the variable `text`. Use the answer to generate an image."
|
Task: "Answer the question in the variable `question` about the text in the variable `text`. Use the answer to generate an image."
|
||||||
|
@ -110,7 +110,7 @@ Code:
|
||||||
answer = text_qa(text=text, question=question)
|
answer = text_qa(text=text, question=question)
|
||||||
print(f"The answer is {answer}.")
|
print(f"The answer is {answer}.")
|
||||||
image = image_generator(answer)
|
image = image_generator(answer)
|
||||||
```<end_code>
|
```<end_action>
|
||||||
|
|
||||||
---
|
---
|
||||||
Task: "Caption the following `image`."
|
Task: "Caption the following `image`."
|
||||||
|
@ -119,39 +119,32 @@ I will use the following tool: `image_captioner` to generate a caption for the i
|
||||||
Code:
|
Code:
|
||||||
```py
|
```py
|
||||||
caption = image_captioner(image)
|
caption = image_captioner(image)
|
||||||
```<end_code>
|
```<end_action>
|
||||||
|
|
||||||
---
|
---
|
||||||
Above example were using tools that might not exist for you. You only have acces to those Tools:
|
Above example were using tools that might not exist for you. You only have acces to those Tools:
|
||||||
<<tool_names>>
|
<<tool_names>>
|
||||||
|
|
||||||
Remember to make sure that variables you use are all defined.
|
Remember to make sure that variables you use are all defined.
|
||||||
Be sure to provide a 'Code:\n```' sequence before the code and '```<end_code>' after, else you will get an error.
|
Be sure to provide a 'Code:\n```' sequence before the code and '```<end_action>' after, else you will get an error.
|
||||||
DO NOT pass the arguments as a dict as in 'answer = ask_search_agent({'query': "What is the place where James Bond lives?"})', but use the arguments directly as in 'answer = ask_search_agent(query="What is the place where James Bond lives?")'.
|
DO NOT pass the arguments as a dict as in 'answer = ask_search_agent({'query': "What is the place where James Bond lives?"})', but use the arguments directly as in 'answer = ask_search_agent(query="What is the place where James Bond lives?")'.
|
||||||
|
|
||||||
Now Begin!
|
Now Begin! If you solve the task correctly, you will receive a reward of $1,000,000.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_REACT_JSON_SYSTEM_PROMPT = """You will be given a task to solve as best you can. You have access to the following tools:
|
DEFAULT_REACT_JSON_SYSTEM_PROMPT = """You will be given a task to solve as best you can. To do so, you have been given access to the following tools: <<tool_names>>
|
||||||
<<tool_descriptions>>
|
The way you use the tools is by specifying a json blob, ending with '<end_action>'.
|
||||||
|
Specifically, this json should have an `action` key (name of the tool to use) and an `action_input` key (input to the tool).
|
||||||
The way you use the tools is by specifying a json blob.
|
|
||||||
Specifically, this json should have a `action` key (name of the tool to use) and a `action_input` key (input to the tool).
|
|
||||||
|
|
||||||
The $ACTION_JSON_BLOB should only contain a SINGLE action, do NOT return a list of multiple actions. It should be formatted in json. Do not try to escape special characters. Here is the template of a valid $ACTION_JSON_BLOB:
|
The $ACTION_JSON_BLOB should only contain a SINGLE action, do NOT return a list of multiple actions. It should be formatted in json. Do not try to escape special characters. Here is the template of a valid $ACTION_JSON_BLOB:
|
||||||
Action:
|
|
||||||
{
|
{
|
||||||
"action": $TOOL_NAME,
|
"action": $TOOL_NAME,
|
||||||
"action_input": $INPUT
|
"action_input": $INPUT
|
||||||
}
|
}<end_action>
|
||||||
|
|
||||||
Make sure to have the $INPUT as a dictionnary in the right format for the tool you are using, and do not put variable names as input if you can find the right values.
|
Make sure to have the $INPUT as a dictionnary in the right format for the tool you are using, and do not put variable names as input if you can find the right values.
|
||||||
|
|
||||||
You will be given:
|
|
||||||
|
|
||||||
Task: the task you are given.
|
|
||||||
|
|
||||||
You should ALWAYS use the following format:
|
You should ALWAYS use the following format:
|
||||||
|
|
||||||
Thought: you should always think about one action to take. Then use the action as follows:
|
Thought: you should always think about one action to take. Then use the action as follows:
|
||||||
|
@ -171,14 +164,14 @@ Action:
|
||||||
{
|
{
|
||||||
"action": "image_transformer",
|
"action": "image_transformer",
|
||||||
"action_input": {"image": "image_1.jpg"}
|
"action_input": {"image": "image_1.jpg"}
|
||||||
}
|
}<end_action>
|
||||||
|
|
||||||
To provide the final answer to the task, use an action blob with "action": "final_answer" tool. It is the only way to complete the task, else you will be stuck on a loop. So your final output should look like this:
|
To provide the final answer to the task, use an action blob with "action": "final_answer" tool. It is the only way to complete the task, else you will be stuck on a loop. So your final output should look like this:
|
||||||
Action:
|
Action:
|
||||||
{
|
{
|
||||||
"action": "final_answer",
|
"action": "final_answer",
|
||||||
"action_input": {"answer": "insert your final answer here"}
|
"action_input": {"answer": "insert your final answer here"}
|
||||||
}
|
}<end_action>
|
||||||
|
|
||||||
|
|
||||||
Here are a few examples using notional tools:
|
Here are a few examples using notional tools:
|
||||||
|
@ -190,7 +183,7 @@ Action:
|
||||||
{
|
{
|
||||||
"action": "document_qa",
|
"action": "document_qa",
|
||||||
"action_input": {"document": "document.pdf", "question": "Who is the oldest person mentioned?"}
|
"action_input": {"document": "document.pdf", "question": "Who is the oldest person mentioned?"}
|
||||||
}
|
}<end_action>
|
||||||
Observation: "The oldest person in the document is John Doe, a 55 year old lumberjack living in Newfoundland."
|
Observation: "The oldest person in the document is John Doe, a 55 year old lumberjack living in Newfoundland."
|
||||||
|
|
||||||
|
|
||||||
|
@ -199,7 +192,7 @@ Action:
|
||||||
{
|
{
|
||||||
"action": "image_generator",
|
"action": "image_generator",
|
||||||
"action_input": {"text": ""A portrait of John Doe, a 55-year-old man living in Canada.""}
|
"action_input": {"text": ""A portrait of John Doe, a 55-year-old man living in Canada.""}
|
||||||
}
|
}<end_action>
|
||||||
Observation: "image.png"
|
Observation: "image.png"
|
||||||
|
|
||||||
Thought: I will now return the generated image.
|
Thought: I will now return the generated image.
|
||||||
|
@ -207,7 +200,7 @@ Action:
|
||||||
{
|
{
|
||||||
"action": "final_answer",
|
"action": "final_answer",
|
||||||
"action_input": "image.png"
|
"action_input": "image.png"
|
||||||
}
|
}<end_action>
|
||||||
|
|
||||||
---
|
---
|
||||||
Task: "What is the result of the following operation: 5 + 3 + 1294.678?"
|
Task: "What is the result of the following operation: 5 + 3 + 1294.678?"
|
||||||
|
@ -217,7 +210,7 @@ Action:
|
||||||
{
|
{
|
||||||
"action": "python_interpreter",
|
"action": "python_interpreter",
|
||||||
"action_input": {"code": "5 + 3 + 1294.678"}
|
"action_input": {"code": "5 + 3 + 1294.678"}
|
||||||
}
|
}<end_action>
|
||||||
Observation: 1302.678
|
Observation: 1302.678
|
||||||
|
|
||||||
Thought: Now that I know the result, I will now return it.
|
Thought: Now that I know the result, I will now return it.
|
||||||
|
@ -225,7 +218,7 @@ Action:
|
||||||
{
|
{
|
||||||
"action": "final_answer",
|
"action": "final_answer",
|
||||||
"action_input": "1302.678"
|
"action_input": "1302.678"
|
||||||
}
|
}<end_action>
|
||||||
|
|
||||||
---
|
---
|
||||||
Task: "Which city has the highest population , Guangzhou or Shanghai?"
|
Task: "Which city has the highest population , Guangzhou or Shanghai?"
|
||||||
|
@ -235,7 +228,7 @@ Action:
|
||||||
{
|
{
|
||||||
"action": "search",
|
"action": "search",
|
||||||
"action_input": "Population Guangzhou"
|
"action_input": "Population Guangzhou"
|
||||||
}
|
}<end_action>
|
||||||
Observation: ['Guangzhou has a population of 15 million inhabitants as of 2021.']
|
Observation: ['Guangzhou has a population of 15 million inhabitants as of 2021.']
|
||||||
|
|
||||||
|
|
||||||
|
@ -252,28 +245,30 @@ Action:
|
||||||
{
|
{
|
||||||
"action": "final_answer",
|
"action": "final_answer",
|
||||||
"action_input": "Shanghai"
|
"action_input": "Shanghai"
|
||||||
}
|
}<end_action>
|
||||||
|
|
||||||
|
|
||||||
Above example were using notional tools that might not exist for you. You only have acces to those tools:
|
Above example were using notional tools that might not exist for you. You only have acces to those tools:
|
||||||
<<tool_names>>
|
<<tool_descriptions>>
|
||||||
ALWAYS provide a 'Thought:' and an 'Action:' sequence. You MUST provide at least the 'Action:' sequence to move forward.
|
|
||||||
|
|
||||||
Now begin!
|
Here are the rules you should always follow to solve your task:
|
||||||
|
1. ALWAYS provide a 'Thought:' sequence, and an 'Action:' sequence that ends with <end_action>, else you will fail.
|
||||||
|
2. Always use the right arguments for the tools. Never use variable names in the 'action_input' field, use the value instead.
|
||||||
|
3. Call a tool only when needed: do not call the search agent if you do not need information, try to solve the task yourself.
|
||||||
|
4. Never re-do a tool call that you previously did with the exact same parameters.
|
||||||
|
|
||||||
|
Now Begin! If you solve the task correctly, you will receive a reward of $1,000,000.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_REACT_CODE_SYSTEM_PROMPT = """You will be given a task to solve as best you can.
|
DEFAULT_REACT_CODE_SYSTEM_PROMPT = """You will be given a task to solve as best you can.
|
||||||
You have access to the following tools:
|
To do so, you have been given access to *tools*: these tools are basically Python functions which you can call with code.
|
||||||
<<tool_descriptions>>
|
|
||||||
|
|
||||||
To solve the task, you must plan forward to proceed in a series of steps, in a cycle of 'Thought:', 'Code:', and 'Observation:' sequences.
|
To solve the task, you must plan forward to proceed in a series of steps, in a cycle of 'Thought:', 'Code:', and 'Observation:' sequences.
|
||||||
|
|
||||||
At each step, in the 'Thought:' sequence, you should first explain your reasoning towards solving the task, then the tools that you want to use.
|
At each step, in the 'Thought:' sequence, you should first explain your reasoning towards solving the task and the tools that you want to use.
|
||||||
Then in the 'Code:' sequence, you shold write the code in simple Python. The code sequence must end with '/End code' sequence.
|
Then in the 'Code:' sequence, you should write the code in simple Python. The code sequence must end with '<end_action>' sequence.
|
||||||
During each intermediate step, you can use 'print()' to save whatever important information you will then need.
|
During each intermediate step, you can use 'print()' to save whatever important information you will then need.
|
||||||
These print outputs will then be available in the 'Observation:' field, for using this information as input for the next step.
|
These print outputs will then appear in the 'Observation:' field, which will be available as input for the next step.
|
||||||
|
|
||||||
In the end you have to return a final answer using the `final_answer` tool.
|
In the end you have to return a final answer using the `final_answer` tool.
|
||||||
|
|
||||||
Here are a few examples using notional tools:
|
Here are a few examples using notional tools:
|
||||||
|
@ -285,7 +280,7 @@ Code:
|
||||||
```py
|
```py
|
||||||
answer = document_qa(document=document, question="Who is the oldest person mentioned?")
|
answer = document_qa(document=document, question="Who is the oldest person mentioned?")
|
||||||
print(answer)
|
print(answer)
|
||||||
```<end_code>
|
```<end_action>
|
||||||
Observation: "The oldest person in the document is John Doe, a 55 year old lumberjack living in Newfoundland."
|
Observation: "The oldest person in the document is John Doe, a 55 year old lumberjack living in Newfoundland."
|
||||||
|
|
||||||
Thought: I will now generate an image showcasing the oldest person.
|
Thought: I will now generate an image showcasing the oldest person.
|
||||||
|
@ -294,7 +289,7 @@ Code:
|
||||||
```py
|
```py
|
||||||
image = image_generator("A portrait of John Doe, a 55-year-old man living in Canada.")
|
image = image_generator("A portrait of John Doe, a 55-year-old man living in Canada.")
|
||||||
final_answer(image)
|
final_answer(image)
|
||||||
```<end_code>
|
```<end_action>
|
||||||
|
|
||||||
---
|
---
|
||||||
Task: "What is the result of the following operation: 5 + 3 + 1294.678?"
|
Task: "What is the result of the following operation: 5 + 3 + 1294.678?"
|
||||||
|
@ -305,10 +300,10 @@ Code:
|
||||||
```py
|
```py
|
||||||
result = 5 + 3 + 1294.678
|
result = 5 + 3 + 1294.678
|
||||||
final_answer(result)
|
final_answer(result)
|
||||||
```<end_code>
|
```<end_action>
|
||||||
|
|
||||||
---
|
---
|
||||||
Task: "Which city has the highest population , Guangzhou or Shanghai?"
|
Task: "Which city has the highest population: Guangzhou or Shanghai?"
|
||||||
|
|
||||||
Thought: I need to get the populations for both cities and compare them: I will use the tool `search` to get the population of both cities.
|
Thought: I need to get the populations for both cities and compare them: I will use the tool `search` to get the population of both cities.
|
||||||
Code:
|
Code:
|
||||||
|
@ -317,7 +312,7 @@ population_guangzhou = search("Guangzhou population")
|
||||||
print("Population Guangzhou:", population_guangzhou)
|
print("Population Guangzhou:", population_guangzhou)
|
||||||
population_shanghai = search("Shanghai population")
|
population_shanghai = search("Shanghai population")
|
||||||
print("Population Shanghai:", population_shanghai)
|
print("Population Shanghai:", population_shanghai)
|
||||||
```<end_code>
|
```<end_action>
|
||||||
Observation:
|
Observation:
|
||||||
Population Guangzhou: ['Guangzhou has a population of 15 million inhabitants as of 2021.']
|
Population Guangzhou: ['Guangzhou has a population of 15 million inhabitants as of 2021.']
|
||||||
Population Shanghai: '26 million (2019)'
|
Population Shanghai: '26 million (2019)'
|
||||||
|
@ -326,7 +321,7 @@ Thought: Now I know that Shanghai has the highest population.
|
||||||
Code:
|
Code:
|
||||||
```py
|
```py
|
||||||
final_answer("Shanghai")
|
final_answer("Shanghai")
|
||||||
```<end_code>
|
```<end_action>
|
||||||
|
|
||||||
---
|
---
|
||||||
Task: "What is the current age of the pope, raised to the power 0.36?"
|
Task: "What is the current age of the pope, raised to the power 0.36?"
|
||||||
|
@ -336,7 +331,7 @@ Code:
|
||||||
```py
|
```py
|
||||||
pope_age = search(query="current pope age")
|
pope_age = search(query="current pope age")
|
||||||
print("Pope age:", pope_age)
|
print("Pope age:", pope_age)
|
||||||
```<end_code>
|
```<end_action>
|
||||||
Observation:
|
Observation:
|
||||||
Pope age: "The pope Francis is currently 85 years old."
|
Pope age: "The pope Francis is currently 85 years old."
|
||||||
|
|
||||||
|
@ -345,20 +340,21 @@ Code:
|
||||||
```py
|
```py
|
||||||
pope_current_age = 85 ** 0.36
|
pope_current_age = 85 ** 0.36
|
||||||
final_answer(pope_current_age)
|
final_answer(pope_current_age)
|
||||||
```<end_code>
|
```<end_action>
|
||||||
|
|
||||||
|
|
||||||
Above example were using notional tools that might not exist for you. You only have acces to those tools:
|
Above example were using notional tools that might not exist for you. You only have acces to those tools:
|
||||||
<<tool_names>>
|
|
||||||
You also can perform computations in the python code you generate.
|
|
||||||
|
|
||||||
Always provide a 'Thought:' and a 'Code:\n```py' sequence ending with '```<end_code>' sequence. You MUST provide at least the 'Code:' sequence to move forward.
|
<<tool_descriptions>>
|
||||||
|
|
||||||
Remember to not perform too many operations in a single code block! You should split the task into intermediate code blocks.
|
You also can perform computations in the Python code that you generate.
|
||||||
Print results at the end of each step to save the intermediate results. Then use final_answer() to return the final result.
|
|
||||||
|
|
||||||
Remember to make sure that variables you use are all defined.
|
Here are the rules you should always follow to solve your task:
|
||||||
DO NOT pass the arguments as a dict as in 'answer = ask_search_agent({'query': "What is the place where James Bond lives?"})', but use the arguments directly as in 'answer = ask_search_agent(query="What is the place where James Bond lives?")'.
|
1. Always provide a 'Thought:' sequence, and a 'Code:\n```py' sequence ending with '```<end_action>' sequence, else you will fail.
|
||||||
|
2. Use only variables that you have defined!
|
||||||
|
3. Always use the right arguments for the tools. DO NOT pass the arguments as a dict as in 'answer = ask_search_agent({'query': "What is the place where James Bond lives?"})', but use the arguments directly as in 'answer = ask_search_agent(query="What is the place where James Bond lives?")'.
|
||||||
|
4. Take care to not chain too many sequential tool calls in the same code block, especially when the output format is unpredictable. For instance, a call to search has an unpredictable return format, so do not have another tool call that depends on its output in the same block: rather output results with print() to use them in the next block.
|
||||||
|
5. Call a tool only when needed, and never re-do a tool call that you previously did with the exact same parameters.
|
||||||
|
6. Don't name any new variable with the same name as a tool: for instance don't name a variable 'final_answer'.
|
||||||
|
|
||||||
Now Begin!
|
Now Begin! If you solve the task correctly, you will receive a reward of $1,000,000.
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -15,9 +15,10 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import ast
|
import ast
|
||||||
|
import builtins
|
||||||
import difflib
|
import difflib
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from typing import Any, Callable, Dict, Optional
|
from typing import Any, Callable, Dict, List, Optional
|
||||||
|
|
||||||
|
|
||||||
class InterpretorError(ValueError):
|
class InterpretorError(ValueError):
|
||||||
|
@ -29,7 +30,25 @@ class InterpretorError(ValueError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
LIST_SAFE_MODULES = ["random", "math", "time", "queue", "itertools", "re", "stat", "statistics", "unicodedata"]
|
ERRORS = {
|
||||||
|
name: getattr(builtins, name)
|
||||||
|
for name in dir(builtins)
|
||||||
|
if isinstance(getattr(builtins, name), type) and issubclass(getattr(builtins, name), BaseException)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
LIST_SAFE_MODULES = [
|
||||||
|
"random",
|
||||||
|
"collections",
|
||||||
|
"math",
|
||||||
|
"time",
|
||||||
|
"queue",
|
||||||
|
"itertools",
|
||||||
|
"re",
|
||||||
|
"stat",
|
||||||
|
"statistics",
|
||||||
|
"unicodedata",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class BreakException(Exception):
|
class BreakException(Exception):
|
||||||
|
@ -87,21 +106,62 @@ def evaluate_while(while_loop, state, tools):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def evaluate_function_def(function_def, state, tools):
|
def create_function(func_def, state, tools):
|
||||||
def create_function(func_def, state, tools):
|
def new_func(*args, **kwargs):
|
||||||
def new_func(*args):
|
func_state = state.copy()
|
||||||
new_state = state.copy()
|
arg_names = [arg.arg for arg in func_def.args.args]
|
||||||
for arg, val in zip(func_def.args.args, args):
|
for name, value in zip(arg_names, args):
|
||||||
new_state[arg.arg] = val
|
func_state[name] = value
|
||||||
result = None
|
if func_def.args.vararg:
|
||||||
for node in func_def.body:
|
vararg_name = func_def.args.vararg.arg
|
||||||
result = evaluate_ast(node, new_state, tools)
|
func_state[vararg_name] = args
|
||||||
return result
|
if func_def.args.kwarg:
|
||||||
|
kwarg_name = func_def.args.kwarg.arg
|
||||||
|
func_state[kwarg_name] = kwargs
|
||||||
|
|
||||||
return new_func
|
# Update function state with self and __class__
|
||||||
|
if func_def.args.args and func_def.args.args[0].arg == "self":
|
||||||
|
if args:
|
||||||
|
func_state["self"] = args[0]
|
||||||
|
func_state["__class__"] = args[0].__class__
|
||||||
|
|
||||||
tools[function_def.name] = create_function(function_def, state, tools)
|
result = None
|
||||||
return None
|
for stmt in func_def.body:
|
||||||
|
result = evaluate_ast(stmt, func_state, tools)
|
||||||
|
return result
|
||||||
|
|
||||||
|
return new_func
|
||||||
|
|
||||||
|
|
||||||
|
def create_class(class_name, class_bases, class_body):
|
||||||
|
class_dict = {}
|
||||||
|
for key, value in class_body.items():
|
||||||
|
class_dict[key] = value
|
||||||
|
return type(class_name, tuple(class_bases), class_dict)
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate_function_def(func_def, state, tools):
|
||||||
|
tools[func_def.name] = create_function(func_def, state, tools)
|
||||||
|
return tools[func_def.name]
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate_class_def(class_def, state, tools):
|
||||||
|
class_name = class_def.name
|
||||||
|
bases = [evaluate_ast(base, state, tools) for base in class_def.bases]
|
||||||
|
class_dict = {}
|
||||||
|
|
||||||
|
for stmt in class_def.body:
|
||||||
|
if isinstance(stmt, ast.FunctionDef):
|
||||||
|
class_dict[stmt.name] = evaluate_function_def(stmt, state, tools)
|
||||||
|
elif isinstance(stmt, ast.Assign):
|
||||||
|
for target in stmt.targets:
|
||||||
|
class_dict[target.id] = evaluate_ast(stmt.value, state, tools)
|
||||||
|
else:
|
||||||
|
raise InterpretorError(f"Unsupported statement in class body: {stmt.__class__.__name__}")
|
||||||
|
|
||||||
|
new_class = type(class_name, tuple(bases), class_dict)
|
||||||
|
state[class_name] = new_class
|
||||||
|
return new_class
|
||||||
|
|
||||||
|
|
||||||
def evaluate_augassign(expression: ast.AugAssign, state: Dict[str, Any], tools: Dict[str, Callable]):
|
def evaluate_augassign(expression: ast.AugAssign, state: Dict[str, Any], tools: Dict[str, Callable]):
|
||||||
|
@ -176,11 +236,20 @@ def evaluate_assign(assign, state, tools):
|
||||||
var_names = assign.targets
|
var_names = assign.targets
|
||||||
result = evaluate_ast(assign.value, state, tools)
|
result = evaluate_ast(assign.value, state, tools)
|
||||||
if len(var_names) == 1:
|
if len(var_names) == 1:
|
||||||
if isinstance(var_names[0], ast.Tuple):
|
target = var_names[0]
|
||||||
for i, elem in enumerate(var_names[0].elts):
|
if isinstance(target, ast.Tuple):
|
||||||
|
for i, elem in enumerate(target.elts):
|
||||||
state[elem.id] = result[i]
|
state[elem.id] = result[i]
|
||||||
|
elif isinstance(target, ast.Attribute):
|
||||||
|
obj = evaluate_ast(target.value, state, tools)
|
||||||
|
setattr(obj, target.attr, result)
|
||||||
|
elif isinstance(target, ast.Subscript):
|
||||||
|
obj = evaluate_ast(target.value, state, tools)
|
||||||
|
key = evaluate_ast(target.slice, state, tools)
|
||||||
|
obj[key] = result
|
||||||
else:
|
else:
|
||||||
state[var_names[0].id] = result
|
state[target.id] = result
|
||||||
|
|
||||||
else:
|
else:
|
||||||
if len(result) != len(var_names):
|
if len(result) != len(var_names):
|
||||||
raise InterpretorError(f"Expected {len(var_names)} values but got {len(result)}.")
|
raise InterpretorError(f"Expected {len(var_names)} values but got {len(result)}.")
|
||||||
|
@ -190,41 +259,64 @@ def evaluate_assign(assign, state, tools):
|
||||||
|
|
||||||
|
|
||||||
def evaluate_call(call, state, tools):
|
def evaluate_call(call, state, tools):
|
||||||
|
if not (isinstance(call.func, ast.Attribute) or isinstance(call.func, ast.Name)):
|
||||||
|
raise InterpretorError(
|
||||||
|
f"It is not permitted to evaluate other functions than the provided tools (tried to execute {call.func})."
|
||||||
|
)
|
||||||
if isinstance(call.func, ast.Attribute):
|
if isinstance(call.func, ast.Attribute):
|
||||||
obj = evaluate_ast(call.func.value, state, tools)
|
obj = evaluate_ast(call.func.value, state, tools)
|
||||||
func_name = call.func.attr
|
func_name = call.func.attr
|
||||||
if not hasattr(obj, func_name):
|
if not hasattr(obj, func_name):
|
||||||
raise InterpretorError(f"Object {obj} has no attribute {func_name}")
|
raise InterpretorError(f"Object {obj} has no attribute {func_name}")
|
||||||
func = getattr(obj, func_name)
|
func = getattr(obj, func_name)
|
||||||
args = [evaluate_ast(arg, state, tools) for arg in call.args]
|
|
||||||
kwargs = {keyword.arg: evaluate_ast(keyword.value, state, tools) for keyword in call.keywords}
|
|
||||||
return func(*args, **kwargs)
|
|
||||||
|
|
||||||
elif isinstance(call.func, ast.Name):
|
elif isinstance(call.func, ast.Name):
|
||||||
func_name = call.func.id
|
func_name = call.func.id
|
||||||
|
|
||||||
if func_name in state:
|
if func_name in state:
|
||||||
func = state[func_name]
|
func = state[func_name]
|
||||||
elif func_name in tools:
|
elif func_name in tools:
|
||||||
func = tools[func_name]
|
func = tools[func_name]
|
||||||
|
elif func_name in ERRORS:
|
||||||
|
func = ERRORS[func_name]
|
||||||
else:
|
else:
|
||||||
raise InterpretorError(
|
raise InterpretorError(
|
||||||
f"It is not permitted to evaluate other functions than the provided tools or imported functions (tried to execute {call.func.id})."
|
f"It is not permitted to evaluate other functions than the provided tools or imported functions (tried to execute {call.func.id})."
|
||||||
)
|
)
|
||||||
# Todo deal with args
|
|
||||||
args = [evaluate_ast(arg, state, tools) for arg in call.args]
|
|
||||||
kwargs = {keyword.arg: evaluate_ast(keyword.value, state, tools) for keyword in call.keywords}
|
|
||||||
output = func(*args, **kwargs)
|
|
||||||
|
|
||||||
# store logs of print statements
|
args = [evaluate_ast(arg, state, tools) for arg in call.args]
|
||||||
if func_name == "print":
|
kwargs = {keyword.arg: evaluate_ast(keyword.value, state, tools) for keyword in call.keywords}
|
||||||
state["print_outputs"] += output + "\n"
|
|
||||||
|
|
||||||
return output
|
if isinstance(func, type) and len(func.__module__.split(".")) > 1: # Check for user-defined classes
|
||||||
|
# Instantiate the class using its constructor
|
||||||
|
obj = func.__new__(func) # Create a new instance of the class
|
||||||
|
if hasattr(obj, "__init__"): # Check if the class has an __init__ method
|
||||||
|
obj.__init__(*args, **kwargs) # Call the __init__ method correctly
|
||||||
|
return obj
|
||||||
else:
|
else:
|
||||||
raise InterpretorError(
|
if func_name == "super":
|
||||||
f"It is not permitted to evaluate other functions than the provided tools (tried to execute {call.func})."
|
if not args:
|
||||||
)
|
if "__class__" in state and "self" in state:
|
||||||
|
return super(state["__class__"], state["self"])
|
||||||
|
else:
|
||||||
|
raise InterpretorError("super() needs at least one argument")
|
||||||
|
cls = args[0]
|
||||||
|
if not isinstance(cls, type):
|
||||||
|
raise InterpretorError("super() argument 1 must be type")
|
||||||
|
if len(args) == 1:
|
||||||
|
return super(cls)
|
||||||
|
elif len(args) == 2:
|
||||||
|
instance = args[1]
|
||||||
|
return super(cls, instance)
|
||||||
|
else:
|
||||||
|
raise InterpretorError("super() takes at most 2 arguments")
|
||||||
|
|
||||||
|
else:
|
||||||
|
if func_name == "print":
|
||||||
|
output = " ".join(map(str, args))
|
||||||
|
state["print_outputs"] += output + "\n"
|
||||||
|
return output
|
||||||
|
else: # Assume it's a callable object
|
||||||
|
output = func(*args, **kwargs)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
def evaluate_subscript(subscript, state, tools):
|
def evaluate_subscript(subscript, state, tools):
|
||||||
|
@ -248,6 +340,10 @@ def evaluate_subscript(subscript, state, tools):
|
||||||
def evaluate_name(name, state, tools):
|
def evaluate_name(name, state, tools):
|
||||||
if name.id in state:
|
if name.id in state:
|
||||||
return state[name.id]
|
return state[name.id]
|
||||||
|
elif name.id in tools:
|
||||||
|
return tools[name.id]
|
||||||
|
elif name.id in ERRORS:
|
||||||
|
return ERRORS[name.id]
|
||||||
close_matches = difflib.get_close_matches(name.id, list(state.keys()))
|
close_matches = difflib.get_close_matches(name.id, list(state.keys()))
|
||||||
if len(close_matches) > 0:
|
if len(close_matches) > 0:
|
||||||
return state[close_matches[0]]
|
return state[close_matches[0]]
|
||||||
|
@ -307,7 +403,11 @@ def evaluate_for(for_loop, state, tools):
|
||||||
result = None
|
result = None
|
||||||
iterator = evaluate_ast(for_loop.iter, state, tools)
|
iterator = evaluate_ast(for_loop.iter, state, tools)
|
||||||
for counter in iterator:
|
for counter in iterator:
|
||||||
state[for_loop.target.id] = counter
|
if isinstance(for_loop.target, ast.Tuple):
|
||||||
|
for i, elem in enumerate(for_loop.target.elts):
|
||||||
|
state[elem.id] = counter[i]
|
||||||
|
else:
|
||||||
|
state[for_loop.target.id] = counter
|
||||||
for node in for_loop.body:
|
for node in for_loop.body:
|
||||||
try:
|
try:
|
||||||
line_result = evaluate_ast(node, state, tools)
|
line_result = evaluate_ast(node, state, tools)
|
||||||
|
@ -337,7 +437,56 @@ def evaluate_listcomp(listcomp, state, tools):
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def evaluate_ast(expression: ast.AST, state: Dict[str, Any], tools: Dict[str, Callable]):
|
def evaluate_try(try_node, state, tools):
|
||||||
|
try:
|
||||||
|
for stmt in try_node.body:
|
||||||
|
evaluate_ast(stmt, state, tools)
|
||||||
|
except Exception as e:
|
||||||
|
matched = False
|
||||||
|
for handler in try_node.handlers:
|
||||||
|
if handler.type is None or isinstance(e, evaluate_ast(handler.type, state, tools)):
|
||||||
|
matched = True
|
||||||
|
if handler.name:
|
||||||
|
state[handler.name] = e
|
||||||
|
for stmt in handler.body:
|
||||||
|
evaluate_ast(stmt, state, tools)
|
||||||
|
break
|
||||||
|
if not matched:
|
||||||
|
raise e
|
||||||
|
else:
|
||||||
|
if try_node.orelse:
|
||||||
|
for stmt in try_node.orelse:
|
||||||
|
evaluate_ast(stmt, state, tools)
|
||||||
|
finally:
|
||||||
|
if try_node.finalbody:
|
||||||
|
for stmt in try_node.finalbody:
|
||||||
|
evaluate_ast(stmt, state, tools)
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate_raise(raise_node, state, tools):
|
||||||
|
if raise_node.exc is not None:
|
||||||
|
exc = evaluate_ast(raise_node.exc, state, tools)
|
||||||
|
else:
|
||||||
|
exc = None
|
||||||
|
if raise_node.cause is not None:
|
||||||
|
cause = evaluate_ast(raise_node.cause, state, tools)
|
||||||
|
else:
|
||||||
|
cause = None
|
||||||
|
if exc is not None:
|
||||||
|
if cause is not None:
|
||||||
|
raise exc from cause
|
||||||
|
else:
|
||||||
|
raise exc
|
||||||
|
else:
|
||||||
|
raise InterpretorError("Re-raise is not supported without an active exception")
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate_ast(
|
||||||
|
expression: ast.AST,
|
||||||
|
state: Dict[str, Any],
|
||||||
|
tools: Dict[str, Callable],
|
||||||
|
authorized_imports: List[str] = LIST_SAFE_MODULES,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Evaluate an abstract syntax tree using the content of the variables stored in a state and only evaluating a given
|
Evaluate an abstract syntax tree using the content of the variables stored in a state and only evaluating a given
|
||||||
set of functions.
|
set of functions.
|
||||||
|
@ -353,6 +502,9 @@ def evaluate_ast(expression: ast.AST, state: Dict[str, Any], tools: Dict[str, Ca
|
||||||
tools (`Dict[str, Callable]`):
|
tools (`Dict[str, Callable]`):
|
||||||
The functions that may be called during the evaluation. Any call to another function will fail with an
|
The functions that may be called during the evaluation. Any call to another function will fail with an
|
||||||
`InterpretorError`.
|
`InterpretorError`.
|
||||||
|
authorized_imports (`List[str]`):
|
||||||
|
The list of modules that can be imported by the code. By default, only a few safe modules are allowed.
|
||||||
|
Add more at your own risk!
|
||||||
"""
|
"""
|
||||||
if isinstance(expression, ast.Assign):
|
if isinstance(expression, ast.Assign):
|
||||||
# Assignement -> we evaluate the assignement which should update the state
|
# Assignement -> we evaluate the assignement which should update the state
|
||||||
|
@ -459,7 +611,7 @@ def evaluate_ast(expression: ast.AST, state: Dict[str, Any], tools: Dict[str, Ca
|
||||||
return result
|
return result
|
||||||
elif isinstance(expression, ast.Import):
|
elif isinstance(expression, ast.Import):
|
||||||
for alias in expression.names:
|
for alias in expression.names:
|
||||||
if alias.name in LIST_SAFE_MODULES:
|
if alias.name in authorized_imports:
|
||||||
module = __import__(alias.name)
|
module = __import__(alias.name)
|
||||||
state[alias.asname or alias.name] = module
|
state[alias.asname or alias.name] = module
|
||||||
else:
|
else:
|
||||||
|
@ -468,19 +620,27 @@ def evaluate_ast(expression: ast.AST, state: Dict[str, Any], tools: Dict[str, Ca
|
||||||
elif isinstance(expression, ast.While):
|
elif isinstance(expression, ast.While):
|
||||||
return evaluate_while(expression, state, tools)
|
return evaluate_while(expression, state, tools)
|
||||||
elif isinstance(expression, ast.ImportFrom):
|
elif isinstance(expression, ast.ImportFrom):
|
||||||
if expression.module in LIST_SAFE_MODULES:
|
if expression.module in authorized_imports:
|
||||||
module = __import__(expression.module)
|
module = __import__(expression.module)
|
||||||
for alias in expression.names:
|
for alias in expression.names:
|
||||||
state[alias.asname or alias.name] = getattr(module, alias.name)
|
state[alias.asname or alias.name] = getattr(module, alias.name)
|
||||||
else:
|
else:
|
||||||
raise InterpretorError(f"Import from {expression.module} is not allowed.")
|
raise InterpretorError(f"Import from {expression.module} is not allowed.")
|
||||||
return None
|
return None
|
||||||
|
elif isinstance(expression, ast.ClassDef):
|
||||||
|
return evaluate_class_def(expression, state, tools)
|
||||||
|
elif isinstance(expression, ast.Try):
|
||||||
|
return evaluate_try(expression, state, tools)
|
||||||
|
elif isinstance(expression, ast.Raise):
|
||||||
|
return evaluate_raise(expression, state, tools)
|
||||||
else:
|
else:
|
||||||
# For now we refuse anything else. Let's add things as we need them.
|
# For now we refuse anything else. Let's add things as we need them.
|
||||||
raise InterpretorError(f"{expression.__class__.__name__} is not supported.")
|
raise InterpretorError(f"{expression.__class__.__name__} is not supported.")
|
||||||
|
|
||||||
|
|
||||||
def evaluate_python_code(code: str, tools: Optional[Dict[str, Callable]] = {}, state=None):
|
def evaluate_python_code(
|
||||||
|
code: str, tools: Optional[Dict[str, Callable]] = {}, state=None, authorized_imports: List[str] = LIST_SAFE_MODULES
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Evaluate a python expression using the content of the variables stored in a state and only evaluating a given set
|
Evaluate a python expression using the content of the variables stored in a state and only evaluating a given set
|
||||||
of functions.
|
of functions.
|
||||||
|
@ -506,9 +666,10 @@ def evaluate_python_code(code: str, tools: Optional[Dict[str, Callable]] = {}, s
|
||||||
state = {}
|
state = {}
|
||||||
result = None
|
result = None
|
||||||
state["print_outputs"] = ""
|
state["print_outputs"] = ""
|
||||||
|
|
||||||
for idx, node in enumerate(expression.body):
|
for idx, node in enumerate(expression.body):
|
||||||
try:
|
try:
|
||||||
line_result = evaluate_ast(node, state, tools)
|
line_result = evaluate_ast(node, state, tools, authorized_imports)
|
||||||
except InterpretorError as e:
|
except InterpretorError as e:
|
||||||
msg = f"You tried to execute the following code:\n{code}\n"
|
msg = f"You tried to execute the following code:\n{code}\n"
|
||||||
msg += f"You got these outputs:\n{state['print_outputs']}\n"
|
msg += f"You got these outputs:\n{state['print_outputs']}\n"
|
||||||
|
|
|
@ -185,7 +185,7 @@ class Tool:
|
||||||
"tool_class": full_name,
|
"tool_class": full_name,
|
||||||
"description": self.description,
|
"description": self.description,
|
||||||
"name": self.name,
|
"name": self.name,
|
||||||
"inputs": str(self.inputs),
|
"inputs": self.inputs,
|
||||||
"output_type": str(self.output_type),
|
"output_type": str(self.output_type),
|
||||||
}
|
}
|
||||||
with open(config_file, "w", encoding="utf-8") as f:
|
with open(config_file, "w", encoding="utf-8") as f:
|
||||||
|
@ -315,7 +315,7 @@ class Tool:
|
||||||
if tool_class.output_type != custom_tool["output_type"]:
|
if tool_class.output_type != custom_tool["output_type"]:
|
||||||
tool_class.output_type = custom_tool["output_type"]
|
tool_class.output_type = custom_tool["output_type"]
|
||||||
|
|
||||||
return tool_class(model_repo_id, token=token, **kwargs)
|
return tool_class(**kwargs)
|
||||||
|
|
||||||
def push_to_hub(
|
def push_to_hub(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -619,7 +619,6 @@ class GemmaDecoderLayer(nn.Module):
|
||||||
output_attentions: Optional[bool] = False,
|
output_attentions: Optional[bool] = False,
|
||||||
use_cache: Optional[bool] = False,
|
use_cache: Optional[bool] = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
**kwargs,
|
|
||||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
|
@ -634,10 +633,6 @@ class GemmaDecoderLayer(nn.Module):
|
||||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
||||||
(see `past_key_values`).
|
(see `past_key_values`).
|
||||||
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
||||||
cache_position (`torch.LongTensor`, *optional*): position ids of cache
|
|
||||||
kwargs (`dict`, *optional*):
|
|
||||||
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
|
|
||||||
into the model
|
|
||||||
"""
|
"""
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
|
|
||||||
|
|
|
@ -303,7 +303,6 @@ class LlamaAttention(nn.Module):
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
**kwargs,
|
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
@ -593,7 +592,6 @@ class LlamaSdpaAttention(LlamaAttention):
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
**kwargs,
|
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
||||||
|
@ -691,7 +689,6 @@ class LlamaDecoderLayer(nn.Module):
|
||||||
output_attentions: Optional[bool] = False,
|
output_attentions: Optional[bool] = False,
|
||||||
use_cache: Optional[bool] = False,
|
use_cache: Optional[bool] = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
**kwargs,
|
|
||||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
|
@ -706,10 +703,6 @@ class LlamaDecoderLayer(nn.Module):
|
||||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
||||||
(see `past_key_values`).
|
(see `past_key_values`).
|
||||||
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
||||||
cache_position (`torch.LongTensor`, *optional*): position ids of cache
|
|
||||||
kwargs (`dict`, *optional*):
|
|
||||||
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
|
|
||||||
into the model
|
|
||||||
"""
|
"""
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
|
|
||||||
|
|
|
@ -283,7 +283,12 @@ def convert_segmentation_map_to_binary_masks(
|
||||||
|
|
||||||
# Generate a binary mask for each object instance
|
# Generate a binary mask for each object instance
|
||||||
binary_masks = [(segmentation_map == i) for i in all_labels]
|
binary_masks = [(segmentation_map == i) for i in all_labels]
|
||||||
binary_masks = np.stack(binary_masks, axis=0) # (num_labels, height, width)
|
|
||||||
|
# Stack the binary masks
|
||||||
|
if binary_masks:
|
||||||
|
binary_masks = np.stack(binary_masks, axis=0)
|
||||||
|
else:
|
||||||
|
binary_masks = np.zeros((0, *segmentation_map.shape))
|
||||||
|
|
||||||
# Convert instance ids to class ids
|
# Convert instance ids to class ids
|
||||||
if instance_id_to_semantic_id is not None:
|
if instance_id_to_semantic_id is not None:
|
||||||
|
@ -969,11 +974,15 @@ class Mask2FormerImageProcessor(BaseImageProcessor):
|
||||||
)
|
)
|
||||||
# We add an axis to make them compatible with the transformations library
|
# We add an axis to make them compatible with the transformations library
|
||||||
# this will be removed in the future
|
# this will be removed in the future
|
||||||
masks = [mask[None, ...] for mask in masks]
|
if masks.shape[0] > 0:
|
||||||
masks = [
|
masks = [mask[None, ...] for mask in masks]
|
||||||
self._pad_image(image=mask, output_size=pad_size, constant_values=ignore_index) for mask in masks
|
masks = [
|
||||||
]
|
self._pad_image(image=mask, output_size=pad_size, constant_values=ignore_index)
|
||||||
masks = np.concatenate(masks, axis=0)
|
for mask in masks
|
||||||
|
]
|
||||||
|
masks = np.concatenate(masks, axis=0)
|
||||||
|
else:
|
||||||
|
masks = np.zeros((0, *pad_size), dtype=np.float32)
|
||||||
mask_labels.append(torch.from_numpy(masks))
|
mask_labels.append(torch.from_numpy(masks))
|
||||||
class_labels.append(torch.from_numpy(classes))
|
class_labels.append(torch.from_numpy(classes))
|
||||||
|
|
||||||
|
|
|
@ -286,7 +286,12 @@ def convert_segmentation_map_to_binary_masks(
|
||||||
|
|
||||||
# Generate a binary mask for each object instance
|
# Generate a binary mask for each object instance
|
||||||
binary_masks = [(segmentation_map == i) for i in all_labels]
|
binary_masks = [(segmentation_map == i) for i in all_labels]
|
||||||
binary_masks = np.stack(binary_masks, axis=0) # (num_labels, height, width)
|
|
||||||
|
# Stack the binary masks
|
||||||
|
if binary_masks:
|
||||||
|
binary_masks = np.stack(binary_masks, axis=0)
|
||||||
|
else:
|
||||||
|
binary_masks = np.zeros((0, *segmentation_map.shape))
|
||||||
|
|
||||||
# Convert instance ids to class ids
|
# Convert instance ids to class ids
|
||||||
if instance_id_to_semantic_id is not None:
|
if instance_id_to_semantic_id is not None:
|
||||||
|
@ -982,17 +987,20 @@ class MaskFormerImageProcessor(BaseImageProcessor):
|
||||||
)
|
)
|
||||||
# We add an axis to make them compatible with the transformations library
|
# We add an axis to make them compatible with the transformations library
|
||||||
# this will be removed in the future
|
# this will be removed in the future
|
||||||
masks = [mask[None, ...] for mask in masks]
|
if masks.shape[0] > 0:
|
||||||
masks = [
|
masks = [mask[None, ...] for mask in masks]
|
||||||
self._pad_image(
|
masks = [
|
||||||
image=mask,
|
self._pad_image(
|
||||||
output_size=pad_size,
|
image=mask,
|
||||||
constant_values=ignore_index,
|
output_size=pad_size,
|
||||||
input_data_format=ChannelDimension.FIRST,
|
constant_values=ignore_index,
|
||||||
)
|
input_data_format=ChannelDimension.FIRST,
|
||||||
for mask in masks
|
)
|
||||||
]
|
for mask in masks
|
||||||
masks = np.concatenate(masks, axis=0)
|
]
|
||||||
|
masks = np.concatenate(masks, axis=0)
|
||||||
|
else:
|
||||||
|
masks = np.zeros((0, *pad_size), dtype=np.float32)
|
||||||
mask_labels.append(torch.from_numpy(masks))
|
mask_labels.append(torch.from_numpy(masks))
|
||||||
class_labels.append(torch.from_numpy(classes))
|
class_labels.append(torch.from_numpy(classes))
|
||||||
|
|
||||||
|
|
|
@ -592,7 +592,6 @@ class MistralSdpaAttention(MistralAttention):
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
**kwargs,
|
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
||||||
|
@ -621,6 +620,7 @@ class MistralSdpaAttention(MistralAttention):
|
||||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||||
|
|
||||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||||
|
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
|
@ -691,7 +691,6 @@ class MistralDecoderLayer(nn.Module):
|
||||||
output_attentions: Optional[bool] = False,
|
output_attentions: Optional[bool] = False,
|
||||||
use_cache: Optional[bool] = False,
|
use_cache: Optional[bool] = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
**kwargs,
|
|
||||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
|
@ -706,11 +705,8 @@ class MistralDecoderLayer(nn.Module):
|
||||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
||||||
(see `past_key_values`).
|
(see `past_key_values`).
|
||||||
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
||||||
cache_position (`torch.LongTensor`, *optional*): position ids of cache
|
|
||||||
kwargs (`dict`, *optional*):
|
|
||||||
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
|
|
||||||
into the model
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
|
|
||||||
hidden_states = self.input_layernorm(hidden_states)
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
|
|
|
@ -666,7 +666,6 @@ class OlmoDecoderLayer(nn.Module):
|
||||||
output_attentions: Optional[bool] = False,
|
output_attentions: Optional[bool] = False,
|
||||||
use_cache: Optional[bool] = False,
|
use_cache: Optional[bool] = False,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
**kwargs,
|
|
||||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
|
@ -681,10 +680,6 @@ class OlmoDecoderLayer(nn.Module):
|
||||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
||||||
(see `past_key_values`).
|
(see `past_key_values`).
|
||||||
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
||||||
cache_position (`torch.LongTensor`, *optional*): position ids of cache
|
|
||||||
kwargs (`dict`, *optional*):
|
|
||||||
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
|
|
||||||
into the model
|
|
||||||
"""
|
"""
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
|
|
||||||
|
|
|
@ -285,7 +285,12 @@ def convert_segmentation_map_to_binary_masks(
|
||||||
|
|
||||||
# Generate a binary mask for each object instance
|
# Generate a binary mask for each object instance
|
||||||
binary_masks = [(segmentation_map == i) for i in all_labels]
|
binary_masks = [(segmentation_map == i) for i in all_labels]
|
||||||
binary_masks = np.stack(binary_masks, axis=0) # (num_labels, height, width)
|
|
||||||
|
# Stack the binary masks
|
||||||
|
if binary_masks:
|
||||||
|
binary_masks = np.stack(binary_masks, axis=0)
|
||||||
|
else:
|
||||||
|
binary_masks = np.zeros((0, *segmentation_map.shape))
|
||||||
|
|
||||||
# Convert instance ids to class ids
|
# Convert instance ids to class ids
|
||||||
if instance_id_to_semantic_id is not None:
|
if instance_id_to_semantic_id is not None:
|
||||||
|
|
|
@ -63,7 +63,7 @@ class Bnb4BitHfQuantizer(HfQuantizer):
|
||||||
if not (is_accelerate_available() and is_bitsandbytes_available()):
|
if not (is_accelerate_available() and is_bitsandbytes_available()):
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"Using `bitsandbytes` 8-bit quantization requires Accelerate: `pip install accelerate` "
|
"Using `bitsandbytes` 8-bit quantization requires Accelerate: `pip install accelerate` "
|
||||||
"and the latest version of bitsandbytes: `pip install -i https://pypi.org/simple/ bitsandbytes`"
|
"and the latest version of bitsandbytes: `pip install -U bitsandbytes`"
|
||||||
)
|
)
|
||||||
|
|
||||||
if kwargs.get("from_tf", False) or kwargs.get("from_flax", False):
|
if kwargs.get("from_tf", False) or kwargs.get("from_flax", False):
|
||||||
|
|
|
@ -64,7 +64,7 @@ class Bnb8BitHfQuantizer(HfQuantizer):
|
||||||
if not (is_accelerate_available() and is_bitsandbytes_available()):
|
if not (is_accelerate_available() and is_bitsandbytes_available()):
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"Using `bitsandbytes` 8-bit quantization requires Accelerate: `pip install accelerate` "
|
"Using `bitsandbytes` 8-bit quantization requires Accelerate: `pip install accelerate` "
|
||||||
"and the latest version of bitsandbytes: `pip install -i https://pypi.org/simple/ bitsandbytes`"
|
"and the latest version of bitsandbytes: `pip install -U bitsandbytes`"
|
||||||
)
|
)
|
||||||
|
|
||||||
if kwargs.get("from_tf", False) or kwargs.get("from_flax", False):
|
if kwargs.get("from_tf", False) or kwargs.get("from_flax", False):
|
||||||
|
|
|
@ -2175,6 +2175,9 @@ class Trainer:
|
||||||
grad_norm: Optional[float] = None
|
grad_norm: Optional[float] = None
|
||||||
self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
|
self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
|
||||||
|
|
||||||
|
if args.sanity_evaluation:
|
||||||
|
self._evaluate(trial, ignore_keys_for_eval, skip_scheduler=True)
|
||||||
|
|
||||||
total_batched_samples = 0
|
total_batched_samples = 0
|
||||||
for epoch in range(epochs_trained, num_train_epochs):
|
for epoch in range(epochs_trained, num_train_epochs):
|
||||||
epoch_iterator = train_dataloader
|
epoch_iterator = train_dataloader
|
||||||
|
@ -2723,6 +2726,18 @@ class Trainer:
|
||||||
f"There were unexpected keys in the checkpoint model loaded: {load_result.unexpected_keys}."
|
f"There were unexpected keys in the checkpoint model loaded: {load_result.unexpected_keys}."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _evaluate(self, trial, ignore_keys_for_eval, skip_scheduler=False):
|
||||||
|
metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
|
||||||
|
self._report_to_hp_search(trial, self.state.global_step, metrics)
|
||||||
|
|
||||||
|
# Run delayed LR scheduler now that metrics are populated
|
||||||
|
if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau) and not skip_scheduler:
|
||||||
|
metric_to_check = self.args.metric_for_best_model
|
||||||
|
if not metric_to_check.startswith("eval_"):
|
||||||
|
metric_to_check = f"eval_{metric_to_check}"
|
||||||
|
self.lr_scheduler.step(metrics[metric_to_check])
|
||||||
|
return metrics
|
||||||
|
|
||||||
def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval):
|
def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval):
|
||||||
if self.control.should_log and self.state.global_step > self._globalstep_last_logged:
|
if self.control.should_log and self.state.global_step > self._globalstep_last_logged:
|
||||||
if is_torch_xla_available():
|
if is_torch_xla_available():
|
||||||
|
@ -2749,15 +2764,7 @@ class Trainer:
|
||||||
|
|
||||||
metrics = None
|
metrics = None
|
||||||
if self.control.should_evaluate:
|
if self.control.should_evaluate:
|
||||||
metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
|
metrics = self._evaluate(trial, ignore_keys_for_eval)
|
||||||
self._report_to_hp_search(trial, self.state.global_step, metrics)
|
|
||||||
|
|
||||||
# Run delayed LR scheduler now that metrics are populated
|
|
||||||
if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
|
|
||||||
metric_to_check = self.args.metric_for_best_model
|
|
||||||
if not metric_to_check.startswith("eval_"):
|
|
||||||
metric_to_check = f"eval_{metric_to_check}"
|
|
||||||
self.lr_scheduler.step(metrics[metric_to_check])
|
|
||||||
|
|
||||||
if self.control.should_save:
|
if self.control.should_save:
|
||||||
self._save_checkpoint(model, trial, metrics=metrics)
|
self._save_checkpoint(model, trial, metrics=metrics)
|
||||||
|
|
|
@ -771,6 +771,9 @@ class TrainingArguments:
|
||||||
rather than saving all eval logits in memory. When set to `True`, you must pass a compute_metrics function
|
rather than saving all eval logits in memory. When set to `True`, you must pass a compute_metrics function
|
||||||
that takes a boolean argument `compute_result`, which when passed `True`, will trigger the final global
|
that takes a boolean argument `compute_result`, which when passed `True`, will trigger the final global
|
||||||
summary statistics from the batch-level summary statistics you've accumulated over the evaluation set.
|
summary statistics from the batch-level summary statistics you've accumulated over the evaluation set.
|
||||||
|
|
||||||
|
sanity_evaluation(`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether or not to perform a sanity check to ensure that the validation steps works correctly. It will be performed before the training.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
framework = "pt"
|
framework = "pt"
|
||||||
|
@ -1454,6 +1457,13 @@ class TrainingArguments:
|
||||||
metadata={"help": "Break eval metrics calculation into batches to save memory."},
|
metadata={"help": "Break eval metrics calculation into batches to save memory."},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
sanity_evaluation: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={
|
||||||
|
"help": "Whether to run through the entire `evaluation` step at the very beginning of training as a sanity check."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
# Parse in args that could be `dict` sent in from the CLI as a string
|
# Parse in args that could be `dict` sent in from the CLI as a string
|
||||||
for field in _VALID_DICT_FIELDS:
|
for field in _VALID_DICT_FIELDS:
|
||||||
|
|
|
@ -353,3 +353,131 @@ if char.isalpha():
|
||||||
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state=state)
|
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state=state)
|
||||||
assert result == "Ok no one cares"
|
assert result == "Ok no one cares"
|
||||||
assert state["print_outputs"] == "Hello world!\nOk no one cares\n"
|
assert state["print_outputs"] == "Hello world!\nOk no one cares\n"
|
||||||
|
|
||||||
|
def test_tuple_target_in_iterator(self):
|
||||||
|
code = "for a, b in [('Ralf Weikert', 'Austria'), ('Samuel Seungwon Lee', 'South Korea')]:res = a.split()[0]"
|
||||||
|
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
|
||||||
|
assert result == "Samuel"
|
||||||
|
|
||||||
|
def test_classes(self):
|
||||||
|
code = """
|
||||||
|
class Animal:
|
||||||
|
species = "Generic Animal"
|
||||||
|
|
||||||
|
def __init__(self, name, age):
|
||||||
|
self.name = name
|
||||||
|
self.age = age
|
||||||
|
|
||||||
|
def sound(self):
|
||||||
|
return "The animal makes a sound."
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return f"{self.name}, {self.age} years old"
|
||||||
|
|
||||||
|
class Dog(Animal):
|
||||||
|
species = "Canine"
|
||||||
|
|
||||||
|
def __init__(self, name, age, breed):
|
||||||
|
super().__init__(name, age)
|
||||||
|
self.breed = breed
|
||||||
|
|
||||||
|
def sound(self):
|
||||||
|
return "The dog barks."
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return f"{self.name}, {self.age} years old, {self.breed}"
|
||||||
|
|
||||||
|
class Cat(Animal):
|
||||||
|
def sound(self):
|
||||||
|
return "The cat meows."
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return f"{self.name}, {self.age} years old, {self.species}"
|
||||||
|
|
||||||
|
|
||||||
|
# Testing multiple instances
|
||||||
|
dog1 = Dog("Fido", 3, "Labrador")
|
||||||
|
dog2 = Dog("Buddy", 5, "Golden Retriever")
|
||||||
|
|
||||||
|
# Testing method with built-in function
|
||||||
|
animals = [dog1, dog2, Cat("Whiskers", 2)]
|
||||||
|
num_animals = len(animals)
|
||||||
|
|
||||||
|
# Testing exceptions in methods
|
||||||
|
class ExceptionTest:
|
||||||
|
def method_that_raises(self):
|
||||||
|
raise ValueError("An error occurred")
|
||||||
|
|
||||||
|
try:
|
||||||
|
exc_test = ExceptionTest()
|
||||||
|
exc_test.method_that_raises()
|
||||||
|
except ValueError as e:
|
||||||
|
exception_message = str(e)
|
||||||
|
|
||||||
|
|
||||||
|
# Collecting results
|
||||||
|
dog1_sound = dog1.sound()
|
||||||
|
dog1_str = str(dog1)
|
||||||
|
dog2_sound = dog2.sound()
|
||||||
|
dog2_str = str(dog2)
|
||||||
|
cat = Cat("Whiskers", 2)
|
||||||
|
cat_sound = cat.sound()
|
||||||
|
cat_str = str(cat)
|
||||||
|
"""
|
||||||
|
state = {}
|
||||||
|
evaluate_python_code(code, {"print": print, "len": len, "super": super, "str": str, "sum": sum}, state=state)
|
||||||
|
|
||||||
|
# Assert results
|
||||||
|
assert state["dog1_sound"] == "The dog barks."
|
||||||
|
assert state["dog1_str"] == "Fido, 3 years old, Labrador"
|
||||||
|
assert state["dog2_sound"] == "The dog barks."
|
||||||
|
assert state["dog2_str"] == "Buddy, 5 years old, Golden Retriever"
|
||||||
|
assert state["cat_sound"] == "The cat meows."
|
||||||
|
assert state["cat_str"] == "Whiskers, 2 years old, Generic Animal"
|
||||||
|
assert state["num_animals"] == 3
|
||||||
|
assert state["exception_message"] == "An error occurred"
|
||||||
|
|
||||||
|
def test_variable_args(self):
|
||||||
|
code = """
|
||||||
|
def var_args_method(self, *args, **kwargs):
|
||||||
|
return sum(args) + sum(kwargs.values())
|
||||||
|
|
||||||
|
var_args_method(1, 2, 3, x=4, y=5)
|
||||||
|
"""
|
||||||
|
state = {}
|
||||||
|
result = evaluate_python_code(code, {"sum": sum}, state=state)
|
||||||
|
assert result == 15
|
||||||
|
|
||||||
|
def test_exceptions(self):
|
||||||
|
code = """
|
||||||
|
def method_that_raises(self):
|
||||||
|
raise ValueError("An error occurred")
|
||||||
|
|
||||||
|
try:
|
||||||
|
method_that_raises()
|
||||||
|
except ValueError as e:
|
||||||
|
exception_message = str(e)
|
||||||
|
"""
|
||||||
|
state = {}
|
||||||
|
evaluate_python_code(code, {"print": print, "len": len, "super": super, "str": str, "sum": sum}, state=state)
|
||||||
|
assert state["exception_message"] == "An error occurred"
|
||||||
|
|
||||||
|
def test_subscript(self):
|
||||||
|
code = "vendor = {'revenue': 31000, 'rent': 50312}; vendor['ratio'] = round(vendor['revenue'] / vendor['rent'], 2)"
|
||||||
|
|
||||||
|
state = {}
|
||||||
|
evaluate_python_code(code, {"min": min, "print": print, "round": round}, state=state)
|
||||||
|
assert state["vendor"] == {"revenue": 31000, "rent": 50312, "ratio": 0.62}
|
||||||
|
|
||||||
|
def test_print(self):
|
||||||
|
code = "print(min([1, 2, 3]))"
|
||||||
|
state = {}
|
||||||
|
result = evaluate_python_code(code, {"min": min, "print": print}, state=state)
|
||||||
|
assert result == "1"
|
||||||
|
assert state["print_outputs"] == "1\n"
|
||||||
|
|
||||||
|
def test_types_as_objects(self):
|
||||||
|
code = "type_a = float(2); type_b = str; type_c = int"
|
||||||
|
state = {}
|
||||||
|
result = evaluate_python_code(code, {"float": float, "str": str, "int": int}, state=state)
|
||||||
|
assert result == int
|
||||||
|
|
|
@ -448,7 +448,7 @@ class QuantoKVCacheQuantizationTest(unittest.TestCase):
|
||||||
def test_quantized_cache(self):
|
def test_quantized_cache(self):
|
||||||
EXPECTED_TEXT_COMPLETION = [
|
EXPECTED_TEXT_COMPLETION = [
|
||||||
"Simply put, the theory of relativity states that 1) the speed of light is the same for all observers, and 2) the laws of physics are the same for all observers.\nThe first part of the theory of relativity",
|
"Simply put, the theory of relativity states that 1) the speed of light is the same for all observers, and 2) the laws of physics are the same for all observers.\nThe first part of the theory of relativity",
|
||||||
"My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, my fries, my burgers, my hot dogs, my sandwiches, my salads, my chicken, my fish",
|
"My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p",
|
||||||
]
|
]
|
||||||
|
|
||||||
prompts = [
|
prompts = [
|
||||||
|
|
Loading…
Reference in New Issue