First draft of the automatic schema generator
This commit is contained in:
parent
072bd22e1c
commit
0340cfa8fd
|
@ -287,14 +287,12 @@ JSON schemas permit highly detailed parameter specifications, so you can pass in
|
|||
arguments. Be careful, however - we find that in practice this can degrade performance, even for state-of-the-art
|
||||
models. We recommend trying to keep your tool schemas simple and flat where possible.
|
||||
|
||||
### Automated function conversion
|
||||
### Automated function conversion for tool use
|
||||
|
||||
Although JSON schemas are precise, widely-supported and language-agnostic, they can be a bit verbose, which means
|
||||
that writing them can be annoying. Don't panic, though, we have a solution!
|
||||
|
||||
TODO Should descriptions come from the docstrings or the type hints?
|
||||
|
||||
TODO Do we need to define a special format for args in the docstrings?
|
||||
TODO Explain function conversion with examples
|
||||
|
||||
### Arguments for retrieval-augmented generation (RAG)
|
||||
|
||||
|
|
|
@ -21,6 +21,7 @@ from packaging import version
|
|||
|
||||
from .. import __version__
|
||||
from .backbone_utils import BackboneConfigMixin, BackboneMixin
|
||||
from .chat_template_utils import get_json_schema
|
||||
from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD
|
||||
from .doc import (
|
||||
add_code_sample_docstrings,
|
||||
|
|
|
@ -0,0 +1,84 @@
|
|||
import inspect
|
||||
import re
|
||||
from typing import Any, Union, get_origin, get_type_hints
|
||||
|
||||
|
||||
BASIC_TYPES = (int, float, str, bool, Any)
|
||||
|
||||
|
||||
def get_json_schema(func):
|
||||
doc = inspect.getdoc(func).strip()
|
||||
if not doc:
|
||||
raise ValueError(f"Cannot generate JSON schema for {func.__name__} because it has no docstring!")
|
||||
param_descriptions = _get_argument_descriptions_from_docstring(doc)
|
||||
|
||||
json_schema = _convert_type_hints_to_json_schema(func)
|
||||
for arg in json_schema["properties"]:
|
||||
if arg not in param_descriptions:
|
||||
raise ValueError(
|
||||
f"Cannot generate JSON schema for {func.__name__} because the docstring has no description for the argument '{arg}'"
|
||||
)
|
||||
json_schema["properties"][arg]["description"] = param_descriptions[arg]
|
||||
|
||||
return json_schema
|
||||
|
||||
|
||||
def _get_argument_descriptions_from_docstring(doc):
|
||||
param_pattern = r":param (\w+): (.+)"
|
||||
params = re.findall(param_pattern, doc)
|
||||
return dict(params)
|
||||
|
||||
|
||||
def _convert_type_hints_to_json_schema(func):
|
||||
type_hints = get_type_hints(func)
|
||||
properties = {}
|
||||
|
||||
signature = inspect.signature(func)
|
||||
required = [
|
||||
param_name for param_name, param in signature.parameters.items() if param.default == inspect.Parameter.empty
|
||||
]
|
||||
|
||||
for param_name, param_type in type_hints.items():
|
||||
if param_name == "return":
|
||||
continue
|
||||
|
||||
if origin := get_origin(param_type) is not None:
|
||||
if origin is Union:
|
||||
if all(t in BASIC_TYPES for t in param_type.__args__):
|
||||
properties[param_name] = {
|
||||
"type": [_get_json_schema_type(t)["type"] for t in param_type.__args__ if t != type(None)],
|
||||
"nullable": type(None) in param_type.__args__,
|
||||
}
|
||||
else:
|
||||
properties[param_name] = {
|
||||
"anyOf": [_get_json_schema_type(t) for t in param_type.__args__ if t != type(None)],
|
||||
"nullable": type(None) in param_type.__args__,
|
||||
}
|
||||
elif origin is list:
|
||||
properties[param_name] = {"type": "array", "items": _get_json_schema_type(param_type.__args__[0])}
|
||||
elif origin is dict:
|
||||
properties[param_name] = {
|
||||
"type": "object",
|
||||
"additionalProperties": _get_json_schema_type(param_type.__args__[1]),
|
||||
}
|
||||
else:
|
||||
properties[param_name] = _get_json_schema_type(param_type)
|
||||
|
||||
schema = {"type": "object", "properties": properties, "required": required}
|
||||
|
||||
return schema
|
||||
|
||||
|
||||
def _get_json_schema_type(param_type):
|
||||
if param_type == int:
|
||||
return {"type": "integer"}
|
||||
elif param_type == float:
|
||||
return {"type": "number"}
|
||||
elif param_type == str:
|
||||
return {"type": "string"}
|
||||
elif param_type == bool:
|
||||
return {"type": "boolean"}
|
||||
elif param_type == Any:
|
||||
return {}
|
||||
else:
|
||||
return {"type": "object"}
|
Loading…
Reference in New Issue