Add tests, update schema generator
This commit is contained in:
parent
5019aa224b
commit
51359bf11e
|
@ -1,6 +1,6 @@
|
|||
import inspect
|
||||
import re
|
||||
from typing import Any, Union, get_origin, get_type_hints, get_args
|
||||
from typing import Any, Union, get_args, get_origin, get_type_hints
|
||||
|
||||
|
||||
BASIC_TYPES = (int, float, str, bool, Any)
|
||||
|
@ -45,31 +45,35 @@ def _convert_type_hints_to_json_schema(func):
|
|||
continue
|
||||
properties[param_name] = _parse_type_hint(param_type)
|
||||
|
||||
|
||||
schema = {"type": "object", "properties": properties}
|
||||
if required:
|
||||
schema["required"] = required
|
||||
|
||||
return schema
|
||||
|
||||
|
||||
# TODO: Return types!! How are those even handled? Does it even matter? I should check what the different APIs do for this
|
||||
# and also add tests
|
||||
def _parse_type_hint(hint):
|
||||
if (origin := get_origin(hint)) is not None:
|
||||
if origin is Union:
|
||||
# If it's a union of basic types, we can express that as a simple list in the schema
|
||||
if all(t in BASIC_TYPES for t in get_args(hint)):
|
||||
return_dict = {"type": [_get_json_schema_type(t)["type"] for t in get_args(hint) if t != type(None)]}
|
||||
return_dict = {
|
||||
"type": [_get_json_schema_type(t)["type"] for t in get_args(hint) if t not in (type(None), ...)]
|
||||
}
|
||||
if len(return_dict["type"]) == 1:
|
||||
return_dict["type"] = return_dict["type"][0]
|
||||
else:
|
||||
# A union of more complex types requires us to recurse into each subtype
|
||||
return_dict = {"anyOf": [_parse_type_hint(t) for t in get_args(hint) if t != type(None)],}
|
||||
return_dict = {
|
||||
"anyOf": [_parse_type_hint(t) for t in get_args(hint) if t not in (type(None), ...)],
|
||||
}
|
||||
if len(return_dict["anyOf"]) == 1:
|
||||
return_dict = return_dict["anyOf"][0]
|
||||
if type(None) in get_args(hint):
|
||||
return_dict["nullable"] = True
|
||||
return return_dict
|
||||
elif origin is list or origin is tuple:
|
||||
elif origin is list:
|
||||
if not get_args(hint):
|
||||
return {"type": "array"}
|
||||
if all(t in BASIC_TYPES for t in get_args(hint)):
|
||||
|
@ -79,13 +83,21 @@ def _parse_type_hint(hint):
|
|||
items["type"] = items["type"][0]
|
||||
else:
|
||||
# And a list of more complex types requires us to recurse into each subtype again
|
||||
items = {"anyOf": [_parse_type_hint(t) for t in get_args(hint) if t != type(None)]}
|
||||
items = {"anyOf": [_parse_type_hint(t) for t in get_args(hint) if t not in (type(None), ...)]}
|
||||
if len(items["anyOf"]) == 1:
|
||||
items = items["anyOf"][0]
|
||||
return_dict = {"type": "array", "items": items}
|
||||
if type(None) in get_args(hint):
|
||||
return_dict["nullable"] = True
|
||||
return return_dict
|
||||
elif origin is tuple:
|
||||
raise ValueError(
|
||||
"This helper does not parse Tuple types, as they are usually used to indicate that "
|
||||
"each position is associated with a specific type, and this requires JSON schemas "
|
||||
"that are not supported by most templates. We recommend "
|
||||
"either using List or List[Union] instead for arguments where this is appropriate, or "
|
||||
"splitting arguments with Tuple types into multiple arguments that take single inputs."
|
||||
)
|
||||
elif origin is dict:
|
||||
# The JSON equivalent to a dict is 'object', which mandates that all keys are strings
|
||||
# However, we can specify the type of the dict values with "additionalProperties"
|
||||
|
|
|
@ -0,0 +1,128 @@
|
|||
import unittest
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from transformers.utils import get_json_schema
|
||||
|
||||
|
||||
class JsonSchemaGeneratorTest(unittest.TestCase):
|
||||
def test_simple_function(self):
|
||||
def fn(x: int):
|
||||
"""
|
||||
Test function
|
||||
|
||||
:param x: The input
|
||||
"""
|
||||
return x
|
||||
|
||||
schema = get_json_schema(fn)
|
||||
expected_schema = {'name': 'fn', 'description': 'Test function', 'parameters': {'type': 'object', 'properties': {'x': {'type': 'integer', 'description': 'The input'}}, 'required': ['x']}}
|
||||
self.assertEqual(schema, expected_schema)
|
||||
|
||||
def test_union(self):
|
||||
def fn(x: Union[int, float]):
|
||||
"""
|
||||
Test function
|
||||
|
||||
:param x: The input
|
||||
"""
|
||||
return x
|
||||
|
||||
schema = get_json_schema(fn)
|
||||
expected_schema = {'name': 'fn', 'description': 'Test function', 'parameters': {'type': 'object', 'properties': {'x': {'type': ['integer', 'number'], 'description': 'The input'}}, 'required': ['x']}}
|
||||
self.assertEqual(schema, expected_schema)
|
||||
|
||||
def test_optional(self):
|
||||
def fn(x: Optional[int]):
|
||||
"""
|
||||
Test function
|
||||
|
||||
:param x: The input
|
||||
"""
|
||||
return x
|
||||
|
||||
schema = get_json_schema(fn)
|
||||
expected_schema = {'name': 'fn', 'description': 'Test function', 'parameters': {'type': 'object', 'properties': {'x': {'type': 'integer', 'description': 'The input', "nullable": True}}, 'required': ['x']}}
|
||||
self.assertEqual(schema, expected_schema)
|
||||
|
||||
def test_default_arg(self):
|
||||
def fn(x: int = 42):
|
||||
"""
|
||||
Test function
|
||||
|
||||
:param x: The input
|
||||
"""
|
||||
return x
|
||||
|
||||
schema = get_json_schema(fn)
|
||||
expected_schema = {'name': 'fn', 'description': 'Test function', 'parameters': {'type': 'object', 'properties': {'x': {'type': 'integer', 'description': 'The input'}}}}
|
||||
self.assertEqual(schema, expected_schema)
|
||||
|
||||
def test_nested_list(self):
|
||||
def fn(x: List[List[Union[int, str]]]):
|
||||
"""
|
||||
Test function
|
||||
|
||||
:param x: The input
|
||||
"""
|
||||
return x
|
||||
|
||||
schema = get_json_schema(fn)
|
||||
expected_schema = {'name': 'fn', 'description': 'Test function', 'parameters': {'type': 'object', 'properties': {'x': {'type': 'array', 'items': {'type': 'array', 'items': {'type': ['integer', 'string']}}, 'description': 'The input'}}, 'required': ['x']}}
|
||||
self.assertEqual(schema, expected_schema)
|
||||
|
||||
def test_multiple_arguments(self):
|
||||
def fn(x: int, y: str):
|
||||
"""
|
||||
Test function
|
||||
|
||||
:param x: The input
|
||||
:param y: Also the input
|
||||
"""
|
||||
return x
|
||||
|
||||
schema = get_json_schema(fn)
|
||||
expected_schema = {'name': 'fn', 'description': 'Test function', 'parameters': {'type': 'object', 'properties': {'x': {'type': 'integer', 'description': 'The input'}, 'y': {'type': 'string', 'description': 'Also the input'}}, 'required': ['x', 'y']}}
|
||||
self.assertEqual(schema, expected_schema)
|
||||
|
||||
def test_multiple_complex_arguments(self):
|
||||
def fn(x: List[Union[int, float]], y: Optional[Union[int, str]] = None):
|
||||
"""
|
||||
Test function
|
||||
|
||||
:param x: The input
|
||||
:param y: Also the input
|
||||
"""
|
||||
return x
|
||||
|
||||
schema = get_json_schema(fn)
|
||||
expected_schema = {'name': 'fn', 'description': 'Test function', 'parameters': {'type': 'object', 'properties': {'x': {'type': 'array', 'items': {'type': ['integer', 'number']}, 'description': 'The input'}, 'y': {'anyOf': [{'type': 'integer'}, {'type': 'string'}], 'nullable': True, 'description': 'Also the input'}}, 'required': ['x']}}
|
||||
self.assertEqual(schema, expected_schema)
|
||||
|
||||
def test_missing_docstring(self):
|
||||
def fn(x: int):
|
||||
return x
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
get_json_schema(fn)
|
||||
|
||||
def test_missing_param_docstring(self):
|
||||
def fn(x: int):
|
||||
"""
|
||||
Test function
|
||||
"""
|
||||
return x
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
get_json_schema(fn)
|
||||
|
||||
def test_missing_type_hint(self):
|
||||
def fn(x):
|
||||
"""
|
||||
Test function
|
||||
|
||||
:param x: The input
|
||||
"""
|
||||
return x
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
get_json_schema(fn)
|
Loading…
Reference in New Issue