Add tests, update schema generator

This commit is contained in:
Matt 2024-05-07 18:35:28 +01:00
parent 5019aa224b
commit 51359bf11e
2 changed files with 147 additions and 7 deletions

View File

@ -1,6 +1,6 @@
import inspect
import re
from typing import Any, Union, get_origin, get_type_hints, get_args
from typing import Any, Union, get_args, get_origin, get_type_hints
BASIC_TYPES = (int, float, str, bool, Any)
@ -45,31 +45,35 @@ def _convert_type_hints_to_json_schema(func):
continue
properties[param_name] = _parse_type_hint(param_type)
schema = {"type": "object", "properties": properties}
if required:
schema["required"] = required
return schema
# TODO: Return types!! How are those even handled? Does it even matter? I should check what the different APIs do for this
# and also add tests
def _parse_type_hint(hint):
if (origin := get_origin(hint)) is not None:
if origin is Union:
# If it's a union of basic types, we can express that as a simple list in the schema
if all(t in BASIC_TYPES for t in get_args(hint)):
return_dict = {"type": [_get_json_schema_type(t)["type"] for t in get_args(hint) if t != type(None)]}
return_dict = {
"type": [_get_json_schema_type(t)["type"] for t in get_args(hint) if t not in (type(None), ...)]
}
if len(return_dict["type"]) == 1:
return_dict["type"] = return_dict["type"][0]
else:
# A union of more complex types requires us to recurse into each subtype
return_dict = {"anyOf": [_parse_type_hint(t) for t in get_args(hint) if t != type(None)],}
return_dict = {
"anyOf": [_parse_type_hint(t) for t in get_args(hint) if t not in (type(None), ...)],
}
if len(return_dict["anyOf"]) == 1:
return_dict = return_dict["anyOf"][0]
if type(None) in get_args(hint):
return_dict["nullable"] = True
return return_dict
elif origin is list or origin is tuple:
elif origin is list:
if not get_args(hint):
return {"type": "array"}
if all(t in BASIC_TYPES for t in get_args(hint)):
@ -79,13 +83,21 @@ def _parse_type_hint(hint):
items["type"] = items["type"][0]
else:
# And a list of more complex types requires us to recurse into each subtype again
items = {"anyOf": [_parse_type_hint(t) for t in get_args(hint) if t != type(None)]}
items = {"anyOf": [_parse_type_hint(t) for t in get_args(hint) if t not in (type(None), ...)]}
if len(items["anyOf"]) == 1:
items = items["anyOf"][0]
return_dict = {"type": "array", "items": items}
if type(None) in get_args(hint):
return_dict["nullable"] = True
return return_dict
elif origin is tuple:
raise ValueError(
"This helper does not parse Tuple types, as they are usually used to indicate that "
"each position is associated with a specific type, and this requires JSON schemas "
"that are not supported by most templates. We recommend "
"either using List or List[Union] instead for arguments where this is appropriate, or "
"splitting arguments with Tuple types into multiple arguments that take single inputs."
)
elif origin is dict:
# The JSON equivalent to a dict is 'object', which mandates that all keys are strings
# However, we can specify the type of the dict values with "additionalProperties"

View File

@ -0,0 +1,128 @@
import unittest
from typing import List, Optional, Union
from transformers.utils import get_json_schema
class JsonSchemaGeneratorTest(unittest.TestCase):
def test_simple_function(self):
def fn(x: int):
"""
Test function
:param x: The input
"""
return x
schema = get_json_schema(fn)
expected_schema = {'name': 'fn', 'description': 'Test function', 'parameters': {'type': 'object', 'properties': {'x': {'type': 'integer', 'description': 'The input'}}, 'required': ['x']}}
self.assertEqual(schema, expected_schema)
def test_union(self):
def fn(x: Union[int, float]):
"""
Test function
:param x: The input
"""
return x
schema = get_json_schema(fn)
expected_schema = {'name': 'fn', 'description': 'Test function', 'parameters': {'type': 'object', 'properties': {'x': {'type': ['integer', 'number'], 'description': 'The input'}}, 'required': ['x']}}
self.assertEqual(schema, expected_schema)
def test_optional(self):
def fn(x: Optional[int]):
"""
Test function
:param x: The input
"""
return x
schema = get_json_schema(fn)
expected_schema = {'name': 'fn', 'description': 'Test function', 'parameters': {'type': 'object', 'properties': {'x': {'type': 'integer', 'description': 'The input', "nullable": True}}, 'required': ['x']}}
self.assertEqual(schema, expected_schema)
def test_default_arg(self):
def fn(x: int = 42):
"""
Test function
:param x: The input
"""
return x
schema = get_json_schema(fn)
expected_schema = {'name': 'fn', 'description': 'Test function', 'parameters': {'type': 'object', 'properties': {'x': {'type': 'integer', 'description': 'The input'}}}}
self.assertEqual(schema, expected_schema)
def test_nested_list(self):
def fn(x: List[List[Union[int, str]]]):
"""
Test function
:param x: The input
"""
return x
schema = get_json_schema(fn)
expected_schema = {'name': 'fn', 'description': 'Test function', 'parameters': {'type': 'object', 'properties': {'x': {'type': 'array', 'items': {'type': 'array', 'items': {'type': ['integer', 'string']}}, 'description': 'The input'}}, 'required': ['x']}}
self.assertEqual(schema, expected_schema)
def test_multiple_arguments(self):
def fn(x: int, y: str):
"""
Test function
:param x: The input
:param y: Also the input
"""
return x
schema = get_json_schema(fn)
expected_schema = {'name': 'fn', 'description': 'Test function', 'parameters': {'type': 'object', 'properties': {'x': {'type': 'integer', 'description': 'The input'}, 'y': {'type': 'string', 'description': 'Also the input'}}, 'required': ['x', 'y']}}
self.assertEqual(schema, expected_schema)
def test_multiple_complex_arguments(self):
def fn(x: List[Union[int, float]], y: Optional[Union[int, str]] = None):
"""
Test function
:param x: The input
:param y: Also the input
"""
return x
schema = get_json_schema(fn)
expected_schema = {'name': 'fn', 'description': 'Test function', 'parameters': {'type': 'object', 'properties': {'x': {'type': 'array', 'items': {'type': ['integer', 'number']}, 'description': 'The input'}, 'y': {'anyOf': [{'type': 'integer'}, {'type': 'string'}], 'nullable': True, 'description': 'Also the input'}}, 'required': ['x']}}
self.assertEqual(schema, expected_schema)
def test_missing_docstring(self):
def fn(x: int):
return x
with self.assertRaises(ValueError):
get_json_schema(fn)
def test_missing_param_docstring(self):
def fn(x: int):
"""
Test function
"""
return x
with self.assertRaises(ValueError):
get_json_schema(fn)
def test_missing_type_hint(self):
def fn(x):
"""
Test function
:param x: The input
"""
return x
with self.assertRaises(ValueError):
get_json_schema(fn)