From 43a481affaa733646e47187a2bcf64f5876bf9ee Mon Sep 17 00:00:00 2001 From: Matt Date: Fri, 3 May 2024 18:00:47 +0100 Subject: [PATCH] Lots of cleanup and edge cases, looking better now --- src/transformers/utils/chat_template_utils.py | 61 ++++++++++++------- 1 file changed, 39 insertions(+), 22 deletions(-) diff --git a/src/transformers/utils/chat_template_utils.py b/src/transformers/utils/chat_template_utils.py index 5b418bcf37..6e06e10da3 100644 --- a/src/transformers/utils/chat_template_utils.py +++ b/src/transformers/utils/chat_template_utils.py @@ -1,7 +1,6 @@ import inspect import re from typing import Any, Union, get_origin, get_type_hints -import pdb BASIC_TYPES = (int, float, str, bool, Any) @@ -44,27 +43,8 @@ def _convert_type_hints_to_json_schema(func): for param_name, param_type in type_hints.items(): if param_name == "return": continue - if (origin := get_origin(param_type)) is not None: - if origin is Union: - if all(t in BASIC_TYPES for t in param_type.__args__): - properties[param_name] = { - "type": [_get_json_schema_type(t)["type"] for t in param_type.__args__ if t != type(None)], - "nullable": type(None) in param_type.__args__, - } - else: - properties[param_name] = { - "anyOf": [_get_json_schema_type(t) for t in param_type.__args__ if t != type(None)], - "nullable": type(None) in param_type.__args__, - } - elif origin is list: - properties[param_name] = {"type": "array", "items": _get_json_schema_type(param_type.__args__[0])} - elif origin is dict: - properties[param_name] = { - "type": "object", - "additionalProperties": _get_json_schema_type(param_type.__args__[1]), - } - else: - properties[param_name] = _get_json_schema_type(param_type) + properties[param_name] = _parse_type_hint(param_type) + schema = {"type": "object", "properties": properties} if required: @@ -72,6 +52,43 @@ def _convert_type_hints_to_json_schema(func): return schema +def _parse_type_hint(hint): + if (origin := get_origin(hint)) is not None: + if origin is Union: + if all(t in BASIC_TYPES for t in hint.__args__): + return_dict = {"type": [_get_json_schema_type(t)["type"] for t in hint.__args__ if t != type(None)]} + if len(return_dict["type"]) == 1: + return_dict["type"] = return_dict["type"][0] + else: + return_dict = {"anyOf": [_parse_type_hint(t) for t in hint.__args__ if t != type(None)],} + if len(return_dict["anyOf"]) == 1: + return_dict = return_dict["anyOf"][0] + if type(None) in hint.__args__: + return_dict["nullable"] = True + return return_dict + elif origin is list or origin is tuple: + if not hasattr(hint, "__args__"): + return {"type": "array"} + if all(t in BASIC_TYPES for t in hint.__args__): + items = {"type": [_get_json_schema_type(t)["type"] for t in hint.__args__ if t != type(None)]} + if len(items["type"]) == 1: + items["type"] = items["type"][0] + else: + items = {"anyOf": [_parse_type_hint(t) for t in hint.__args__ if t != type(None)]} + if len(items["anyOf"]) == 1: + items = items["anyOf"][0] + return_dict = {"type": "array", "items": items} + if "nullable" in hint.__args__: + return_dict["nullable"] = True + return return_dict + elif origin is dict: + return { + "type": "object", + "additionalProperties": _parse_type_hint(hint.__args__[1]), + } + else: + return _get_json_schema_type(hint) + def _get_json_schema_type(param_type): if param_type == int: