diff --git a/docs/source/en/chat_templating.md b/docs/source/en/chat_templating.md index 430396ceb8..7a2e2507b9 100644 --- a/docs/source/en/chat_templating.md +++ b/docs/source/en/chat_templating.md @@ -287,14 +287,12 @@ JSON schemas permit highly detailed parameter specifications, so you can pass in arguments. Be careful, however - we find that in practice this can degrade performance, even for state-of-the-art models. We recommend trying to keep your tool schemas simple and flat where possible. -### Automated function conversion +### Automated function conversion for tool use Although JSON schemas are precise, widely-supported and language-agnostic, they can be a bit verbose, which means that writing them can be annoying. Don't panic, though, we have a solution! -TODO Should descriptions come from the docstrings or the type hints? - -TODO Do we need to define a special format for args in the docstrings? +TODO Explain function conversion with examples ### Arguments for retrieval-augmented generation (RAG) diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index 3769a0b4c7..9552296137 100755 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -21,6 +21,7 @@ from packaging import version from .. import __version__ from .backbone_utils import BackboneConfigMixin, BackboneMixin +from .chat_template_utils import get_json_schema from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD from .doc import ( add_code_sample_docstrings, diff --git a/src/transformers/utils/chat_template_utils.py b/src/transformers/utils/chat_template_utils.py new file mode 100644 index 0000000000..f34a7ecbc4 --- /dev/null +++ b/src/transformers/utils/chat_template_utils.py @@ -0,0 +1,84 @@ +import inspect +import re +from typing import Any, Union, get_origin, get_type_hints + + +BASIC_TYPES = (int, float, str, bool, Any) + + +def get_json_schema(func): + doc = inspect.getdoc(func).strip() + if not doc: + raise ValueError(f"Cannot generate JSON schema for {func.__name__} because it has no docstring!") + param_descriptions = _get_argument_descriptions_from_docstring(doc) + + json_schema = _convert_type_hints_to_json_schema(func) + for arg in json_schema["properties"]: + if arg not in param_descriptions: + raise ValueError( + f"Cannot generate JSON schema for {func.__name__} because the docstring has no description for the argument '{arg}'" + ) + json_schema["properties"][arg]["description"] = param_descriptions[arg] + + return json_schema + + +def _get_argument_descriptions_from_docstring(doc): + param_pattern = r":param (\w+): (.+)" + params = re.findall(param_pattern, doc) + return dict(params) + + +def _convert_type_hints_to_json_schema(func): + type_hints = get_type_hints(func) + properties = {} + + signature = inspect.signature(func) + required = [ + param_name for param_name, param in signature.parameters.items() if param.default == inspect.Parameter.empty + ] + + for param_name, param_type in type_hints.items(): + if param_name == "return": + continue + + if origin := get_origin(param_type) is not None: + if origin is Union: + if all(t in BASIC_TYPES for t in param_type.__args__): + properties[param_name] = { + "type": [_get_json_schema_type(t)["type"] for t in param_type.__args__ if t != type(None)], + "nullable": type(None) in param_type.__args__, + } + else: + properties[param_name] = { + "anyOf": [_get_json_schema_type(t) for t in param_type.__args__ if t != type(None)], + "nullable": type(None) in param_type.__args__, + } + elif origin is list: + properties[param_name] = {"type": "array", "items": _get_json_schema_type(param_type.__args__[0])} + elif origin is dict: + properties[param_name] = { + "type": "object", + "additionalProperties": _get_json_schema_type(param_type.__args__[1]), + } + else: + properties[param_name] = _get_json_schema_type(param_type) + + schema = {"type": "object", "properties": properties, "required": required} + + return schema + + +def _get_json_schema_type(param_type): + if param_type == int: + return {"type": "integer"} + elif param_type == float: + return {"type": "number"} + elif param_type == str: + return {"type": "string"} + elif param_type == bool: + return {"type": "boolean"} + elif param_type == Any: + return {} + else: + return {"type": "object"}