ob-repository-synchronize/extras/obfastapi/frame.py

1094 lines
43 KiB
Python

from .config import ConfigsUtil
from .log import LoggerFactory
import sys
import json
import base64
import threading
import importlib
import asyncio
import email.message
from uuid import uuid1, UUID
from inspect import isfunction, iscoroutinefunction
from enum import Enum
from typing import (
Any,
Callable,
Coroutine,
Generator,
Dict,
List,
Tuple,
Generic,
TypeVar,
Optional,
Sequence,
Set,
Type,
Union,
cast
)
from pydantic import BaseModel
from pydantic.main import BaseConfig
from pydantic.fields import Undefined
from pydantic.schema import encode_default, ModelField, field_schema, model_process_schema
from pydantic.generics import GenericModel
from pydantic.typing import is_callable_type
from pydantic.utils import lenient_issubclass
from pydantic.error_wrappers import ErrorWrapper
from starlette.requests import Request
from starlette.datastructures import State, FormData
from starlette.middleware import Middleware
from starlette.exceptions import HTTPException
from starlette.concurrency import run_in_threadpool
from starlette.responses import JSONResponse, Response
from starlette.status import *
from starlette.types import ASGIApp
from starlette.routing import request_response
from fastapi import FastAPI, APIRouter, Security
from fastapi.routing import BaseRoute, _prepare_response_content, APIRoute
from fastapi.params import Depends, Form
from fastapi.exceptions import RequestValidationError
from fastapi.datastructures import Default, DefaultPlaceholder
from fastapi.utils import create_response_field
from fastapi.types import DecoratedCallable
from fastapi.security import APIKeyCookie
from starlette.background import BackgroundTasks
from starlette.websockets import WebSocket
from fastapi.security.oauth2 import SecurityScopes
from fastapi.dependencies.utils import (
get_dependant,
is_gen_callable,
is_async_gen_callable,
async_contextmanager_dependencies_error,
solve_generator,
is_coroutine_callable,
request_body_to_args,
request_params_to_args
)
from fastapi.dependencies.models import Dependant, SecurityRequirement
from fastapi.encoders import DictIntStrAny, SetIntStr, jsonable_encoder
from fastapi.openapi.utils import (
Body,
get_flat_models_from_routes,
get_model_name_map,
get_openapi_path
)
from fastapi.openapi.models import OpenAPI
from fastapi.openapi.constants import (
METHODS_WITH_BODY,
REF_PREFIX,
STATUS_CODES_WITH_NO_BODY,
)
from .mysql import OBDataBaseError
Logger = LoggerFactory.create_logger(name = 'fastapi.logger')
LoggerFactory.update_global_config(
level = ConfigsUtil.get_obfastapi_config('log_level'),
path = ConfigsUtil.get_obfastapi_config('log_path'),
interval = ConfigsUtil.get_obfastapi_config('log_interval'),
backup_count = ConfigsUtil.get_obfastapi_config('log_count')
)
__all__ = ("OBAPIRouter", "OBFastAPI", "OBResponse", "Controller", "DataList", "Trace", "Logger")
async def serialize_response(
*,
field: Optional[ModelField] = None,
response_content: Any,
include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
by_alias: bool = True,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
is_coroutine: bool = True,
) -> Any:
if field:
errors = []
if issubclass(type(response_content), BaseModel):
value = response_content
else:
response_content = _prepare_response_content(
response_content,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
)
if is_coroutine:
value, errors_ = field.validate(response_content, {}, loc=("response",))
else:
value, errors_ = await run_in_threadpool(
field.validate, response_content, {}, loc=("response",)
)
if isinstance(errors_, ErrorWrapper):
errors.append(errors_)
elif isinstance(errors_, list):
errors.extend(errors_)
if errors:
raise ValidationError(errors, field.type_)
return jsonable_encoder(
value,
include=include,
exclude=exclude,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
)
else:
return jsonable_encoder(response_content)
async def solve_dependencies(
*,
request: Union[Request, WebSocket],
dependant: Dependant,
body: Optional[Union[Dict[str, Any], FormData]] = None,
background_tasks: Optional[BackgroundTasks] = None,
response: Optional[Response] = None,
dependency_overrides_provider: Optional[Any] = None,
dependency_cache: Optional[Dict[Tuple[Callable[..., Any], Tuple[str]], Any]] = None,
) -> Tuple[
Dict[str, Any],
List[ErrorWrapper],
Optional[BackgroundTasks],
Response,
Dict[Tuple[Callable[..., Any], Tuple[str]], Any],
]:
values: Dict[str, Any] = {}
errors: List[ErrorWrapper] = []
response = response or Response(
content=None,
status_code=200, # type: ignore
headers=None, # type: ignore # in Starlette
media_type=None, # type: ignore # in Starlette
background=None, # type: ignore # in Starlette
)
dependency_cache = dependency_cache or {}
sub_dependant: Dependant
for sub_dependant in dependant.dependencies:
sub_dependant.call = cast(Callable[..., Any], sub_dependant.call)
sub_dependant.cache_key = cast(
Tuple[Callable[..., Any], Tuple[str]], sub_dependant.cache_key
)
call = sub_dependant.call
use_sub_dependant = sub_dependant
if (
dependency_overrides_provider
and dependency_overrides_provider.dependency_overrides
):
original_call = sub_dependant.call
call = getattr(
dependency_overrides_provider, "dependency_overrides", {}
).get(original_call, original_call)
use_path: str = sub_dependant.path # type: ignore
use_sub_dependant = get_dependant(
path=use_path,
call=call,
name=sub_dependant.name,
security_scopes=sub_dependant.security_scopes,
)
use_sub_dependant.security_scopes = sub_dependant.security_scopes
solved_result = await solve_dependencies(
request=request,
dependant=use_sub_dependant,
body=body,
background_tasks=background_tasks,
response=response,
dependency_overrides_provider=dependency_overrides_provider,
dependency_cache=dependency_cache,
)
(
sub_values,
sub_errors,
background_tasks,
_, # the subdependency returns the same response we have
sub_dependency_cache,
) = solved_result
dependency_cache.update(sub_dependency_cache)
if sub_errors:
errors.extend(sub_errors)
continue
if sub_dependant.use_cache and sub_dependant.cache_key in dependency_cache:
solved = dependency_cache[sub_dependant.cache_key]
elif is_gen_callable(call) or is_async_gen_callable(call):
stack = request.scope.get("fastapi_astack")
if stack is None:
raise RuntimeError(
async_contextmanager_dependencies_error
) # pragma: no cover
solved = await solve_generator(
call=call, stack=stack, sub_values=sub_values
)
elif is_coroutine_callable(call):
if isinstance(sub_dependant, OBDependant):
solved = await call(__request__=request, **sub_values)
else:
solved = await call(**sub_values)
else:
if isinstance(sub_dependant, OBDependant):
solved = await run_in_threadpool(call, __request__=request, **sub_values)
else:
solved = await run_in_threadpool(call, **sub_values)
if sub_dependant.name is not None:
values[sub_dependant.name] = solved
if sub_dependant.cache_key not in dependency_cache:
dependency_cache[sub_dependant.cache_key] = solved
path_values, path_errors = request_params_to_args(
dependant.path_params, request.path_params
)
query_values, query_errors = request_params_to_args(
dependant.query_params, request.query_params
)
header_values, header_errors = request_params_to_args(
dependant.header_params, request.headers
)
cookie_values, cookie_errors = request_params_to_args(
dependant.cookie_params, request.cookies
)
values.update(path_values)
values.update(query_values)
values.update(header_values)
values.update(cookie_values)
errors += path_errors + query_errors + header_errors + cookie_errors
if dependant.body_params:
(
body_values,
body_errors,
) = await request_body_to_args( # body_params checked above
required_params=dependant.body_params, received_body=body
)
values.update(body_values)
errors.extend(body_errors)
if dependant.http_connection_param_name:
values[dependant.http_connection_param_name] = request
if dependant.request_param_name and isinstance(request, Request):
values[dependant.request_param_name] = request
elif dependant.websocket_param_name and isinstance(request, WebSocket):
values[dependant.websocket_param_name] = request
if dependant.background_tasks_param_name:
if background_tasks is None:
background_tasks = BackgroundTasks()
values[dependant.background_tasks_param_name] = background_tasks
if dependant.response_param_name:
values[dependant.response_param_name] = response
if dependant.security_scopes_param_name:
values[dependant.security_scopes_param_name] = SecurityScopes(
scopes=dependant.security_scopes
)
return values, errors, background_tasks, response, dependency_cache
def get_field_info_schema(field: ModelField) -> Tuple[Dict[str, Any], bool]:
schema_overrides = False
# If no title is explicitly set, we don't set title in the schema for enums.
# The behaviour is the same as `BaseModel` reference, where the default title
# is in the definitions part of the schema.
schema: Dict[str, Any] = {}
if field.field_info.title or not lenient_issubclass(field.type_, Enum):
if field.field_info.title:
schema['title'] = field.field_info.title
else:
# field.field_info.title =
schema['title'] = field.name
if field.field_info.title:
schema_overrides = True
if field.field_info.description:
schema['description'] = field.field_info.description
schema_overrides = True
if (
not field.required
and not field.field_info.const
and field.default is not None
and not is_callable_type(field.outer_type_)
):
schema['default'] = encode_default(field.default)
schema_overrides = True
return schema, schema_overrides
async def run_endpoint_function(
*, request: Request, dependant: Dependant, values: Dict[str, Any], is_coroutine: bool
) -> Any:
assert dependant.call is not None, "dependant.call must be a function"
is_ob_dependant = isinstance(dependant, OBDependant)
try:
if is_ob_dependant:
values['__request__'] = request
if is_coroutine:
return await dependant.call(**values)
else:
return await run_in_threadpool(dependant.call, **values)
except OBDataBaseError as e:
Logger.exception('Database Error')
raise OBHTTPException(status_code=HTTP_400_BAD_REQUEST, msg='Database Error: %s' % e.orig)
except Exception as e:
if isinstance(e, HTTPException):
raise e
else:
Logger.exception('INTERNAL SERVER ERROR')
raise OBHTTPException(status_code=HTTP_500_INTERNAL_SERVER_ERROR, msg='INTERNAL SERVER ERROR')
finally:
if is_ob_dependant:
dependant.clear_local(request)
def support_one_api(schema):
if 'allOf' in schema and len(schema['allOf']) == 1:
for ref in schema['allOf']:
if ref.get('$ref'):
schema['$ref'] = ref.get('$ref')
break
return schema
def get_model_definitions(
*,
flat_models: Set[Union[Type[BaseModel], Type[Enum]]],
model_name_map: Dict[Union[Type[BaseModel], Type[Enum]], str],
) -> Dict[str, Any]:
definitions: Dict[str, Dict[str, Any]] = {}
for model in flat_models:
m_schema, m_definitions, m_nested_models = model_process_schema(
model, model_name_map=model_name_map, ref_prefix=REF_PREFIX
)
definitions.update(m_definitions)
model_name = model_name_map[model]
definitions[model_name] = m_schema
for k in definitions:
properties = definitions[k]['properties']
for property_name in properties:
support_one_api(properties[property_name])
return definitions
def get_openapi_operation_request_body(
*,
body_field: Optional[ModelField],
model_name_map: Dict[Union[Type[BaseModel], Type[Enum]], str],
) -> Optional[Dict[str, Any]]:
if not body_field:
return None
assert isinstance(body_field, ModelField)
body_schema, _, _ = field_schema(
body_field, model_name_map=model_name_map, ref_prefix=REF_PREFIX
)
field_info = cast(Body, body_field.field_info)
request_media_type = field_info.media_type
required = body_field.required
request_body_oai: Dict[str, Any] = {}
if required:
request_body_oai["required"] = required
request_media_content: Dict[str, Any] = {"schema": support_one_api(body_schema)}
if field_info.examples:
request_media_content["examples"] = jsonable_encoder(field_info.examples)
elif field_info.example != Undefined:
request_media_content["example"] = jsonable_encoder(field_info.example)
request_body_oai["content"] = {request_media_type: request_media_content}
return request_body_oai
def get_openapi(
*,
title: str,
version: str,
openapi_version: str = "3.0.2",
description: Optional[str] = None,
routes: Sequence[BaseRoute],
tags: Optional[List[Dict[str, Any]]] = None,
servers: Optional[List[Dict[str, Union[str, Any]]]] = None,
) -> Dict[str, Any]:
info = {"title": title, "version": version}
if description:
info["description"] = description
output: Dict[str, Any] = {"openapi": openapi_version, "info": info}
if servers:
output["servers"] = servers
components: Dict[str, Dict[str, Any]] = {}
paths: Dict[str, Dict[str, Any]] = {}
flat_models = get_flat_models_from_routes(routes)
model_name_map = get_model_name_map(flat_models)
definitions = get_model_definitions(
flat_models=flat_models, model_name_map=model_name_map
)
for route in routes:
if isinstance(route, APIRoute):
result = get_openapi_path(route=route, model_name_map=model_name_map)
if result:
path, security_schemes, _ = result
if path:
paths.setdefault(route.path_format, {}).update(path)
if security_schemes:
components.setdefault("securitySchemes", {}).update(
security_schemes
)
if definitions:
components["schemas"] = {k: definitions[k] for k in sorted(definitions)}
if components:
output["components"] = components
output["paths"] = paths
if tags:
output["tags"] = tags
return jsonable_encoder(OpenAPI(**output), by_alias=True, exclude_none=True) # type: ignore
# 注入方法
setattr(sys.modules['fastapi.routing'], 'serialize_response', serialize_response)
setattr(sys.modules['fastapi.openapi.utils'], 'get_openapi_operation_request_body', get_openapi_operation_request_body)
# pydantic 升级后可能会失效
setattr(sys.modules['pydantic.schema'], 'get_field_info_schema', get_field_info_schema)
Data = TypeVar('Data')
class OBResponse(GenericModel, Generic[Data]):
code: int = 200
data: Optional[Data] = None
msg: str = ''
success: bool = True
finished: bool = True
class DataList(GenericModel, Generic[Data]):
total: int
list: List[Data]
class Trace(BaseModel):
trace_id: UUID
def __init__(self, trace_id: Optional[UUID] = None):
super().__init__(trace_id=trace_id if trace_id else uuid1())
class OBHTTPException(HTTPException):
def __init__(
self,
status_code: int,
msg: str = '',
headers: Optional[Dict[str, Any]] = None,
) -> None:
super().__init__(status_code=status_code, detail=msg)
self.headers = headers
class ValidationError(BaseModel):
loc: List[str]
msg: str
type: str
class HTTPValidationError(OBResponse):
code: int = HTTP_422_UNPROCESSABLE_ENTITY
data: List[ValidationError]
ob_http_exception_response_field = create_response_field(name='OBResponse', type_=OBResponse)
request_validation_exception_field = create_response_field(name='HTTPValidationError', type_=HTTPValidationError)
async def request_validation_exception_handler(request: Request, exc: RequestValidationError) -> JSONResponse:
errors = []
for error in exc.errors():
error['loc'] = list(error['loc'])
errors.append(ValidationError(**error))
response = HTTPValidationError(
msg=str(exc),
data=errors,
success=False
)
response_data = await serialize_response(
field=request_validation_exception_field,
response_content=response,
)
return JSONResponse(response_data, status_code=response.code)
async def ob_http_exception_handler(request: Request, exc: HTTPException) -> JSONResponse:
headers = getattr(exc, "headers", None)
detail = OBResponse(code=exc.status_code, msg=exc.detail)
response_data = await serialize_response(
field=ob_http_exception_response_field,
response_content=detail,
)
return JSONResponse(response_data, status_code=exc.status_code, headers=headers)
class User:
def __init__(
self,
*,
emp_id: Union[str, int] = '',
name: str = '',
nick: str = '',
email: str = '',
dept: str = ''
):
self.emp_id: Union[str, int] = emp_id
self.name: str = name
self.nick: str = nick
self.email: str = email
self.dept: str = dept
class Controller:
BUC_KEY_NAME = ConfigsUtil.get_obfastapi_config('buc_key') # default OBVOS_USER_SIGN
API_KEY_BUC_COOKIE = APIKeyCookie(name=BUC_KEY_NAME, auto_error=False)
def __init__(self, user=None):
self._user = user
self._trace: Trace = Trace()
def __hash__(self):
return self.trace
def __eq__(self, value):
if isinstance(value, self.__class__):
return value.__hash__() == self.__hash__()
return False
@property
def user(self):
return self._user
@property
def trace(self):
return self._trace
def b64_urlsafe_decode_with_padding(self, data: str) -> str:
missing_padding = 4 - len(data) % 4
if missing_padding:
data += '=' * missing_padding
return base64.urlsafe_b64decode(data)
def get_user(
self,
cookie_key: str = Security(API_KEY_BUC_COOKIE)
):
if cookie_key:
try:
payload_json = self.b64_urlsafe_decode_with_padding(cookie_key.split('.')[1])
payload = json.loads(payload_json)
user = payload.get('user')
if user:
self._user = User(
emp_id=user.get('empId'),
name=user.get('loginName'),
nick=user.get('nickNameCn'),
email=user.get('emailAddr'),
dept=user.get('depDesc')
)
except:
pass
if self._user:
return self._user
raise OBHTTPException(status_code=403, msg='Could not validate credentials')
class OBDependantLocal(Dict):
pass
class OBRequestHanlder(object):
def __init__(
self,
dependant: Dependant,
body_field: Optional[ModelField] = None,
status_code: int = 200,
response_class: Union[Type[Response], DefaultPlaceholder] = Default(JSONResponse),
response_field: Optional[ModelField] = None,
response_model_include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
response_model_exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
response_model_by_alias: bool = True,
response_model_exclude_unset: bool = False,
response_model_exclude_defaults: bool = False,
response_model_exclude_none: bool = False,
dependency_overrides_provider: Optional[Any] = None
) -> None:
assert dependant.call is not None, "dependant.call must be a function"
self.body_field = body_field
self.is_coroutine = asyncio.iscoroutinefunction(dependant.call)
self.is_body_form = body_field and isinstance(body_field.field_info, Form)
if isinstance(response_class, DefaultPlaceholder):
actual_response_class: Type[Response] = response_class.value
else:
actual_response_class = response_class
self.dependant = dependant
self.status_code = status_code
self.dependency_overrides_provider = dependency_overrides_provider
self.response_field = response_field
self.response_model_include = response_model_include
self.response_model_exclude = response_model_exclude
self.response_model_by_alias = response_model_by_alias
self.response_model_exclude_unset = response_model_exclude_unset
self.response_model_exclude_defaults = response_model_exclude_defaults
self.response_model_exclude_none = response_model_exclude_none
self.actual_response_class = actual_response_class
self.local_controllers_map: Dict[Request, OBDependantLocal] = {}
def get_request_id(self, request):
return id(request)
def get_local_controllers(self, request: Request):
request_id = self.get_request_id(request)
if request_id not in self.local_controllers_map:
self.local_controllers_map[request_id] = OBDependantLocal()
return self.local_controllers_map[request_id]
def free_local_controllers(self, request: Request):
request_id = self.get_request_id(request)
if request_id in self.local_controllers_map:
del self.local_controllers_map[request_id]
async def __call__(self, request: Request) -> Response:
try:
body: Any = None
if self.body_field:
if self.is_body_form:
body = await request.form()
else:
body_bytes = await request.body()
if body_bytes:
json_body: Any = Undefined
content_type_value = request.headers.get("content-type")
if content_type_value:
message = email.message.Message()
message["content-type"] = content_type_value
if message.get_content_maintype() == "application":
subtype = message.get_content_subtype()
if subtype == "json" or subtype.endswith("+json"):
json_body = await request.json()
if json_body != Undefined:
body = json_body
else:
body = body_bytes
except json.JSONDecodeError as e:
raise RequestValidationError([ErrorWrapper(e, ("body", e.pos))], body=e.doc)
except Exception as e:
raise OBHTTPException(
status_code=400, msg="There was an error parsing the body"
) from e
solved_result = await solve_dependencies(
request=request,
dependant=self.dependant,
body=body,
dependency_overrides_provider=self.dependency_overrides_provider,
)
values, errors, background_tasks, sub_response, _ = solved_result
if errors:
raise RequestValidationError(errors, body=body)
else:
raw_response = await run_endpoint_function(
request=request, dependant=self.dependant, values=values, is_coroutine=self.is_coroutine
)
if isinstance(raw_response, Response):
if raw_response.background is None:
raw_response.background = background_tasks
return raw_response
response_data = await serialize_response(
field=self.response_field,
response_content=raw_response,
include=self.response_model_include,
exclude=self.response_model_exclude,
by_alias=self.response_model_by_alias,
exclude_unset=self.response_model_exclude_unset,
exclude_defaults=self.response_model_exclude_defaults,
exclude_none=self.response_model_exclude_none,
is_coroutine=self.is_coroutine,
)
response = self.actual_response_class(
content=response_data,
status_code=self.status_code,
background=background_tasks, # type: ignore # in Starlette
)
response.headers.raw.extend(sub_response.headers.raw)
if sub_response.status_code:
response.status_code = sub_response.status_code
return response
class OBDependant(Dependant):
def __init__(
self,
controller_cls:Type,
*,
request_handler: OBRequestHanlder = None,
path_params: Optional[List[ModelField]] = None,
query_params: Optional[List[ModelField]] = None,
header_params: Optional[List[ModelField]] = None,
cookie_params: Optional[List[ModelField]] = None,
body_params: Optional[List[ModelField]] = None,
dependencies: Optional[List["Dependant"]] = None,
security_schemes: Optional[List[SecurityRequirement]] = None,
name: Optional[str] = None,
call: Optional[Callable[..., Any]] = None,
request_param_name: Optional[str] = None,
websocket_param_name: Optional[str] = None,
http_connection_param_name: Optional[str] = None,
response_param_name: Optional[str] = None,
background_tasks_param_name: Optional[str] = None,
security_scopes_param_name: Optional[str] = None,
security_scopes: Optional[List[str]] = None,
use_cache: bool = True,
path: Optional[str] = None,
) -> None:
if query_params:
if query_params[0].name == 'self':
del query_params[0]
call = self._call(call)
self.controller_cls: Type = controller_cls
self.request_handler = request_handler
super().__init__(path_params=path_params, query_params=query_params, header_params=header_params, cookie_params=cookie_params, body_params=body_params, dependencies=dependencies, security_schemes=security_schemes, name=name, call=call, request_param_name=request_param_name, websocket_param_name=websocket_param_name, http_connection_param_name=http_connection_param_name, response_param_name=response_param_name, background_tasks_param_name=background_tasks_param_name, security_scopes_param_name=security_scopes_param_name, security_scopes=security_scopes, use_cache=use_cache, path=path)
def get_controller(self, request: Request):
# 避免依赖自身时重复创建导致self指针不一致
if self.request_handler:
local_controllers = self.request_handler.get_local_controllers(request)
if self.controller_cls not in local_controllers:
local_controllers[self.controller_cls] = self.controller_cls()
return local_controllers[self.controller_cls]
else:
return self.controller_cls()
def _call(self, endpoint: Optional[Callable[..., Any]]):
if iscoroutinefunction(endpoint):
async def __nt__(__request__: Request, *args, **kwargs):
if not args or not isinstance(args[0], self.controller_cls):
args = list(args)
args.insert(0, self.get_controller(request=__request__))
return await endpoint(*args, **kwargs)
else:
async def __nt__(__request__: Request, *args, **kwargs):
if not args or not isinstance(args[0], self.controller_cls):
args = list(args)
args.insert(0, self.get_controller(request=__request__))
return endpoint(*args, **kwargs)
return __nt__
def clear_local(self, request: Request):
if self.request_handler:
self.request_handler.free_local_controllers(request)
class OBAPIRoute(APIRoute):
def __init__(
self,
path: str,
endpoint: Callable[..., Any],
*,
response_model: Optional[Type[Any]] = None,
status_code: int = 200,
tags: Optional[List[str]] = None,
dependencies: Optional[Sequence[Depends]] = None,
summary: Optional[str] = None,
description: Optional[str] = None,
response_description: str = "Successful Response",
responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
deprecated: Optional[bool] = None,
name: Optional[str] = None,
methods: Optional[Union[Set[str], List[str]]] = None,
operation_id: Optional[str] = None,
response_model_include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
response_model_exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
response_model_by_alias: bool = True,
response_model_exclude_unset: bool = False,
response_model_exclude_defaults: bool = False,
response_model_exclude_none: bool = False,
include_in_schema: bool = True,
response_class: Union[Type[Response], DefaultPlaceholder] = Default(
JSONResponse
),
dependency_overrides_provider: Optional[Any] = None,
callbacks: Optional[List[BaseRoute]] = None,
) -> None:
super().__init__(path, endpoint, response_model=response_model, status_code=status_code, tags=tags, dependencies=dependencies, summary=summary, description=description, response_description=response_description, responses=responses, deprecated=deprecated, name=name, methods=methods, operation_id=operation_id, response_model_include=response_model_include, response_model_exclude=response_model_exclude, response_model_by_alias=response_model_by_alias, response_model_exclude_unset=response_model_exclude_unset, response_model_exclude_defaults=response_model_exclude_defaults, response_model_exclude_none=response_model_exclude_none, include_in_schema=include_in_schema, response_class=response_class, dependency_overrides_provider=dependency_overrides_provider, callbacks=callbacks)
self.response_fields[HTTP_422_UNPROCESSABLE_ENTITY] = request_validation_exception_field
threading.Thread(target=_init_controller, args=(self, self.request_hanlder)).start()
def get_route_handler(self):
self.request_hanlder = OBRequestHanlder(
dependant=self.dependant,
body_field=self.body_field,
status_code=self.status_code,
response_class=self.response_class,
response_field=self.secure_cloned_response_field,
response_model_include=self.response_model_include,
response_model_exclude=self.response_model_exclude,
response_model_by_alias=self.response_model_by_alias,
response_model_exclude_unset=self.response_model_exclude_unset,
response_model_exclude_defaults=self.response_model_exclude_defaults,
response_model_exclude_none=self.response_model_exclude_none,
dependency_overrides_provider=self.dependency_overrides_provider,
)
async def app(request: Request):
return await self.request_hanlder(request)
return app
class OBAPIRouter(APIRouter):
def __init__(
self,
*,
prefix: str = "",
tags: Optional[List[str]] = None,
dependencies: Optional[Sequence[Depends]] = None,
default_response_class: Type[Response] = Default(JSONResponse),
responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
callbacks: Optional[List[BaseRoute]] = None,
routes: Optional[List[BaseRoute]] = None,
redirect_slashes: bool = True,
default: Optional[ASGIApp] = None,
dependency_overrides_provider: Optional[Any] = None,
route_class: Type[APIRoute] = OBAPIRoute,
on_startup: Optional[Sequence[Callable[[], Any]]] = None,
on_shutdown: Optional[Sequence[Callable[[], Any]]] = None,
deprecated: Optional[bool] = None,
include_in_schema: bool = True,
) -> None:
super().__init__(prefix=prefix, tags=tags, dependencies=dependencies, default_response_class=default_response_class, responses=responses, callbacks=callbacks, routes=routes, redirect_slashes=redirect_slashes, default=default, dependency_overrides_provider=dependency_overrides_provider, route_class=route_class, on_startup=on_startup, on_shutdown=on_shutdown, deprecated=deprecated, include_in_schema=include_in_schema)
def api_route(
self,
path: str,
*,
response_model: Optional[Type[Any]] = None,
status_code: int = 200,
tags: Optional[List[str]] = None,
dependencies: Optional[Sequence[Depends]] = None,
summary: Optional[str] = None,
description: Optional[str] = None,
response_description: str = "Successful Response",
responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
deprecated: Optional[bool] = None,
methods: Optional[List[str]] = None,
operation_id: Optional[str] = None,
response_model_include: Optional[Union[SetIntStr, DictIntStrAny]] = None,
response_model_exclude: Optional[Union[SetIntStr, DictIntStrAny]] = None,
response_model_by_alias: bool = True,
response_model_exclude_unset: bool = False,
response_model_exclude_defaults: bool = False,
response_model_exclude_none: bool = False,
include_in_schema: bool = True,
response_class: Type[Response] = Default(JSONResponse),
name: Optional[str] = None,
callbacks: Optional[List[BaseRoute]] = None,
) -> Callable[[DecoratedCallable], DecoratedCallable]:
def decorator(func: DecoratedCallable) -> DecoratedCallable:
self.add_api_route(
path,
func,
response_model=response_model,
status_code=status_code,
tags=tags,
dependencies=dependencies,
summary=summary,
description=description,
response_description=response_description,
responses=responses,
deprecated=deprecated,
methods=methods,
operation_id=operation_id if operation_id else func.__name__,
response_model_include=response_model_include,
response_model_exclude=response_model_exclude,
response_model_by_alias=response_model_by_alias,
response_model_exclude_unset=response_model_exclude_unset,
response_model_exclude_defaults=response_model_exclude_defaults,
response_model_exclude_none=response_model_exclude_none,
include_in_schema=include_in_schema,
response_class=response_class,
name=name,
callbacks=callbacks,
)
return func
return decorator
class OBFastAPI(FastAPI):
def __init__(
self,
*,
debug: bool = False,
routes: Optional[List[OBAPIRouter]] = None,
title: str = "FastAPI",
description: str = "",
version: str = "0.1.0",
openapi_url: Optional[str] = "/openapi.json",
openapi_tags: Optional[List[Dict[str, Any]]] = None,
servers: Optional[List[Dict[str, Union[str, Any]]]] = None,
dependencies: Optional[Sequence[Depends]] = None,
default_response_class: Type[Response] = Default(JSONResponse),
docs_url: Optional[str] = "/docs",
redoc_url: Optional[str] = "/redoc",
swagger_ui_oauth2_redirect_url: Optional[str] = "/docs/oauth2-redirect",
swagger_ui_init_oauth: Optional[Dict[str, Any]] = None,
middleware: Optional[Sequence[Middleware]] = None,
exception_handlers: Optional[
Dict[
Union[int, Type[Exception]],
Callable[[Request, Any], Coroutine[Any, Any, Response]],
]
] = None,
on_startup: Optional[Sequence[Callable[[], Any]]] = None,
on_shutdown: Optional[Sequence[Callable[[], Any]]] = None,
root_prefix: str = "",
openapi_prefix: str = "",
root_path: str = "",
root_path_in_servers: bool = True,
responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
callbacks: Optional[List[OBAPIRouter]] = None,
deprecated: Optional[bool] = None,
include_in_schema: bool = True,
**extra: Any,
) -> None:
self._debug: bool = debug
self.state: State = State()
self.router: OBAPIRouter = OBAPIRouter(
routes=routes,
prefix=root_prefix,
dependency_overrides_provider=self,
on_startup=on_startup,
on_shutdown=on_shutdown,
default_response_class=default_response_class,
dependencies=dependencies,
callbacks=callbacks,
deprecated=deprecated,
include_in_schema=include_in_schema,
responses=responses,
)
self.exception_handlers: Dict[
Union[int, Type[Exception]],
Callable[[Request, Any], Coroutine[Any, Any, Response]],
] = (
{} if exception_handlers is None else dict(exception_handlers)
)
self.exception_handlers.setdefault(HTTPException, ob_http_exception_handler)
self.exception_handlers.setdefault(
RequestValidationError, request_validation_exception_handler
)
self.user_middleware: List[Middleware] = (
[] if middleware is None else list(middleware)
)
self.middleware_stack: ASGIApp = self.build_middleware_stack()
self.title = title
self.description = description
self.version = version
self.servers = servers or []
self.openapi_url = openapi_url
self.openapi_tags = openapi_tags
# TODO: remove when discarding the openapi_prefix parameter
if openapi_prefix:
Logger.warning(
'"openapi_prefix" has been deprecated in favor of "root_path", which '
"follows more closely the ASGI standard, is simpler, and more "
"automatic. Check the docs at "
"https://fastapi.tiangolo.com/advanced/sub-applications/"
)
self.root_path = root_path or openapi_prefix
self.root_path_in_servers = root_path_in_servers
self.docs_url = docs_url
self.redoc_url = redoc_url
self.swagger_ui_oauth2_redirect_url = swagger_ui_oauth2_redirect_url
self.swagger_ui_init_oauth = swagger_ui_init_oauth
self.extra = extra
self.dependency_overrides: Dict[Callable[..., Any], Callable[..., Any]] = {}
self.openapi_version = "3.0.2"
if self.openapi_url:
assert self.title, "A title must be provided for OpenAPI, e.g.: 'My API'"
assert self.version, "A version must be provided for OpenAPI, e.g.: '2.1.0'"
self.openapi_schema: Optional[Dict[str, Any]] = None
self.setup()
def openapi(self) -> Dict[str, Any]:
if not self.openapi_schema:
self.openapi_schema = get_openapi(
title=self.title,
version=self.version,
openapi_version=self.openapi_version,
description=self.description,
routes=self.routes,
tags=self.openapi_tags,
servers=self.servers,
)
return self.openapi_schema
def _init_controller(route: OBAPIRoute, request_handler:OBRequestHanlder):
import time
while True:
try:
clz = route.endpoint.__qualname__.split('.')[0]
mod = importlib.import_module(route.endpoint.__module__)
clz = mod.__dict__[clz]
break
except:
time.sleep(0.1)
route.dependant = _clear_self(route.dependant, request_handler)
route.app = request_response(route.get_route_handler())
def _clear_self(dependant: OBDependant, request_handler:OBRequestHanlder = None) -> OBDependant:
endpoint = dependant.call
if isfunction(endpoint) and '.' in endpoint.__qualname__:
try:
clz = endpoint.__qualname__.split('.')[0]
mod = importlib.import_module(endpoint.__module__)
clz = mod.__dict__[clz]
if issubclass(clz, Controller):
dependant = OBDependant(
clz,
request_handler=request_handler,
path_params=dependant.path_params,
query_params=dependant.query_params,
header_params=dependant.header_params,
cookie_params=dependant.cookie_params,
body_params=dependant.body_params,
dependencies=dependant.dependencies,
security_schemes=dependant.security_requirements,
name=dependant.name,
call=dependant.call,
request_param_name=dependant.request_param_name,
http_connection_param_name=dependant.http_connection_param_name,
response_param_name=dependant.response_param_name,
websocket_param_name=dependant.websocket_param_name,
background_tasks_param_name=dependant.background_tasks_param_name,
security_scopes_param_name=dependant.security_scopes_param_name,
security_scopes=dependant.security_scopes,
use_cache=dependant.use_cache,
path=dependant.path
)
except:
Logger.exception("Fail to init Dependant")
if dependant.dependencies:
dependant.dependencies = [_clear_self(d, request_handler) for d in dependant.dependencies]
return dependant