diff --git a/src/transformers/utils/chat_template_utils.py b/src/transformers/utils/chat_template_utils.py index dc85175588..cee8a0bcdf 100644 --- a/src/transformers/utils/chat_template_utils.py +++ b/src/transformers/utils/chat_template_utils.py @@ -1,6 +1,6 @@ import inspect import re -from typing import Any, Union, get_origin, get_type_hints, get_args +from typing import Any, Union, get_args, get_origin, get_type_hints BASIC_TYPES = (int, float, str, bool, Any) @@ -45,31 +45,35 @@ def _convert_type_hints_to_json_schema(func): continue properties[param_name] = _parse_type_hint(param_type) - schema = {"type": "object", "properties": properties} if required: schema["required"] = required return schema - +# TODO: Return types!! How are those even handled? Does it even matter? I should check what the different APIs do for this +# and also add tests def _parse_type_hint(hint): if (origin := get_origin(hint)) is not None: if origin is Union: # If it's a union of basic types, we can express that as a simple list in the schema if all(t in BASIC_TYPES for t in get_args(hint)): - return_dict = {"type": [_get_json_schema_type(t)["type"] for t in get_args(hint) if t != type(None)]} + return_dict = { + "type": [_get_json_schema_type(t)["type"] for t in get_args(hint) if t not in (type(None), ...)] + } if len(return_dict["type"]) == 1: return_dict["type"] = return_dict["type"][0] else: # A union of more complex types requires us to recurse into each subtype - return_dict = {"anyOf": [_parse_type_hint(t) for t in get_args(hint) if t != type(None)],} + return_dict = { + "anyOf": [_parse_type_hint(t) for t in get_args(hint) if t not in (type(None), ...)], + } if len(return_dict["anyOf"]) == 1: return_dict = return_dict["anyOf"][0] if type(None) in get_args(hint): return_dict["nullable"] = True return return_dict - elif origin is list or origin is tuple: + elif origin is list: if not get_args(hint): return {"type": "array"} if all(t in BASIC_TYPES for t in get_args(hint)): @@ -79,13 +83,21 @@ def _parse_type_hint(hint): items["type"] = items["type"][0] else: # And a list of more complex types requires us to recurse into each subtype again - items = {"anyOf": [_parse_type_hint(t) for t in get_args(hint) if t != type(None)]} + items = {"anyOf": [_parse_type_hint(t) for t in get_args(hint) if t not in (type(None), ...)]} if len(items["anyOf"]) == 1: items = items["anyOf"][0] return_dict = {"type": "array", "items": items} if type(None) in get_args(hint): return_dict["nullable"] = True return return_dict + elif origin is tuple: + raise ValueError( + "This helper does not parse Tuple types, as they are usually used to indicate that " + "each position is associated with a specific type, and this requires JSON schemas " + "that are not supported by most templates. We recommend " + "either using List or List[Union] instead for arguments where this is appropriate, or " + "splitting arguments with Tuple types into multiple arguments that take single inputs." + ) elif origin is dict: # The JSON equivalent to a dict is 'object', which mandates that all keys are strings # However, we can specify the type of the dict values with "additionalProperties" diff --git a/tests/utils/test_chat_template_utils.py b/tests/utils/test_chat_template_utils.py new file mode 100644 index 0000000000..972ed779ba --- /dev/null +++ b/tests/utils/test_chat_template_utils.py @@ -0,0 +1,128 @@ +import unittest +from typing import List, Optional, Union + +from transformers.utils import get_json_schema + + +class JsonSchemaGeneratorTest(unittest.TestCase): + def test_simple_function(self): + def fn(x: int): + """ + Test function + + :param x: The input + """ + return x + + schema = get_json_schema(fn) + expected_schema = {'name': 'fn', 'description': 'Test function', 'parameters': {'type': 'object', 'properties': {'x': {'type': 'integer', 'description': 'The input'}}, 'required': ['x']}} + self.assertEqual(schema, expected_schema) + + def test_union(self): + def fn(x: Union[int, float]): + """ + Test function + + :param x: The input + """ + return x + + schema = get_json_schema(fn) + expected_schema = {'name': 'fn', 'description': 'Test function', 'parameters': {'type': 'object', 'properties': {'x': {'type': ['integer', 'number'], 'description': 'The input'}}, 'required': ['x']}} + self.assertEqual(schema, expected_schema) + + def test_optional(self): + def fn(x: Optional[int]): + """ + Test function + + :param x: The input + """ + return x + + schema = get_json_schema(fn) + expected_schema = {'name': 'fn', 'description': 'Test function', 'parameters': {'type': 'object', 'properties': {'x': {'type': 'integer', 'description': 'The input', "nullable": True}}, 'required': ['x']}} + self.assertEqual(schema, expected_schema) + + def test_default_arg(self): + def fn(x: int = 42): + """ + Test function + + :param x: The input + """ + return x + + schema = get_json_schema(fn) + expected_schema = {'name': 'fn', 'description': 'Test function', 'parameters': {'type': 'object', 'properties': {'x': {'type': 'integer', 'description': 'The input'}}}} + self.assertEqual(schema, expected_schema) + + def test_nested_list(self): + def fn(x: List[List[Union[int, str]]]): + """ + Test function + + :param x: The input + """ + return x + + schema = get_json_schema(fn) + expected_schema = {'name': 'fn', 'description': 'Test function', 'parameters': {'type': 'object', 'properties': {'x': {'type': 'array', 'items': {'type': 'array', 'items': {'type': ['integer', 'string']}}, 'description': 'The input'}}, 'required': ['x']}} + self.assertEqual(schema, expected_schema) + + def test_multiple_arguments(self): + def fn(x: int, y: str): + """ + Test function + + :param x: The input + :param y: Also the input + """ + return x + + schema = get_json_schema(fn) + expected_schema = {'name': 'fn', 'description': 'Test function', 'parameters': {'type': 'object', 'properties': {'x': {'type': 'integer', 'description': 'The input'}, 'y': {'type': 'string', 'description': 'Also the input'}}, 'required': ['x', 'y']}} + self.assertEqual(schema, expected_schema) + + def test_multiple_complex_arguments(self): + def fn(x: List[Union[int, float]], y: Optional[Union[int, str]] = None): + """ + Test function + + :param x: The input + :param y: Also the input + """ + return x + + schema = get_json_schema(fn) + expected_schema = {'name': 'fn', 'description': 'Test function', 'parameters': {'type': 'object', 'properties': {'x': {'type': 'array', 'items': {'type': ['integer', 'number']}, 'description': 'The input'}, 'y': {'anyOf': [{'type': 'integer'}, {'type': 'string'}], 'nullable': True, 'description': 'Also the input'}}, 'required': ['x']}} + self.assertEqual(schema, expected_schema) + + def test_missing_docstring(self): + def fn(x: int): + return x + + with self.assertRaises(ValueError): + get_json_schema(fn) + + def test_missing_param_docstring(self): + def fn(x: int): + """ + Test function + """ + return x + + with self.assertRaises(ValueError): + get_json_schema(fn) + + def test_missing_type_hint(self): + def fn(x): + """ + Test function + + :param x: The input + """ + return x + + with self.assertRaises(ValueError): + get_json_schema(fn)