Compare commits
57 Commits
main
...
new_chat_t
Author | SHA1 | Date |
---|---|---|
Matt | 4918868afb | |
Matt | dd54280d95 | |
Matt | 656813133f | |
Matt | 82d90244dd | |
Matt | c2404ed744 | |
Matt | fbd7af80aa | |
Matt | 2145e9bba1 | |
Matt | 4550ad25b6 | |
Matt | 11ac07e545 | |
Matt | deae402554 | |
Matt | 74144d98fc | |
Matt | f80258ac0c | |
Matt | a8cc7d6ab2 | |
Matt | c074c61a1e | |
Matt | 1fa68da868 | |
Matt | 491cad643a | |
Matt | 4f3eafe978 | |
Matt | 54fde17c7e | |
Matt | e07cd5e42b | |
Matt | 587558d109 | |
Matt | 1a002454a1 | |
Matt | 2f18a47c7f | |
Matt | 6177c16df5 | |
Matt | 8164ae070e | |
Matt | e46d5842af | |
Matt | bfbd88f5d3 | |
Matt | fbbf7b077d | |
Matt | 21442d09c5 | |
Matt | 5437ac84b2 | |
Matt | 22111b6df9 | |
Matt | 1419bf4b6e | |
Matt | 2f1055d72a | |
Matt | dbba59f1e9 | |
Matt | 6148213a93 | |
Matt | 7fd0da44c6 | |
Matt | 4912b0df77 | |
Matt | 1225af2f41 | |
Matt | 0e50045cf3 | |
Matt | 9faf6f4e56 | |
Matt | d17441f811 | |
Matt | 60775f19cd | |
Matt | 079d6d5bf4 | |
Matt | 44560a7280 | |
Matt | 156eed9024 | |
Matt | ee288915d7 | |
Matt | d266a99b45 | |
Matt | dae3812bee | |
Matt | 89c0a49ba5 | |
Matt | 51359bf11e | |
Matt | 5019aa224b | |
Matt | 06f9a3437b | |
Matt | 43a481affa | |
Matt | b875a5d85f | |
Matt | 8c2a9ae41d | |
Matt | 096061323c | |
Matt | 0340cfa8fd | |
Matt | 072bd22e1c |
|
@ -233,6 +233,180 @@ The sun.</s>
|
|||
|
||||
From here, just continue training like you would with a standard language modelling task, using the `formatted_chat` column.
|
||||
|
||||
## Advanced: Extra inputs to chat templates
|
||||
|
||||
The only argument that `apply_chat_template` requires is `messages`. However, you can pass any keyword
|
||||
argument to `apply_chat_template` and it will be accessible inside the template. This gives you a lot of freedom to use
|
||||
chat templates for many things. There are no restrictions on the names or the format of these arguments - you can pass
|
||||
strings, lists, dicts or whatever else you want.
|
||||
|
||||
That said, there are some common use-cases for these extra arguments,
|
||||
such as passing tools for function calling, or documents for retrieval-augmented generation. In these common cases,
|
||||
we have some opinionated recommendations about what the names and formats of these arguments should be.
|
||||
|
||||
### Tool use / function calling
|
||||
|
||||
"Tool use" LLMs can choose to call functions as external tools before generating an answer. When passing tools
|
||||
to a tool-use model, you can simply pass a list of functions to the `tools` argument:
|
||||
|
||||
```python
|
||||
import datetime
|
||||
|
||||
def current_time():
|
||||
"""Get the current local time as a string."""
|
||||
return str(datetime.now())
|
||||
|
||||
def multiply(a: float, b: float):
|
||||
"""
|
||||
A function that multiplies two numbers
|
||||
|
||||
Args:
|
||||
a: The first number to multiply
|
||||
b: The second number to multiply
|
||||
"""
|
||||
return a * b
|
||||
|
||||
tools = [current_time, multiply]
|
||||
|
||||
model_input = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tools=tools
|
||||
)
|
||||
```
|
||||
|
||||
In order for this to work correctly, you should write your functions in the format above, so that they can be parsed
|
||||
correctly as tools. Specifically, you should follow these rules:
|
||||
|
||||
- The function should have a descriptive name
|
||||
- Every argument must have a type hint
|
||||
- The function must have a docstring in the standard Google style (in other words, an initial function description
|
||||
followed by an `Args:` block that describes the arguments, unless the function does not have any arguments.
|
||||
Do not include types in the `Args:` block - put them in the type hints in the function header instead.
|
||||
- The function can have a return type and a `Returns:` block in the docstring. However, these are optional
|
||||
because most tool-use models ignore them.
|
||||
|
||||
### Understanding tool schemas
|
||||
|
||||
Each function you pass to the `tools` argument of `apply_chat_template` is converted into a
|
||||
[JSON schema](https://json-schema.org/learn/getting-started-step-by-step. These schemas
|
||||
are then passed to the model chat template. In other words, tool-use models do not see your functions directly, and they
|
||||
never see the actual code inside them. What they care about is the function **definitions** and the **arguments** they
|
||||
need to pass to them - they care about what the tools do and how to use them, not how they work! It is up to you
|
||||
to read their outputs, detect if they have requested to use a tool, pass their arguments to the tool function, and
|
||||
return the response in the chat.
|
||||
|
||||
Generating JSON schemas to pass to the template should be automatic and invisible as long as your functions
|
||||
follow the specification above, but if you encounter problems, or you simply want more control over the conversion,
|
||||
you can handle the conversion manually. Here is an example of a manual schema conversion.
|
||||
|
||||
```python
|
||||
from transformers.utils import get_json_schema
|
||||
|
||||
def multiply(a: float, b: float):
|
||||
"""
|
||||
A function that multiplies two numbers
|
||||
|
||||
Args:
|
||||
a: The first number to multiply
|
||||
b: The second number to multiply
|
||||
"""
|
||||
return a * b
|
||||
|
||||
schema = get_json_schema(multiply)
|
||||
print(schema)
|
||||
```
|
||||
|
||||
This will yield:
|
||||
|
||||
```json
|
||||
{
|
||||
"name": "multiply",
|
||||
"description": "Multiply two numbers together.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"a": {
|
||||
"type": "number",
|
||||
"description": "The first number to multiply."
|
||||
},
|
||||
"b": {
|
||||
"type": "number",
|
||||
"description": "The second number to multiply."
|
||||
}
|
||||
},
|
||||
"required": ["a", "b"]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
If you wish, you can edit these schemas, or even write them from scratch yourself without using `get_json_schema` at
|
||||
all. JSON schemas can be passed directly to the `tools` argument of
|
||||
`apply_chat_template` - this gives you a lot of power to define precise schemas for more complex functions. Be careful,
|
||||
though - the more complex your schemas, the more likely the model is to get confused when dealing with them! We
|
||||
recommend simple function signatures where possible, keeping arguments (and especially complex, nested arguments)
|
||||
to a minimum.
|
||||
|
||||
Here is an example of defining schemas by hand, and passing them directly to `apply_chat_template`:
|
||||
|
||||
```python
|
||||
# A simple function that takes no arguments
|
||||
current_time = {
|
||||
"name": "current_time",
|
||||
"description": "Get the current local time as a string.",
|
||||
"parameters": {
|
||||
'type': 'object',
|
||||
'properties': {}
|
||||
},
|
||||
}
|
||||
|
||||
# A more complete function that takes two numerical arguments
|
||||
multiply = {
|
||||
"name": "multiply",
|
||||
"description": "Multiply two numbers together.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"a": {"type": "number", "description": "The first number to multiply."},
|
||||
"b": {"type": "number", "description": "The second number to multiply."},
|
||||
},
|
||||
"required": ["a", "b"],
|
||||
}
|
||||
}
|
||||
|
||||
model_input = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tools = [current_time, multiply]
|
||||
)
|
||||
```
|
||||
|
||||
### Retrieval-augmented generation
|
||||
|
||||
"Retrieval-augmented generation" or "RAG" LLMs can search a corpus of documents for information before responding
|
||||
to a query. This allows models to vastly expand their knowledge base beyond their limited context size. Our
|
||||
recommendation for RAG models is that their template
|
||||
should accept a `documents` argument. This should be a list of documents, where each "document"
|
||||
is a single dict with `title` and `contents` keys, both of which are strings. Because this format is much simpler
|
||||
than the JSON schemas used for tools, no helper functions are necessary.
|
||||
|
||||
Here's an example of a RAG template in action:
|
||||
|
||||
```python
|
||||
document1 = {
|
||||
"title": "The Moon: Our Age-Old Foe",
|
||||
"contents": "Man has always dreamed of destroying the moon. In this essay, I shall..."
|
||||
}
|
||||
|
||||
document2 = {
|
||||
"title": "The Sun: Our Age-Old Friend",
|
||||
"contents": "Although often underappreciated, the sun provides several notable benefits..."
|
||||
}
|
||||
|
||||
model_input = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
documents=[document1, document2]
|
||||
)
|
||||
```
|
||||
|
||||
## Advanced: How do chat templates work?
|
||||
|
||||
The chat template for a model is stored on the `tokenizer.chat_template` attribute. If no chat template is set, the
|
||||
|
|
|
@ -28,6 +28,7 @@ from collections.abc import Mapping, Sized
|
|||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from functools import lru_cache
|
||||
from inspect import isfunction
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
@ -47,6 +48,7 @@ from .utils import (
|
|||
copy_func,
|
||||
download_url,
|
||||
extract_commit_hash,
|
||||
get_json_schema,
|
||||
is_flax_available,
|
||||
is_jax_tensor,
|
||||
is_mlx_available,
|
||||
|
@ -1685,6 +1687,8 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
|||
def apply_chat_template(
|
||||
self,
|
||||
conversation: Union[List[Dict[str, str]], List[List[Dict[str, str]]], "Conversation"],
|
||||
tools: Optional[List[Dict]] = None,
|
||||
documents: Optional[List[Dict[str, str]]] = None,
|
||||
chat_template: Optional[str] = None,
|
||||
add_generation_prompt: bool = False,
|
||||
tokenize: bool = True,
|
||||
|
@ -1705,8 +1709,21 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
|||
Args:
|
||||
conversation (Union[List[Dict[str, str]], List[List[Dict[str, str]]], "Conversation"]): A list of dicts
|
||||
with "role" and "content" keys, representing the chat history so far.
|
||||
chat_template (str, *optional*): A Jinja template to use for this conversion. If
|
||||
this is not passed, the model's default chat template will be used instead.
|
||||
tools (`List[Dict]`, *optional*):
|
||||
A list of tools (callable functions) that will be accessible to the model. If the template does not
|
||||
support function calling, this argument will have no effect. Each tool should be passed as a JSON Schema,
|
||||
giving the name, description and argument types for the tool. See our
|
||||
[chat templating guide](https://huggingface.co/docs/transformers/main/en/chat_templating#automated-function-conversion-for-tool-use)
|
||||
for more information.
|
||||
documents (`List[Dict[str, str]]`, *optional*):
|
||||
A list of dicts representing documents that will be accessible to the model if it is performing RAG
|
||||
(retrieval-augmented generation). If the template does not support RAG, this argument will have no
|
||||
effect. We recommend that each document should be a dict containing "title" and "text" keys. Please
|
||||
see the RAG section of the [chat templating guide](https://huggingface.co/docs/transformers/main/en/chat_templating#arguments-for-RAG)
|
||||
for examples of passing documents with chat templates.
|
||||
chat_template (`str`, *optional*):
|
||||
A Jinja template to use for this conversion. It is usually not necessary to pass anything to this
|
||||
argument, as the model's template will be used by default.
|
||||
add_generation_prompt (bool, *optional*): Whether to end the prompt with the token(s) that indicate
|
||||
the start of an assistant message. This is useful when you want to generate a response from the model.
|
||||
Note that this argument will be passed to the chat template, and so it must be supported in the
|
||||
|
@ -1804,6 +1821,27 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
|||
conversations = [conversation]
|
||||
is_batched = False
|
||||
|
||||
# We accept either JSON schemas or functions for tools. If we get functions, we convert them to schemas
|
||||
if tools is not None:
|
||||
tool_schemas = []
|
||||
for tool in tools:
|
||||
if isinstance(tool, dict):
|
||||
tool_schemas.append(tool)
|
||||
elif isfunction(tool):
|
||||
tool_schemas.append(get_json_schema(tool))
|
||||
else:
|
||||
raise ValueError(
|
||||
"Tools should either be a JSON schema, or a callable function with type hints "
|
||||
"and a docstring suitable for auto-conversion to a schema."
|
||||
)
|
||||
else:
|
||||
tool_schemas = None
|
||||
|
||||
if documents is not None:
|
||||
for document in documents:
|
||||
if not isinstance(document, dict):
|
||||
raise TypeError("Documents should be a list of dicts with 'title' and 'text' keys!")
|
||||
|
||||
rendered = []
|
||||
template_kwargs = {**self.special_tokens_map, **kwargs} # kwargs overwrite special tokens if both are present
|
||||
for chat in conversations:
|
||||
|
@ -1811,7 +1849,11 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
|||
# Indicates it's a Conversation object
|
||||
chat = chat.messages
|
||||
rendered_chat = compiled_template.render(
|
||||
messages=chat, add_generation_prompt=add_generation_prompt, **template_kwargs
|
||||
messages=chat,
|
||||
tools=tool_schemas,
|
||||
documents=documents,
|
||||
add_generation_prompt=add_generation_prompt,
|
||||
**template_kwargs,
|
||||
)
|
||||
rendered.append(rendered_chat)
|
||||
|
||||
|
|
|
@ -21,6 +21,7 @@ from packaging import version
|
|||
|
||||
from .. import __version__
|
||||
from .backbone_utils import BackboneConfigMixin, BackboneMixin
|
||||
from .chat_template_utils import get_json_schema
|
||||
from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD
|
||||
from .doc import (
|
||||
add_code_sample_docstrings,
|
||||
|
|
|
@ -0,0 +1,309 @@
|
|||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import json
|
||||
import re
|
||||
from typing import Any, Union, get_args, get_origin, get_type_hints
|
||||
|
||||
|
||||
BASIC_TYPES = (int, float, str, bool, Any, type(None), ...)
|
||||
description_re = re.compile(r"^(.*?)[\n\s]*(Args:|Returns:|Raises:|\Z)", re.DOTALL)
|
||||
args_re = re.compile(r"\n\s*Args:\n\s*(.*?)[\n\s]*(Returns:|Raises:|\Z)", re.DOTALL)
|
||||
args_split_re = re.compile(
|
||||
r"""
|
||||
(?:^|\n) # Match the start of the args block, or a newline
|
||||
\s*(\w+):\s* # Capture the argument name and strip spacing
|
||||
(.*?)\s* # Capture the argument description, which can span multiple lines, and strip trailing spacing
|
||||
(?=\n\s*\w+:|\Z) # Stop when you hit the next argument or the end of the block
|
||||
""",
|
||||
re.DOTALL | re.VERBOSE,
|
||||
)
|
||||
returns_re = re.compile(r"\n\s*Returns:\n\s*(.*?)[\n\s]*(Raises:|\Z)", re.DOTALL)
|
||||
|
||||
|
||||
def _get_json_schema_type(param_type):
|
||||
type_mapping = {
|
||||
int: {"type": "integer"},
|
||||
float: {"type": "number"},
|
||||
str: {"type": "string"},
|
||||
bool: {"type": "boolean"},
|
||||
Any: {},
|
||||
}
|
||||
return type_mapping.get(param_type, {"type": "object"})
|
||||
|
||||
|
||||
def _parse_type_hint(hint):
|
||||
origin = get_origin(hint)
|
||||
args = get_args(hint)
|
||||
|
||||
if origin is None:
|
||||
return _get_json_schema_type(hint)
|
||||
|
||||
if origin is Union:
|
||||
# If it's a union of basic types, we can express that as a simple list in the schema
|
||||
if all(t in BASIC_TYPES for t in args):
|
||||
return_dict = {"type": [_get_json_schema_type(t)["type"] for t in args if t not in (type(None), ...)]}
|
||||
if len(return_dict["type"]) == 1:
|
||||
return_dict["type"] = return_dict["type"][0]
|
||||
else:
|
||||
# A union of more complex types requires us to recurse into each subtype
|
||||
return_dict = {
|
||||
"anyOf": [_parse_type_hint(t) for t in args if t not in (type(None), ...)],
|
||||
}
|
||||
if len(return_dict["anyOf"]) == 1:
|
||||
return_dict = return_dict["anyOf"][0]
|
||||
if type(None) in args:
|
||||
return_dict["nullable"] = True
|
||||
return return_dict
|
||||
|
||||
if origin is list:
|
||||
if not args:
|
||||
return {"type": "array"}
|
||||
|
||||
# Similarly to unions, a list of basic types can be expressed as a list in the schema
|
||||
if all(t in BASIC_TYPES for t in args):
|
||||
items = {"type": [_get_json_schema_type(t)["type"] for t in args if t != type(None)]}
|
||||
if len(items["type"]) == 1:
|
||||
items["type"] = items["type"][0]
|
||||
else:
|
||||
# And a list of more complex types requires us to recurse into each subtype again
|
||||
items = {"anyOf": [_parse_type_hint(t) for t in args if t not in (type(None), ...)]}
|
||||
if len(items["anyOf"]) == 1:
|
||||
items = items["anyOf"][0]
|
||||
|
||||
return_dict = {"type": "array", "items": items}
|
||||
|
||||
if type(None) in args:
|
||||
return_dict["nullable"] = True
|
||||
|
||||
return return_dict
|
||||
|
||||
if origin is tuple:
|
||||
if not args:
|
||||
return {"type": "array"}
|
||||
if len(args) == 1:
|
||||
raise ValueError(
|
||||
f"The type hint {str(hint).replace('typing.', '')} is a Tuple with a single element, which "
|
||||
"we do not automatically convert to JSON schema as it is rarely necessary. If this input can contain "
|
||||
"more than one element, we recommend "
|
||||
"using a List[] type instead, or if it really is a single element, remove the Tuple[] wrapper and just "
|
||||
"pass the element directly."
|
||||
)
|
||||
if ... in args:
|
||||
raise ValueError(
|
||||
"Conversion of '...' is not supported in Tuple type hints. "
|
||||
"Use List[] types for variable-length"
|
||||
" inputs instead."
|
||||
)
|
||||
return {"type": "array", "prefixItems": [_parse_type_hint(t) for t in args]}
|
||||
|
||||
if origin is dict:
|
||||
# The JSON equivalent to a dict is 'object', which mandates that all keys are strings
|
||||
# However, we can specify the type of the dict values with "additionalProperties"
|
||||
out = {"type": "object"}
|
||||
if len(args) == 2:
|
||||
out["additionalProperties"] = _parse_type_hint(args[1])
|
||||
return out
|
||||
|
||||
raise ValueError("Couldn't parse this type hint, likely due to a custom class or object: ", hint)
|
||||
|
||||
|
||||
def _convert_type_hints_to_json_schema(func):
|
||||
type_hints = get_type_hints(func)
|
||||
signature = inspect.signature(func)
|
||||
required = []
|
||||
for param_name, param in signature.parameters.items():
|
||||
if param.annotation == inspect.Parameter.empty:
|
||||
raise ValueError(f"Argument {param.name} is missing a type hint in function {func.__name__}")
|
||||
if param.default == inspect.Parameter.empty:
|
||||
required.append(param_name)
|
||||
|
||||
properties = {}
|
||||
for param_name, param_type in type_hints.items():
|
||||
properties[param_name] = _parse_type_hint(param_type)
|
||||
|
||||
schema = {"type": "object", "properties": properties}
|
||||
if required:
|
||||
schema["required"] = required
|
||||
|
||||
return schema
|
||||
|
||||
|
||||
def parse_google_format_docstring(docstring):
|
||||
"""
|
||||
Parses a Google-style docstring to extract the function description,
|
||||
argument descriptions, and return description.
|
||||
|
||||
Args:
|
||||
docstring (str): The docstring to parse.
|
||||
|
||||
Returns:
|
||||
The function description, arguments, and return description.
|
||||
"""
|
||||
|
||||
# Extract the sections
|
||||
description_match = description_re.search(docstring)
|
||||
args_match = args_re.search(docstring)
|
||||
returns_match = returns_re.search(docstring)
|
||||
|
||||
# Clean and store the sections
|
||||
description = description_match.group(1).strip() if description_match else None
|
||||
docstring_args = args_match.group(1).strip() if args_match else None
|
||||
returns = returns_match.group(1).strip() if returns_match else None
|
||||
|
||||
# Parsing the arguments into a dictionary
|
||||
if docstring_args is not None:
|
||||
docstring_args = "\n".join([line for line in docstring_args.split("\n") if line.strip()]) # Remove blank lines
|
||||
matches = args_split_re.findall(docstring_args)
|
||||
args_dict = {match[0]: re.sub(r"\s*\n+\s*", " ", match[1].strip()) for match in matches}
|
||||
else:
|
||||
args_dict = {}
|
||||
|
||||
return description, args_dict, returns
|
||||
|
||||
|
||||
def get_json_schema(func):
|
||||
"""
|
||||
This function generates a JSON schema for a given function, based on its docstring and type hints. This is
|
||||
mostly used for passing lists of tools to a chat template. The JSON schema contains the name and description of
|
||||
the function, as well as the names, types and descriptions for each of its arguments. `get_json_schema()` requires
|
||||
that the function has a docstring, and that each argument has a description in the docstring, in the standard
|
||||
Google docstring format shown below. It also requires that all the function arguments have a valid Python type hint.
|
||||
|
||||
Although it is not required, a `Returns` block can also be added, which will be included in the schema. This is
|
||||
optional because most chat templates ignore the return value of the function.
|
||||
|
||||
Args:
|
||||
func: The function to generate a JSON schema for.
|
||||
|
||||
Returns:
|
||||
A dictionary containing the JSON schema for the function.
|
||||
|
||||
Examples:
|
||||
```python
|
||||
>>> def multiply(x: float, y: float):
|
||||
>>> '''
|
||||
>>> A function that multiplies two numbers
|
||||
>>>
|
||||
>>> Args:
|
||||
>>> x: The first number to multiply
|
||||
>>> y: The second number to multiply
|
||||
>>> '''
|
||||
>>> return x * y
|
||||
>>>
|
||||
>>> print(get_json_schema(multiply))
|
||||
{
|
||||
"name": "multiply",
|
||||
"description": "A function that multiplies two numbers",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"x": {"type": "number", "description": "The first number to multiply"},
|
||||
"y": {"type": "number", "description": "The second number to multiply"}
|
||||
},
|
||||
"required": ["x", "y"]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
The general use for these schemas is that they are used to generate tool descriptions for chat templates that
|
||||
support them, like so:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer
|
||||
>>> from transformers.utils import get_json_schema
|
||||
>>>
|
||||
>>> def multiply(x: float, y: float):
|
||||
>>> '''
|
||||
>>> A function that multiplies two numbers
|
||||
>>>
|
||||
>>> Args:
|
||||
>>> x: The first number to multiply
|
||||
>>> y: The second number to multiply
|
||||
>>> return x * y
|
||||
>>> '''
|
||||
>>>
|
||||
>>> multiply_schema = get_json_schema(multiply)
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("CohereForAI/c4ai-command-r-v01")
|
||||
>>> messages = [{"role": "user", "content": "What is 179 x 4571?"}]
|
||||
>>> formatted_chat = tokenizer.apply_chat_template(
|
||||
>>> messages,
|
||||
>>> tools=[multiply_schema],
|
||||
>>> chat_template="tool_use",
|
||||
>>> return_dict=True,
|
||||
>>> return_tensors="pt",
|
||||
>>> add_generation_prompt=True
|
||||
>>> )
|
||||
>>> # The formatted chat can now be passed to model.generate()
|
||||
```
|
||||
|
||||
Each argument description can also have an optional `(choices: ...)` block at the end, such as
|
||||
`(choices: ["tea", "coffee"])`, which will be parsed into an `enum` field in the schema. Note that this will
|
||||
only be parsed correctly if it is at the end of the line:
|
||||
|
||||
```python
|
||||
>>> def drink_beverage(beverage: str):
|
||||
>>> '''
|
||||
>>> A function that drinks a beverage
|
||||
>>>
|
||||
>>> Args:
|
||||
>>> beverage: The beverage to drink (choices: ["tea", "coffee"])
|
||||
>>> '''
|
||||
>>> pass
|
||||
>>>
|
||||
>>> print(get_json_schema(drink_beverage))
|
||||
```
|
||||
{
|
||||
'name': 'drink_beverage',
|
||||
'description': 'A function that drinks a beverage',
|
||||
'parameters': {
|
||||
'type': 'object',
|
||||
'properties': {
|
||||
'beverage': {
|
||||
'type': 'string',
|
||||
'enum': ['tea', 'coffee'],
|
||||
'description': 'The beverage to drink'
|
||||
}
|
||||
},
|
||||
'required': ['beverage']
|
||||
}
|
||||
}
|
||||
"""
|
||||
doc = inspect.getdoc(func)
|
||||
if not doc:
|
||||
raise ValueError(f"Cannot generate JSON schema for {func.__name__} because it has no docstring!")
|
||||
doc = doc.strip()
|
||||
main_doc, param_descriptions, return_doc = parse_google_format_docstring(doc)
|
||||
|
||||
json_schema = _convert_type_hints_to_json_schema(func)
|
||||
if (return_dict := json_schema["properties"].pop("return", None)) is not None:
|
||||
if return_doc is not None: # We allow a missing return docstring since most templates ignore it
|
||||
return_dict["description"] = return_doc
|
||||
for arg in json_schema["properties"]:
|
||||
if arg not in param_descriptions:
|
||||
raise ValueError(
|
||||
f"Cannot generate JSON schema for {func.__name__} because the docstring has no description for the argument '{arg}'"
|
||||
)
|
||||
desc = param_descriptions[arg]
|
||||
enum_choices = re.search(r"\(choices:\s*(.*?)\)\s*$", desc, flags=re.IGNORECASE)
|
||||
if enum_choices:
|
||||
json_schema["properties"][arg]["enum"] = [c.strip() for c in json.loads(enum_choices.group(1))]
|
||||
desc = enum_choices.string[: enum_choices.start()].strip()
|
||||
json_schema["properties"][arg]["description"] = desc
|
||||
|
||||
output = {"name": func.__name__, "description": main_doc, "parameters": json_schema}
|
||||
if return_dict is not None:
|
||||
output["return"] = return_dict
|
||||
return {"type": "function", "function": output}
|
|
@ -0,0 +1,476 @@
|
|||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
from transformers.utils import get_json_schema
|
||||
|
||||
|
||||
class JsonSchemaGeneratorTest(unittest.TestCase):
|
||||
def test_simple_function(self):
|
||||
def fn(x: int):
|
||||
"""
|
||||
Test function
|
||||
|
||||
Args:
|
||||
x: The input
|
||||
"""
|
||||
return x
|
||||
|
||||
schema = get_json_schema(fn)
|
||||
expected_schema = {
|
||||
"name": "fn",
|
||||
"description": "Test function",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"x": {"type": "integer", "description": "The input"}},
|
||||
"required": ["x"],
|
||||
},
|
||||
}
|
||||
self.assertEqual(schema["function"], expected_schema)
|
||||
|
||||
def test_no_arguments(self):
|
||||
def fn():
|
||||
"""
|
||||
Test function
|
||||
"""
|
||||
return True
|
||||
|
||||
schema = get_json_schema(fn)
|
||||
expected_schema = {
|
||||
"name": "fn",
|
||||
"description": "Test function",
|
||||
"parameters": {"type": "object", "properties": {}},
|
||||
}
|
||||
self.assertEqual(schema["function"], expected_schema)
|
||||
|
||||
def test_union(self):
|
||||
def fn(x: Union[int, float]):
|
||||
"""
|
||||
Test function
|
||||
|
||||
Args:
|
||||
x: The input
|
||||
"""
|
||||
return x
|
||||
|
||||
schema = get_json_schema(fn)
|
||||
expected_schema = {
|
||||
"name": "fn",
|
||||
"description": "Test function",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"x": {"type": ["integer", "number"], "description": "The input"}},
|
||||
"required": ["x"],
|
||||
},
|
||||
}
|
||||
self.assertEqual(schema["function"], expected_schema)
|
||||
|
||||
def test_optional(self):
|
||||
def fn(x: Optional[int]):
|
||||
"""
|
||||
Test function
|
||||
|
||||
Args:
|
||||
x: The input
|
||||
"""
|
||||
return x
|
||||
|
||||
schema = get_json_schema(fn)
|
||||
expected_schema = {
|
||||
"name": "fn",
|
||||
"description": "Test function",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"x": {"type": "integer", "description": "The input", "nullable": True}},
|
||||
"required": ["x"],
|
||||
},
|
||||
}
|
||||
self.assertEqual(schema["function"], expected_schema)
|
||||
|
||||
def test_default_arg(self):
|
||||
def fn(x: int = 42):
|
||||
"""
|
||||
Test function
|
||||
|
||||
Args:
|
||||
x: The input
|
||||
"""
|
||||
return x
|
||||
|
||||
schema = get_json_schema(fn)
|
||||
expected_schema = {
|
||||
"name": "fn",
|
||||
"description": "Test function",
|
||||
"parameters": {"type": "object", "properties": {"x": {"type": "integer", "description": "The input"}}},
|
||||
}
|
||||
self.assertEqual(schema["function"], expected_schema)
|
||||
|
||||
def test_nested_list(self):
|
||||
def fn(x: List[List[Union[str, int]]]):
|
||||
"""
|
||||
Test function
|
||||
|
||||
Args:
|
||||
x: The input
|
||||
"""
|
||||
return x
|
||||
|
||||
schema = get_json_schema(fn)
|
||||
expected_schema = {
|
||||
"name": "fn",
|
||||
"description": "Test function",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"x": {
|
||||
"type": "array",
|
||||
"items": {"type": "array", "items": {"type": ["string", "integer"]}},
|
||||
"description": "The input",
|
||||
}
|
||||
},
|
||||
"required": ["x"],
|
||||
},
|
||||
}
|
||||
self.assertEqual(schema["function"], expected_schema)
|
||||
|
||||
def test_multiple_arguments(self):
|
||||
def fn(x: int, y: str):
|
||||
"""
|
||||
Test function
|
||||
|
||||
Args:
|
||||
x: The input
|
||||
y: Also the input
|
||||
"""
|
||||
return x
|
||||
|
||||
schema = get_json_schema(fn)
|
||||
expected_schema = {
|
||||
"name": "fn",
|
||||
"description": "Test function",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"x": {"type": "integer", "description": "The input"},
|
||||
"y": {"type": "string", "description": "Also the input"},
|
||||
},
|
||||
"required": ["x", "y"],
|
||||
},
|
||||
}
|
||||
self.assertEqual(schema["function"], expected_schema)
|
||||
|
||||
def test_multiple_complex_arguments(self):
|
||||
def fn(x: List[Union[int, float]], y: Optional[Union[int, str]] = None):
|
||||
"""
|
||||
Test function
|
||||
|
||||
Args:
|
||||
x: The input
|
||||
y: Also the input
|
||||
"""
|
||||
return x
|
||||
|
||||
schema = get_json_schema(fn)
|
||||
expected_schema = {
|
||||
"name": "fn",
|
||||
"description": "Test function",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"x": {"type": "array", "items": {"type": ["integer", "number"]}, "description": "The input"},
|
||||
"y": {
|
||||
"type": ["integer", "string"],
|
||||
"nullable": True,
|
||||
"description": "Also the input",
|
||||
},
|
||||
},
|
||||
"required": ["x"],
|
||||
},
|
||||
}
|
||||
self.assertEqual(schema["function"], expected_schema)
|
||||
|
||||
def test_missing_docstring(self):
|
||||
def fn(x: int):
|
||||
return x
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
get_json_schema(fn)
|
||||
|
||||
def test_missing_param_docstring(self):
|
||||
def fn(x: int):
|
||||
"""
|
||||
Test function
|
||||
"""
|
||||
return x
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
get_json_schema(fn)
|
||||
|
||||
def test_missing_type_hint(self):
|
||||
def fn(x):
|
||||
"""
|
||||
Test function
|
||||
|
||||
Args:
|
||||
x: The input
|
||||
"""
|
||||
return x
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
get_json_schema(fn)
|
||||
|
||||
def test_return_value(self):
|
||||
def fn(x: int) -> int:
|
||||
"""
|
||||
Test function
|
||||
|
||||
Args:
|
||||
x: The input
|
||||
"""
|
||||
return x
|
||||
|
||||
schema = get_json_schema(fn)
|
||||
expected_schema = {
|
||||
"name": "fn",
|
||||
"description": "Test function",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"x": {"type": "integer", "description": "The input"}},
|
||||
"required": ["x"],
|
||||
},
|
||||
"return": {"type": "integer"},
|
||||
}
|
||||
self.assertEqual(schema["function"], expected_schema)
|
||||
|
||||
def test_return_value_docstring(self):
|
||||
def fn(x: int) -> int:
|
||||
"""
|
||||
Test function
|
||||
|
||||
Args:
|
||||
x: The input
|
||||
|
||||
|
||||
Returns:
|
||||
The output
|
||||
"""
|
||||
return x
|
||||
|
||||
schema = get_json_schema(fn)
|
||||
expected_schema = {
|
||||
"name": "fn",
|
||||
"description": "Test function",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"x": {"type": "integer", "description": "The input"}},
|
||||
"required": ["x"],
|
||||
},
|
||||
"return": {"type": "integer", "description": "The output"},
|
||||
}
|
||||
self.assertEqual(schema["function"], expected_schema)
|
||||
|
||||
def test_tuple(self):
|
||||
def fn(x: Tuple[int, str]):
|
||||
"""
|
||||
Test function
|
||||
|
||||
Args:
|
||||
x: The input
|
||||
|
||||
|
||||
Returns:
|
||||
The output
|
||||
"""
|
||||
return x
|
||||
|
||||
schema = get_json_schema(fn)
|
||||
expected_schema = {
|
||||
"name": "fn",
|
||||
"description": "Test function",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"x": {
|
||||
"type": "array",
|
||||
"prefixItems": [{"type": "integer"}, {"type": "string"}],
|
||||
"description": "The input",
|
||||
}
|
||||
},
|
||||
"required": ["x"],
|
||||
},
|
||||
}
|
||||
self.assertEqual(schema["function"], expected_schema)
|
||||
|
||||
def test_single_element_tuple_fails(self):
|
||||
def fn(x: Tuple[int]):
|
||||
"""
|
||||
Test function
|
||||
|
||||
Args:
|
||||
x: The input
|
||||
|
||||
|
||||
Returns:
|
||||
The output
|
||||
"""
|
||||
return x
|
||||
|
||||
# Single-element tuples should just be the type itself, or List[type] for variable-length inputs
|
||||
with self.assertRaises(ValueError):
|
||||
get_json_schema(fn)
|
||||
|
||||
def test_ellipsis_type_fails(self):
|
||||
def fn(x: Tuple[int, ...]):
|
||||
"""
|
||||
Test function
|
||||
|
||||
Args:
|
||||
x: The input
|
||||
|
||||
|
||||
Returns:
|
||||
The output
|
||||
"""
|
||||
return x
|
||||
|
||||
# Variable length inputs should be specified with List[type], not Tuple[type, ...]
|
||||
with self.assertRaises(ValueError):
|
||||
get_json_schema(fn)
|
||||
|
||||
def test_enum_extraction(self):
|
||||
def fn(temperature_format: str):
|
||||
"""
|
||||
Test function
|
||||
|
||||
Args:
|
||||
temperature_format: The temperature format to use (Choices: ["celsius", "fahrenheit"])
|
||||
|
||||
|
||||
Returns:
|
||||
The temperature
|
||||
"""
|
||||
return -40.0
|
||||
|
||||
# Let's see if that gets correctly parsed as an enum
|
||||
schema = get_json_schema(fn)
|
||||
expected_schema = {
|
||||
"name": "fn",
|
||||
"description": "Test function",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"temperature_format": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"description": "The temperature format to use",
|
||||
}
|
||||
},
|
||||
"required": ["temperature_format"],
|
||||
},
|
||||
}
|
||||
|
||||
self.assertEqual(schema["function"], expected_schema)
|
||||
|
||||
def test_multiline_docstring_with_types(self):
|
||||
def fn(x: int, y: int):
|
||||
"""
|
||||
Test function
|
||||
|
||||
Args:
|
||||
x: The first input
|
||||
|
||||
y: The second input. This is a longer description
|
||||
that spans multiple lines with indentation and stuff.
|
||||
|
||||
Returns:
|
||||
God knows what
|
||||
"""
|
||||
pass
|
||||
|
||||
schema = get_json_schema(fn)
|
||||
expected_schema = {
|
||||
"name": "fn",
|
||||
"description": "Test function",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"x": {"type": "integer", "description": "The first input"},
|
||||
"y": {
|
||||
"type": "integer",
|
||||
"description": "The second input. This is a longer description that spans multiple lines with indentation and stuff.",
|
||||
},
|
||||
},
|
||||
"required": ["x", "y"],
|
||||
},
|
||||
}
|
||||
|
||||
self.assertEqual(schema["function"], expected_schema)
|
||||
|
||||
def test_everything_all_at_once(self):
|
||||
def fn(
|
||||
x: str, y: Optional[List[Union[str, int]]], z: Tuple[Union[str, int], str] = (42, "hello")
|
||||
) -> Tuple[int, str]:
|
||||
"""
|
||||
Test function with multiple args, and docstring args that we have to strip out.
|
||||
|
||||
Args:
|
||||
x: The first input. It's got a big multiline
|
||||
description and also contains
|
||||
(choices: ["a", "b", "c"])
|
||||
|
||||
y: The second input. It's a big list with a single-line description.
|
||||
|
||||
z: The third input. It's some kind of tuple with a default arg.
|
||||
|
||||
Returns:
|
||||
The output. The return description is also a big multiline
|
||||
description that spans multiple lines.
|
||||
"""
|
||||
pass
|
||||
|
||||
schema = get_json_schema(fn)
|
||||
expected_schema = {
|
||||
"name": "fn",
|
||||
"description": "Test function with multiple args, and docstring args that we have to strip out.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"x": {
|
||||
"type": "string",
|
||||
"enum": ["a", "b", "c"],
|
||||
"description": "The first input. It's got a big multiline description and also contains",
|
||||
},
|
||||
"y": {
|
||||
"type": "array",
|
||||
"items": {"type": ["string", "integer"]},
|
||||
"nullable": True,
|
||||
"description": "The second input. It's a big list with a single-line description.",
|
||||
},
|
||||
"z": {
|
||||
"type": "array",
|
||||
"prefixItems": [{"type": ["string", "integer"]}, {"type": "string"}],
|
||||
"description": "The third input. It's some kind of tuple with a default arg.",
|
||||
},
|
||||
},
|
||||
"required": ["x", "y"],
|
||||
},
|
||||
"return": {
|
||||
"type": "array",
|
||||
"prefixItems": [{"type": "integer"}, {"type": "string"}],
|
||||
"description": "The output. The return description is also a big multiline\n description that spans multiple lines.",
|
||||
},
|
||||
}
|
||||
self.assertEqual(schema["function"], expected_schema)
|
Loading…
Reference in New Issue