Compare commits

...

57 Commits

Author SHA1 Message Date
Matt 4918868afb Merge remote-tracking branch 'origin/new_chat_template_args' into new_chat_template_args 2024-05-29 18:12:52 +01:00
Matt dd54280d95 Wrap functions in {"type": "function", "function": ...} 2024-05-29 18:12:40 +01:00
Matt 656813133f
Update src/transformers/utils/chat_template_utils.py
Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com>
2024-05-28 19:20:34 +01:00
Matt 82d90244dd Clean up enum regex 2024-05-28 16:21:22 +01:00
Matt c2404ed744 Update docs for the regex change 2024-05-28 16:18:49 +01:00
Matt fbd7af80aa Stop supporting type hints in docstring to fix bugs and simplify the regex 2024-05-28 16:17:34 +01:00
Matt 2145e9bba1 Quick test fixes 2024-05-28 14:28:20 +01:00
Matt 4550ad25b6 Document enum block 2024-05-28 14:28:20 +01:00
Matt 11ac07e545 Add an extra test for very complex defs and docstrings and clean everything up for it 2024-05-28 14:28:19 +01:00
Matt deae402554 Clean up Tuple error 2024-05-28 14:28:19 +01:00
Matt 74144d98fc Refactor docs 2024-05-28 14:28:19 +01:00
Matt f80258ac0c Refactor docs 2024-05-28 14:28:19 +01:00
Matt a8cc7d6ab2 Refactor docs 2024-05-28 14:28:19 +01:00
Matt c074c61a1e Add document type validation 2024-05-28 14:28:19 +01:00
Matt 1fa68da868 Update ruff 2024-05-28 14:28:19 +01:00
Matt 491cad643a Update error message for ... 2024-05-28 14:28:19 +01:00
Matt 4f3eafe978 Support more complex, multi-line arg docstrings 2024-05-28 14:28:19 +01:00
Matt 54fde17c7e Make regexes module-level 2024-05-28 14:28:19 +01:00
Matt e07cd5e42b Correct return value 2024-05-28 14:28:19 +01:00
Matt 587558d109 Reformat chat_template_utils 2024-05-28 14:28:18 +01:00
Matt 1a002454a1 Fix indentation 2024-05-28 14:28:18 +01:00
Matt 2f18a47c7f make fixup 2024-05-28 14:28:18 +01:00
Matt 6177c16df5 Add copyright header 2024-05-28 14:28:18 +01:00
Matt 8164ae070e Update src/transformers/utils/chat_template_utils.py
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
2024-05-28 14:28:18 +01:00
Matt e46d5842af Update src/transformers/utils/chat_template_utils.py
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
2024-05-28 14:28:18 +01:00
Matt bfbd88f5d3 Update docs/source/en/chat_templating.md
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
2024-05-28 14:28:18 +01:00
Matt fbbf7b077d Update src/transformers/tokenization_utils_base.py
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
2024-05-28 14:28:18 +01:00
Matt 21442d09c5 Add copyright header 2024-05-28 14:28:18 +01:00
Matt 5437ac84b2 Update docstring 2024-05-28 14:28:18 +01:00
Matt 22111b6df9 Add enum support to get_json_schema 2024-05-28 14:28:18 +01:00
Matt 1419bf4b6e No more decorator - we just do it implicitly! 2024-05-28 14:28:17 +01:00
Matt 2f1055d72a Add Tuple support 2024-05-28 14:28:17 +01:00
Matt dbba59f1e9 Stop putting the return type in with the other parameters 2024-05-28 14:28:17 +01:00
Matt 6148213a93 Update chat templating docs to match new format 2024-05-28 14:28:17 +01:00
Matt 7fd0da44c6 Switch to Google format docstrings 2024-05-28 14:28:17 +01:00
Matt 4912b0df77 Proper return type tests 2024-05-28 14:28:17 +01:00
Matt 1225af2f41 Support return types for the templates that need them 2024-05-28 14:28:17 +01:00
Matt 0e50045cf3 Less "anyOf" when unnecessary 2024-05-28 14:28:17 +01:00
Matt 9faf6f4e56 Fix something that was bugging me in the chat template docstring 2024-05-28 14:28:17 +01:00
Matt d17441f811 Quick test fix 2024-05-28 14:28:17 +01:00
Matt 60775f19cd add import for add_json_schema 2024-05-28 14:28:17 +01:00
Matt 079d6d5bf4 self.maxDiff = None to see the whole diff for the nested list test 2024-05-28 14:28:17 +01:00
Matt 44560a7280 Clean up the TODOs and finish the docs 2024-05-28 14:28:17 +01:00
Matt 156eed9024 Add json_schema decorator 2024-05-28 14:28:16 +01:00
Matt ee288915d7 More doc updates 2024-05-28 14:28:16 +01:00
Matt d266a99b45 More doc updates 2024-05-28 14:28:16 +01:00
Matt dae3812bee Small docstring change 2024-05-28 14:28:16 +01:00
Matt 89c0a49ba5 Update tests, proper handling of return values 2024-05-28 14:28:16 +01:00
Matt 51359bf11e Add tests, update schema generator 2024-05-28 14:28:16 +01:00
Matt 5019aa224b More cleanup 2024-05-28 14:28:16 +01:00
Matt 06f9a3437b Comments and bugfixes for the type hint parser 2024-05-28 14:28:16 +01:00
Matt 43a481affa Lots of cleanup and edge cases, looking better now 2024-05-28 14:28:16 +01:00
Matt b875a5d85f please stop committing your debug breakpoints 2024-05-28 14:28:16 +01:00
Matt 8c2a9ae41d the walrus has betrayed me 2024-05-28 14:28:16 +01:00
Matt 096061323c Lots of small fixes 2024-05-28 14:28:16 +01:00
Matt 0340cfa8fd First draft of the automatic schema generator 2024-05-28 14:28:16 +01:00
Matt 072bd22e1c First draft, still missing automatic function conversion 2024-05-28 14:28:15 +01:00
5 changed files with 1005 additions and 3 deletions

View File

@ -233,6 +233,180 @@ The sun.</s>
From here, just continue training like you would with a standard language modelling task, using the `formatted_chat` column.
## Advanced: Extra inputs to chat templates
The only argument that `apply_chat_template` requires is `messages`. However, you can pass any keyword
argument to `apply_chat_template` and it will be accessible inside the template. This gives you a lot of freedom to use
chat templates for many things. There are no restrictions on the names or the format of these arguments - you can pass
strings, lists, dicts or whatever else you want.
That said, there are some common use-cases for these extra arguments,
such as passing tools for function calling, or documents for retrieval-augmented generation. In these common cases,
we have some opinionated recommendations about what the names and formats of these arguments should be.
### Tool use / function calling
"Tool use" LLMs can choose to call functions as external tools before generating an answer. When passing tools
to a tool-use model, you can simply pass a list of functions to the `tools` argument:
```python
import datetime
def current_time():
"""Get the current local time as a string."""
return str(datetime.now())
def multiply(a: float, b: float):
"""
A function that multiplies two numbers
Args:
a: The first number to multiply
b: The second number to multiply
"""
return a * b
tools = [current_time, multiply]
model_input = tokenizer.apply_chat_template(
messages,
tools=tools
)
```
In order for this to work correctly, you should write your functions in the format above, so that they can be parsed
correctly as tools. Specifically, you should follow these rules:
- The function should have a descriptive name
- Every argument must have a type hint
- The function must have a docstring in the standard Google style (in other words, an initial function description
followed by an `Args:` block that describes the arguments, unless the function does not have any arguments.
Do not include types in the `Args:` block - put them in the type hints in the function header instead.
- The function can have a return type and a `Returns:` block in the docstring. However, these are optional
because most tool-use models ignore them.
### Understanding tool schemas
Each function you pass to the `tools` argument of `apply_chat_template` is converted into a
[JSON schema](https://json-schema.org/learn/getting-started-step-by-step. These schemas
are then passed to the model chat template. In other words, tool-use models do not see your functions directly, and they
never see the actual code inside them. What they care about is the function **definitions** and the **arguments** they
need to pass to them - they care about what the tools do and how to use them, not how they work! It is up to you
to read their outputs, detect if they have requested to use a tool, pass their arguments to the tool function, and
return the response in the chat.
Generating JSON schemas to pass to the template should be automatic and invisible as long as your functions
follow the specification above, but if you encounter problems, or you simply want more control over the conversion,
you can handle the conversion manually. Here is an example of a manual schema conversion.
```python
from transformers.utils import get_json_schema
def multiply(a: float, b: float):
"""
A function that multiplies two numbers
Args:
a: The first number to multiply
b: The second number to multiply
"""
return a * b
schema = get_json_schema(multiply)
print(schema)
```
This will yield:
```json
{
"name": "multiply",
"description": "Multiply two numbers together.",
"parameters": {
"type": "object",
"properties": {
"a": {
"type": "number",
"description": "The first number to multiply."
},
"b": {
"type": "number",
"description": "The second number to multiply."
}
},
"required": ["a", "b"]
}
}
```
If you wish, you can edit these schemas, or even write them from scratch yourself without using `get_json_schema` at
all. JSON schemas can be passed directly to the `tools` argument of
`apply_chat_template` - this gives you a lot of power to define precise schemas for more complex functions. Be careful,
though - the more complex your schemas, the more likely the model is to get confused when dealing with them! We
recommend simple function signatures where possible, keeping arguments (and especially complex, nested arguments)
to a minimum.
Here is an example of defining schemas by hand, and passing them directly to `apply_chat_template`:
```python
# A simple function that takes no arguments
current_time = {
"name": "current_time",
"description": "Get the current local time as a string.",
"parameters": {
'type': 'object',
'properties': {}
},
}
# A more complete function that takes two numerical arguments
multiply = {
"name": "multiply",
"description": "Multiply two numbers together.",
"parameters": {
"type": "object",
"properties": {
"a": {"type": "number", "description": "The first number to multiply."},
"b": {"type": "number", "description": "The second number to multiply."},
},
"required": ["a", "b"],
}
}
model_input = tokenizer.apply_chat_template(
messages,
tools = [current_time, multiply]
)
```
### Retrieval-augmented generation
"Retrieval-augmented generation" or "RAG" LLMs can search a corpus of documents for information before responding
to a query. This allows models to vastly expand their knowledge base beyond their limited context size. Our
recommendation for RAG models is that their template
should accept a `documents` argument. This should be a list of documents, where each "document"
is a single dict with `title` and `contents` keys, both of which are strings. Because this format is much simpler
than the JSON schemas used for tools, no helper functions are necessary.
Here's an example of a RAG template in action:
```python
document1 = {
"title": "The Moon: Our Age-Old Foe",
"contents": "Man has always dreamed of destroying the moon. In this essay, I shall..."
}
document2 = {
"title": "The Sun: Our Age-Old Friend",
"contents": "Although often underappreciated, the sun provides several notable benefits..."
}
model_input = tokenizer.apply_chat_template(
messages,
documents=[document1, document2]
)
```
## Advanced: How do chat templates work?
The chat template for a model is stored on the `tokenizer.chat_template` attribute. If no chat template is set, the

View File

@ -28,6 +28,7 @@ from collections.abc import Mapping, Sized
from contextlib import contextmanager
from dataclasses import dataclass
from functools import lru_cache
from inspect import isfunction
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
import numpy as np
@ -47,6 +48,7 @@ from .utils import (
copy_func,
download_url,
extract_commit_hash,
get_json_schema,
is_flax_available,
is_jax_tensor,
is_mlx_available,
@ -1685,6 +1687,8 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
def apply_chat_template(
self,
conversation: Union[List[Dict[str, str]], List[List[Dict[str, str]]], "Conversation"],
tools: Optional[List[Dict]] = None,
documents: Optional[List[Dict[str, str]]] = None,
chat_template: Optional[str] = None,
add_generation_prompt: bool = False,
tokenize: bool = True,
@ -1705,8 +1709,21 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
Args:
conversation (Union[List[Dict[str, str]], List[List[Dict[str, str]]], "Conversation"]): A list of dicts
with "role" and "content" keys, representing the chat history so far.
chat_template (str, *optional*): A Jinja template to use for this conversion. If
this is not passed, the model's default chat template will be used instead.
tools (`List[Dict]`, *optional*):
A list of tools (callable functions) that will be accessible to the model. If the template does not
support function calling, this argument will have no effect. Each tool should be passed as a JSON Schema,
giving the name, description and argument types for the tool. See our
[chat templating guide](https://huggingface.co/docs/transformers/main/en/chat_templating#automated-function-conversion-for-tool-use)
for more information.
documents (`List[Dict[str, str]]`, *optional*):
A list of dicts representing documents that will be accessible to the model if it is performing RAG
(retrieval-augmented generation). If the template does not support RAG, this argument will have no
effect. We recommend that each document should be a dict containing "title" and "text" keys. Please
see the RAG section of the [chat templating guide](https://huggingface.co/docs/transformers/main/en/chat_templating#arguments-for-RAG)
for examples of passing documents with chat templates.
chat_template (`str`, *optional*):
A Jinja template to use for this conversion. It is usually not necessary to pass anything to this
argument, as the model's template will be used by default.
add_generation_prompt (bool, *optional*): Whether to end the prompt with the token(s) that indicate
the start of an assistant message. This is useful when you want to generate a response from the model.
Note that this argument will be passed to the chat template, and so it must be supported in the
@ -1804,6 +1821,27 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
conversations = [conversation]
is_batched = False
# We accept either JSON schemas or functions for tools. If we get functions, we convert them to schemas
if tools is not None:
tool_schemas = []
for tool in tools:
if isinstance(tool, dict):
tool_schemas.append(tool)
elif isfunction(tool):
tool_schemas.append(get_json_schema(tool))
else:
raise ValueError(
"Tools should either be a JSON schema, or a callable function with type hints "
"and a docstring suitable for auto-conversion to a schema."
)
else:
tool_schemas = None
if documents is not None:
for document in documents:
if not isinstance(document, dict):
raise TypeError("Documents should be a list of dicts with 'title' and 'text' keys!")
rendered = []
template_kwargs = {**self.special_tokens_map, **kwargs} # kwargs overwrite special tokens if both are present
for chat in conversations:
@ -1811,7 +1849,11 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
# Indicates it's a Conversation object
chat = chat.messages
rendered_chat = compiled_template.render(
messages=chat, add_generation_prompt=add_generation_prompt, **template_kwargs
messages=chat,
tools=tool_schemas,
documents=documents,
add_generation_prompt=add_generation_prompt,
**template_kwargs,
)
rendered.append(rendered_chat)

View File

@ -21,6 +21,7 @@ from packaging import version
from .. import __version__
from .backbone_utils import BackboneConfigMixin, BackboneMixin
from .chat_template_utils import get_json_schema
from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD
from .doc import (
add_code_sample_docstrings,

View File

@ -0,0 +1,309 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import json
import re
from typing import Any, Union, get_args, get_origin, get_type_hints
BASIC_TYPES = (int, float, str, bool, Any, type(None), ...)
description_re = re.compile(r"^(.*?)[\n\s]*(Args:|Returns:|Raises:|\Z)", re.DOTALL)
args_re = re.compile(r"\n\s*Args:\n\s*(.*?)[\n\s]*(Returns:|Raises:|\Z)", re.DOTALL)
args_split_re = re.compile(
r"""
(?:^|\n) # Match the start of the args block, or a newline
\s*(\w+):\s* # Capture the argument name and strip spacing
(.*?)\s* # Capture the argument description, which can span multiple lines, and strip trailing spacing
(?=\n\s*\w+:|\Z) # Stop when you hit the next argument or the end of the block
""",
re.DOTALL | re.VERBOSE,
)
returns_re = re.compile(r"\n\s*Returns:\n\s*(.*?)[\n\s]*(Raises:|\Z)", re.DOTALL)
def _get_json_schema_type(param_type):
type_mapping = {
int: {"type": "integer"},
float: {"type": "number"},
str: {"type": "string"},
bool: {"type": "boolean"},
Any: {},
}
return type_mapping.get(param_type, {"type": "object"})
def _parse_type_hint(hint):
origin = get_origin(hint)
args = get_args(hint)
if origin is None:
return _get_json_schema_type(hint)
if origin is Union:
# If it's a union of basic types, we can express that as a simple list in the schema
if all(t in BASIC_TYPES for t in args):
return_dict = {"type": [_get_json_schema_type(t)["type"] for t in args if t not in (type(None), ...)]}
if len(return_dict["type"]) == 1:
return_dict["type"] = return_dict["type"][0]
else:
# A union of more complex types requires us to recurse into each subtype
return_dict = {
"anyOf": [_parse_type_hint(t) for t in args if t not in (type(None), ...)],
}
if len(return_dict["anyOf"]) == 1:
return_dict = return_dict["anyOf"][0]
if type(None) in args:
return_dict["nullable"] = True
return return_dict
if origin is list:
if not args:
return {"type": "array"}
# Similarly to unions, a list of basic types can be expressed as a list in the schema
if all(t in BASIC_TYPES for t in args):
items = {"type": [_get_json_schema_type(t)["type"] for t in args if t != type(None)]}
if len(items["type"]) == 1:
items["type"] = items["type"][0]
else:
# And a list of more complex types requires us to recurse into each subtype again
items = {"anyOf": [_parse_type_hint(t) for t in args if t not in (type(None), ...)]}
if len(items["anyOf"]) == 1:
items = items["anyOf"][0]
return_dict = {"type": "array", "items": items}
if type(None) in args:
return_dict["nullable"] = True
return return_dict
if origin is tuple:
if not args:
return {"type": "array"}
if len(args) == 1:
raise ValueError(
f"The type hint {str(hint).replace('typing.', '')} is a Tuple with a single element, which "
"we do not automatically convert to JSON schema as it is rarely necessary. If this input can contain "
"more than one element, we recommend "
"using a List[] type instead, or if it really is a single element, remove the Tuple[] wrapper and just "
"pass the element directly."
)
if ... in args:
raise ValueError(
"Conversion of '...' is not supported in Tuple type hints. "
"Use List[] types for variable-length"
" inputs instead."
)
return {"type": "array", "prefixItems": [_parse_type_hint(t) for t in args]}
if origin is dict:
# The JSON equivalent to a dict is 'object', which mandates that all keys are strings
# However, we can specify the type of the dict values with "additionalProperties"
out = {"type": "object"}
if len(args) == 2:
out["additionalProperties"] = _parse_type_hint(args[1])
return out
raise ValueError("Couldn't parse this type hint, likely due to a custom class or object: ", hint)
def _convert_type_hints_to_json_schema(func):
type_hints = get_type_hints(func)
signature = inspect.signature(func)
required = []
for param_name, param in signature.parameters.items():
if param.annotation == inspect.Parameter.empty:
raise ValueError(f"Argument {param.name} is missing a type hint in function {func.__name__}")
if param.default == inspect.Parameter.empty:
required.append(param_name)
properties = {}
for param_name, param_type in type_hints.items():
properties[param_name] = _parse_type_hint(param_type)
schema = {"type": "object", "properties": properties}
if required:
schema["required"] = required
return schema
def parse_google_format_docstring(docstring):
"""
Parses a Google-style docstring to extract the function description,
argument descriptions, and return description.
Args:
docstring (str): The docstring to parse.
Returns:
The function description, arguments, and return description.
"""
# Extract the sections
description_match = description_re.search(docstring)
args_match = args_re.search(docstring)
returns_match = returns_re.search(docstring)
# Clean and store the sections
description = description_match.group(1).strip() if description_match else None
docstring_args = args_match.group(1).strip() if args_match else None
returns = returns_match.group(1).strip() if returns_match else None
# Parsing the arguments into a dictionary
if docstring_args is not None:
docstring_args = "\n".join([line for line in docstring_args.split("\n") if line.strip()]) # Remove blank lines
matches = args_split_re.findall(docstring_args)
args_dict = {match[0]: re.sub(r"\s*\n+\s*", " ", match[1].strip()) for match in matches}
else:
args_dict = {}
return description, args_dict, returns
def get_json_schema(func):
"""
This function generates a JSON schema for a given function, based on its docstring and type hints. This is
mostly used for passing lists of tools to a chat template. The JSON schema contains the name and description of
the function, as well as the names, types and descriptions for each of its arguments. `get_json_schema()` requires
that the function has a docstring, and that each argument has a description in the docstring, in the standard
Google docstring format shown below. It also requires that all the function arguments have a valid Python type hint.
Although it is not required, a `Returns` block can also be added, which will be included in the schema. This is
optional because most chat templates ignore the return value of the function.
Args:
func: The function to generate a JSON schema for.
Returns:
A dictionary containing the JSON schema for the function.
Examples:
```python
>>> def multiply(x: float, y: float):
>>> '''
>>> A function that multiplies two numbers
>>>
>>> Args:
>>> x: The first number to multiply
>>> y: The second number to multiply
>>> '''
>>> return x * y
>>>
>>> print(get_json_schema(multiply))
{
"name": "multiply",
"description": "A function that multiplies two numbers",
"parameters": {
"type": "object",
"properties": {
"x": {"type": "number", "description": "The first number to multiply"},
"y": {"type": "number", "description": "The second number to multiply"}
},
"required": ["x", "y"]
}
}
```
The general use for these schemas is that they are used to generate tool descriptions for chat templates that
support them, like so:
```python
>>> from transformers import AutoTokenizer
>>> from transformers.utils import get_json_schema
>>>
>>> def multiply(x: float, y: float):
>>> '''
>>> A function that multiplies two numbers
>>>
>>> Args:
>>> x: The first number to multiply
>>> y: The second number to multiply
>>> return x * y
>>> '''
>>>
>>> multiply_schema = get_json_schema(multiply)
>>> tokenizer = AutoTokenizer.from_pretrained("CohereForAI/c4ai-command-r-v01")
>>> messages = [{"role": "user", "content": "What is 179 x 4571?"}]
>>> formatted_chat = tokenizer.apply_chat_template(
>>> messages,
>>> tools=[multiply_schema],
>>> chat_template="tool_use",
>>> return_dict=True,
>>> return_tensors="pt",
>>> add_generation_prompt=True
>>> )
>>> # The formatted chat can now be passed to model.generate()
```
Each argument description can also have an optional `(choices: ...)` block at the end, such as
`(choices: ["tea", "coffee"])`, which will be parsed into an `enum` field in the schema. Note that this will
only be parsed correctly if it is at the end of the line:
```python
>>> def drink_beverage(beverage: str):
>>> '''
>>> A function that drinks a beverage
>>>
>>> Args:
>>> beverage: The beverage to drink (choices: ["tea", "coffee"])
>>> '''
>>> pass
>>>
>>> print(get_json_schema(drink_beverage))
```
{
'name': 'drink_beverage',
'description': 'A function that drinks a beverage',
'parameters': {
'type': 'object',
'properties': {
'beverage': {
'type': 'string',
'enum': ['tea', 'coffee'],
'description': 'The beverage to drink'
}
},
'required': ['beverage']
}
}
"""
doc = inspect.getdoc(func)
if not doc:
raise ValueError(f"Cannot generate JSON schema for {func.__name__} because it has no docstring!")
doc = doc.strip()
main_doc, param_descriptions, return_doc = parse_google_format_docstring(doc)
json_schema = _convert_type_hints_to_json_schema(func)
if (return_dict := json_schema["properties"].pop("return", None)) is not None:
if return_doc is not None: # We allow a missing return docstring since most templates ignore it
return_dict["description"] = return_doc
for arg in json_schema["properties"]:
if arg not in param_descriptions:
raise ValueError(
f"Cannot generate JSON schema for {func.__name__} because the docstring has no description for the argument '{arg}'"
)
desc = param_descriptions[arg]
enum_choices = re.search(r"\(choices:\s*(.*?)\)\s*$", desc, flags=re.IGNORECASE)
if enum_choices:
json_schema["properties"][arg]["enum"] = [c.strip() for c in json.loads(enum_choices.group(1))]
desc = enum_choices.string[: enum_choices.start()].strip()
json_schema["properties"][arg]["description"] = desc
output = {"name": func.__name__, "description": main_doc, "parameters": json_schema}
if return_dict is not None:
output["return"] = return_dict
return {"type": "function", "function": output}

View File

@ -0,0 +1,476 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from typing import List, Optional, Tuple, Union
from transformers.utils import get_json_schema
class JsonSchemaGeneratorTest(unittest.TestCase):
def test_simple_function(self):
def fn(x: int):
"""
Test function
Args:
x: The input
"""
return x
schema = get_json_schema(fn)
expected_schema = {
"name": "fn",
"description": "Test function",
"parameters": {
"type": "object",
"properties": {"x": {"type": "integer", "description": "The input"}},
"required": ["x"],
},
}
self.assertEqual(schema["function"], expected_schema)
def test_no_arguments(self):
def fn():
"""
Test function
"""
return True
schema = get_json_schema(fn)
expected_schema = {
"name": "fn",
"description": "Test function",
"parameters": {"type": "object", "properties": {}},
}
self.assertEqual(schema["function"], expected_schema)
def test_union(self):
def fn(x: Union[int, float]):
"""
Test function
Args:
x: The input
"""
return x
schema = get_json_schema(fn)
expected_schema = {
"name": "fn",
"description": "Test function",
"parameters": {
"type": "object",
"properties": {"x": {"type": ["integer", "number"], "description": "The input"}},
"required": ["x"],
},
}
self.assertEqual(schema["function"], expected_schema)
def test_optional(self):
def fn(x: Optional[int]):
"""
Test function
Args:
x: The input
"""
return x
schema = get_json_schema(fn)
expected_schema = {
"name": "fn",
"description": "Test function",
"parameters": {
"type": "object",
"properties": {"x": {"type": "integer", "description": "The input", "nullable": True}},
"required": ["x"],
},
}
self.assertEqual(schema["function"], expected_schema)
def test_default_arg(self):
def fn(x: int = 42):
"""
Test function
Args:
x: The input
"""
return x
schema = get_json_schema(fn)
expected_schema = {
"name": "fn",
"description": "Test function",
"parameters": {"type": "object", "properties": {"x": {"type": "integer", "description": "The input"}}},
}
self.assertEqual(schema["function"], expected_schema)
def test_nested_list(self):
def fn(x: List[List[Union[str, int]]]):
"""
Test function
Args:
x: The input
"""
return x
schema = get_json_schema(fn)
expected_schema = {
"name": "fn",
"description": "Test function",
"parameters": {
"type": "object",
"properties": {
"x": {
"type": "array",
"items": {"type": "array", "items": {"type": ["string", "integer"]}},
"description": "The input",
}
},
"required": ["x"],
},
}
self.assertEqual(schema["function"], expected_schema)
def test_multiple_arguments(self):
def fn(x: int, y: str):
"""
Test function
Args:
x: The input
y: Also the input
"""
return x
schema = get_json_schema(fn)
expected_schema = {
"name": "fn",
"description": "Test function",
"parameters": {
"type": "object",
"properties": {
"x": {"type": "integer", "description": "The input"},
"y": {"type": "string", "description": "Also the input"},
},
"required": ["x", "y"],
},
}
self.assertEqual(schema["function"], expected_schema)
def test_multiple_complex_arguments(self):
def fn(x: List[Union[int, float]], y: Optional[Union[int, str]] = None):
"""
Test function
Args:
x: The input
y: Also the input
"""
return x
schema = get_json_schema(fn)
expected_schema = {
"name": "fn",
"description": "Test function",
"parameters": {
"type": "object",
"properties": {
"x": {"type": "array", "items": {"type": ["integer", "number"]}, "description": "The input"},
"y": {
"type": ["integer", "string"],
"nullable": True,
"description": "Also the input",
},
},
"required": ["x"],
},
}
self.assertEqual(schema["function"], expected_schema)
def test_missing_docstring(self):
def fn(x: int):
return x
with self.assertRaises(ValueError):
get_json_schema(fn)
def test_missing_param_docstring(self):
def fn(x: int):
"""
Test function
"""
return x
with self.assertRaises(ValueError):
get_json_schema(fn)
def test_missing_type_hint(self):
def fn(x):
"""
Test function
Args:
x: The input
"""
return x
with self.assertRaises(ValueError):
get_json_schema(fn)
def test_return_value(self):
def fn(x: int) -> int:
"""
Test function
Args:
x: The input
"""
return x
schema = get_json_schema(fn)
expected_schema = {
"name": "fn",
"description": "Test function",
"parameters": {
"type": "object",
"properties": {"x": {"type": "integer", "description": "The input"}},
"required": ["x"],
},
"return": {"type": "integer"},
}
self.assertEqual(schema["function"], expected_schema)
def test_return_value_docstring(self):
def fn(x: int) -> int:
"""
Test function
Args:
x: The input
Returns:
The output
"""
return x
schema = get_json_schema(fn)
expected_schema = {
"name": "fn",
"description": "Test function",
"parameters": {
"type": "object",
"properties": {"x": {"type": "integer", "description": "The input"}},
"required": ["x"],
},
"return": {"type": "integer", "description": "The output"},
}
self.assertEqual(schema["function"], expected_schema)
def test_tuple(self):
def fn(x: Tuple[int, str]):
"""
Test function
Args:
x: The input
Returns:
The output
"""
return x
schema = get_json_schema(fn)
expected_schema = {
"name": "fn",
"description": "Test function",
"parameters": {
"type": "object",
"properties": {
"x": {
"type": "array",
"prefixItems": [{"type": "integer"}, {"type": "string"}],
"description": "The input",
}
},
"required": ["x"],
},
}
self.assertEqual(schema["function"], expected_schema)
def test_single_element_tuple_fails(self):
def fn(x: Tuple[int]):
"""
Test function
Args:
x: The input
Returns:
The output
"""
return x
# Single-element tuples should just be the type itself, or List[type] for variable-length inputs
with self.assertRaises(ValueError):
get_json_schema(fn)
def test_ellipsis_type_fails(self):
def fn(x: Tuple[int, ...]):
"""
Test function
Args:
x: The input
Returns:
The output
"""
return x
# Variable length inputs should be specified with List[type], not Tuple[type, ...]
with self.assertRaises(ValueError):
get_json_schema(fn)
def test_enum_extraction(self):
def fn(temperature_format: str):
"""
Test function
Args:
temperature_format: The temperature format to use (Choices: ["celsius", "fahrenheit"])
Returns:
The temperature
"""
return -40.0
# Let's see if that gets correctly parsed as an enum
schema = get_json_schema(fn)
expected_schema = {
"name": "fn",
"description": "Test function",
"parameters": {
"type": "object",
"properties": {
"temperature_format": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
"description": "The temperature format to use",
}
},
"required": ["temperature_format"],
},
}
self.assertEqual(schema["function"], expected_schema)
def test_multiline_docstring_with_types(self):
def fn(x: int, y: int):
"""
Test function
Args:
x: The first input
y: The second input. This is a longer description
that spans multiple lines with indentation and stuff.
Returns:
God knows what
"""
pass
schema = get_json_schema(fn)
expected_schema = {
"name": "fn",
"description": "Test function",
"parameters": {
"type": "object",
"properties": {
"x": {"type": "integer", "description": "The first input"},
"y": {
"type": "integer",
"description": "The second input. This is a longer description that spans multiple lines with indentation and stuff.",
},
},
"required": ["x", "y"],
},
}
self.assertEqual(schema["function"], expected_schema)
def test_everything_all_at_once(self):
def fn(
x: str, y: Optional[List[Union[str, int]]], z: Tuple[Union[str, int], str] = (42, "hello")
) -> Tuple[int, str]:
"""
Test function with multiple args, and docstring args that we have to strip out.
Args:
x: The first input. It's got a big multiline
description and also contains
(choices: ["a", "b", "c"])
y: The second input. It's a big list with a single-line description.
z: The third input. It's some kind of tuple with a default arg.
Returns:
The output. The return description is also a big multiline
description that spans multiple lines.
"""
pass
schema = get_json_schema(fn)
expected_schema = {
"name": "fn",
"description": "Test function with multiple args, and docstring args that we have to strip out.",
"parameters": {
"type": "object",
"properties": {
"x": {
"type": "string",
"enum": ["a", "b", "c"],
"description": "The first input. It's got a big multiline description and also contains",
},
"y": {
"type": "array",
"items": {"type": ["string", "integer"]},
"nullable": True,
"description": "The second input. It's a big list with a single-line description.",
},
"z": {
"type": "array",
"prefixItems": [{"type": ["string", "integer"]}, {"type": "string"}],
"description": "The third input. It's some kind of tuple with a default arg.",
},
},
"required": ["x", "y"],
},
"return": {
"type": "array",
"prefixItems": [{"type": "integer"}, {"type": "string"}],
"description": "The output. The return description is also a big multiline\n description that spans multiple lines.",
},
}
self.assertEqual(schema["function"], expected_schema)