First draft of the automatic schema generator

This commit is contained in:
Matt 2024-05-03 15:16:56 +01:00
parent 072bd22e1c
commit 0340cfa8fd
3 changed files with 87 additions and 4 deletions

View File

@ -287,14 +287,12 @@ JSON schemas permit highly detailed parameter specifications, so you can pass in
arguments. Be careful, however - we find that in practice this can degrade performance, even for state-of-the-art
models. We recommend trying to keep your tool schemas simple and flat where possible.
### Automated function conversion
### Automated function conversion for tool use
Although JSON schemas are precise, widely-supported and language-agnostic, they can be a bit verbose, which means
that writing them can be annoying. Don't panic, though, we have a solution!
TODO Should descriptions come from the docstrings or the type hints?
TODO Do we need to define a special format for args in the docstrings?
TODO Explain function conversion with examples
### Arguments for retrieval-augmented generation (RAG)

View File

@ -21,6 +21,7 @@ from packaging import version
from .. import __version__
from .backbone_utils import BackboneConfigMixin, BackboneMixin
from .chat_template_utils import get_json_schema
from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD
from .doc import (
add_code_sample_docstrings,

View File

@ -0,0 +1,84 @@
import inspect
import re
from typing import Any, Union, get_origin, get_type_hints
BASIC_TYPES = (int, float, str, bool, Any)
def get_json_schema(func):
doc = inspect.getdoc(func).strip()
if not doc:
raise ValueError(f"Cannot generate JSON schema for {func.__name__} because it has no docstring!")
param_descriptions = _get_argument_descriptions_from_docstring(doc)
json_schema = _convert_type_hints_to_json_schema(func)
for arg in json_schema["properties"]:
if arg not in param_descriptions:
raise ValueError(
f"Cannot generate JSON schema for {func.__name__} because the docstring has no description for the argument '{arg}'"
)
json_schema["properties"][arg]["description"] = param_descriptions[arg]
return json_schema
def _get_argument_descriptions_from_docstring(doc):
param_pattern = r":param (\w+): (.+)"
params = re.findall(param_pattern, doc)
return dict(params)
def _convert_type_hints_to_json_schema(func):
type_hints = get_type_hints(func)
properties = {}
signature = inspect.signature(func)
required = [
param_name for param_name, param in signature.parameters.items() if param.default == inspect.Parameter.empty
]
for param_name, param_type in type_hints.items():
if param_name == "return":
continue
if origin := get_origin(param_type) is not None:
if origin is Union:
if all(t in BASIC_TYPES for t in param_type.__args__):
properties[param_name] = {
"type": [_get_json_schema_type(t)["type"] for t in param_type.__args__ if t != type(None)],
"nullable": type(None) in param_type.__args__,
}
else:
properties[param_name] = {
"anyOf": [_get_json_schema_type(t) for t in param_type.__args__ if t != type(None)],
"nullable": type(None) in param_type.__args__,
}
elif origin is list:
properties[param_name] = {"type": "array", "items": _get_json_schema_type(param_type.__args__[0])}
elif origin is dict:
properties[param_name] = {
"type": "object",
"additionalProperties": _get_json_schema_type(param_type.__args__[1]),
}
else:
properties[param_name] = _get_json_schema_type(param_type)
schema = {"type": "object", "properties": properties, "required": required}
return schema
def _get_json_schema_type(param_type):
if param_type == int:
return {"type": "integer"}
elif param_type == float:
return {"type": "number"}
elif param_type == str:
return {"type": "string"}
elif param_type == bool:
return {"type": "boolean"}
elif param_type == Any:
return {}
else:
return {"type": "object"}